Commit 218796ed by Ting PAN

Remove the deprecated DALI API

Summary:
This commit removes the deprecated API for DALI 0.24.
Besides, variable length keyword arguments are added for forward compatibility.
1 parent c40eaf7b
Showing with 1749 additions and 921 deletions
...@@ -21,6 +21,7 @@ from dragon.vm.caffe.core.solver import AdamSolver ...@@ -21,6 +21,7 @@ from dragon.vm.caffe.core.solver import AdamSolver
from dragon.vm.caffe.core.solver import NesterovSolver from dragon.vm.caffe.core.solver import NesterovSolver
from dragon.vm.caffe.core.solver import RMSPropSolver from dragon.vm.caffe.core.solver import RMSPropSolver
from dragon.vm.caffe.core.solver import SGDSolver from dragon.vm.caffe.core.solver import SGDSolver
from dragon.vm.caffe.core.solver import Solver
# Functions # Functions
from dragon.vm.caffe.core.net_spec import to_proto from dragon.vm.caffe.core.net_spec import to_proto
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The solver to update parameters.""" """The solver to optimize parameters."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -28,7 +28,7 @@ from dragon.vm.caffe.core.proto import caffe_pb2 ...@@ -28,7 +28,7 @@ from dragon.vm.caffe.core.proto import caffe_pb2
class Solver(object): class Solver(object):
"""The abstraction ``caffe.Solver``.""" """The base solver class to optimize parameters."""
def __init__(self, solver_file, is_root=True): def __init__(self, solver_file, is_root=True):
"""Create a ``Solver``. """Create a ``Solver``.
...@@ -330,7 +330,7 @@ class AdamSolver(Solver): ...@@ -330,7 +330,7 @@ class AdamSolver(Solver):
momentum=0.9, momentum=0.9,
momentum2=0.999, momentum2=0.999,
delta=1e-8, delta=1e-8,
) }
``` ```
""" """
...@@ -397,7 +397,7 @@ class RMSPropSolver(Solver): ...@@ -397,7 +397,7 @@ class RMSPropSolver(Solver):
base_lr=0.01, base_lr=0.01,
rms_decay=0.99, rms_decay=0.99,
delta=1e-8, delta=1e-8,
) }
``` ```
""" """
...@@ -430,13 +430,13 @@ class SGDSolver(Solver): ...@@ -430,13 +430,13 @@ class SGDSolver(Solver):
solver { solver {
base_lr=0.01, base_lr=0.01,
momentum=0.9, momentum=0.9,
) }
``` ```
""" """
def __init__(self, solver_file, is_root=True): def __init__(self, solver_file, is_root=True):
"""Create a `SGDSolver``. """Create a ``SGDSolver``.
Parameters Parameters
---------- ----------
......
...@@ -12,7 +12,7 @@ if (MSVC) ...@@ -12,7 +12,7 @@ if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}
/wd4003 /wd4114 /wd4003 /wd4114
/wd4244 /wd4251 /wd4273 /wd4275 /wd4244 /wd4251 /wd4267 /wd4273 /wd4275
/wd4800 /wd4819 /wd4996") /wd4800 /wd4819 /wd4996")
string(REPLACE "/W3" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") string(REPLACE "/W3" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
......
...@@ -22,32 +22,32 @@ import sys as _sys ...@@ -22,32 +22,32 @@ import sys as _sys
from dragon.vm.dali._api import ops from dragon.vm.dali._api import ops
# Classes # Classes
from dragon.vm.dali.core.iterator import Iterator from dragon.vm.dali.core.framework.iterator import Iterator
from dragon.vm.dali.core.pipeline import Pipeline from dragon.vm.dali.core.framework.pipeline import Pipeline
# Functions # Functions
from dragon.vm.dali.core.context import device from dragon.vm.dali.core.framework.context import device
from dragon.vm.dali.core.context import get_device_type from dragon.vm.dali.core.framework.context import get_device_type
from dragon.vm.dali.core.context import get_distributed_info from dragon.vm.dali.core.framework.context import get_distributed_info
# Enums # Enums
from dragon.vm.dali.core.types import BOOL from dragon.vm.dali.core.framework.types import BOOL
from dragon.vm.dali.core.types import BGR from dragon.vm.dali.core.framework.types import BGR
from dragon.vm.dali.core.types import FLOAT from dragon.vm.dali.core.framework.types import FLOAT
from dragon.vm.dali.core.types import FLOAT32 from dragon.vm.dali.core.framework.types import FLOAT32
from dragon.vm.dali.core.types import FLOAT64 from dragon.vm.dali.core.framework.types import FLOAT64
from dragon.vm.dali.core.types import INT8 from dragon.vm.dali.core.framework.types import INT8
from dragon.vm.dali.core.types import INT32 from dragon.vm.dali.core.framework.types import INT32
from dragon.vm.dali.core.types import INT64 from dragon.vm.dali.core.framework.types import INT64
from dragon.vm.dali.core.types import INTERP_TRIANGULAR from dragon.vm.dali.core.framework.types import INTERP_TRIANGULAR
from dragon.vm.dali.core.types import NCHW from dragon.vm.dali.core.framework.types import NCHW
from dragon.vm.dali.core.types import NHWC from dragon.vm.dali.core.framework.types import NHWC
from dragon.vm.dali.core.types import RGB from dragon.vm.dali.core.framework.types import RGB
from dragon.vm.dali.core.types import STRING from dragon.vm.dali.core.framework.types import STRING
from dragon.vm.dali.core.types import UINT8 from dragon.vm.dali.core.framework.types import UINT8
from dragon.vm.dali.core.types import UINT16 from dragon.vm.dali.core.framework.types import UINT16
from dragon.vm.dali.core.types import UINT32 from dragon.vm.dali.core.framework.types import UINT32
from dragon.vm.dali.core.types import UINT64 from dragon.vm.dali.core.framework.types import UINT64
# Attributes # Attributes
_API_MODULE = ops _API_MODULE = ops
......
...@@ -13,24 +13,24 @@ from __future__ import absolute_import as _absolute_import ...@@ -13,24 +13,24 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
from dragon.vm.dali.core.ops.array import Cast from dragon.vm.dali.core.ops.bbox_ops import BbFlip
from dragon.vm.dali.core.ops.array import Pad from dragon.vm.dali.core.ops.bbox_ops import BBoxPaste
from dragon.vm.dali.core.ops.array import Reshape from dragon.vm.dali.core.ops.builtin_ops import ExternalSource
from dragon.vm.dali.core.ops.builtin import ExternalSource from dragon.vm.dali.core.ops.decoder_ops import ImageDecoder
from dragon.vm.dali.core.ops.color import BrightnessContrast from dragon.vm.dali.core.ops.decoder_ops import ImageDecoderRandomCrop
from dragon.vm.dali.core.ops.color import Hsv from dragon.vm.dali.core.ops.generic_ops import Cast
from dragon.vm.dali.core.ops.crop import RandomBBoxCrop from dragon.vm.dali.core.ops.generic_ops import Pad
from dragon.vm.dali.core.ops.crop import Slice from dragon.vm.dali.core.ops.generic_ops import Reshape
from dragon.vm.dali.core.ops.decoder import ImageDecoder from dragon.vm.dali.core.ops.generic_ops import Slice
from dragon.vm.dali.core.ops.decoder import ImageDecoderRandomCrop from dragon.vm.dali.core.ops.image_ops import BrightnessContrast
from dragon.vm.dali.core.ops.fused import CropMirrorNormalize from dragon.vm.dali.core.ops.image_ops import CropMirrorNormalize
from dragon.vm.dali.core.ops.geometric import BbFlip from dragon.vm.dali.core.ops.image_ops import Hsv
from dragon.vm.dali.core.ops.paste import BBoxPaste from dragon.vm.dali.core.ops.image_ops import Paste
from dragon.vm.dali.core.ops.paste import Paste from dragon.vm.dali.core.ops.image_ops import RandomBBoxCrop
from dragon.vm.dali.core.ops.random import CoinFlip from dragon.vm.dali.core.ops.image_ops import Resize
from dragon.vm.dali.core.ops.random import Uniform from dragon.vm.dali.core.ops.random_ops import CoinFlip
from dragon.vm.dali.core.ops.reader import KPLRecordReader from dragon.vm.dali.core.ops.random_ops import Uniform
from dragon.vm.dali.core.ops.reader import TFRecordReader from dragon.vm.dali.core.ops.reader_ops import KPLRecordReader
from dragon.vm.dali.core.ops.resize import Resize from dragon.vm.dali.core.ops.reader_ops import TFRecordReader
__all__ = [_s for _s in dir() if not _s.startswith('_')] __all__ = [_s for _s in dir() if not _s.startswith('_')]
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
...@@ -26,7 +26,7 @@ from dragon.core.device import cuda ...@@ -26,7 +26,7 @@ from dragon.core.device import cuda
from dragon.core.eager.tensor import EagerTensor from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import device_spec from dragon.core.framework import device_spec
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.vm.dali.core import types from dragon.vm.dali.core.framework import types
class Iterator(object): class Iterator(object):
......
...@@ -15,7 +15,7 @@ from __future__ import print_function ...@@ -15,7 +15,7 @@ from __future__ import print_function
try: try:
from nvidia.dali import pipeline from nvidia.dali import pipeline
from dragon.vm.dali.core import context from dragon.vm.dali.core.framework import context
class Pipeline(pipeline.Pipeline): class Pipeline(pipeline.Pipeline):
"""The base pipeline class to define operations. """The base pipeline class to define operations.
...@@ -151,6 +151,8 @@ except ImportError: ...@@ -151,6 +151,8 @@ except ImportError:
""" """
self._batch_size = batch_size self._batch_size = batch_size
self._num_threads = num_threads self._num_threads = num_threads
self._seed = seed
self._prefetch_queue_depth = prefetch_queue_depth
@property @property
def batch_size(self): def batch_size(self):
......
...@@ -16,85 +16,65 @@ from __future__ import print_function ...@@ -16,85 +16,65 @@ from __future__ import print_function
try: try:
from nvidia.dali import ops from nvidia.dali import ops
except ImportError: except ImportError:
ops = None from dragon.core.util import deprecation
ops = deprecation.not_installed('nvidia.dali')
from dragon.vm.dali.core import context from dragon.vm.dali.core.framework import context
class BBoxPaste(object): class BbFlip(object):
"""Transform bounding boxes to match the ``Paste`` operator. """Flip the bounding boxes.
Examples: Examples:
```python ```python
bbox_paste = dali.ops.BBoxPaste() bbox_flip = dali.ops.BbFlip()
paste_pos = dali.ops.Uniform((0., 1.)) flip_rng = dali.ops.CoinFlip(0.5)
paste_ratio = dali.ops.Uniform((0., 3.)) bbox = bbox_flip(inputs['bbox'], horizontal=flip_rng())
paste_prob = dali.ops.CoinFlip(0.5) ```
bbox = bbox_paste(
inputs['bbox'],
# Expand ratio
ratio=paste_ratio() * paste_prob() + 1.,
# PosX, PosY
paste_x=paste_pos(),
paste_y=paste_pos(),
)
""" """
def __new__( def __new__(cls, horizontal=None, vertical=None, ltrb=True, **kwargs):
cls, """Create a ``BbFlip`` operator.
ltrb=True,
ratio=None,
paste_x=None,
paste_y=None,
):
"""Create a ``BBoxPaste`` operator.
Parameters Parameters
---------- ----------
horizontal : int, optional
Whether to apply the horizontal flip.
vertical : int, optional
Whether to apply the vertical flip.
ltrb : bool, optional, default=True ltrb : bool, optional, default=True
Indicate the bbox is ``ltrb`` or ``xywh`` format. Indicate the bbox is ``ltrb`` or ``xywh`` format.
ratio : int, optional
The expand ratio.
paste_x : int, optional
The paste position at x-axis.
paste_y : int, optional
The paste position at y-axis.
Returns Returns
------- -------
nvidia.dali.ops.BBoxPaste nvidia.dali.ops.BbFlip
The operator. The operator.
""" """
return ops.BBoxPaste( return ops.BbFlip(
horizontal=horizontal,
vertical=vertical,
ltrb=ltrb, ltrb=ltrb,
ratio=ratio, device=context.get_device_type(),
paste_x=paste_x, **kwargs
paste_y=paste_y,
device='cpu',
) )
class Paste(object): class BBoxPaste(object):
"""Copy image into a larger canvas. """Transform bounding boxes to match the ``Paste`` operator.
Examples: Examples:
```python ```python
paste = dali.ops.Paste( bbox_paste = dali.ops.BBoxPaste()
# The image channels
n_channels=3,
# Historical values before mean subtraction
fill_value=(102., 115., 122.),
)
paste_pos = dali.ops.Uniform((0., 1.)) paste_pos = dali.ops.Uniform((0., 1.))
paste_ratio = dali.ops.Uniform((0., 3.)) paste_ratio = dali.ops.Uniform((0., 3.))
paste_prob = dali.ops.CoinFlip(0.5) paste_prob = dali.ops.CoinFlip(0.5)
y = paste( bbox = bbox_paste(
inputs['x'], inputs['bbox'],
# Expand ratio # Expand ratio
ratio=paste_ratio() * paste_prob() + 1., ratio=paste_ratio() * paste_prob() + 1.,
# PosX, PosY # PosX, PosY
...@@ -107,20 +87,18 @@ class Paste(object): ...@@ -107,20 +87,18 @@ class Paste(object):
def __new__( def __new__(
cls, cls,
n_channels=3, ltrb=True,
fill_value=(0., 0., 0.),
ratio=None, ratio=None,
paste_x=None, paste_x=None,
paste_y=None, paste_y=None,
**kwargs
): ):
"""Create a ``Paste`` operator. """Create a ``BBoxPaste`` operator.
Parameters Parameters
---------- ----------
n_channels : int, optional, default=3 ltrb : bool, optional, default=True
The image channels. Indicate the bbox is ``ltrb`` or ``xywh`` format.
fill_value : Sequence[number], optional
The value(s) to fill for the canvas.
ratio : int, optional ratio : int, optional
The expand ratio. The expand ratio.
paste_x : int, optional paste_x : int, optional
...@@ -130,15 +108,15 @@ class Paste(object): ...@@ -130,15 +108,15 @@ class Paste(object):
Returns Returns
------- -------
nvidia.dali.ops.Paste nvidia.dali.ops.BBoxPaste
The operator. The operator.
""" """
return ops.Paste( return ops.BBoxPaste(
n_channels=n_channels, ltrb=ltrb,
fill_value=fill_value,
ratio=ratio, ratio=ratio,
paste_x=paste_x, paste_x=paste_x,
paste_y=paste_y, paste_y=paste_y,
device=context.get_device_type(), device='cpu',
**kwargs
) )
...@@ -16,9 +16,10 @@ from __future__ import print_function ...@@ -16,9 +16,10 @@ from __future__ import print_function
try: try:
from nvidia.dali import ops from nvidia.dali import ops
except ImportError: except ImportError:
ops = None from dragon.core.util import deprecation
ops = deprecation.not_installed('nvidia.dali')
from dragon.vm.dali.core import context from dragon.vm.dali.core.framework import context
class ExternalSource(object): class ExternalSource(object):
...@@ -42,7 +43,7 @@ class ExternalSource(object): ...@@ -42,7 +43,7 @@ class ExternalSource(object):
""" """
def __new__(cls): def __new__(cls, **kwargs):
"""Create a ``ExternalSource`` operator. """Create a ``ExternalSource`` operator.
Returns Returns
...@@ -51,4 +52,4 @@ class ExternalSource(object): ...@@ -51,4 +52,4 @@ class ExternalSource(object):
The operator. The operator.
""" """
return ops.ExternalSource(device=context.get_device_type()) return ops.ExternalSource(device=context.get_device_type(), **kwargs)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from nvidia.dali import ops
except ImportError:
ops = None
from dragon.vm.dali.core import context
class BrightnessContrast(object):
"""Adjust the brightness and contrast.
Examples:
```python
# Historical jitter range for brightness and contrast
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
bc = dali.ops.BrightnessContrast()
y = bc(inputs['x'], brightness=twist_rng(), contrast=twist_rng())
```
"""
def __new__(cls):
"""Create a ``BrightnessContrastBrightnessContrast`` operator.
Returns
-------
nvidia.dali.ops.BrightnessContrast
The operator.
"""
return ops.BrightnessContrast(device=context.get_device_type())
class Hsv(object):
"""Adjust the hue and saturation.
Examples:
```python
# Historical jitter range for saturation
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
hsv = dali.ops.Hsv()
y = hsv(inputs['x'], saturation=twist_rng())
```
"""
def __new__(cls):
"""Create a ``Hsv`` operator.
Returns
-------
nvidia.dali.ops.Hsv
The operator.
"""
return ops.Hsv(device=context.get_device_type())
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from nvidia.dali import ops
except ImportError:
ops = None
from dragon.vm.dali.core import context
class RandomBBoxCrop(object):
"""Return a valid crop restricted by bounding boxes.
Examples:
```python
bbox_crop = dali.ops.RandomBBoxCrop(
# Range of scale
scaling=[0.3, 1.0],
# Range of aspect ratio
aspect_ratio=[0.5, 2.0],
# Minimum IoUs to satisfy
thresholds=[0.0, 0.1, 0.3, 0.5, 0.7, 0.9],
)
crop_begin, crop_size, bbox, label = bbox_crop(inputs['bbox'], inputs['label'])
```
"""
def __new__(
cls,
scaling=(0.3, 1.0),
aspect_ratio=(0.5, 2.0),
thresholds=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9),
allow_no_crop=True,
ltrb=True,
num_attempts=10,
):
"""Create a ``RandomBBoxCrop`` operator.
Parameters
----------
scaling : Sequence[float], optional, default=(0.3, 1.0)
The range of scale for sampling regions.
aspect_ratio : Sequence[float], optional, default=(0.5, 2.0)
The range of aspect ratio for sampling regions.
thresholds : Sequence[float], optional
The minimum IoU(s) to satisfy.
allow_no_crop : bool, optional, default=True
**True** to include the no-cropping as a option.
ltrb : bool, optional, default=True
Indicate the bbox is ``ltrb`` or ``xywh`` format.
num_attempts : int, optional, default=10
The max number of sampling trails.
Returns
-------
nvidia.dali.ops.RandomBBoxCrop
The operator.
"""
return ops.RandomBBoxCrop(
scaling=scaling,
aspect_ratio=aspect_ratio,
thresholds=thresholds,
allow_no_crop=allow_no_crop,
ltrb=ltrb,
num_attempts=num_attempts,
device='cpu',
)
class Slice(object):
"""Select an interval of elements from input.
Examples:
```python
slice = dali.ops.Slice(
# Axis of intervals
axes=[1, 0],
# Whether the begin of interval is normalized
# in a range of [0.0, 1.0]
normalized_anchor=True,
# Whether the size of interval is normalized
# in a range of [0.0, 1.0]
normalized_shape=True,
)
y = slice(inputs['x'], crop_begin, crop_size)
```
"""
def __new__(
cls,
axes=(1, 0),
normalized_anchor=True,
normalized_shape=True,
):
"""Create a ``Slice`` operator.
Parameters
----------
axes : Sequence[int], optional
The axis to select.
normalized_anchor : bool, optional, default=True
Whether the begin of interval is normalized.
normalized_shape : bool, optional, default=True
Whether the size of interval is normalized.
Returns
-------
nvidia.dali.ops.Slice
The operator.
"""
return ops.Slice(
axes=axes,
normalized_anchor=normalized_anchor,
device=context.get_device_type(),
)
...@@ -16,11 +16,12 @@ from __future__ import print_function ...@@ -16,11 +16,12 @@ from __future__ import print_function
try: try:
from nvidia.dali import ops from nvidia.dali import ops
except ImportError: except ImportError:
ops = None from dragon.core.util import deprecation
ops = deprecation.not_installed('nvidia.dali')
from dragon.core.util import six from dragon.core.util import six
from dragon.vm.dali.core import context from dragon.vm.dali.core.framework import context
from dragon.vm.dali.core import types from dragon.vm.dali.core.framework import types
class ImageDecoder(object): class ImageDecoder(object):
...@@ -40,6 +41,7 @@ class ImageDecoder(object): ...@@ -40,6 +41,7 @@ class ImageDecoder(object):
output_type='BGR', output_type='BGR',
host_memory_padding=8388608, host_memory_padding=8388608,
device_memory_padding=16777216, device_memory_padding=16777216,
**kwargs
): ):
"""Create a ``ImageDecoder`` operator. """Create a ``ImageDecoder`` operator.
...@@ -65,6 +67,7 @@ class ImageDecoder(object): ...@@ -65,6 +67,7 @@ class ImageDecoder(object):
host_memory_padding=host_memory_padding, host_memory_padding=host_memory_padding,
device_memory_padding=device_memory_padding, device_memory_padding=device_memory_padding,
device=context.get_device_type(mixed=True), device=context.get_device_type(mixed=True),
**kwargs
) )
...@@ -93,6 +96,7 @@ class ImageDecoderRandomCrop(object): ...@@ -93,6 +96,7 @@ class ImageDecoderRandomCrop(object):
random_area=(0.08, 1.), random_area=(0.08, 1.),
random_aspect_ratio=(0.75, 1.33), random_aspect_ratio=(0.75, 1.33),
num_attempts=10, num_attempts=10,
**kwargs
): ):
"""Create a ``ImageDecoderRandomCrop`` operator. """Create a ``ImageDecoderRandomCrop`` operator.
...@@ -127,4 +131,5 @@ class ImageDecoderRandomCrop(object): ...@@ -127,4 +131,5 @@ class ImageDecoderRandomCrop(object):
random_aspect_ratio=random_aspect_ratio, random_aspect_ratio=random_aspect_ratio,
num_attempts=num_attempts, num_attempts=num_attempts,
device=context.get_device_type(mixed=True), device=context.get_device_type(mixed=True),
**kwargs
) )
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from nvidia.dali import ops
except ImportError:
ops = None
from dragon.core.util import six
from dragon.vm.dali.core import context
from dragon.vm.dali.core import types
class CropMirrorNormalize(object):
"""Crop and normalize input with the horizontal flip.
Examples:
```python
flip_rng = dali.ops.CoinFlip(0.5)
cmn = dali.ops.CropMirrorNormalize(
# Match the number of spatial dims
# (H, W) for 2d input
# (D, H, W) for 3d input
crop=(224, 224),
# Historical values to normalize input
mean=(102., 115., 122.),
std=(1., 1., 1.),
# ``BGR``, ``RGB``, or ``GRAY``
image_type='BGR',
# Or ``float16`` for fp16 training
output_dtype='float32',
# Or ``NHWC``
output_layout='NCHW'
)
y = cmn(inputs['x'], mirror=flip_rng())
```
"""
def __new__(
cls,
crop=None,
mirror=None,
mean=0.,
std=1.,
image_type='BGR',
output_dtype='float32',
output_layout='NCHW',
):
"""Create a ``CropMirrorNormalize`` operator.
Parameters
----------
crop : Sequence[int], optional
The cropped spatial dimensions for output.
mirror : {0, 1}, optional
Whether to apply the horizontal flip.
mean : Union[float, Sequence[float]], optional
The values to subtract.
std : Union[float, Sequence[float]], optional
The values to divide after subtraction.
image_type : {'BGR', 'RGB'}, optional
The color space of input.
output_dtype : {'float16', 'float32'}, optional
The data type of output.
output_layout : {'NCHW', 'NHWC'}, optional
The data format of output.
Returns
-------
nvidia.dali.ops.CropMirrorNormalize
The operator.
"""
if isinstance(output_dtype, six.string_types):
output_dtype = getattr(types, output_dtype.upper())
if isinstance(output_layout, six.string_types):
output_layout = getattr(types, output_layout.upper())
if isinstance(image_type, six.string_types):
image_type = getattr(types, image_type.upper())
return ops.CropMirrorNormalize(
crop=crop,
mirror=mirror,
mean=mean,
std=std,
output_dtype=output_dtype,
output_layout=output_layout,
image_type=image_type,
device=context.get_device_type(),
)
...@@ -16,11 +16,12 @@ from __future__ import print_function ...@@ -16,11 +16,12 @@ from __future__ import print_function
try: try:
from nvidia.dali import ops from nvidia.dali import ops
except ImportError: except ImportError:
ops = None from dragon.core.util import deprecation
ops = deprecation.not_installed('nvidia.dali')
from dragon.core.util import six from dragon.core.util import six
from dragon.vm.dali.core import context from dragon.vm.dali.core.framework import context
from dragon.vm.dali.core import types from dragon.vm.dali.core.framework import types
class Cast(object): class Cast(object):
...@@ -35,7 +36,7 @@ class Cast(object): ...@@ -35,7 +36,7 @@ class Cast(object):
""" """
def __new__(cls, dtype): def __new__(cls, dtype, **kwargs):
"""Create a ``Cast`` operator. """Create a ``Cast`` operator.
Parameters Parameters
...@@ -54,6 +55,7 @@ class Cast(object): ...@@ -54,6 +55,7 @@ class Cast(object):
return ops.Cast( return ops.Cast(
dtype=dtype, dtype=dtype,
device=context.get_device_type(), device=context.get_device_type(),
**kwargs
) )
...@@ -74,7 +76,7 @@ class Pad(object): ...@@ -74,7 +76,7 @@ class Pad(object):
""" """
def __new__(cls, axes=(0, 1), fill_value=0., align=None): def __new__(cls, axes=(0, 1), fill_value=0., align=None, **kwargs):
"""Create a ``Pad`` operator. """Create a ``Pad`` operator.
Parameters Parameters
...@@ -97,6 +99,7 @@ class Pad(object): ...@@ -97,6 +99,7 @@ class Pad(object):
fill_value=fill_value, fill_value=fill_value,
align=align, align=align,
device=context.get_device_type(), device=context.get_device_type(),
**kwargs
) )
...@@ -117,7 +120,7 @@ class Reshape(object): ...@@ -117,7 +120,7 @@ class Reshape(object):
""" """
def __new__(cls, shape=None): def __new__(cls, shape=None, **kwargs):
"""Create a ``Reshape`` operator. """Create a ``Reshape`` operator.
Parameters Parameters
...@@ -134,4 +137,59 @@ class Reshape(object): ...@@ -134,4 +137,59 @@ class Reshape(object):
return ops.Reshape( return ops.Reshape(
shape=shape, shape=shape,
device=context.get_device_type(), device=context.get_device_type(),
**kwargs
)
class Slice(object):
"""Select an interval of elements from input.
Examples:
```python
slice = dali.ops.Slice(
# Axis of intervals
axes=[1, 0],
# Whether the begin of interval is normalized
# in a range of [0.0, 1.0]
normalized_anchor=True,
# Whether the size of interval is normalized
# in a range of [0.0, 1.0]
normalized_shape=True,
)
y = slice(inputs['x'], crop_begin, crop_size)
```
"""
def __new__(
cls,
axes=(1, 0),
normalized_anchor=True,
normalized_shape=True,
**kwargs
):
"""Create a ``Slice`` operator.
Parameters
----------
axes : Sequence[int], optional
The axis to select.
normalized_anchor : bool, optional, default=True
Whether the begin of interval is normalized.
normalized_shape : bool, optional, default=True
Whether the size of interval is normalized.
Returns
-------
nvidia.dali.ops.Slice
The operator.
"""
return ops.Slice(
axes=axes,
normalized_anchor=normalized_anchor,
device=context.get_device_type(),
**kwargs
) )
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from nvidia.dali import ops
except ImportError:
ops = None
from dragon.vm.dali.core import context
class BbFlip(object):
"""Flip the bounding boxes.
Examples:
```python
bbox_flip = dali.ops.BbFlip()
flip_rng = dali.ops.CoinFlip(0.5)
bbox = bbox_flip(inputs['bbox'], horizontal=flip_rng())
```
"""
def __new__(cls, horizontal=None, vertical=None, ltrb=True):
"""Create a ``BbFlip`` operator.
Parameters
----------
horizontal : int, optional
Whether to apply the horizontal flip.
vertical : int, optional
Whether to apply the vertical flip.
ltrb : bool, optional, default=True
Indicate the bbox is ``ltrb`` or ``xywh`` format.
Returns
-------
nvidia.dali.ops.BbFlip
The operator.
"""
return ops.BbFlip(
horizontal=horizontal,
vertical=vertical,
ltrb=ltrb,
device=context.get_device_type(),
)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from nvidia.dali import ops
except ImportError:
from dragon.core.util import deprecation
ops = deprecation.not_installed('nvidia.dali')
from dragon.core.util import six
from dragon.vm.dali.core.framework import context
from dragon.vm.dali.core.framework import types
class BrightnessContrast(object):
"""Adjust the brightness and contrast of image.
Examples:
```python
# Historical jitter range for brightness and contrast
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
bc = dali.ops.BrightnessContrast()
y = bc(inputs['x'], brightness=twist_rng(), contrast=twist_rng())
```
"""
def __new__(cls, **kwargs):
"""Create a ``BrightnessContrastBrightnessContrast`` operator.
Returns
-------
nvidia.dali.ops.BrightnessContrast
The operator.
"""
return ops.BrightnessContrast(device=context.get_device_type(), **kwargs)
class CropMirrorNormalize(object):
"""Crop and normalize image with the horizontal flip.
Examples:
```python
flip_rng = dali.ops.CoinFlip(0.5)
cmn = dali.ops.CropMirrorNormalize(
# Match the number of spatial dims
# (H, W) for 2d input
# (D, H, W) for 3d input
crop=(224, 224),
# Historical values to normalize input
mean=(102., 115., 122.),
std=(1., 1., 1.),
# Or ``float16`` for fp16 training
dtype='float32',
# Or ``NHWC``
output_layout='NCHW'
)
y = cmn(inputs['x'], mirror=flip_rng())
```
"""
def __new__(
cls,
crop=None,
mirror=None,
mean=0.,
std=1.,
dtype='float32',
output_layout='NCHW',
**kwargs
):
"""Create a ``CropMirrorNormalize`` operator.
Parameters
----------
crop : Sequence[int], optional
The cropped spatial dimensions for output.
mirror : {0, 1}, optional
Whether to apply the horizontal flip.
mean : Union[float, Sequence[float]], optional
The values to subtract.
std : Union[float, Sequence[float]], optional
The values to divide after subtraction.
dtype : {'float16', 'float32'}, optional
The data type of output.
output_layout : {'NCHW', 'NHWC'}, optional
The data format of output.
Returns
-------
nvidia.dali.ops.CropMirrorNormalize
The operator.
"""
if isinstance(dtype, six.string_types):
dtype = getattr(types, dtype.upper())
if isinstance(output_layout, six.string_types):
output_layout = getattr(types, output_layout.upper())
return ops.CropMirrorNormalize(
crop=crop,
mirror=mirror,
mean=mean,
std=std,
dtype=dtype,
output_layout=output_layout,
device=context.get_device_type(),
**kwargs
)
class Hsv(object):
"""Adjust the hue and saturation.
Examples:
```python
# Historical jitter range for saturation
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
hsv = dali.ops.Hsv()
y = hsv(inputs['x'], saturation=twist_rng())
```
"""
def __new__(cls, **kwargs):
"""Create a ``Hsv`` operator.
Returns
-------
nvidia.dali.ops.Hsv
The operator.
"""
return ops.Hsv(device=context.get_device_type(), **kwargs)
class Paste(object):
"""Copy image into a larger canvas.
Examples:
```python
paste = dali.ops.Paste(
# The image channels
n_channels=3,
# Historical values before mean subtraction
fill_value=(102., 115., 122.),
)
paste_pos = dali.ops.Uniform((0., 1.))
paste_ratio = dali.ops.Uniform((0., 3.))
paste_prob = dali.ops.CoinFlip(0.5)
y = paste(
inputs['x'],
# Expand ratio
ratio=paste_ratio() * paste_prob() + 1.,
# PosX, PosY
paste_x=paste_pos(),
paste_y=paste_pos(),
)
```
"""
def __new__(
cls,
n_channels=3,
fill_value=(0., 0., 0.),
ratio=None,
paste_x=None,
paste_y=None,
**kwargs
):
"""Create a ``Paste`` operator.
Parameters
----------
n_channels : int, optional, default=3
The image channels.
fill_value : Sequence[number], optional
The value(s) to fill for the canvas.
ratio : int, optional
The expand ratio.
paste_x : int, optional
The paste position at x-axis.
paste_y : int, optional
The paste position at y-axis.
Returns
-------
nvidia.dali.ops.Paste
The operator.
"""
return ops.Paste(
n_channels=n_channels,
fill_value=fill_value,
ratio=ratio,
paste_x=paste_x,
paste_y=paste_y,
device=context.get_device_type(),
**kwargs
)
class RandomBBoxCrop(object):
"""Return an valid image crop restricted by bounding boxes.
Examples:
```python
bbox_crop = dali.ops.RandomBBoxCrop(
# Range of scale
scaling=[0.3, 1.0],
# Range of aspect ratio
aspect_ratio=[0.5, 2.0],
# Minimum IoUs to satisfy
thresholds=[0.0, 0.1, 0.3, 0.5, 0.7, 0.9],
)
crop_begin, crop_size, bbox, label = bbox_crop(inputs['bbox'], inputs['label'])
```
"""
def __new__(
cls,
scaling=(0.3, 1.0),
aspect_ratio=(0.5, 2.0),
thresholds=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9),
allow_no_crop=True,
ltrb=True,
num_attempts=10,
**kwargs
):
"""Create a ``RandomBBoxCrop`` operator.
Parameters
----------
scaling : Sequence[float], optional, default=(0.3, 1.0)
The range of scale for sampling regions.
aspect_ratio : Sequence[float], optional, default=(0.5, 2.0)
The range of aspect ratio for sampling regions.
thresholds : Sequence[float], optional
The minimum IoU(s) to satisfy.
allow_no_crop : bool, optional, default=True
**True** to include the no-cropping as a option.
ltrb : bool, optional, default=True
Indicate the bbox is ``ltrb`` or ``xywh`` format.
num_attempts : int, optional, default=10
The max number of sampling trails.
Returns
-------
nvidia.dali.ops.RandomBBoxCrop
The operator.
"""
return ops.RandomBBoxCrop(
scaling=scaling,
aspect_ratio=aspect_ratio,
thresholds=thresholds,
allow_no_crop=allow_no_crop,
ltrb=ltrb,
num_attempts=num_attempts,
device='cpu',
**kwargs
)
class Resize(object):
"""Resize the image.
Examples:
```python
# Resize to a fixed area
resize1 = dali.ops.Resize(resize_x=300, resize_y=300)
# Resize along the shorter side
resize2 = dali.ops.Resize(resize_shorter=600, max_size=1000)
# Resize along the longer side
resize3 = dali.ops.Resize(resize_longer=512)
```
"""
def __new__(
cls,
resize_x=None,
resize_y=None,
resize_shorter=None,
resize_longer=None,
max_size=None,
interp_type='TRIANGULAR',
):
"""Create a ``Resize`` operator.
Parameters
----------
resize_x : int, optional
The output image width.
resize_y : int, optional
The output image height.
resize_shorter : int, optional
Resize along the shorter side and limited by ``max_size``.
resize_longer : int, optional
Resize along the longer side.
max_size : int, optional, default=0
The limited size for ``resize_shorter``.
interp_type : {'NN', 'LINEAR', 'TRIANGULAR', 'CUBIC', 'GAUSSIAN', 'LANCZOS3'}, optional
The interpolation method.
"""
if isinstance(interp_type, six.string_types):
interp_type = getattr(types, 'INTERP_' + interp_type.upper())
return ops.Resize(
resize_x=resize_x,
resize_y=resize_y,
resize_shorter=resize_shorter,
resize_longer=resize_longer,
max_size=max_size,
interp_type=interp_type,
device=context.get_device_type(),
)
...@@ -16,7 +16,8 @@ from __future__ import print_function ...@@ -16,7 +16,8 @@ from __future__ import print_function
try: try:
from nvidia.dali import ops from nvidia.dali import ops
except ImportError: except ImportError:
ops = None from dragon.core.util import deprecation
ops = deprecation.not_installed('nvidia.dali')
class CoinFlip(object): class CoinFlip(object):
...@@ -31,7 +32,7 @@ class CoinFlip(object): ...@@ -31,7 +32,7 @@ class CoinFlip(object):
""" """
def __new__(cls, probability=0.5): def __new__(cls, probability=0.5, **kwargs):
"""Create a ``CoinFlip`` operator. """Create a ``CoinFlip`` operator.
Parameters Parameters
...@@ -45,11 +46,11 @@ class CoinFlip(object): ...@@ -45,11 +46,11 @@ class CoinFlip(object):
The operator. The operator.
""" """
return ops.CoinFlip(probability=probability) return ops.CoinFlip(probability=probability, **kwargs)
class Uniform(object): class Uniform(object):
"""Sample values from a uniform distribution. """Sample values from an uniform distribution.
Examples: Examples:
...@@ -60,7 +61,7 @@ class Uniform(object): ...@@ -60,7 +61,7 @@ class Uniform(object):
""" """
def __new__(cls, range=(-1., 1.)): def __new__(cls, range=(-1., 1.), **kwargs):
"""Create an ``Uniform`` operator. """Create an ``Uniform`` operator.
Parameters Parameters
...@@ -74,4 +75,4 @@ class Uniform(object): ...@@ -74,4 +75,4 @@ class Uniform(object):
The operator. The operator.
""" """
return ops.Uniform(range=range) return ops.Uniform(range=range, **kwargs)
...@@ -19,16 +19,16 @@ import os ...@@ -19,16 +19,16 @@ import os
try: try:
from nvidia.dali import ops from nvidia.dali import ops
from nvidia.dali import tfrecord as tfrec from nvidia.dali import tfrecord
except ImportError: except ImportError:
from dragon.core.util import deprecation from dragon.core.util import deprecation
ops = deprecation.NotInstalled('nvidia.dali') ops = deprecation.NotInstalled('nvidia.dali')
tfrec = deprecation.NotInstalled('nvidia.dali') tfrecord = deprecation.NotInstalled('nvidia.dali')
from dragon.core.io import reader from dragon.core.io import reader
from dragon.core.io import kpl_record from dragon.core.io import kpl_record
from dragon.vm.dali.core import context from dragon.vm.dali.core.framework import context
from dragon.vm.dali.core.ops.builtin import ExternalSource from dragon.vm.dali.core.ops.builtin_ops import ExternalSource
class KPLRecordReader(object): class KPLRecordReader(object):
...@@ -73,6 +73,7 @@ class KPLRecordReader(object): ...@@ -73,6 +73,7 @@ class KPLRecordReader(object):
num_shards=1, num_shards=1,
shuffle_after_epoch=False, shuffle_after_epoch=False,
shuffle_chunks=0, shuffle_chunks=0,
**kwargs
): ):
"""Create a ``KPLRecordReader``. """Create a ``KPLRecordReader``.
...@@ -100,6 +101,7 @@ class KPLRecordReader(object): ...@@ -100,6 +101,7 @@ class KPLRecordReader(object):
num_parts=num_shards, num_parts=num_shards,
shuffle=shuffle_after_epoch, shuffle=shuffle_after_epoch,
num_chunks=shuffle_chunks, num_chunks=shuffle_chunks,
**kwargs
) )
self._buffer = self._reader.q_out = mp.Queue( self._buffer = self._reader.q_out = mp.Queue(
self._prefetch_depth * self._batch_size) self._prefetch_depth * self._batch_size)
...@@ -186,6 +188,7 @@ class TFRecordReader(object): ...@@ -186,6 +188,7 @@ class TFRecordReader(object):
num_shards=1, num_shards=1,
random_shuffle=False, random_shuffle=False,
initial_fill=1024, initial_fill=1024,
**kwargs
): ):
"""Create a ``TFRecordReader``. """Create a ``TFRecordReader``.
...@@ -217,6 +220,7 @@ class TFRecordReader(object): ...@@ -217,6 +220,7 @@ class TFRecordReader(object):
features=features, features=features,
random_shuffle=random_shuffle, random_shuffle=random_shuffle,
initial_fill=initial_fill, initial_fill=initial_fill,
**kwargs
) )
@staticmethod @staticmethod
...@@ -232,7 +236,10 @@ class TFRecordReader(object): ...@@ -232,7 +236,10 @@ class TFRecordReader(object):
if features_file is None: if features_file is None:
raise FileNotFoundError('File <FEATURES> is missing.') raise FileNotFoundError('File <FEATURES> is missing.')
with open(os.path.join(path, features_file), 'r') as f: with open(os.path.join(path, features_file), 'r') as f:
features = eval(f.read().replace('tf.', 'tfrec.')) features = f.read()
features = features.replace('tf.', 'tfrecord.')
features = features.replace('tf.io.', 'tfrecord.')
features = eval(features)
data_files.sort() data_files.sort()
index_files.sort() index_files.sort()
data = [os.path.join(path, e) for e in data_files] data = [os.path.join(path, e) for e in data_files]
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from nvidia.dali import ops
except ImportError:
ops = None
from dragon.core.util import six
from dragon.vm.dali.core import context
from dragon.vm.dali.core import types
class Resize(object):
"""Resize the image.
Examples:
```python
# Resize to a fixed area
resize1 = dali.ops.Resize(resize_x=300, resize_y=300)
# Resize along the shorter side
resize2 = dali.ops.Resize(resize_shorter=600, max_size=1000)
# Resize along the longer side
resize3 = dali.ops.Resize(resize_longer=512)
```
"""
def __new__(
cls,
resize_x=None,
resize_y=None,
resize_shorter=None,
resize_longer=None,
max_size=None,
interp_type='TRIANGULAR',
):
"""Create a ``Resize`` operator.
Parameters
----------
resize_x : int, optional
The output image width.
resize_y : int, optional
The output image height.
resize_shorter : int, optional
Resize along the shorter side and limited by ``max_size``.
resize_longer : int, optional
Resize along the longer side.
max_size : int, optional, default=0
The limited size for ``resize_shorter``.
interp_type : {'NN', 'LINEAR', 'TRIANGULAR', 'CUBIC', 'GAUSSIAN', 'LANCZOS3'}, optional
The interpolation method.
"""
if isinstance(interp_type, six.string_types):
interp_type = getattr(types, 'INTERP_' + interp_type.upper())
return ops.Resize(
resize_x=resize_x,
resize_y=resize_y,
resize_shorter=resize_shorter,
resize_longer=resize_longer,
max_size=max_size,
interp_type=interp_type,
device=context.get_device_type(),
)
...@@ -35,14 +35,7 @@ master_doc = 'index' ...@@ -35,14 +35,7 @@ master_doc = 'index'
source_suffix = '.rst' source_suffix = '.rst'
# Extension # Extension
extensions = [ extensions = ['sphinx.ext.autodoc', 'sphinxcontrib.katex', 'breathe']
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
'sphinxcontrib.katex',
'breathe',
]
napoleon_use_rtype = False
# Project # Project
project = 'dragon' project = 'dragon'
......
...@@ -22,7 +22,6 @@ NUMBER_OF_PROCESSORS:=$(shell getconf _NPROCESSORS_ONLN) ...@@ -22,7 +22,6 @@ NUMBER_OF_PROCESSORS:=$(shell getconf _NPROCESSORS_ONLN)
help: help:
@echo "Please use \`make <target>' where <target> is one of" @echo "Please use \`make <target>' where <target> is one of"
@echo " html to make standalone HTML files" @echo " html to make standalone HTML files"
@echo " debughtml to make debugging HTML files"
@echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
@echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdf to make LaTeX files and run them through pdflatex"
...@@ -30,11 +29,6 @@ clean: ...@@ -30,11 +29,6 @@ clean:
rm -rf $(BUILDDIR)/* rm -rf $(BUILDDIR)/*
html: html:
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)."
debughtml:
$(SPHINXBUILD) -b html -j ${NUMBER_OF_PROCESSORS} $(ALLSPHINXOPTS) $(BUILDDIR) $(SPHINXBUILD) -b html -j ${NUMBER_OF_PROCESSORS} $(ALLSPHINXOPTS) $(BUILDDIR)
@echo @echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)." @echo "Build finished. The HTML pages are in $(BUILDDIR)."
......
...@@ -25,6 +25,9 @@ vm.caffe ...@@ -25,6 +25,9 @@ vm.caffe
: The Momentum-SGD solver. : The Momentum-SGD solver.
`[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_. `[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_.
`class Solver <caffe/Solver.html>`_
: The base solver class to optimize parameters.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -33,6 +36,7 @@ vm.caffe ...@@ -33,6 +36,7 @@ vm.caffe
caffe/Net caffe/Net
caffe/RMSPropSolver caffe/RMSPropSolver
caffe/SGDSolver caffe/SGDSolver
caffe/Solver
.. raw:: html .. raw:: html
......
...@@ -12,30 +12,36 @@ Properties ...@@ -12,30 +12,36 @@ Properties
base_lr base_lr
####### #######
.. autoattribute:: dragon.vm.caffe.AdamSolver.base_lr .. autoattribute:: dragon.vm.caffe.Solver.base_lr
:noindex:
iter iter
#### ####
.. autoattribute:: dragon.vm.caffe.AdamSolver.iter .. autoattribute:: dragon.vm.caffe.Solver.iter
:noindex:
net net
### ###
.. autoattribute:: dragon.vm.caffe.AdamSolver.net .. autoattribute:: dragon.vm.caffe.Solver.net
:noindex:
test_nets test_nets
######### #########
.. autoattribute:: dragon.vm.caffe.AdamSolver.test_nets .. autoattribute:: dragon.vm.caffe.Solver.test_nets
:noindex:
Methods Methods
------- -------
snapshot snapshot
######## ########
.. automethod:: dragon.vm.caffe.AdamSolver.snapshot .. automethod:: dragon.vm.caffe.Solver.snapshot
:noindex:
step step
######## ########
.. automethod:: dragon.vm.caffe.AdamSolver.step .. automethod:: dragon.vm.caffe.Solver.step
:noindex:
.. raw:: html .. raw:: html
......
...@@ -12,30 +12,36 @@ Properties ...@@ -12,30 +12,36 @@ Properties
base_lr base_lr
####### #######
.. autoattribute:: dragon.vm.caffe.NesterovSolver.base_lr .. autoattribute:: dragon.vm.caffe.Solver.base_lr
:noindex:
iter iter
#### ####
.. autoattribute:: dragon.vm.caffe.NesterovSolver.iter .. autoattribute:: dragon.vm.caffe.Solver.iter
:noindex:
net net
### ###
.. autoattribute:: dragon.vm.caffe.NesterovSolver.net .. autoattribute:: dragon.vm.caffe.Solver.net
:noindex:
test_nets test_nets
######### #########
.. autoattribute:: dragon.vm.caffe.NesterovSolver.test_nets .. autoattribute:: dragon.vm.caffe.Solver.test_nets
:noindex:
Methods Methods
------- -------
snapshot snapshot
######## ########
.. automethod:: dragon.vm.caffe.NesterovSolver.snapshot .. automethod:: dragon.vm.caffe.Solver.snapshot
:noindex:
step step
######## ########
.. automethod:: dragon.vm.caffe.NesterovSolver.step .. automethod:: dragon.vm.caffe.Solver.step
:noindex:
.. raw:: html .. raw:: html
......
...@@ -12,30 +12,36 @@ Properties ...@@ -12,30 +12,36 @@ Properties
base_lr base_lr
####### #######
.. autoattribute:: dragon.vm.caffe.RMSPropSolver.base_lr .. autoattribute:: dragon.vm.caffe.Solver.base_lr
:noindex:
iter iter
#### ####
.. autoattribute:: dragon.vm.caffe.RMSPropSolver.iter .. autoattribute:: dragon.vm.caffe.Solver.iter
:noindex:
net net
### ###
.. autoattribute:: dragon.vm.caffe.RMSPropSolver.net .. autoattribute:: dragon.vm.caffe.Solver.net
:noindex:
test_nets test_nets
######### #########
.. autoattribute:: dragon.vm.caffe.RMSPropSolver.test_nets .. autoattribute:: dragon.vm.caffe.Solver.test_nets
:noindex:
Methods Methods
------- -------
snapshot snapshot
######## ########
.. automethod:: dragon.vm.caffe.RMSPropSolver.snapshot .. automethod:: dragon.vm.caffe.Solver.snapshot
:noindex:
step step
######## ########
.. automethod:: dragon.vm.caffe.RMSPropSolver.step .. automethod:: dragon.vm.caffe.Solver.step
:noindex:
.. raw:: html .. raw:: html
......
...@@ -12,30 +12,36 @@ Properties ...@@ -12,30 +12,36 @@ Properties
base_lr base_lr
####### #######
.. autoattribute:: dragon.vm.caffe.SGDSolver.base_lr .. autoattribute:: dragon.vm.caffe.Solver.base_lr
:noindex:
iter iter
#### ####
.. autoattribute:: dragon.vm.caffe.SGDSolver.iter .. autoattribute:: dragon.vm.caffe.Solver.iter
:noindex:
net net
### ###
.. autoattribute:: dragon.vm.caffe.SGDSolver.net .. autoattribute:: dragon.vm.caffe.Solver.net
:noindex:
test_nets test_nets
######### #########
.. autoattribute:: dragon.vm.caffe.SGDSolver.test_nets .. autoattribute:: dragon.vm.caffe.Solver.test_nets
:noindex:
Methods Methods
------- -------
snapshot snapshot
######## ########
.. automethod:: dragon.vm.caffe.SGDSolver.snapshot .. automethod:: dragon.vm.caffe.Solver.snapshot
:noindex:
step step
######## ########
.. automethod:: dragon.vm.caffe.SGDSolver.step .. automethod:: dragon.vm.caffe.Solver.step
:noindex:
.. raw:: html .. raw:: html
......
Solver
======
.. autoclass:: dragon.vm.caffe.Solver
__init__
--------
.. automethod:: dragon.vm.caffe.Solver.__init__
Properties
----------
base_lr
#######
.. autoattribute:: dragon.vm.caffe.Solver.base_lr
iter
####
.. autoattribute:: dragon.vm.caffe.Solver.iter
net
###
.. autoattribute:: dragon.vm.caffe.Solver.net
test_nets
#########
.. autoattribute:: dragon.vm.caffe.Solver.test_nets
Methods
-------
snapshot
########
.. automethod:: dragon.vm.caffe.Solver.snapshot
step
########
.. automethod:: dragon.vm.caffe.Solver.step
.. raw:: html
<style>
h1:before {
content: "caffe.";
color: #103d3e;
}
</style>
...@@ -33,9 +33,9 @@ source_suffix = '.rst' ...@@ -33,9 +33,9 @@ source_suffix = '.rst'
# Extension # Extension
extensions = [ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinxcontrib.katex', 'sphinxcontrib.katex',
'sphinx_seeta_theme.ext.viewcode',
] ]
napoleon_use_rtype = False napoleon_use_rtype = False
......
...@@ -13,19 +13,22 @@ vm.dali.ops ...@@ -13,19 +13,22 @@ vm.dali.ops
: Transform bounding boxes to match the ``Paste`` operator. : Transform bounding boxes to match the ``Paste`` operator.
`class BrightnessContrast <ops/BrightnessContrast.html>`_ `class BrightnessContrast <ops/BrightnessContrast.html>`_
: Adjust the brightness and contrast. : Adjust the brightness and contrast of image.
`class Cast <ops/Cast.html>`_ `class Cast <ops/Cast.html>`_
: Cast the data type of input. : Cast the data type of input.
`class CoinFlip <ops/CoinFlip.html>`_
: Sample values from a bernoulli distribution.
`class CropMirrorNormalize <ops/CropMirrorNormalize.html>`_ `class CropMirrorNormalize <ops/CropMirrorNormalize.html>`_
: Crop and normalize input with the horizontal flip. : Crop and normalize image with the horizontal flip.
`class ExternalSource <ops/Cast.html>`_ `class ExternalSource <ops/Cast.html>`_
: Create a placeholder providing data from feeding. : Create a placeholder providing data from feeding.
`class Hsv <ops/Hsv.html>`_ `class Hsv <ops/Hsv.html>`_
: Adjust the hue and saturation. : Adjust the hue and saturation of image.
`class ImageDecoder <ops/ImageDecoder.html>`_ `class ImageDecoder <ops/ImageDecoder.html>`_
: Decode image from bytes. : Decode image from bytes.
...@@ -40,7 +43,7 @@ vm.dali.ops ...@@ -40,7 +43,7 @@ vm.dali.ops
: Copy image into a larger canvas. : Copy image into a larger canvas.
`class RandomBBoxCrop <ops/RandomBBoxCrop.html>`_ `class RandomBBoxCrop <ops/RandomBBoxCrop.html>`_
: Return a valid crop restricted by bounding boxes. : Return an valid image crop restricted by bounding boxes.
`class Reshape <ops/Reshape.html>`_ `class Reshape <ops/Reshape.html>`_
: Change the dimensions of input. : Change the dimensions of input.
...@@ -58,7 +61,7 @@ vm.dali.ops ...@@ -58,7 +61,7 @@ vm.dali.ops
: Read examples from the tf-record file. : Read examples from the tf-record file.
`class Uniform <ops/Uniform.html>`_ `class Uniform <ops/Uniform.html>`_
: Select an interval of elements from input. : Sample values from an uniform distribution.
.. toctree:: .. toctree::
:hidden: :hidden:
......
...@@ -175,8 +175,23 @@ __truediv__ ...@@ -175,8 +175,23 @@ __truediv__
.. _dragon.assign(...): assign.html .. _dragon.assign(...): assign.html
.. _dragon.cast(...): cast.html .. _dragon.cast(...): cast.html
.. _dragon.copy(...): copy.html .. _dragon.copy(...): copy.html
.. _dragon.fill(...): fill.html
.. _dragon.masked_assign(...): masked_assign.html .. _dragon.masked_assign(...): masked_assign.html
.. _dragon.masked_select(...): masked_select.html .. _dragon.masked_select(...): masked_select.html
.. _dragon.math.add(...): math/add.html
.. _dragon.math.div(...): math/div.html
.. _dragon.math.greater(...): math/greater.html
.. _dragon.math.greater_equal(...): math/greater_equal.html
.. _dragon.math.less(...): math/less.html
.. _dragon.math.less_equal(...): math/less_equal.html
.. _dragon.math.mul(...): math/mul.html
.. _dragon.math.negative(...): math/negative.html
.. _dragon.math.sub(...): math/sub.html
.. _dragon.random.glorot_normal(...): random/glorot_normal.html
.. _dragon.random.glorot_uniform(...): random/glorot_uniform.html
.. _dragon.random.normal(...): random/normal.html
.. _dragon.random.truncated_normal(...): random/truncated_normal.html
.. _dragon.random.uniform(...): random/uniform.html
.. _dragon.reshape(...): reshape.html .. _dragon.reshape(...): reshape.html
.. _dragon.slice(...): slice.html .. _dragon.slice(...): slice.html
......
...@@ -155,8 +155,23 @@ __truediv__ ...@@ -155,8 +155,23 @@ __truediv__
.. _dragon.assign(...): assign.html .. _dragon.assign(...): assign.html
.. _dragon.cast(...): cast.html .. _dragon.cast(...): cast.html
.. _dragon.copy(...): copy.html .. _dragon.copy(...): copy.html
.. _dragon.fill(...): fill.html
.. _dragon.masked_assign(...): masked_assign.html .. _dragon.masked_assign(...): masked_assign.html
.. _dragon.masked_select(...): masked_select.html .. _dragon.masked_select(...): masked_select.html
.. _dragon.math.add(...): math/add.html
.. _dragon.math.div(...): math/div.html
.. _dragon.math.greater(...): math/greater.html
.. _dragon.math.greater_equal(...): math/greater_equal.html
.. _dragon.math.less(...): math/less.html
.. _dragon.math.less_equal(...): math/less_equal.html
.. _dragon.math.mul(...): math/mul.html
.. _dragon.math.negative(...): math/negative.html
.. _dragon.math.sub(...): math/sub.html
.. _dragon.random.glorot_normal(...): random/glorot_normal.html
.. _dragon.random.glorot_uniform(...): random/glorot_uniform.html
.. _dragon.random.normal(...): random/normal.html
.. _dragon.random.truncated_normal(...): random/truncated_normal.html
.. _dragon.random.uniform(...): random/uniform.html
.. _dragon.reshape(...): reshape.html .. _dragon.reshape(...): reshape.html
.. _dragon.slice(...): slice.html .. _dragon.slice(...): slice.html
......
...@@ -20,7 +20,6 @@ if "%1" == "help" ( ...@@ -20,7 +20,6 @@ if "%1" == "help" (
:help :help
echo.Please use `make ^<target^>` where ^<target^> is one of echo.Please use `make ^<target^>` where ^<target^> is one of
echo. html to make standalone HTML files echo. html to make standalone HTML files
echo. debughtml to make debugging HTML files
echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
echo. latexpdf to make LaTeX files and run them through pdflatex echo. latexpdf to make LaTeX files and run them through pdflatex
goto end goto end
...@@ -56,14 +55,6 @@ if errorlevel 9009 ( ...@@ -56,14 +55,6 @@ if errorlevel 9009 (
:sphinx_ok :sphinx_ok
if "%1" == "html" ( if "%1" == "html" (
%SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%.
goto end
)
if "%1" == "debughtml" (
%SPHINXBUILD% -b html -j %NUMBER_OF_PROCESSORS% %ALLSPHINXOPTS% %BUILDDIR% %SPHINXBUILD% -b html -j %NUMBER_OF_PROCESSORS% %ALLSPHINXOPTS% %BUILDDIR%
if errorlevel 1 exit /b 1 if errorlevel 1 exit /b 1
echo. echo.
......
...@@ -12,7 +12,8 @@ Methods ...@@ -12,7 +12,8 @@ Methods
__call__ __call__
######## ########
.. automethod:: dragon.vm.tensorflow.keras.initializers.GlorotNormal.__call__ .. automethod:: dragon.vm.tensorflow.keras.initializers.VarianceScaling.__call__
:noindex:
.. raw:: html .. raw:: html
......
...@@ -12,7 +12,8 @@ Methods ...@@ -12,7 +12,8 @@ Methods
__call__ __call__
######## ########
.. automethod:: dragon.vm.tensorflow.keras.initializers.GlorotUniform.__call__ .. automethod:: dragon.vm.tensorflow.keras.initializers.VarianceScaling.__call__
:noindex:
.. raw:: html .. raw:: html
......
...@@ -479,6 +479,7 @@ zero\_ ...@@ -479,6 +479,7 @@ zero\_
.. _torch.mul(...): mul.html .. _torch.mul(...): mul.html
.. _torch.ne(...): ne.html .. _torch.ne(...): ne.html
.. _torch.neg(...): neg.html .. _torch.neg(...): neg.html
.. _torch.nonzero(...): nonzero.html
.. _torch.pow(...): pow.html .. _torch.pow(...): pow.html
.. _torch.reciprocal(...): reciprocal.html .. _torch.reciprocal(...): reciprocal.html
.. _torch.reshape(...): reshape.html .. _torch.reshape(...): reshape.html
...@@ -487,8 +488,12 @@ zero\_ ...@@ -487,8 +488,12 @@ zero\_
.. _torch.sign(...): sign.html .. _torch.sign(...): sign.html
.. _torch.sin(...): sin.html .. _torch.sin(...): sin.html
.. _torch.sqrt(...): sqrt.html .. _torch.sqrt(...): sqrt.html
.. _torch.squeeze(...): squeeze.html
.. _torch.sub(...): sub.html .. _torch.sub(...): sub.html
.. _torch.sum(...): sum.html
.. _torch.topk(...): topk.html .. _torch.topk(...): topk.html
.. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html
.. raw:: html .. raw:: html
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Affine.__init__ .. automethod:: dragon.vm.torch.nn.Affine.__init__
.. _torch.nn.functional.affine(...): functional/affine.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.AvgPool2d.__init__ .. automethod:: dragon.vm.torch.nn.AvgPool2d.__init__
.. _torch.nn.functional.avg_pool2d(...): functional/avg_pool2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.BCEWithLogitsLoss.__init__ .. automethod:: dragon.vm.torch.nn.BCEWithLogitsLoss.__init__
.. _torch.nn.functional.binary_cross_entropy_with_logits(...): functional/binary_cross_entropy_with_logits.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.BatchNorm1d.__init__ .. automethod:: dragon.vm.torch.nn.BatchNorm1d.__init__
.. _torch.nn.functional.batch_norm(...): functional/batch_norm.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.BatchNorm2d.__init__ .. automethod:: dragon.vm.torch.nn.BatchNorm2d.__init__
.. _torch.nn.functional.batch_norm(...): functional/batch_norm.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.BatchNorm3d.__init__ .. automethod:: dragon.vm.torch.nn.BatchNorm3d.__init__
.. _torch.nn.functional.batch_norm(...): functional/batch_norm.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.CTCLoss.__init__ .. automethod:: dragon.vm.torch.nn.CTCLoss.__init__
.. _torch.nn.functional.ctc_loss(...): functional/ctc_loss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ConstantPad1d.__init__ .. automethod:: dragon.vm.torch.nn.ConstantPad1d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ConstantPad2d.__init__ .. automethod:: dragon.vm.torch.nn.ConstantPad2d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ConstantPad3d.__init__ .. automethod:: dragon.vm.torch.nn.ConstantPad3d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Conv2d.__init__ .. automethod:: dragon.vm.torch.nn.Conv2d.__init__
.. _torch.nn.functional.conv2d(...): functional/conv2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ConvTranspose2d.__init__ .. automethod:: dragon.vm.torch.nn.ConvTranspose2d.__init__
.. _torch.nn.functional.conv_transpose2d(...): functional/conv_transpose2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.CrossEntropyLoss.__init__ .. automethod:: dragon.vm.torch.nn.CrossEntropyLoss.__init__
.. _torch.nn.functional.cross_entropy(...): functional/cross_entropy.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.DepthwiseConv2d.__init__ .. automethod:: dragon.vm.torch.nn.DepthwiseConv2d.__init__
.. _torch.nn.functional.depthwise_conv2d(...): functional/depthwise_conv2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.DropBlock2d.__init__ .. automethod:: dragon.vm.torch.nn.DropBlock2d.__init__
.. _torch.nn.functional.drop_block2d(...): functional/drop_block2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.DropPath.__init__ .. automethod:: dragon.vm.torch.nn.DropPath.__init__
.. _torch.nn.functional.drop_path(...): functional/drop_path.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Dropout.__init__ .. automethod:: dragon.vm.torch.nn.Dropout.__init__
.. _torch.nn.functional.dropout(...): functional/dropout.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,7 @@ __init__ ...@@ -7,6 +7,7 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.GroupNorm.__init__ .. automethod:: dragon.vm.torch.nn.GroupNorm.__init__
.. _torch.nn.functional.group_norm(...): functional/group_norm.html
.. raw:: html .. raw:: html
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.L1Loss.__init__ .. automethod:: dragon.vm.torch.nn.L1Loss.__init__
.. _torch.nn.functional.l1_loss(...): functional/l1_loss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Linear.__init__ .. automethod:: dragon.vm.torch.nn.Linear.__init__
.. _torch.nn.functional.linear(...): functional/linear.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.LogSoftmax.__init__ .. automethod:: dragon.vm.torch.nn.LogSoftmax.__init__
.. _torch.nn.functional.log_softmax(...): functional/log_softmax.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.MSELoss.__init__ .. automethod:: dragon.vm.torch.nn.MSELoss.__init__
.. _torch.nn.functional.mse_loss(...): functional/mse_loss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.MaxPool2d.__init__ .. automethod:: dragon.vm.torch.nn.MaxPool2d.__init__
.. _torch.nn.functional.max_pool2d(...): functional/max_pool2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.NLLLoss.__init__ .. automethod:: dragon.vm.torch.nn.NLLLoss.__init__
.. _torch.nn.functional.nll_loss(...): functional/nll_loss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.PReLU.__init__ .. automethod:: dragon.vm.torch.nn.PReLU.__init__
.. _torch.nn.functional.prelu(...): functional/prelu.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ReLU.__init__ .. automethod:: dragon.vm.torch.nn.ReLU.__init__
.. _torch.nn.functional.relu(...): functional/relu.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ReLU6.__init__ .. automethod:: dragon.vm.torch.nn.ReLU6.__init__
.. _torch.nn.functional.relu6(...): functional/relu6.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ReflectionPad1d.__init__ .. automethod:: dragon.vm.torch.nn.ReflectionPad1d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ReflectionPad2d.__init__ .. automethod:: dragon.vm.torch.nn.ReflectionPad2d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ReflectionPad3d.__init__ .. automethod:: dragon.vm.torch.nn.ReflectionPad3d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ReplicationPad1d.__init__ .. automethod:: dragon.vm.torch.nn.ReplicationPad1d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ReplicationPad2d.__init__ .. automethod:: dragon.vm.torch.nn.ReplicationPad2d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ReplicationPad3d.__init__ .. automethod:: dragon.vm.torch.nn.ReplicationPad3d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.SELU.__init__ .. automethod:: dragon.vm.torch.nn.SELU.__init__
.. _torch.nn.functional.selu(...): functional/selu.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Sigmoid.__init__ .. automethod:: dragon.vm.torch.nn.Sigmoid.__init__
.. _torch.nn.functional.sigmoid(...): functional/sigmoid.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.SigmoidFocalLoss.__init__ .. automethod:: dragon.vm.torch.nn.SigmoidFocalLoss.__init__
.. _torch.nn.functional.sigmoid_focal_loss(...): functional/sigmoid_focal_loss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.SmoothL1Loss.__init__ .. automethod:: dragon.vm.torch.nn.SmoothL1Loss.__init__
.. _torch.nn.functional.smooth_l1_loss(...): functional/smooth_l1_loss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Softmax.__init__ .. automethod:: dragon.vm.torch.nn.Softmax.__init__
.. _torch.nn.functional.softmax(...): functional/softmax.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.SyncBatchNorm.__init__ .. automethod:: dragon.vm.torch.nn.SyncBatchNorm.__init__
.. _torch.nn.functional.batch_norm(...): functional/batch_norm.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Tanh.__init__ .. automethod:: dragon.vm.torch.nn.Tanh.__init__
.. _torch.nn.functional.tanh(...): functional/tanh.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Upsample.__init__ .. automethod:: dragon.vm.torch.nn.Upsample.__init__
.. _torch.nn.functional.interpolate(...): functional/interpolate.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.UpsamplingBilinear2d.__init__ .. automethod:: dragon.vm.torch.nn.UpsamplingBilinear2d.__init__
.. _torch.nn.functional.interpolate(...): functional/interpolate.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.UpsamplingNearest2d.__init__ .. automethod:: dragon.vm.torch.nn.UpsamplingNearest2d.__init__
.. _torch.nn.functional.interpolate(...): functional/interpolate.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,6 +7,8 @@ __init__ ...@@ -7,6 +7,8 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.ZeroPad2d.__init__ .. automethod:: dragon.vm.torch.nn.ZeroPad2d.__init__
.. _torch.nn.functional.pad(...): functional/pad.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ affine ...@@ -3,6 +3,8 @@ affine
.. autofunction:: dragon.vm.torch.nn.functional.affine .. autofunction:: dragon.vm.torch.nn.functional.affine
.. _torch.nn.Affine(...): ../Affine.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ avg_pool2d ...@@ -3,6 +3,8 @@ avg_pool2d
.. autofunction:: dragon.vm.torch.nn.functional.avg_pool2d .. autofunction:: dragon.vm.torch.nn.functional.avg_pool2d
.. _torch.nn.AvgPool2d(...): ../AvgPool2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ batch_norm ...@@ -3,6 +3,8 @@ batch_norm
.. autofunction:: dragon.vm.torch.nn.functional.batch_norm .. autofunction:: dragon.vm.torch.nn.functional.batch_norm
.. _torch.nn.BatchNorm2d(...): ../BatchNorm2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ binary_cross_entropy_with_logits ...@@ -3,6 +3,8 @@ binary_cross_entropy_with_logits
.. autofunction:: dragon.vm.torch.nn.functional.binary_cross_entropy_with_logits .. autofunction:: dragon.vm.torch.nn.functional.binary_cross_entropy_with_logits
.. _torch.nn.BCEWithLogitsLoss(...): ../BCEWithLogitsLoss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ conv2d ...@@ -3,6 +3,8 @@ conv2d
.. autofunction:: dragon.vm.torch.nn.functional.conv2d .. autofunction:: dragon.vm.torch.nn.functional.conv2d
.. _torch.nn.Conv2d(...): ../Conv2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ conv_transpose2d ...@@ -3,6 +3,8 @@ conv_transpose2d
.. autofunction:: dragon.vm.torch.nn.functional.conv_transpose2d .. autofunction:: dragon.vm.torch.nn.functional.conv_transpose2d
.. _torch.nn.ConvTranspose2d(...): ../ConvTranspose2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ cross_entropy ...@@ -3,6 +3,8 @@ cross_entropy
.. autofunction:: dragon.vm.torch.nn.functional.cross_entropy .. autofunction:: dragon.vm.torch.nn.functional.cross_entropy
.. _torch.nn.CrossEntropyLoss(...): ../CrossEntropyLoss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ ctc_loss ...@@ -3,6 +3,8 @@ ctc_loss
.. autofunction:: dragon.vm.torch.nn.functional.ctc_loss .. autofunction:: dragon.vm.torch.nn.functional.ctc_loss
.. _torch.nn.CTCLoss(...): ../CTCLoss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ depthwise_conv2d ...@@ -3,6 +3,8 @@ depthwise_conv2d
.. autofunction:: dragon.vm.torch.nn.functional.depthwise_conv2d .. autofunction:: dragon.vm.torch.nn.functional.depthwise_conv2d
.. _torch.nn.DepthwiseConv2d(...): ../DepthwiseConv2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ drop_block2d ...@@ -3,6 +3,8 @@ drop_block2d
.. autofunction:: dragon.vm.torch.nn.functional.drop_block2d .. autofunction:: dragon.vm.torch.nn.functional.drop_block2d
.. _torch.nn.DropBlock2d(...): ../DropBlock2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ drop_path ...@@ -3,6 +3,8 @@ drop_path
.. autofunction:: dragon.vm.torch.nn.functional.drop_path .. autofunction:: dragon.vm.torch.nn.functional.drop_path
.. _torch.nn.DropPath(...): ../DropPath.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ dropout ...@@ -3,6 +3,8 @@ dropout
.. autofunction:: dragon.vm.torch.nn.functional.dropout .. autofunction:: dragon.vm.torch.nn.functional.dropout
.. _torch.nn.Dropout(...): ../Dropout.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,7 +3,7 @@ elu ...@@ -3,7 +3,7 @@ elu
.. autofunction:: dragon.vm.torch.nn.functional.elu .. autofunction:: dragon.vm.torch.nn.functional.elu
.. _torch.nn.ELU: ../ELU.html .. _torch.nn.ELU(...): ../ELU.html
.. raw:: html .. raw:: html
......
...@@ -3,6 +3,8 @@ group_norm ...@@ -3,6 +3,8 @@ group_norm
.. autofunction:: dragon.vm.torch.nn.functional.group_norm .. autofunction:: dragon.vm.torch.nn.functional.group_norm
.. _torch.nn.GroupNorm(...): ../GroupNorm.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ interpolate ...@@ -3,6 +3,8 @@ interpolate
.. autofunction:: dragon.vm.torch.nn.functional.interpolate .. autofunction:: dragon.vm.torch.nn.functional.interpolate
.. _torch.nn.Upsample(...): ../Upsample.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ l1_loss ...@@ -3,6 +3,8 @@ l1_loss
.. autofunction:: dragon.vm.torch.nn.functional.l1_loss .. autofunction:: dragon.vm.torch.nn.functional.l1_loss
.. _torch.nn.L1Loss(...): ../L1Loss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,7 +3,7 @@ leaky_relu ...@@ -3,7 +3,7 @@ leaky_relu
.. autofunction:: dragon.vm.torch.nn.functional.leaky_relu .. autofunction:: dragon.vm.torch.nn.functional.leaky_relu
.. _torch.nn.LeakyReLU: ../LeakyReLU.html .. _torch.nn.LeakyReLU(...): ../LeakyReLU.html
.. raw:: html .. raw:: html
......
...@@ -3,6 +3,8 @@ linear ...@@ -3,6 +3,8 @@ linear
.. autofunction:: dragon.vm.torch.nn.functional.linear .. autofunction:: dragon.vm.torch.nn.functional.linear
.. _torch.nn.Linear(...): ../Linear.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,7 +3,7 @@ local_response_norm ...@@ -3,7 +3,7 @@ local_response_norm
.. autofunction:: dragon.vm.torch.nn.functional.local_response_norm .. autofunction:: dragon.vm.torch.nn.functional.local_response_norm
.. _torch.nn.LocalResponseNorm: ../LocalResponseNorm.html .. _torch.nn.LocalResponseNorm(...): ../LocalResponseNorm.html
.. raw:: html .. raw:: html
......
...@@ -3,6 +3,8 @@ log_softmax ...@@ -3,6 +3,8 @@ log_softmax
.. autofunction:: dragon.vm.torch.nn.functional.log_softmax .. autofunction:: dragon.vm.torch.nn.functional.log_softmax
.. _torch.nn.LogSoftmax(...): ../LogSoftmax.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ max_pool2d ...@@ -3,6 +3,8 @@ max_pool2d
.. autofunction:: dragon.vm.torch.nn.functional.max_pool2d .. autofunction:: dragon.vm.torch.nn.functional.max_pool2d
.. _torch.nn.MaxPool2d(...): ../MaxPool2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ mse_loss ...@@ -3,6 +3,8 @@ mse_loss
.. autofunction:: dragon.vm.torch.nn.functional.mse_loss .. autofunction:: dragon.vm.torch.nn.functional.mse_loss
.. _torch.nn.MSELoss(...): ../MSELoss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ nll_loss ...@@ -3,6 +3,8 @@ nll_loss
.. autofunction:: dragon.vm.torch.nn.functional.nll_loss .. autofunction:: dragon.vm.torch.nn.functional.nll_loss
.. _torch.nn.NLLLoss(...): ../NLLLoss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,10 @@ pad ...@@ -3,6 +3,10 @@ pad
.. autofunction:: dragon.vm.torch.nn.functional.pad .. autofunction:: dragon.vm.torch.nn.functional.pad
.. _torch.nn.ConstantPad2d(...): ../ConstantPad2d.html
.. _torch.nn.ReflectionPad2d(...): ../ReflectionPad2d.html
.. _torch.nn.ReplicationPad2d(...): ../ReplicationPad2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ prelu ...@@ -3,6 +3,8 @@ prelu
.. autofunction:: dragon.vm.torch.nn.functional.prelu .. autofunction:: dragon.vm.torch.nn.functional.prelu
.. _torch.nn.PReLU(...): ../PReLU.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ relu ...@@ -3,6 +3,8 @@ relu
.. autofunction:: dragon.vm.torch.nn.functional.relu .. autofunction:: dragon.vm.torch.nn.functional.relu
.. _torch.nn.ReLU(...): ../ReLU.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ relu6 ...@@ -3,6 +3,8 @@ relu6
.. autofunction:: dragon.vm.torch.nn.functional.relu6 .. autofunction:: dragon.vm.torch.nn.functional.relu6
.. _torch.nn.ReLU6(...): ../ReLU6.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ selu ...@@ -3,6 +3,8 @@ selu
.. autofunction:: dragon.vm.torch.nn.functional.selu .. autofunction:: dragon.vm.torch.nn.functional.selu
.. _torch.nn.SELU(...): ../SELU.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ sigmoid ...@@ -3,6 +3,8 @@ sigmoid
.. autofunction:: dragon.vm.torch.nn.functional.sigmoid .. autofunction:: dragon.vm.torch.nn.functional.sigmoid
.. _torch.nn.Sigmoid(...): ../Sigmoid.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ sigmoid_focal_loss ...@@ -3,6 +3,8 @@ sigmoid_focal_loss
.. autofunction:: dragon.vm.torch.nn.functional.sigmoid_focal_loss .. autofunction:: dragon.vm.torch.nn.functional.sigmoid_focal_loss
.. _torch.nn.SigmoidFocalLoss(...): ../SigmoidFocalLoss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ smooth_l1_loss ...@@ -3,6 +3,8 @@ smooth_l1_loss
.. autofunction:: dragon.vm.torch.nn.functional.smooth_l1_loss .. autofunction:: dragon.vm.torch.nn.functional.smooth_l1_loss
.. _torch.nn.SmoothL1Loss(...): ../SmoothL1Loss.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ softmax ...@@ -3,6 +3,8 @@ softmax
.. autofunction:: dragon.vm.torch.nn.functional.softmax .. autofunction:: dragon.vm.torch.nn.functional.softmax
.. _torch.nn.Softmax(...): ../Softmax.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ sync_batch_norm ...@@ -3,6 +3,8 @@ sync_batch_norm
.. autofunction:: dragon.vm.torch.nn.functional.sync_batch_norm .. autofunction:: dragon.vm.torch.nn.functional.sync_batch_norm
.. _torch.nn.SyncBatchNorm(...): ../SyncBatchNorm.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ tanh ...@@ -3,6 +3,8 @@ tanh
.. autofunction:: dragon.vm.torch.nn.functional.tanh .. autofunction:: dragon.vm.torch.nn.functional.tanh
.. _torch.nn.Tanh(...): ../Tanh.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ upsample ...@@ -3,6 +3,8 @@ upsample
.. autofunction:: dragon.vm.torch.nn.functional.upsample .. autofunction:: dragon.vm.torch.nn.functional.upsample
.. _torch.nn.Upsample(...): ../Upsample.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ upsample_bilinear ...@@ -3,6 +3,8 @@ upsample_bilinear
.. autofunction:: dragon.vm.torch.nn.functional.upsample_bilinear .. autofunction:: dragon.vm.torch.nn.functional.upsample_bilinear
.. _torch.nn.UpsamplingBilinear2d(...): ../UpsamplingBilinear2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -3,6 +3,8 @@ upsample_nearest ...@@ -3,6 +3,8 @@ upsample_nearest
.. autofunction:: dragon.vm.torch.nn.functional.upsample_nearest .. autofunction:: dragon.vm.torch.nn.functional.upsample_nearest
.. _torch.nn.UpsamplingNearest2d(...): ../UpsamplingNearest2d.html
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -40,32 +40,30 @@ class CUDAObjects { ...@@ -40,32 +40,30 @@ class CUDAObjects {
/*! \brief Destructor */ /*! \brief Destructor */
~CUDAObjects() { ~CUDAObjects() {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) { for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
for (int j = 0; j < cuda_streams_[i].size(); j++) { #ifdef USE_NCCL
auto& stream = cuda_streams_[i][j]; for (auto& comm : nccl_comms_[i]) {
/*! /*!
* Do not check the stream destroying, * Temporarily disable the comm destroying,
* error code 29 (driver shutting down) is inevitable. * to avoid an unhandled error.
*/ */
if (stream) cudaStreamDestroy(stream);
}
for (auto& handle : cublas_handles_[i])
if (handle) {
CUBLAS_CHECK(cublasDestroy(handle));
} }
#endif
#ifdef USE_CUDNN #ifdef USE_CUDNN
for (auto& handle : cudnn_handles_[i]) for (auto& handle : cudnn_handles_[i]) {
if (handle) { if (handle) CUDNN_CHECK(cudnnDestroy(handle));
CUDNN_CHECK(cudnnDestroy(handle));
} }
#endif #endif
#ifdef USE_NCCL for (auto& handle : cublas_handles_[i]) {
for (auto& comm : nccl_comms_[i]) { if (handle) CUBLAS_CHECK(cublasDestroy(handle));
}
for (int j = 0; j < cuda_streams_[i].size(); j++) {
auto& stream = cuda_streams_[i][j];
/*! /*!
* Temporarily disable the comm destroying, * Do not check the stream destroying,
* to avoid an unhandled error. * error code 29 (driver shutting down) is inevitable.
*/ */
if (stream) cudaStreamDestroy(stream);
} }
#endif
} }
} }
......
...@@ -179,10 +179,10 @@ OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) { ...@@ -179,10 +179,10 @@ OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) {
} }
OperatorDef mutable_def(def); OperatorDef mutable_def(def);
// Heuristically make each random seed slightly different // Heuristically make each random seed slightly different
static unsigned int op_seed_uuid = 0; static unsigned int seed_offset = 0;
mutable_def.mutable_device_option()->set_random_seed( mutable_def.mutable_device_option()->set_random_seed(
op_seed_uuid + def.device_option().random_seed()); seed_offset + def.device_option().random_seed());
op_seed_uuid = (op_seed_uuid + 1) % UINT32_MAX; seed_offset = (seed_offset + 1) % UINT32_MAX;
return TryCreateOperator(def.type(), mutable_def, ws); return TryCreateOperator(def.type(), mutable_def, ws);
} }
......
...@@ -47,29 +47,29 @@ class DRAGON_API Workspace { ...@@ -47,29 +47,29 @@ class DRAGON_API Workspace {
} }
/*! \brief Create the tensor */ /*! \brief Create the tensor */
Tensor* CreateTensor(const string&, FillerInfo* = nullptr); Tensor* CreateTensor(const string& name, FillerInfo* filler = nullptr);
/*! \brief Try to return the tensor */ /*! \brief Try to return the tensor */
Tensor* TryGetTensor(const string&, bool = true) const; Tensor* TryGetTensor(const string& name, bool external = true) const;
/*! \brief Return the tensor */ /*! \brief Return the tensor */
Tensor* GetTensor(const string&, bool = true) const; Tensor* GetTensor(const string& name, bool external = true) const;
/*! \brief Reset the tensor */ /*! \brief Reset the tensor */
void ResetTensor(const string&); void ResetTensor(const string& name);
/*! \brief Return the filler info */ /*! \brief Return the filler info */
FillerInfo* GetFillerInfo(const string&); FillerInfo* GetFillerInfo(const string& name);
/*! \brief Run the operator */ /*! \brief Run the operator */
void RunOperator(const OperatorDef&); void RunOperator(const OperatorDef& def);
/*! \brief Create the graph */ /*! \brief Create the graph */
GraphBase* CreateGraph(const GraphDef&); GraphBase* CreateGraph(const GraphDef& def);
/*! \brief Run the graph */ /*! \brief Run the graph */
void RunGraph( void RunGraph(
const string& graph_name, const string& name,
const string& include = "", const string& include = "",
const string& exclude = "", const string& exclude = "",
const int stream = 0); const int stream = 0);
...@@ -88,28 +88,30 @@ class DRAGON_API Workspace { ...@@ -88,28 +88,30 @@ class DRAGON_API Workspace {
/*! \brief Return a group of the shared raw data */ /*! \brief Return a group of the shared raw data */
template <class Context> template <class Context>
vector<void*> data(const vector<size_t>& segments) { vector<void*> data(const vector<size_t>& segments) {
int64_t nbytes = 0; vector<void*> group(segments.size());
vector<void*> ret(segments.size()); auto total_bytes = std::accumulate(segments.begin(), segments.end(), 0);
for (auto& segment : segments) group[0] = CreateTensor("/share/data")
nbytes += (int64_t)segment; ->Reshape({(int64_t)total_bytes})
auto* T = CreateTensor("/share/data")->Reshape({nbytes}); ->template mutable_data<uint8_t, Context>();
ret[0] = T->template mutable_data<uint8_t, Context>(); for (int i = 1; i < segments.size(); ++i) {
for (int i = 1; i < segments.size(); i++) group[i] = (uint8_t*)group[i - 1] + segments[i - 1];
ret[i] = (uint8_t*)ret[i - 1] + segments[i - 1]; }
return ret; return group;
} }
/*! \brief Return a group of shared typed data */ /*! \brief Return a group of shared typed data */
template <typename T, class Context> template <typename T, class Context>
vector<T*> data(const vector<int64_t>& segments) { vector<T*> data(const vector<int64_t>& segments) {
vector<size_t> segments_in_byte; vector<T*> group(segments.size());
vector<T*> ret(segments.size()); vector<size_t> segments_v2;
for (const auto& e : segments) for (const auto size : segments) {
segments_in_byte.emplace_back(e * sizeof(T)); segments_v2.push_back(size * sizeof(T));
auto ret_in_byte = data<Context>(segments_in_byte); }
for (int i = 0; i < segments.size(); i++) auto group_v2 = data<Context>(segments_v2);
ret[i] = (T*)ret_in_byte[i]; for (int i = 0; i < segments.size(); ++i) {
return ret; group[i] = (T*)group_v2[i];
}
return group;
} }
private: private:
......
...@@ -71,7 +71,7 @@ class ConversionContext { ...@@ -71,7 +71,7 @@ class ConversionContext {
const int opset_version_; const int opset_version_;
}; };
typedef struct { struct ONNXImporterReturns {
vector<OperatorDef> ops; vector<OperatorDef> ops;
OperatorDef* AddOp() { OperatorDef* AddOp() {
...@@ -83,7 +83,7 @@ typedef struct { ...@@ -83,7 +83,7 @@ typedef struct {
CHECK_LT(index, ops.size()); CHECK_LT(index, ops.size());
return &ops[index]; return &ops[index];
} }
} ONNXImporterReturns; };
class ONNXAttributes { class ONNXAttributes {
public: public:
......
...@@ -189,7 +189,7 @@ class Tensor(types.TensorMetaclass): ...@@ -189,7 +189,7 @@ class Tensor(types.TensorMetaclass):
See Also See Also
-------- --------
`dragon.cast(...)`_ : Cast the data type of input. `dragon.cast(...)`_
""" """
...@@ -208,6 +208,10 @@ class Tensor(types.TensorMetaclass): ...@@ -208,6 +208,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The self. The self.
See Also
--------
`dragon.fill(...)`_
""" """
return self._register_as('constant', value=value) return self._register_as('constant', value=value)
...@@ -221,7 +225,7 @@ class Tensor(types.TensorMetaclass): ...@@ -221,7 +225,7 @@ class Tensor(types.TensorMetaclass):
See Also See Also
-------- --------
`dragon.copy(...)`_ : Copy the input. `dragon.copy(...)`_
""" """
...@@ -252,6 +256,10 @@ class Tensor(types.TensorMetaclass): ...@@ -252,6 +256,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The self. The self.
See Also
--------
`dragon.random.glorot_normal(...)`_
""" """
return self._register_as('glorot_normal', mode=mode, scale=scale) return self._register_as('glorot_normal', mode=mode, scale=scale)
...@@ -273,6 +281,10 @@ class Tensor(types.TensorMetaclass): ...@@ -273,6 +281,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The self. The self.
See Also
--------
`dragon.random.glorot_uniform(...)`_
""" """
return self._register_as('glorot_uniform', mode=mode, scale=scale) return self._register_as('glorot_uniform', mode=mode, scale=scale)
...@@ -293,6 +305,10 @@ class Tensor(types.TensorMetaclass): ...@@ -293,6 +305,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The self. The self.
See Also
--------
`dragon.random.normal(...)`_
""" """
return self._register_as('normal', mean=mean, std=std) return self._register_as('normal', mean=mean, std=std)
...@@ -311,7 +327,7 @@ class Tensor(types.TensorMetaclass): ...@@ -311,7 +327,7 @@ class Tensor(types.TensorMetaclass):
See Also See Also
-------- --------
`dragon.reshape(...)`_ : Change the dimensions of input. `dragon.reshape(...)`_
""" """
...@@ -347,6 +363,10 @@ class Tensor(types.TensorMetaclass): ...@@ -347,6 +363,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The self. The self.
See Also
--------
`dragon.random.truncated_normal(...)`_
""" """
return self._register_as('truncated_normal', mean=mean, std=std) return self._register_as('truncated_normal', mean=mean, std=std)
...@@ -367,6 +387,10 @@ class Tensor(types.TensorMetaclass): ...@@ -367,6 +387,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The self. The self.
See Also
--------
`dragon.random.uniform(...)`_
""" """
return self._register_as('uniform', low=low, high=high) return self._register_as('uniform', low=low, high=high)
...@@ -441,6 +465,10 @@ class Tensor(types.TensorMetaclass): ...@@ -441,6 +465,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.add(...)`_
""" """
def __float__(self): def __float__(self):
...@@ -467,6 +495,10 @@ class Tensor(types.TensorMetaclass): ...@@ -467,6 +495,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.greater_equal(...)`_
""" """
def __getitem__(self, item): def __getitem__(self, item):
...@@ -492,6 +524,10 @@ class Tensor(types.TensorMetaclass): ...@@ -492,6 +524,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.greater(...)`_
""" """
def __hash__(self): def __hash__(self):
...@@ -521,6 +557,10 @@ class Tensor(types.TensorMetaclass): ...@@ -521,6 +557,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.less_equal(...)`_
""" """
def __lt__(self, other): def __lt__(self, other):
...@@ -536,6 +576,10 @@ class Tensor(types.TensorMetaclass): ...@@ -536,6 +576,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.less(...)`_
""" """
def __mul__(self, other): def __mul__(self, other):
...@@ -551,6 +595,10 @@ class Tensor(types.TensorMetaclass): ...@@ -551,6 +595,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.mul(...)`_
""" """
def __neg__(self): def __neg__(self):
...@@ -561,6 +609,10 @@ class Tensor(types.TensorMetaclass): ...@@ -561,6 +609,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.negative(...)`_
""" """
def __radd__(self, other): def __radd__(self, other):
...@@ -576,6 +628,10 @@ class Tensor(types.TensorMetaclass): ...@@ -576,6 +628,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.add(...)`_
""" """
def __repr__(self): def __repr__(self):
...@@ -600,6 +656,10 @@ class Tensor(types.TensorMetaclass): ...@@ -600,6 +656,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.div(...)`_
""" """
def __rmul__(self, other): def __rmul__(self, other):
...@@ -615,6 +675,10 @@ class Tensor(types.TensorMetaclass): ...@@ -615,6 +675,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.mul(...)`_
""" """
def __rsub__(self, other): def __rsub__(self, other):
...@@ -630,6 +694,10 @@ class Tensor(types.TensorMetaclass): ...@@ -630,6 +694,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.sub(...)`_
""" """
def __setitem__(self, key, value): def __setitem__(self, key, value):
...@@ -648,6 +716,10 @@ class Tensor(types.TensorMetaclass): ...@@ -648,6 +716,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.sub(...)`_
""" """
def __truediv__(self, other): def __truediv__(self, other):
...@@ -663,6 +735,10 @@ class Tensor(types.TensorMetaclass): ...@@ -663,6 +735,10 @@ class Tensor(types.TensorMetaclass):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.div(...)`_
""" """
......
...@@ -188,7 +188,7 @@ class EagerTensor(Tensor): ...@@ -188,7 +188,7 @@ class EagerTensor(Tensor):
See Also See Also
-------- --------
`dragon.cast(...)`_ : Cast the data type of input. `dragon.cast(...)`_
""" """
...@@ -207,6 +207,10 @@ class EagerTensor(Tensor): ...@@ -207,6 +207,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.fill(...)`_
""" """
def copy(self): def copy(self):
...@@ -219,7 +223,7 @@ class EagerTensor(Tensor): ...@@ -219,7 +223,7 @@ class EagerTensor(Tensor):
See Also See Also
-------- --------
`dragon.copy(...)`_ : Copy the input. `dragon.copy(...)`_
""" """
...@@ -251,6 +255,10 @@ class EagerTensor(Tensor): ...@@ -251,6 +255,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.glorot_normal(...)`_
""" """
def glorot_uniform(self, mode='fan_in', scale=3.0): def glorot_uniform(self, mode='fan_in', scale=3.0):
...@@ -271,6 +279,10 @@ class EagerTensor(Tensor): ...@@ -271,6 +279,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.glorot_uniform(...)`_
""" """
def numpy(self, readonly=True): def numpy(self, readonly=True):
...@@ -306,6 +318,10 @@ class EagerTensor(Tensor): ...@@ -306,6 +318,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.normal(...)`_
""" """
def reshape(self, shape): def reshape(self, shape):
...@@ -323,7 +339,7 @@ class EagerTensor(Tensor): ...@@ -323,7 +339,7 @@ class EagerTensor(Tensor):
See Also See Also
-------- --------
`dragon.reshape(...)`_ : Change the dimensions of input. `dragon.reshape(...)`_
""" """
...@@ -363,6 +379,10 @@ class EagerTensor(Tensor): ...@@ -363,6 +379,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.truncated_normal(...)`_
""" """
def uniform(self, low=0, high=1): def uniform(self, low=0, high=1):
...@@ -382,6 +402,10 @@ class EagerTensor(Tensor): ...@@ -382,6 +402,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.uniform(...)`_
""" """
def _from_numpy(self, array, copy): def _from_numpy(self, array, copy):
...@@ -413,6 +437,10 @@ class EagerTensor(Tensor): ...@@ -413,6 +437,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.add(...)`_
""" """
def __del__(self): def __del__(self):
...@@ -445,6 +473,10 @@ class EagerTensor(Tensor): ...@@ -445,6 +473,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.greater_equal(...)`_
""" """
def __getitem__(self, item): def __getitem__(self, item):
...@@ -470,6 +502,10 @@ class EagerTensor(Tensor): ...@@ -470,6 +502,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.greater(...)`_
""" """
def __hash__(self): def __hash__(self):
...@@ -488,6 +524,10 @@ class EagerTensor(Tensor): ...@@ -488,6 +524,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.math.add(...)`_
""" """
def __imul__(self, other): def __imul__(self, other):
...@@ -503,6 +543,10 @@ class EagerTensor(Tensor): ...@@ -503,6 +543,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.math.mul(...)`_
""" """
def __int__(self): def __int__(self):
...@@ -529,6 +573,10 @@ class EagerTensor(Tensor): ...@@ -529,6 +573,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.math.sub(...)`_
""" """
def __le__(self, other): def __le__(self, other):
...@@ -544,6 +592,10 @@ class EagerTensor(Tensor): ...@@ -544,6 +592,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.less_equal(...)`_
""" """
def __lt__(self, other): def __lt__(self, other):
...@@ -559,6 +611,10 @@ class EagerTensor(Tensor): ...@@ -559,6 +611,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.less(...)`_
""" """
def __mul__(self, other): def __mul__(self, other):
...@@ -574,6 +630,10 @@ class EagerTensor(Tensor): ...@@ -574,6 +630,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.mul(...)`_
""" """
def __neg__(self): def __neg__(self):
...@@ -584,6 +644,10 @@ class EagerTensor(Tensor): ...@@ -584,6 +644,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.negative(...)`_
""" """
def __radd__(self, other): def __radd__(self, other):
...@@ -599,6 +663,10 @@ class EagerTensor(Tensor): ...@@ -599,6 +663,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.add(...)`_
""" """
def __repr__(self): def __repr__(self):
...@@ -629,6 +697,10 @@ class EagerTensor(Tensor): ...@@ -629,6 +697,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.div(...)`_
""" """
def __rmul__(self, other): def __rmul__(self, other):
...@@ -644,6 +716,10 @@ class EagerTensor(Tensor): ...@@ -644,6 +716,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.mul(...)`_
""" """
def __rsub__(self, other): def __rsub__(self, other):
...@@ -677,6 +753,10 @@ class EagerTensor(Tensor): ...@@ -677,6 +753,10 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.sub(...)`_
""" """
def __truediv__(self, other): def __truediv__(self, other):
...@@ -692,4 +772,8 @@ class EagerTensor(Tensor): ...@@ -692,4 +772,8 @@ class EagerTensor(Tensor):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.div(...)`_
""" """
...@@ -1554,7 +1554,7 @@ def where(inputs, **kwargs): ...@@ -1554,7 +1554,7 @@ def where(inputs, **kwargs):
See Also See Also
-------- --------
`dragon.nonzero(...)`_ : Return the index of non-zero elements. dragon.nonzero()
""" """
if types.is_tensor(inputs) or len(inputs) == 1: if types.is_tensor(inputs) or len(inputs) == 1:
......
...@@ -58,7 +58,7 @@ def astype(self, dtype, inplace=False): ...@@ -58,7 +58,7 @@ def astype(self, dtype, inplace=False):
See Also See Also
-------- --------
`dragon.cast(...)`_ : Cast the data type of input. `dragon.cast(...)`_
""" """
return array_ops_lib.Cast \ return array_ops_lib.Cast \
...@@ -80,6 +80,10 @@ def constant(self, value=0): ...@@ -80,6 +80,10 @@ def constant(self, value=0):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.fill(...)`_
""" """
shape = self.shape shape = self.shape
return init_ops_lib.Fill \ return init_ops_lib.Fill \
...@@ -100,7 +104,7 @@ def copy(self): ...@@ -100,7 +104,7 @@ def copy(self):
See Also See Also
-------- --------
`dragon.copy(...)`_ : Copy the value to ref. `dragon.copy(...)`_
""" """
return control_flow_ops_lib.Copy \ return control_flow_ops_lib.Copy \
...@@ -120,6 +124,10 @@ def div(self, other): ...@@ -120,6 +124,10 @@ def div(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.div(...)`_
""" """
return _binary_op(self, other, 'Div') return _binary_op(self, other, 'Div')
...@@ -137,6 +145,10 @@ def ge(self, other): ...@@ -137,6 +145,10 @@ def ge(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.greater_equal(...)`_
""" """
return _binary_op(self, other, 'GreaterEqual') return _binary_op(self, other, 'GreaterEqual')
...@@ -179,6 +191,10 @@ def glorot_normal(self, mode='fan_in', scale=2.0): ...@@ -179,6 +191,10 @@ def glorot_normal(self, mode='fan_in', scale=2.0):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.glorot_normal(...)`_
""" """
shape = self.shape shape = self.shape
return init_ops_lib.GlorotNormal \ return init_ops_lib.GlorotNormal \
...@@ -208,6 +224,10 @@ def glorot_uniform(self, mode='fan_in', scale=3.0): ...@@ -208,6 +224,10 @@ def glorot_uniform(self, mode='fan_in', scale=3.0):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.glorot_uniform(...)`_
""" """
shape = self.shape shape = self.shape
return init_ops_lib.GlorotUniform \ return init_ops_lib.GlorotUniform \
...@@ -232,6 +252,10 @@ def gt(self, other): ...@@ -232,6 +252,10 @@ def gt(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.greater(...)`_
""" """
return _binary_op(self, other, 'Greater') return _binary_op(self, other, 'Greater')
...@@ -249,6 +273,10 @@ def iadd(self, other): ...@@ -249,6 +273,10 @@ def iadd(self, other):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.math.add(...)`_
""" """
return _binary_op(self, other, 'Add', [self]) return _binary_op(self, other, 'Add', [self])
...@@ -266,6 +294,10 @@ def idiv(self, other): ...@@ -266,6 +294,10 @@ def idiv(self, other):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.math.div(...)`_
""" """
return _binary_op(self, other, 'Div', [self]) return _binary_op(self, other, 'Div', [self])
...@@ -283,6 +315,10 @@ def imul(self, other): ...@@ -283,6 +315,10 @@ def imul(self, other):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.math.mul(...)`_
""" """
return _binary_op(self, other, 'Mul', [self]) return _binary_op(self, other, 'Mul', [self])
...@@ -300,6 +336,10 @@ def isub(self, other): ...@@ -300,6 +336,10 @@ def isub(self, other):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.math.sub(...)`_
""" """
return _binary_op(self, other, 'Sub', [self]) return _binary_op(self, other, 'Sub', [self])
...@@ -317,6 +357,10 @@ def le(self, other): ...@@ -317,6 +357,10 @@ def le(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.less_equal(...)`_
""" """
return _binary_op(self, other, 'LessEqual') return _binary_op(self, other, 'LessEqual')
...@@ -334,6 +378,10 @@ def lt(self, other): ...@@ -334,6 +378,10 @@ def lt(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.less(...)`_
""" """
return _binary_op(self, other, 'Less') return _binary_op(self, other, 'Less')
...@@ -351,6 +399,10 @@ def mul(self, other): ...@@ -351,6 +399,10 @@ def mul(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.mul(...)`_
""" """
return _binary_op(self, other, 'Mul') return _binary_op(self, other, 'Mul')
...@@ -363,6 +415,10 @@ def neg(self): ...@@ -363,6 +415,10 @@ def neg(self):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.negative(...)`_
""" """
return _unary_op(self, 'Neg') return _unary_op(self, 'Neg')
...@@ -384,6 +440,10 @@ def normal(self, mean=0, std=1): ...@@ -384,6 +440,10 @@ def normal(self, mean=0, std=1):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.normal(...)`_
""" """
shape = self.shape shape = self.shape
return init_ops_lib.RandomNormal \ return init_ops_lib.RandomNormal \
...@@ -408,6 +468,10 @@ def radd(self, other): ...@@ -408,6 +468,10 @@ def radd(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.add(...)`_
""" """
return _binary_op(other, self, 'Add') return _binary_op(other, self, 'Add')
...@@ -425,6 +489,10 @@ def rdiv(self, other): ...@@ -425,6 +489,10 @@ def rdiv(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.div(...)`_
""" """
return _binary_op(other, self, 'Div') return _binary_op(other, self, 'Div')
...@@ -444,7 +512,7 @@ def reshape(self, shape): ...@@ -444,7 +512,7 @@ def reshape(self, shape):
See Also See Also
-------- --------
`dragon.reshape(...)`_ : Change the dimensions of input. `dragon.reshape(...)`_
""" """
with context.eager_mode(): with context.eager_mode():
...@@ -464,6 +532,10 @@ def rmul(self, other): ...@@ -464,6 +532,10 @@ def rmul(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.mul(...)`_
""" """
return _binary_op(other, self, 'Mul') return _binary_op(other, self, 'Mul')
...@@ -481,6 +553,10 @@ def rsub(self, other): ...@@ -481,6 +553,10 @@ def rsub(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.sub(...)`_
""" """
return _binary_op(other, self, 'Sub') return _binary_op(other, self, 'Sub')
...@@ -516,6 +592,10 @@ def sub(self, other): ...@@ -516,6 +592,10 @@ def sub(self, other):
dragon.EagerTensor dragon.EagerTensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.sub(...)`_
""" """
return _binary_op(self, other, 'Sub') return _binary_op(self, other, 'Sub')
...@@ -537,6 +617,10 @@ def truncated_normal(self, mean=0, std=1): ...@@ -537,6 +617,10 @@ def truncated_normal(self, mean=0, std=1):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.truncated_normal(...)`_
""" """
shape = self.shape shape = self.shape
return init_ops_lib.TruncatedNormal \ return init_ops_lib.TruncatedNormal \
...@@ -565,6 +649,10 @@ def uniform(self, low=0, high=1): ...@@ -565,6 +649,10 @@ def uniform(self, low=0, high=1):
dragon.EagerTensor dragon.EagerTensor
The self. The self.
See Also
--------
`dragon.random.uniform(...)`_
""" """
shape = self.shape shape = self.shape
return init_ops_lib.RandomUniform \ return init_ops_lib.RandomUniform \
......
...@@ -35,6 +35,10 @@ def add(self, other): ...@@ -35,6 +35,10 @@ def add(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.add(...)`_
""" """
return _binary_op(self, other, 'Add') return _binary_op(self, other, 'Add')
...@@ -56,7 +60,7 @@ def astype(self, dtype, inplace=False): ...@@ -56,7 +60,7 @@ def astype(self, dtype, inplace=False):
See Also See Also
-------- --------
`dragon.cast(...)`_ : Cast the data type of input. `dragon.cast(...)`_
""" """
if self.dtype == dtype: if self.dtype == dtype:
...@@ -75,7 +79,7 @@ def copy(self): ...@@ -75,7 +79,7 @@ def copy(self):
See Also See Also
-------- --------
`dragon.copy(...)`_ : Copy the value to ref. `dragon.copy(...)`_
""" """
outputs = [Tensor(shape=self.shape, dtype=self.dtype)] outputs = [Tensor(shape=self.shape, dtype=self.dtype)]
...@@ -95,6 +99,10 @@ def div(self, other): ...@@ -95,6 +99,10 @@ def div(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.div(...)`_
""" """
return _binary_op(self, other, 'Div') return _binary_op(self, other, 'Div')
...@@ -112,6 +120,10 @@ def ge(self, other): ...@@ -112,6 +120,10 @@ def ge(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.greater_equal(...)`_
""" """
return _binary_op(self, other, 'GreaterEqual') return _binary_op(self, other, 'GreaterEqual')
...@@ -162,6 +174,10 @@ def gt(self, other): ...@@ -162,6 +174,10 @@ def gt(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.greater(...)`_
""" """
return _binary_op(self, other, 'Greater') return _binary_op(self, other, 'Greater')
...@@ -179,6 +195,10 @@ def le(self, other): ...@@ -179,6 +195,10 @@ def le(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.less_equal(...)`_
""" """
return _binary_op(self, other, 'LessEqual') return _binary_op(self, other, 'LessEqual')
...@@ -196,6 +216,10 @@ def lt(self, other): ...@@ -196,6 +216,10 @@ def lt(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.less(...)`_
""" """
return _binary_op(self, other, 'Less') return _binary_op(self, other, 'Less')
...@@ -213,6 +237,10 @@ def mul(self, other): ...@@ -213,6 +237,10 @@ def mul(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.mul(...)`_
""" """
return _binary_op(self, other, 'Mul') return _binary_op(self, other, 'Mul')
...@@ -225,6 +253,10 @@ def neg(self): ...@@ -225,6 +253,10 @@ def neg(self):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.negative(...)`_
""" """
return _unary_op(self, 'Neg') return _unary_op(self, 'Neg')
...@@ -242,6 +274,10 @@ def radd(self, other): ...@@ -242,6 +274,10 @@ def radd(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.add(...)`_
""" """
return _binary_op(other, self, 'Add') return _binary_op(other, self, 'Add')
...@@ -259,6 +295,10 @@ def rdiv(self, other): ...@@ -259,6 +295,10 @@ def rdiv(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.div(...)`_
""" """
return _binary_op(other, self, 'Div') return _binary_op(other, self, 'Div')
...@@ -278,7 +318,7 @@ def reshape(self, shape): ...@@ -278,7 +318,7 @@ def reshape(self, shape):
See Also See Also
-------- --------
`dragon.reshape(...)`_ : Change the dimensions of input. `dragon.reshape(...)`_
""" """
with context.graph_mode(): with context.graph_mode():
...@@ -298,6 +338,10 @@ def rmul(self, other): ...@@ -298,6 +338,10 @@ def rmul(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.mul(...)`_
""" """
return _binary_op(other, self, 'Mul') return _binary_op(other, self, 'Mul')
...@@ -315,6 +359,10 @@ def rsub(self, other): ...@@ -315,6 +359,10 @@ def rsub(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.sub(...)`_
""" """
return _binary_op(other, self, 'Sub') return _binary_op(other, self, 'Sub')
...@@ -368,6 +416,10 @@ def sub(self, other): ...@@ -368,6 +416,10 @@ def sub(self, other):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`dragon.math.sub(...)`_
""" """
return _binary_op(self, other, 'Sub') return _binary_op(self, other, 'Sub')
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Execute tensor operations. """ """Execute tensor operations."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -39,11 +39,12 @@ def affine(input, weight, bias=None): ...@@ -39,11 +39,12 @@ def affine(input, weight, bias=None):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Affine(...)`_
""" """
return _functions.Affine \ return _functions.Affine.instantiate(input.device).apply(input, weight, bias)
.instantiate(
input.device,
).apply(input, weight, bias)
def avg_pool2d( def avg_pool2d(
...@@ -82,6 +83,10 @@ def avg_pool2d( ...@@ -82,6 +83,10 @@ def avg_pool2d(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.AvgPool2d(...)`_
""" """
return _pool( return _pool(
_pool_mode='AVG', _pool_mode='AVG',
...@@ -138,6 +143,10 @@ def batch_norm( ...@@ -138,6 +143,10 @@ def batch_norm(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.BatchNorm2d(...)`_
""" """
return _functions.BatchNorm \ return _functions.BatchNorm \
.instantiate( .instantiate(
...@@ -181,6 +190,10 @@ def binary_cross_entropy_with_logits( ...@@ -181,6 +190,10 @@ def binary_cross_entropy_with_logits(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.BCEWithLogitsLoss(...)`_
""" """
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction = _reduction.legacy_get_string(size_average, reduce) reduction = _reduction.legacy_get_string(size_average, reduce)
...@@ -236,6 +249,10 @@ def conv2d( ...@@ -236,6 +249,10 @@ def conv2d(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Conv2d(...)`_
""" """
return _conv( return _conv(
_nd_util=utils._pair, _nd_util=utils._pair,
...@@ -290,6 +307,10 @@ def conv_transpose2d( ...@@ -290,6 +307,10 @@ def conv_transpose2d(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.ConvTranspose2d(...)`_
""" """
return _conv_transpose( return _conv_transpose(
_nd_util=utils._pair, _nd_util=utils._pair,
...@@ -335,6 +356,10 @@ def cross_entropy( ...@@ -335,6 +356,10 @@ def cross_entropy(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The loss. The loss.
See Also
--------
`torch.nn.CrossEntropyLoss(...)`_
""" """
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction = _reduction.legacy_get_string(size_average, reduce) reduction = _reduction.legacy_get_string(size_average, reduce)
...@@ -368,6 +393,10 @@ def ctc_loss(input, target, padding_mask=-1, reduction='mean'): ...@@ -368,6 +393,10 @@ def ctc_loss(input, target, padding_mask=-1, reduction='mean'):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The loss. The loss.
See Also
--------
`torch.nn.CTCLoss(...)`_
""" """
prob = softmax(input, 2) prob = softmax(input, 2)
return _functions.CTCLoss \ return _functions.CTCLoss \
...@@ -408,6 +437,10 @@ def depthwise_conv2d( ...@@ -408,6 +437,10 @@ def depthwise_conv2d(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.DepthwiseConv2d(...)`_
""" """
return _conv( return _conv(
_nd_util=utils._pair, _nd_util=utils._pair,
...@@ -440,6 +473,10 @@ def dropout(input, p=0.5, training=True, inplace=False): ...@@ -440,6 +473,10 @@ def dropout(input, p=0.5, training=True, inplace=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Dropout(...)`_
""" """
if not training: if not training:
return input return input
...@@ -494,6 +531,10 @@ def drop_block2d( ...@@ -494,6 +531,10 @@ def drop_block2d(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.DropBlock2d(...)`_
""" """
if not training: if not training:
return input return input
...@@ -543,6 +584,10 @@ def drop_path( ...@@ -543,6 +584,10 @@ def drop_path(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.DropPath(...)`_
""" """
if not training: if not training:
return input return input
...@@ -570,7 +615,7 @@ def elu(input, alpha=1., inplace=False): ...@@ -570,7 +615,7 @@ def elu(input, alpha=1., inplace=False):
See Also See Also
-------- --------
`torch.nn.ELU`_ - Apply the exponential linear unit. `torch.nn.ELU(...)`_
Parameters Parameters
---------- ----------
...@@ -588,10 +633,8 @@ def elu(input, alpha=1., inplace=False): ...@@ -588,10 +633,8 @@ def elu(input, alpha=1., inplace=False):
""" """
return _functions.Elu \ return _functions.Elu \
.instantiate( .instantiate(input.device, alpha=alpha) \
input.device, .apply(input, inplace=inplace)
alpha=alpha,
).apply(input, inplace=inplace)
def group_norm(input, weight, bias, groups=32, eps=1e-5): def group_norm(input, weight, bias, groups=32, eps=1e-5):
...@@ -621,6 +664,10 @@ def group_norm(input, weight, bias, groups=32, eps=1e-5): ...@@ -621,6 +664,10 @@ def group_norm(input, weight, bias, groups=32, eps=1e-5):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.GroupNorm(...)`_
""" """
return _functions.GroupNorm \ return _functions.GroupNorm \
.instantiate( .instantiate(
...@@ -679,6 +726,10 @@ def interpolate( ...@@ -679,6 +726,10 @@ def interpolate(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Upsample(...)`_
""" """
if size is not None: if size is not None:
size = nest.flatten(size) size = nest.flatten(size)
...@@ -725,16 +776,18 @@ def l1_loss( ...@@ -725,16 +776,18 @@ def l1_loss(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.L1Loss(...)`_
""" """
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction = _reduction.legacy_get_string(size_average, reduce) reduction = _reduction.legacy_get_string(size_average, reduce)
else: else:
reduction = reduction reduction = reduction
return _functions.L1Loss \ return _functions.L1Loss \
.instantiate( .instantiate(input.device, reduction=reduction) \
input.device, .apply([input, target])
reduction=reduction,
).apply([input, target])
def leaky_relu(input, negative_slope=0.01, inplace=False): def leaky_relu(input, negative_slope=0.01, inplace=False):
...@@ -749,10 +802,6 @@ def leaky_relu(input, negative_slope=0.01, inplace=False): ...@@ -749,10 +802,6 @@ def leaky_relu(input, negative_slope=0.01, inplace=False):
slope * x, & \text{ otherwise } slope * x, & \text{ otherwise }
\end{cases} \end{cases}
See Also
--------
`torch.nn.LeakyReLU`_ - Apply the leaky rectified linear unit.
Parameters Parameters
---------- ----------
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
...@@ -767,12 +816,14 @@ def leaky_relu(input, negative_slope=0.01, inplace=False): ...@@ -767,12 +816,14 @@ def leaky_relu(input, negative_slope=0.01, inplace=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.LeakyReLU(...)`_
""" """
return _functions.Relu \ return _functions.Relu \
.instantiate( .instantiate(input.device, alpha=float(negative_slope)) \
input.device, .apply(input, inplace=inplace)
alpha=float(negative_slope),
).apply(input, inplace=inplace)
def linear(input, weight, bias=None): def linear(input, weight, bias=None):
...@@ -794,11 +845,12 @@ def linear(input, weight, bias=None): ...@@ -794,11 +845,12 @@ def linear(input, weight, bias=None):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Linear(...)`_
""" """
return _functions.Linear \ return _functions.Linear.instantiate(input.device).apply(input, weight, bias)
.instantiate(
input.device,
).apply(input, weight, bias)
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
...@@ -812,10 +864,6 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): ...@@ -812,10 +864,6 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
\sum_{j=\max(0, i-n/2)}^{\min(N-1,i+n/2)}x_{j}^2 \sum_{j=\max(0, i-n/2)}^{\min(N-1,i+n/2)}x_{j}^2
\right)^{-\beta} \right)^{-\beta}
See Also
--------
`torch.nn.LocalResponseNorm`_ - Apply the local response normalization.
Parameters Parameters
---------- ----------
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
...@@ -834,6 +882,10 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): ...@@ -834,6 +882,10 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.LocalResponseNorm(...)`_
""" """
return _functions.LocalResponseNorm \ return _functions.LocalResponseNorm \
.instantiate( .instantiate(
...@@ -864,6 +916,10 @@ def log_softmax(input, dim): ...@@ -864,6 +916,10 @@ def log_softmax(input, dim):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.LogSoftmax(...)`_
""" """
return input - input.logsumexp(dim, keepdim=True) return input - input.logsumexp(dim, keepdim=True)
...@@ -883,11 +939,12 @@ def lstm_cell(input, cx): ...@@ -883,11 +939,12 @@ def lstm_cell(input, cx):
sequence of dragon.vm.torch.Tensor sequence of dragon.vm.torch.Tensor
The **h** and **c**. The **h** and **c**.
See Also
--------
`torch.nn.LSTMCell(...)`_
""" """
return _functions.LSTMCell \ return _functions.LSTMCell.instantiate(input.device).apply(input, cx)
.instantiate(
input.device,
).apply(input, cx)
def max_pool2d( def max_pool2d(
...@@ -926,6 +983,10 @@ def max_pool2d( ...@@ -926,6 +983,10 @@ def max_pool2d(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.MaxPool2d(...)`_
""" """
return _pool( return _pool(
_pool_mode='MAX', _pool_mode='MAX',
...@@ -966,16 +1027,18 @@ def mse_loss( ...@@ -966,16 +1027,18 @@ def mse_loss(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.MSELoss(...)`_
""" """
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction = _reduction.legacy_get_string(size_average, reduce) reduction = _reduction.legacy_get_string(size_average, reduce)
else: else:
reduction = reduction reduction = reduction
return _functions.L2Loss \ return _functions.L2Loss \
.instantiate( .instantiate(input.device, reduction=reduction) \
input.device, .apply([input, target])
reduction=reduction,
).apply([input, target])
def nll_loss( def nll_loss(
...@@ -1015,6 +1078,10 @@ def nll_loss( ...@@ -1015,6 +1078,10 @@ def nll_loss(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The loss. The loss.
See Also
--------
`torch.nn.NLLLoss(...)`_
""" """
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction = _reduction.legacy_get_string(size_average, reduce) reduction = _reduction.legacy_get_string(size_average, reduce)
...@@ -1089,6 +1156,12 @@ def pad(input, pad, mode='constant', value=0): ...@@ -1089,6 +1156,12 @@ def pad(input, pad, mode='constant', value=0):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.ConstantPad2d(...)`_,
`torch.nn.ReflectionPad2d(...)`_,
`torch.nn.ReplicationPad2d(...)`_
""" """
ndim = input.ndimension() ndim = input.ndimension()
pads_begin, pads_end = [0] * ndim, [0] * ndim pads_begin, pads_end = [0] * ndim, [0] * ndim
...@@ -1132,11 +1205,12 @@ def prelu(input, weight): ...@@ -1132,11 +1205,12 @@ def prelu(input, weight):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.PReLU(...)`_
""" """
return _functions.PRelu \ return _functions.PRelu.instantiate(input.device).apply(input, weight)
.instantiate(
input.device,
).apply(input, weight)
def relu(input, inplace=False): def relu(input, inplace=False):
...@@ -1164,6 +1238,10 @@ def relu(input, inplace=False): ...@@ -1164,6 +1238,10 @@ def relu(input, inplace=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.ReLU(...)`_
""" """
return leaky_relu(input, 0., inplace=inplace) return leaky_relu(input, 0., inplace=inplace)
...@@ -1193,11 +1271,14 @@ def relu6(input, inplace=False): ...@@ -1193,11 +1271,14 @@ def relu6(input, inplace=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.ReLU6(...)`_
""" """
return _functions.Relu6 \ return _functions.Relu6 \
.instantiate( .instantiate(input.device) \
input.device, .apply(input, inplace=inplace)
).apply(input, inplace=inplace)
def selu(input, inplace=False): def selu(input, inplace=False):
...@@ -1225,6 +1306,10 @@ def selu(input, inplace=False): ...@@ -1225,6 +1306,10 @@ def selu(input, inplace=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.SELU(...)`_
""" """
return _activation(input, inplace, 'Selu') return _activation(input, inplace, 'Selu')
...@@ -1248,6 +1333,10 @@ def sigmoid(input, inplace=False): ...@@ -1248,6 +1333,10 @@ def sigmoid(input, inplace=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Sigmoid(...)`_
""" """
return _activation(input, inplace, 'Sigmoid') return _activation(input, inplace, 'Sigmoid')
...@@ -1296,6 +1385,10 @@ def sigmoid_focal_loss( ...@@ -1296,6 +1385,10 @@ def sigmoid_focal_loss(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.SigmoidFocalLoss(...)`_
""" """
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction = _reduction.legacy_get_string(size_average, reduce) reduction = _reduction.legacy_get_string(size_average, reduce)
...@@ -1351,6 +1444,10 @@ def smooth_l1_loss( ...@@ -1351,6 +1444,10 @@ def smooth_l1_loss(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.SmoothL1Loss(...)`_
""" """
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction = _reduction.legacy_get_string(size_average, reduce) reduction = _reduction.legacy_get_string(size_average, reduce)
...@@ -1385,12 +1482,14 @@ def softmax(input, dim, inplace=False): ...@@ -1385,12 +1482,14 @@ def softmax(input, dim, inplace=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Softmax(...)`_
""" """
return _functions.Softmax \ return _functions.Softmax \
.instantiate( .instantiate(input.device, axis=dim) \
input.device, .apply(input, inplace=inplace)
axis=dim,
).apply(input, inplace=inplace)
def sync_batch_norm( def sync_batch_norm(
...@@ -1447,6 +1546,10 @@ def sync_batch_norm( ...@@ -1447,6 +1546,10 @@ def sync_batch_norm(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.SyncBatchNorm(...)`_
""" """
if process_group is None: if process_group is None:
raise ValueError('<process_group> is required.') raise ValueError('<process_group> is required.')
...@@ -1479,6 +1582,10 @@ def tanh(input, inplace=False): ...@@ -1479,6 +1582,10 @@ def tanh(input, inplace=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Tanh(...)`_
""" """
return _activation(input, inplace, 'Tanh') return _activation(input, inplace, 'Tanh')
...@@ -1496,8 +1603,8 @@ def upsample( ...@@ -1496,8 +1603,8 @@ def upsample(
```python ```python
x = torch.ones((1, 2, 3, 4)) x = torch.ones((1, 2, 3, 4))
y = F.interpolate(x, size=6) # Shape: (1, 2, 6, 6) y = F.upsample(x, size=6) # Shape: (1, 2, 6, 6)
z = F.interpolate(x, scale_factor=2) # Shape: (1, 2, 6, 8) z = F.upsample(x, scale_factor=2) # Shape: (1, 2, 6, 8)
``` ```
Set ``align_corners`` to determine the input coordinates in linear ``mode``: Set ``align_corners`` to determine the input coordinates in linear ``mode``:
...@@ -1532,6 +1639,10 @@ def upsample( ...@@ -1532,6 +1639,10 @@ def upsample(
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.Upsample(...)`_
""" """
return interpolate(input, size, scale_factor, mode, align_corners) return interpolate(input, size, scale_factor, mode, align_corners)
...@@ -1561,6 +1672,10 @@ def upsample_bilinear(input, size=None, scale_factor=None): ...@@ -1561,6 +1672,10 @@ def upsample_bilinear(input, size=None, scale_factor=None):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.UpsamplingBilinear2d(...)`_
""" """
return interpolate(input, size, scale_factor, 'linear', align_corners=True) return interpolate(input, size, scale_factor, 'linear', align_corners=True)
...@@ -1590,16 +1705,18 @@ def upsample_nearest(input, size=None, scale_factor=None): ...@@ -1590,16 +1705,18 @@ def upsample_nearest(input, size=None, scale_factor=None):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nn.UpsamplingNearest2d(...)`_
""" """
return interpolate(input, size, scale_factor, 'nearest') return interpolate(input, size, scale_factor, 'nearest')
def _activation(input, inplace=False, _op_type=''): def _activation(input, inplace=False, _op_type=''):
return _functions._Activation \ return _functions._Activation \
.instantiate( .instantiate(input.device, op_type=_op_type) \
input.device, .apply(input, inplace=inplace)
op_type=_op_type,
).apply(input, inplace=inplace)
def _conv( def _conv(
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Activation modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -43,7 +44,7 @@ class ELU(Module): ...@@ -43,7 +44,7 @@ class ELU(Module):
See Also See Also
-------- --------
`torch.nn.functional.elu(...)`_ - Apply the exponential linear unit to input. `torch.nn.functional.elu(...)`_
""" """
...@@ -145,7 +146,7 @@ class LeakyReLU(Module): ...@@ -145,7 +146,7 @@ class LeakyReLU(Module):
See Also See Also
-------- --------
`torch.nn.functional.leaky_relu(...)`_ - Apply the leaky rectified linear unit. `torch.nn.functional.leaky_relu(...)`_
""" """
...@@ -187,6 +188,10 @@ class LogSoftmax(Module): ...@@ -187,6 +188,10 @@ class LogSoftmax(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.log_softmax(...)`_
""" """
def __init__(self, dim): def __init__(self, dim):
...@@ -234,6 +239,11 @@ class PReLU(Module): ...@@ -234,6 +239,11 @@ class PReLU(Module):
mm = torch.nn.PReLU(num_parameters=3) mm = torch.nn.PReLU(num_parameters=3)
z = mm(x) z = mm(x)
``` ```
See Also
--------
`torch.nn.functional.prelu(...)`_
""" """
def __init__(self, num_parameters=1, init=0.25): def __init__(self, num_parameters=1, init=0.25):
...@@ -279,6 +289,10 @@ class ReLU(Module): ...@@ -279,6 +289,10 @@ class ReLU(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.relu(...)`_
""" """
def __init__(self, inplace=False): def __init__(self, inplace=False):
...@@ -322,6 +336,10 @@ class ReLU6(Module): ...@@ -322,6 +336,10 @@ class ReLU6(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.relu6(...)`_
""" """
def __init__(self, inplace=False): def __init__(self, inplace=False):
...@@ -365,6 +383,10 @@ class SELU(Module): ...@@ -365,6 +383,10 @@ class SELU(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.selu(...)`_
""" """
def __init__(self, inplace=False): def __init__(self, inplace=False):
...@@ -402,6 +424,10 @@ class Sigmoid(Module): ...@@ -402,6 +424,10 @@ class Sigmoid(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.sigmoid(...)`_
""" """
def __init__(self, inplace=False): def __init__(self, inplace=False):
...@@ -439,6 +465,10 @@ class Softmax(Module): ...@@ -439,6 +465,10 @@ class Softmax(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.softmax(...)`_
""" """
def __init__(self, dim, inplace=False): def __init__(self, dim, inplace=False):
...@@ -479,6 +509,10 @@ class Tanh(Module): ...@@ -479,6 +509,10 @@ class Tanh(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.tanh(...)`_
""" """
def __init__(self, inplace=False): def __init__(self, inplace=False):
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Affine modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -25,8 +26,6 @@ class Affine(Module): ...@@ -25,8 +26,6 @@ class Affine(Module):
.. math:: y = Ax + b .. math:: y = Ax + b
This transform is often taken as a post-processing of normalization. This transform is often taken as a post-processing of normalization.
Specially, a trained ``BatchNorm`` can be fused to this under some
fine-tune settings, such as detection and segmentation.
Examples: Examples:
...@@ -46,6 +45,10 @@ class Affine(Module): ...@@ -46,6 +45,10 @@ class Affine(Module):
y4d = m(x4d) y4d = m(x4d)
``` ```
See Also
--------
`torch.nn.functional.affine(...)`_
""" """
def __init__( def __init__(
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""BatchNorm modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -89,10 +90,14 @@ class BatchNorm1d(_BatchNorm): ...@@ -89,10 +90,14 @@ class BatchNorm1d(_BatchNorm):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The moving average of stats are calculated as: The running average of statistics are calculated as:
.. math:: .. math::
x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{stat} x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
See Also
--------
`torch.nn.functional.batch_norm(...)`_
""" """
...@@ -136,10 +141,14 @@ class BatchNorm2d(_BatchNorm): ...@@ -136,10 +141,14 @@ class BatchNorm2d(_BatchNorm):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The moving average of stats are calculated as: The running average of statistics are calculated as:
.. math:: .. math::
x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{stat} x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
See Also
--------
`torch.nn.functional.batch_norm(...)`_
""" """
...@@ -183,10 +192,14 @@ class BatchNorm3d(_BatchNorm): ...@@ -183,10 +192,14 @@ class BatchNorm3d(_BatchNorm):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The moving average of stats are calculated as: The running average of statistics are calculated as:
.. math:: .. math::
x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{stat} x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
See Also
--------
`torch.nn.functional.batch_norm(...)`_
""" """
...@@ -230,15 +243,19 @@ class SyncBatchNorm(_BatchNorm): ...@@ -230,15 +243,19 @@ class SyncBatchNorm(_BatchNorm):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The moving average of stats are calculated as: The running average of statistics are calculated as:
.. math:: .. math::
x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{stat} x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
Additionally, you can specify ``process_group`` to perform synchronization. Additionally, specify ``process_group`` to perform synchronization.
If not, value returning from ``dragon.distributed.get_group(...)`` will be used. If not, value returning from ``dragon.distributed.get_group(...)`` will be used.
See Also
--------
`torch.nn.functional.batch_norm(...)`_
""" """
def __init__( def __init__(
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/container.py> # <https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/container.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Container modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Convolution modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -90,17 +91,7 @@ class _ConvNd(Module): ...@@ -90,17 +91,7 @@ class _ConvNd(Module):
class Conv2d(_ConvNd): class Conv2d(_ConvNd):
r"""Apply the 2d convolution. """Apply the 2d convolution.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} +
2 * pad - \text{DK}_{size}) / stride + 1
\end{cases}
Examples: Examples:
...@@ -110,6 +101,10 @@ class Conv2d(_ConvNd): ...@@ -110,6 +101,10 @@ class Conv2d(_ConvNd):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.conv2d(...)`_
""" """
def __init__( def __init__(
...@@ -171,17 +166,7 @@ class Conv2d(_ConvNd): ...@@ -171,17 +166,7 @@ class Conv2d(_ConvNd):
class ConvTranspose2d(_ConvNd): class ConvTranspose2d(_ConvNd):
r"""Apply the 2d deconvolution. """Apply the 2d deconvolution.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} - 1) *
stride + \text{DK}_{size} - 2 * pad
\end{cases}
Examples: Examples:
...@@ -191,6 +176,10 @@ class ConvTranspose2d(_ConvNd): ...@@ -191,6 +176,10 @@ class ConvTranspose2d(_ConvNd):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.conv_transpose2d(...)`_
""" """
def __init__( def __init__(
...@@ -256,17 +245,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -256,17 +245,7 @@ class ConvTranspose2d(_ConvNd):
class DepthwiseConv2d(Conv2d): class DepthwiseConv2d(Conv2d):
r"""Apply the 2d depthwise convolution. """Apply the 2d depthwise convolution.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} +
2 * pad - \text{DK}_{size}) / stride + 1
\end{cases}
Examples: Examples:
...@@ -276,6 +255,10 @@ class DepthwiseConv2d(Conv2d): ...@@ -276,6 +255,10 @@ class DepthwiseConv2d(Conv2d):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.depthwise_conv2d(...)`_
""" """
def __init__( def __init__(
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Dropout modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -38,6 +39,10 @@ class DropBlock2d(Module): ...@@ -38,6 +39,10 @@ class DropBlock2d(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.drop_block2d(...)`_
""" """
# Store the global unique slot index # Store the global unique slot index
...@@ -114,6 +119,10 @@ class Dropout(Module): ...@@ -114,6 +119,10 @@ class Dropout(Module):
z = m(x) z = m(x)
``` ```
See Also
--------
`torch.nn.functional.dropout(...)`_
""" """
def __init__(self, p=0.5, inplace=False): def __init__(self, p=0.5, inplace=False):
...@@ -155,6 +164,10 @@ class DropPath(Module): ...@@ -155,6 +164,10 @@ class DropPath(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.drop_path(...)`_
""" """
# Store the global unique slot index # Store the global unique slot index
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Linear modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -34,6 +35,10 @@ class Linear(Module): ...@@ -34,6 +35,10 @@ class Linear(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.linear(...)`_
""" """
def __init__(self, in_features, out_features, bias=True): def __init__(self, in_features, out_features, bias=True):
......
...@@ -7,11 +7,8 @@ ...@@ -7,11 +7,8 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Loss modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -68,6 +65,10 @@ class CTCLoss(_Loss): ...@@ -68,6 +65,10 @@ class CTCLoss(_Loss):
loss = m(logits, labels) loss = m(logits, labels)
``` ```
See Also
--------
`torch.nn.functional.ctc_loss(...)`_
""" """
def __init__(self, padding_mask=-1, reduction='mean'): def __init__(self, padding_mask=-1, reduction='mean'):
...@@ -107,6 +108,10 @@ class NLLLoss(_WeightedLoss): ...@@ -107,6 +108,10 @@ class NLLLoss(_WeightedLoss):
loss = m2(m1(torch.randn(2, 2)), torch.tensor([0, 1])) loss = m2(m1(torch.randn(2, 2)), torch.tensor([0, 1]))
``` ```
See Also
--------
`torch.nn.functional.nll_loss(...)`_
""" """
def __init__( def __init__(
...@@ -155,6 +160,10 @@ class BCEWithLogitsLoss(_WeightedLoss): ...@@ -155,6 +160,10 @@ class BCEWithLogitsLoss(_WeightedLoss):
loss = m(torch.randn(2, 1), torch.tensor([0., 1.], 'float32')) loss = m(torch.randn(2, 1), torch.tensor([0., 1.], 'float32'))
``` ```
See Also
--------
`torch.nn.functional.binary_cross_entropy_with_logits(...)`_
""" """
def __init__( def __init__(
...@@ -205,6 +214,10 @@ class CrossEntropyLoss(_WeightedLoss): ...@@ -205,6 +214,10 @@ class CrossEntropyLoss(_WeightedLoss):
loss = m(logits, targets) loss = m(logits, targets)
``` ```
See Also
--------
`torch.nn.functional.cross_entropy(...)`_
""" """
def __init__( def __init__(
...@@ -257,6 +270,10 @@ class L1Loss(_Loss): ...@@ -257,6 +270,10 @@ class L1Loss(_Loss):
loss = m(torch.ones(2, 3), torch.zeros(2, 3)) loss = m(torch.ones(2, 3), torch.zeros(2, 3))
``` ```
See Also
--------
`torch.nn.functional.l1_loss(...)`_
""" """
def __init__(self, size_average=None, reduce=None, reduction='mean'): def __init__(self, size_average=None, reduce=None, reduction='mean'):
...@@ -292,6 +309,10 @@ class MSELoss(_Loss): ...@@ -292,6 +309,10 @@ class MSELoss(_Loss):
loss = m(torch.ones(2, 3) * 2, torch.zeros(2, 3)) loss = m(torch.ones(2, 3) * 2, torch.zeros(2, 3))
``` ```
See Also
--------
`torch.nn.functional.mse_loss(...)`_
""" """
def __init__(self, size_average=None, reduce=None, reduction='mean'): def __init__(self, size_average=None, reduce=None, reduction='mean'):
...@@ -333,6 +354,10 @@ class SmoothL1Loss(_Loss): ...@@ -333,6 +354,10 @@ class SmoothL1Loss(_Loss):
loss = m(torch.ones(2, 3), torch.zeros(2, 3)) loss = m(torch.ones(2, 3), torch.zeros(2, 3))
``` ```
See Also
--------
`torch.nn.functional.smooth_l1_loss(...)`_
""" """
def __init__( def __init__(
...@@ -384,6 +409,10 @@ class SigmoidFocalLoss(_WeightedLoss): ...@@ -384,6 +409,10 @@ class SigmoidFocalLoss(_WeightedLoss):
loss = m(logits, targets) loss = m(logits, targets)
``` ```
See Also
--------
`torch.nn.functional.sigmoid_focal_loss(...)`_
""" """
def __init__( def __init__(
......
...@@ -7,11 +7,8 @@ ...@@ -7,11 +7,8 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Base module class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Normalization modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -39,6 +40,10 @@ class GroupNorm(Module): ...@@ -39,6 +40,10 @@ class GroupNorm(Module):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.group_norm(...)`_
""" """
def __init__( def __init__(
...@@ -123,7 +128,7 @@ class LocalResponseNorm(Module): ...@@ -123,7 +128,7 @@ class LocalResponseNorm(Module):
See Also See Also
-------- --------
`torch.nn.functional.local_response_norm(...)`_ - Apply the local response normalization to input. `torch.nn.functional.local_response_norm(...)`_
""" """
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Padding modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -69,7 +70,12 @@ class ConstantPad1d(_ConstantPadNd): ...@@ -69,7 +70,12 @@ class ConstantPad1d(_ConstantPadNd):
y = m(x) # (1, 2) -> (1, 4) y = m(x) # (1, 2) -> (1, 4)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding, value): def __init__(self, padding, value):
"""Create a ``ConstantPad1d`` module. """Create a ``ConstantPad1d`` module.
...@@ -100,6 +106,10 @@ class ConstantPad2d(_ConstantPadNd): ...@@ -100,6 +106,10 @@ class ConstantPad2d(_ConstantPadNd):
y = m(x) # (1, 2, 2) -> (1, 4, 4) y = m(x) # (1, 2, 2) -> (1, 4, 4)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding, value): def __init__(self, padding, value):
...@@ -132,6 +142,10 @@ class ConstantPad3d(_ConstantPadNd): ...@@ -132,6 +142,10 @@ class ConstantPad3d(_ConstantPadNd):
y = m(x) # (1, 2, 2, 2) -> (1, 4, 4, 4) y = m(x) # (1, 2, 2, 2) -> (1, 4, 4, 4)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding, value): def __init__(self, padding, value):
...@@ -164,6 +178,10 @@ class ReflectionPad1d(_ReflectionPadNd): ...@@ -164,6 +178,10 @@ class ReflectionPad1d(_ReflectionPadNd):
y = m(x) # (1, 4) -> (1, 6) y = m(x) # (1, 4) -> (1, 6)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding): def __init__(self, padding):
...@@ -194,6 +212,10 @@ class ReflectionPad2d(_ReflectionPadNd): ...@@ -194,6 +212,10 @@ class ReflectionPad2d(_ReflectionPadNd):
y = m(x) # (1, 4, 4) -> (1, 6, 6) y = m(x) # (1, 4, 4) -> (1, 6, 6)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding): def __init__(self, padding):
...@@ -224,6 +246,10 @@ class ReflectionPad3d(_ReflectionPadNd): ...@@ -224,6 +246,10 @@ class ReflectionPad3d(_ReflectionPadNd):
y = m(x) # (1, 4, 4, 4) -> (1, 6, 6, 6) y = m(x) # (1, 4, 4, 4) -> (1, 6, 6, 6)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding): def __init__(self, padding):
...@@ -254,6 +280,10 @@ class ReplicationPad1d(_ReplicationPadNd): ...@@ -254,6 +280,10 @@ class ReplicationPad1d(_ReplicationPadNd):
y = m(x) # (1, 4) -> (1, 6) y = m(x) # (1, 4) -> (1, 6)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding): def __init__(self, padding):
...@@ -284,6 +314,10 @@ class ReplicationPad2d(_ReplicationPadNd): ...@@ -284,6 +314,10 @@ class ReplicationPad2d(_ReplicationPadNd):
y = m(x) # (1, 4, 4) -> (1, 6, 6) y = m(x) # (1, 4, 4) -> (1, 6, 6)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding): def __init__(self, padding):
...@@ -314,6 +348,10 @@ class ReplicationPad3d(_ReplicationPadNd): ...@@ -314,6 +348,10 @@ class ReplicationPad3d(_ReplicationPadNd):
y = m(x) # (1, 4, 4, 4) -> (1, 6, 6, 6) y = m(x) # (1, 4, 4, 4) -> (1, 6, 6, 6)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding): def __init__(self, padding):
...@@ -344,6 +382,10 @@ class ZeroPad2d(ConstantPad2d): ...@@ -344,6 +382,10 @@ class ZeroPad2d(ConstantPad2d):
y = m(x) # (1, 2, 2) -> (1, 4, 4) y = m(x) # (1, 2, 2) -> (1, 4, 4)
``` ```
See Also
--------
`torch.nn.functional.pad(...)`_
""" """
def __init__(self, padding): def __init__(self, padding):
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Pooling modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -59,6 +60,10 @@ class AvgPool2d(_PoolNd): ...@@ -59,6 +60,10 @@ class AvgPool2d(_PoolNd):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.avg_pool2d(...)`_
""" """
def __init__( def __init__(
...@@ -121,6 +126,10 @@ class MaxPool2d(_PoolNd): ...@@ -121,6 +126,10 @@ class MaxPool2d(_PoolNd):
y = m(x) y = m(x)
``` ```
See Also
--------
`torch.nn.functional.max_pool2d(...)`_
""" """
def __init__( def __init__(
......
...@@ -7,11 +7,8 @@ ...@@ -7,11 +7,8 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
"""RNN modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Upsampling modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -28,6 +29,10 @@ class Upsample(Module): ...@@ -28,6 +29,10 @@ class Upsample(Module):
z = torch.nn.UpSample(scale_factor=2)(x) # Shape: (1, 2, 6, 8) z = torch.nn.UpSample(scale_factor=2)(x) # Shape: (1, 2, 6, 8)
``` ```
See Also
--------
`torch.nn.functional.interpolate(...)`_
""" """
def __init__( def __init__(
...@@ -89,6 +94,10 @@ class UpsamplingBilinear2d(Upsample): ...@@ -89,6 +94,10 @@ class UpsamplingBilinear2d(Upsample):
z = torch.nn.UpsamplingBilinear2d(scale_factor=2)(x) # Shape: (1, 2, 6, 8) z = torch.nn.UpsamplingBilinear2d(scale_factor=2)(x) # Shape: (1, 2, 6, 8)
``` ```
See Also
--------
`torch.nn.functional.interpolate(...)`_
""" """
def __init__(self, size=None, scale_factor=None): def __init__(self, size=None, scale_factor=None):
...@@ -117,6 +126,10 @@ class UpsamplingNearest2d(Upsample): ...@@ -117,6 +126,10 @@ class UpsamplingNearest2d(Upsample):
z = torch.nn.UpsamplingNearest2d(scale_factor=2)(x) # Shape: (1, 2, 6, 8) z = torch.nn.UpsamplingNearest2d(scale_factor=2)(x) # Shape: (1, 2, 6, 8)
``` ```
See Also
--------
`torch.nn.functional.interpolate(...)`_
""" """
def __init__(self, size=None, scale_factor=None): def __init__(self, size=None, scale_factor=None):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py> # <https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Module utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -33,7 +33,7 @@ def abs(self): ...@@ -33,7 +33,7 @@ def abs(self):
See Also See Also
-------- --------
`torch.abs(...)`_ : Compute the absolute value of input. `torch.abs(...)`_
""" """
return math_funcs.abs(self) return math_funcs.abs(self)
...@@ -56,7 +56,7 @@ def add(self, other): ...@@ -56,7 +56,7 @@ def add(self, other):
See Also See Also
-------- --------
`torch.add(...)`_ : Compute the element-wise addition. `torch.add(...)`_
""" """
return math_funcs.add(self, other) return math_funcs.add(self, other)
...@@ -79,7 +79,7 @@ def add_(self, other): ...@@ -79,7 +79,7 @@ def add_(self, other):
See Also See Also
-------- --------
`torch.add(...)`_ : Compute the element-wise addition. `torch.add(...)`_
""" """
return math_funcs.add(self, other, self) return math_funcs.add(self, other, self)
...@@ -102,7 +102,7 @@ def argmax(self, dim=None, keepdim=False): ...@@ -102,7 +102,7 @@ def argmax(self, dim=None, keepdim=False):
See Also See Also
-------- --------
`torch.argmax(...)`_ : Return the index of maximum elements along the given dimension. `torch.argmax(...)`_
""" """
return array_funcs.argmax(self, dim, keepdim) return array_funcs.argmax(self, dim, keepdim)
...@@ -125,7 +125,7 @@ def argmin(self, dim=None, keepdim=False): ...@@ -125,7 +125,7 @@ def argmin(self, dim=None, keepdim=False):
See Also See Also
-------- --------
`torch.argmin(...)`_ : Return the index of minimum elements along the given dimension. `torch.argmin(...)`_
""" """
return array_funcs.argmin(self, dim, keepdim) return array_funcs.argmin(self, dim, keepdim)
...@@ -161,7 +161,7 @@ def bitwise_not(self): ...@@ -161,7 +161,7 @@ def bitwise_not(self):
See Also See Also
-------- --------
`torch.bitwise_not(...)`_ : Compute the element-wise NOT bitwise operation. `torch.bitwise_not(...)`_
""" """
return math_funcs.bitwise_not(self) return math_funcs.bitwise_not(self)
...@@ -179,7 +179,7 @@ def bitwise_not_(self): ...@@ -179,7 +179,7 @@ def bitwise_not_(self):
See Also See Also
-------- --------
`torch.bitwise_not(...)`_ : Compute the element-wise NOT bitwise operation. `torch.bitwise_not(...)`_
""" """
return math_funcs.bitwise_not(self, self) return math_funcs.bitwise_not(self, self)
...@@ -202,7 +202,7 @@ def bitwise_xor(self, other): ...@@ -202,7 +202,7 @@ def bitwise_xor(self, other):
See Also See Also
-------- --------
`torch.bitwise_xor(...)`_ : Compute the element-wise XOR bitwise operation. `torch.bitwise_xor(...)`_
""" """
return math_funcs.bitwise_xor(self, other) return math_funcs.bitwise_xor(self, other)
...@@ -225,7 +225,7 @@ def bitwise_xor_(self, other): ...@@ -225,7 +225,7 @@ def bitwise_xor_(self, other):
See Also See Also
-------- --------
`torch.bitwise_xor(...)`_ : Compute the element-wise XOR bitwise operation. `torch.bitwise_xor(...)`_
""" """
return math_funcs.bitwise_xor(self, other, self) return math_funcs.bitwise_xor(self, other, self)
...@@ -291,7 +291,7 @@ def ceil(self): ...@@ -291,7 +291,7 @@ def ceil(self):
See Also See Also
-------- --------
`torch.ceil(...)`_ : Compute the smallest integer not less than input. `torch.ceil(...)`_
""" """
return math_funcs.ceil(self) return math_funcs.ceil(self)
...@@ -309,7 +309,7 @@ def ceil_(self): ...@@ -309,7 +309,7 @@ def ceil_(self):
See Also See Also
-------- --------
`torch.ceil(...)`_ : Compute the smallest integer not less than input. `torch.ceil(...)`_
""" """
return math_funcs.ceil(self, self) return math_funcs.ceil(self, self)
...@@ -375,7 +375,7 @@ def clamp(self, min=None, max=None): ...@@ -375,7 +375,7 @@ def clamp(self, min=None, max=None):
See Also See Also
-------- --------
`torch.clamp(...)`_ : Compute the clipped input according to the given bounds. `torch.clamp(...)`_
""" """
return math_funcs.clamp(self, min, max) return math_funcs.clamp(self, min, max)
...@@ -398,7 +398,7 @@ def clamp_(self, min=None, max=None): ...@@ -398,7 +398,7 @@ def clamp_(self, min=None, max=None):
See Also See Also
-------- --------
`torch.clamp(...)`_ : Compute the clipped input according to the given bounds. `torch.clamp(...)`_
""" """
return math_funcs.clamp(self, min, max, self) return math_funcs.clamp(self, min, max, self)
...@@ -416,7 +416,7 @@ def cos(self): ...@@ -416,7 +416,7 @@ def cos(self):
See Also See Also
-------- --------
`torch.cos(...)`_ : Compute the cos of input. `torch.cos(...)`_
""" """
return math_funcs.cos(self) return math_funcs.cos(self)
...@@ -437,7 +437,7 @@ def cumsum(self, dim): ...@@ -437,7 +437,7 @@ def cumsum(self, dim):
See Also See Also
-------- --------
`torch.cumsum(...)`_ : Compute the cumulative sum of elements along the given dimension. `torch.cumsum(...)`_
""" """
return array_funcs.cumsum(self, dim) return array_funcs.cumsum(self, dim)
...@@ -460,7 +460,7 @@ def div(self, other): ...@@ -460,7 +460,7 @@ def div(self, other):
See Also See Also
-------- --------
`torch.div(...)`_ : Compute the element-wise division. `torch.div(...)`_
""" """
return math_funcs.div(self, other) return math_funcs.div(self, other)
...@@ -483,7 +483,7 @@ def div_(self, other): ...@@ -483,7 +483,7 @@ def div_(self, other):
See Also See Also
-------- --------
`torch.div(...)`_ : Compute the element-wise division. `torch.div(...)`_
""" """
return math_funcs.div(self, other, self) return math_funcs.div(self, other, self)
...@@ -530,7 +530,7 @@ def eq(self, other): ...@@ -530,7 +530,7 @@ def eq(self, other):
See Also See Also
-------- --------
`torch.eq(...)`_ : Compute the element-wise equal comparison. `torch.eq(...)`_
""" """
return math_funcs.eq(self, other) return math_funcs.eq(self, other)
...@@ -548,7 +548,7 @@ def exp(self): ...@@ -548,7 +548,7 @@ def exp(self):
See Also See Also
-------- --------
`torch.exp(...)`_ : Compute the exponential of input. `torch.exp(...)`_
""" """
return math_funcs.exp(self) return math_funcs.exp(self)
...@@ -569,7 +569,7 @@ def expand(self, *sizes): ...@@ -569,7 +569,7 @@ def expand(self, *sizes):
See Also See Also
-------- --------
`torch.expand(...)`_ : Broadcast input according to given sizes. `torch.expand(...)`_
""" """
return array_funcs.expand(self, sizes) return array_funcs.expand(self, sizes)
...@@ -630,7 +630,7 @@ def floor(self): ...@@ -630,7 +630,7 @@ def floor(self):
See Also See Also
-------- --------
`torch.floor(...)`_ : Compute the largest integer not greater than input. `torch.floor(...)`_
""" """
return math_funcs.floor(self) return math_funcs.floor(self)
...@@ -648,7 +648,7 @@ def floor_(self): ...@@ -648,7 +648,7 @@ def floor_(self):
See Also See Also
-------- --------
`torch.floor(...)`_ : Compute the largest integer not greater than input. `torch.floor(...)`_
""" """
return math_funcs.floor(self, self) return math_funcs.floor(self, self)
...@@ -671,7 +671,7 @@ def ge(self, other): ...@@ -671,7 +671,7 @@ def ge(self, other):
See Also See Also
-------- --------
`torch.ge(...)`_ : Compute the element-wise greater-equal comparison. `torch.ge(...)`_
""" """
return math_funcs.ge(self, other) return math_funcs.ge(self, other)
...@@ -715,7 +715,7 @@ def gt(self, other): ...@@ -715,7 +715,7 @@ def gt(self, other):
See Also See Also
-------- --------
`torch.gt(...)`_ : Compute the element-wise greater comparison. `torch.gt(...)`_
""" """
return math_funcs.gt(self, other) return math_funcs.gt(self, other)
...@@ -805,7 +805,7 @@ def le(self, other): ...@@ -805,7 +805,7 @@ def le(self, other):
See Also See Also
-------- --------
`torch.le(...)`_ : Compute the element-wise less-equal comparison. `torch.le(...)`_
""" """
return math_funcs.le(self, other) return math_funcs.le(self, other)
...@@ -887,7 +887,7 @@ def lt(self, other): ...@@ -887,7 +887,7 @@ def lt(self, other):
See Also See Also
-------- --------
`torch.lt(...)`_ : Compute the element-wise less comparison. `torch.lt(...)`_
""" """
return math_funcs.lt(self, other) return math_funcs.lt(self, other)
...@@ -1010,7 +1010,7 @@ def mul(self, other): ...@@ -1010,7 +1010,7 @@ def mul(self, other):
See Also See Also
-------- --------
`torch.mul(...)`_ : Compute the element-wise multiplication. `torch.mul(...)`_
""" """
return math_funcs.mul(self, other) return math_funcs.mul(self, other)
...@@ -1033,7 +1033,7 @@ def mul_(self, other): ...@@ -1033,7 +1033,7 @@ def mul_(self, other):
See Also See Also
-------- --------
`torch.mul(...)`_ : Compute the element-wise multiplication. `torch.mul(...)`_
""" """
return math_funcs.mul(self, other, self) return math_funcs.mul(self, other, self)
...@@ -1097,7 +1097,7 @@ def ne(self, other): ...@@ -1097,7 +1097,7 @@ def ne(self, other):
See Also See Also
-------- --------
`torch.ne(...)`_ : Compute the element-wise not-equal comparison. `torch.ne(...)`_
""" """
return math_funcs.ne(self, other) return math_funcs.ne(self, other)
...@@ -1115,20 +1115,26 @@ def neg(self): ...@@ -1115,20 +1115,26 @@ def neg(self):
See Also See Also
-------- --------
`torch.neg(...)`_ : Compute the element-wise negative. `torch.neg(...)`_
""" """
return math_funcs.neg(self) return math_funcs.neg(self)
def nonzero(self): def nonzero(self):
"""Return the index of non-zero elements. r"""Return the index of non-zero elements.
.. math:: \text{out} = \{i\}, \text{ if } \text{self}_{i} \neq 0
Returns Returns
------- -------
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nonzero(...)`_
""" """
return array_funcs.nonzero(self) return array_funcs.nonzero(self)
...@@ -1186,7 +1192,7 @@ def pow(self, exponent): ...@@ -1186,7 +1192,7 @@ def pow(self, exponent):
See Also See Also
-------- --------
`torch.pow(...)`_ : Compute the power of input. `torch.pow(...)`_
""" """
return math_funcs.pow(self, exponent) return math_funcs.pow(self, exponent)
...@@ -1204,7 +1210,7 @@ def reciprocal(self): ...@@ -1204,7 +1210,7 @@ def reciprocal(self):
See Also See Also
-------- --------
`torch.reciprocal(...)`_ : Compute the reciprocal of input. `torch.reciprocal(...)`_
""" """
return math_funcs.reciprocal(self) return math_funcs.reciprocal(self)
...@@ -1222,7 +1228,7 @@ def reciprocal_(self): ...@@ -1222,7 +1228,7 @@ def reciprocal_(self):
See Also See Also
-------- --------
`torch.reciprocal(...)`_ : Compute the reciprocal of input. `torch.reciprocal(...)`_
""" """
return math_funcs.reciprocal(self, self) return math_funcs.reciprocal(self, self)
...@@ -1260,7 +1266,7 @@ def reshape(self, shape): ...@@ -1260,7 +1266,7 @@ def reshape(self, shape):
See Also See Also
-------- --------
`torch.reshape(...)`_ : Change the shape of input. `torch.reshape(...)`_
""" """
return array_funcs.reshape(self, shape) return array_funcs.reshape(self, shape)
...@@ -1281,7 +1287,7 @@ def reshape_(self, shape): ...@@ -1281,7 +1287,7 @@ def reshape_(self, shape):
See Also See Also
-------- --------
`torch.reshape(...)`_ : Change the shape of input. `torch.reshape(...)`_
""" """
return array_funcs.reshape(self, shape, self) return array_funcs.reshape(self, shape, self)
...@@ -1299,7 +1305,7 @@ def round(self): ...@@ -1299,7 +1305,7 @@ def round(self):
See Also See Also
-------- --------
`torch.round(...)`_ : Compute the nearest integer of input. `torch.round(...)`_
""" """
return math_funcs.round(self) return math_funcs.round(self)
...@@ -1317,7 +1323,7 @@ def round_(self): ...@@ -1317,7 +1323,7 @@ def round_(self):
See Also See Also
-------- --------
`torch.round(...)`_ : Compute the nearest integer of input. `torch.round(...)`_
""" """
return math_funcs.round(self, self) return math_funcs.round(self, self)
...@@ -1335,7 +1341,7 @@ def rsqrt(self): ...@@ -1335,7 +1341,7 @@ def rsqrt(self):
See Also See Also
-------- --------
`torch.rsqrt(...)`_ : Compute the square root of input. `torch.rsqrt(...)`_
""" """
return math_funcs.rsqrt(self) return math_funcs.rsqrt(self)
...@@ -1353,7 +1359,7 @@ def rsqrt_(self): ...@@ -1353,7 +1359,7 @@ def rsqrt_(self):
See Also See Also
-------- --------
`torch.rsqrt(...)`_ : Compute the square root of input. `torch.rsqrt(...)`_
""" """
return math_funcs.rsqrt(self, self) return math_funcs.rsqrt(self, self)
...@@ -1395,7 +1401,7 @@ def sign(self): ...@@ -1395,7 +1401,7 @@ def sign(self):
See Also See Also
-------- --------
`torch.sign(...)`_ : Compute the sign indication of input. `torch.sign(...)`_
""" """
return math_funcs.sign(self) return math_funcs.sign(self)
...@@ -1419,7 +1425,7 @@ def sign_(self): ...@@ -1419,7 +1425,7 @@ def sign_(self):
See Also See Also
-------- --------
`torch.sign(...)`_ : Compute the sign indication of input. `torch.sign(...)`_
""" """
return math_funcs.sign(self, self) return math_funcs.sign(self, self)
...@@ -1437,7 +1443,7 @@ def sin(self): ...@@ -1437,7 +1443,7 @@ def sin(self):
See Also See Also
-------- --------
`torch.sin(...)`_ : Compute the sin of input. `torch.sin(...)`_
""" """
return math_funcs.sin(self) return math_funcs.sin(self)
...@@ -1455,7 +1461,7 @@ def sqrt(self): ...@@ -1455,7 +1461,7 @@ def sqrt(self):
See Also See Also
-------- --------
`torch.sqrt(...)`_ : Compute the square root of input. `torch.sqrt(...)`_
""" """
return math_funcs.sqrt(self) return math_funcs.sqrt(self)
...@@ -1473,7 +1479,7 @@ def sqrt_(self): ...@@ -1473,7 +1479,7 @@ def sqrt_(self):
See Also See Also
-------- --------
`torch.sqrt(...)`_ : Compute the square root of input. `torch.sqrt(...)`_
""" """
return math_funcs.sqrt(self, self) return math_funcs.sqrt(self, self)
...@@ -1492,12 +1498,16 @@ def squeeze(self, dim=None): ...@@ -1492,12 +1498,16 @@ def squeeze(self, dim=None):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.squeeze(...)`_
""" """
return array_funcs.squeeze(self, dim) return array_funcs.squeeze(self, dim)
def squeeze_(self, dim=None): def squeeze_(self, dim=None):
"""Inplace version of ``Tensor.squeeze()``. """Remove the dimensions with size 1.
Parameters Parameters
---------- ----------
...@@ -1509,6 +1519,10 @@ def squeeze_(self, dim=None): ...@@ -1509,6 +1519,10 @@ def squeeze_(self, dim=None):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The self. The self.
See Also
--------
`torch.squeeze(...)`_
""" """
return array_funcs.squeeze(self, dim, self) return array_funcs.squeeze(self, dim, self)
...@@ -1528,6 +1542,10 @@ def sum(self, dim=None, keepdim=False): ...@@ -1528,6 +1542,10 @@ def sum(self, dim=None, keepdim=False):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.sum(...)`_
""" """
return array_funcs.sum(self, dim, keepdim) return array_funcs.sum(self, dim, keepdim)
...@@ -1549,7 +1567,7 @@ def sub(self, other): ...@@ -1549,7 +1567,7 @@ def sub(self, other):
See Also See Also
-------- --------
`torch.sub(...)`_ : Compute the element-wise subtraction. `torch.sub(...)`_
""" """
return math_funcs.sub(self, other) return math_funcs.sub(self, other)
...@@ -1572,7 +1590,7 @@ def sub_(self, other): ...@@ -1572,7 +1590,7 @@ def sub_(self, other):
See Also See Also
-------- --------
`torch.sub(...)`_ : Compute the element-wise subtraction. `torch.sub(...)`_
""" """
return math_funcs.sub(self, other, self) return math_funcs.sub(self, other, self)
...@@ -1599,7 +1617,7 @@ def topk(self, k, dim=None, largest=True, sorted=True): ...@@ -1599,7 +1617,7 @@ def topk(self, k, dim=None, largest=True, sorted=True):
See Also See Also
-------- --------
`torch.topk(...)`_ : Return the top-K largest or smallest elements along the given dimension. `torch.topk(...)`_
""" """
return array_funcs.topk(self, k, dim, largest, sorted) return array_funcs.topk(self, k, dim, largest, sorted)
...@@ -1660,12 +1678,16 @@ def unsqueeze(self, dim): ...@@ -1660,12 +1678,16 @@ def unsqueeze(self, dim):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.unsqueeze(...)`_
""" """
return array_funcs.unsqueeze(self, dim) return array_funcs.unsqueeze(self, dim)
def unsqueeze_(self, dim): def unsqueeze_(self, dim):
"""In-place version of ``Tensor.unsqueeze()``. """Insert the dimensions of size 1.
Parameters Parameters
---------- ----------
...@@ -1677,6 +1699,10 @@ def unsqueeze_(self, dim): ...@@ -1677,6 +1699,10 @@ def unsqueeze_(self, dim):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The self. The self.
See Also
--------
`torch.unsqueeze(...)`_
""" """
return array_funcs.unsqueeze(self, dim, self) return array_funcs.unsqueeze(self, dim, self)
...@@ -1703,6 +1729,10 @@ def where(self, condition, y): ...@@ -1703,6 +1729,10 @@ def where(self, condition, y):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.where(...)`_
""" """
return array_funcs.where(condition, self, y) return array_funcs.where(condition, self, y)
......
...@@ -221,7 +221,7 @@ class Tensor(object): ...@@ -221,7 +221,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.abs(...)`_ : Compute the absolute value of input. `torch.abs(...)`_
""" """
...@@ -242,7 +242,7 @@ class Tensor(object): ...@@ -242,7 +242,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.add(...)`_ : Compute the element-wise addition. `torch.add(...)`_
""" """
...@@ -263,7 +263,7 @@ class Tensor(object): ...@@ -263,7 +263,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.add(...)`_ : Compute the element-wise addition. `torch.add(...)`_
""" """
...@@ -284,7 +284,7 @@ class Tensor(object): ...@@ -284,7 +284,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.argmax(...)`_ : Return the index of maximum elements along the given dimension. `torch.argmax(...)`_
""" """
...@@ -305,7 +305,7 @@ class Tensor(object): ...@@ -305,7 +305,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.argmin(...)`_ : Return the index of minimum elements along the given dimension. `torch.argmin(...)`_
""" """
...@@ -333,7 +333,7 @@ class Tensor(object): ...@@ -333,7 +333,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.bitwise_not(...)`_ : Compute the element-wise NOT bitwise operation. `torch.bitwise_not(...)`_
""" """
...@@ -349,7 +349,7 @@ class Tensor(object): ...@@ -349,7 +349,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.bitwise_not(...)`_ : Compute the element-wise NOT bitwise operation. `torch.bitwise_not(...)`_
""" """
...@@ -370,7 +370,7 @@ class Tensor(object): ...@@ -370,7 +370,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.bitwise_xor(...)`_ : Compute the element-wise XOR bitwise operation. `torch.bitwise_xor(...)`_
""" """
...@@ -391,7 +391,7 @@ class Tensor(object): ...@@ -391,7 +391,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.bitwise_xor(...)`_ : Compute the element-wise XOR bitwise operation. `torch.bitwise_xor(...)`_
""" """
...@@ -447,7 +447,7 @@ class Tensor(object): ...@@ -447,7 +447,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.ceil(...)`_ : Compute the smallest integer not less than input. `torch.ceil(...)`_
""" """
...@@ -463,7 +463,7 @@ class Tensor(object): ...@@ -463,7 +463,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.ceil(...)`_ : Compute the smallest integer not less than input. `torch.ceil(...)`_
""" """
...@@ -521,7 +521,7 @@ class Tensor(object): ...@@ -521,7 +521,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.clamp(...)`_ : Compute the clipped input according to the given bounds. `torch.clamp(...)`_
""" """
...@@ -542,7 +542,7 @@ class Tensor(object): ...@@ -542,7 +542,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.clamp(...)`_ : Compute the clipped input according to the given bounds. `torch.clamp(...)`_
""" """
...@@ -585,7 +585,7 @@ class Tensor(object): ...@@ -585,7 +585,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.cos(...)`_ : Compute the cos of input. `torch.cos(...)`_
""" """
...@@ -638,7 +638,7 @@ class Tensor(object): ...@@ -638,7 +638,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.cumsum(...)`_ : Compute the cumulative sum of elements along the given dimension. `torch.cumsum(...)`_
""" """
...@@ -681,7 +681,7 @@ class Tensor(object): ...@@ -681,7 +681,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.div(...)`_ : Compute the element-wise division. `torch.div(...)`_
""" """
...@@ -702,7 +702,7 @@ class Tensor(object): ...@@ -702,7 +702,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.div(...)`_ : Compute the element-wise division. `torch.div(...)`_
""" """
...@@ -743,7 +743,7 @@ class Tensor(object): ...@@ -743,7 +743,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.eq(...)`_ : Compute the element-wise equal comparison. `torch.eq(...)`_
""" """
...@@ -759,7 +759,7 @@ class Tensor(object): ...@@ -759,7 +759,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.exp(...)`_ : Compute the exponential of input. `torch.exp(...)`_
""" """
...@@ -778,7 +778,7 @@ class Tensor(object): ...@@ -778,7 +778,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.expand(...)`_ : Broadcast input according to given sizes. `torch.expand(...)`_
""" """
...@@ -797,7 +797,7 @@ class Tensor(object): ...@@ -797,7 +797,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.expand(...)`_ : Broadcast input according to given sizes. `torch.expand(...)`_
""" """
return self.expand(*other.size()) return self.expand(*other.size())
...@@ -851,7 +851,7 @@ class Tensor(object): ...@@ -851,7 +851,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.floor(...)`_ : Compute the largest integer not greater than input. `torch.floor(...)`_
""" """
...@@ -867,7 +867,7 @@ class Tensor(object): ...@@ -867,7 +867,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.floor(...)`_ : Compute the largest integer not greater than input. `torch.floor(...)`_
""" """
...@@ -888,7 +888,7 @@ class Tensor(object): ...@@ -888,7 +888,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.ge(...)`_ : Compute the element-wise greater-equal comparison. `torch.ge(...)`_
""" """
...@@ -909,7 +909,7 @@ class Tensor(object): ...@@ -909,7 +909,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.gt(...)`_ : Compute the element-wise greater comparison. `torch.gt(...)`_
""" """
...@@ -1000,7 +1000,7 @@ class Tensor(object): ...@@ -1000,7 +1000,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.le(...)`_ : Compute the element-wise less-equal comparison. `torch.le(...)`_
""" """
...@@ -1072,7 +1072,7 @@ class Tensor(object): ...@@ -1072,7 +1072,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.lt(...)`_ : Compute the element-wise less comparison. `torch.lt(...)`_
""" """
...@@ -1183,7 +1183,7 @@ class Tensor(object): ...@@ -1183,7 +1183,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.mul(...)`_ : Compute the element-wise multiplication. `torch.mul(...)`_
""" """
...@@ -1204,7 +1204,7 @@ class Tensor(object): ...@@ -1204,7 +1204,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.mul(...)`_ : Compute the element-wise multiplication. `torch.mul(...)`_
""" """
...@@ -1273,7 +1273,7 @@ class Tensor(object): ...@@ -1273,7 +1273,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.ne(...)`_ : Compute the element-wise not-equal comparison. `torch.ne(...)`_
""" """
...@@ -1289,18 +1289,24 @@ class Tensor(object): ...@@ -1289,18 +1289,24 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.neg(...)`_ : Compute the element-wise negative. `torch.neg(...)`_
""" """
def nonzero(self): def nonzero(self):
"""Return the index of non-zero elements. r"""Return the index of non-zero elements.
.. math:: \text{out} = \{i\}, \text{ if } \text{self}_{i} \neq 0
Returns Returns
------- -------
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.nonzero(...)`_
""" """
def normal_(self, mean=0, std=1): def normal_(self, mean=0, std=1):
...@@ -1392,7 +1398,7 @@ class Tensor(object): ...@@ -1392,7 +1398,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.pow(...)`_ : Compute the power of input. `torch.pow(...)`_
""" """
...@@ -1408,7 +1414,7 @@ class Tensor(object): ...@@ -1408,7 +1414,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.reciprocal(...)`_ : Compute the reciprocal of input. `torch.reciprocal(...)`_
""" """
...@@ -1424,7 +1430,7 @@ class Tensor(object): ...@@ -1424,7 +1430,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.reciprocal(...)`_ : Compute the reciprocal of input. `torch.reciprocal(...)`_
""" """
...@@ -1458,7 +1464,7 @@ class Tensor(object): ...@@ -1458,7 +1464,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.reshape(...)`_ : Change the shape of input. `torch.reshape(...)`_
""" """
...@@ -1477,7 +1483,7 @@ class Tensor(object): ...@@ -1477,7 +1483,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.reshape(...)`_ : Change the shape of input. `torch.reshape(...)`_
""" """
...@@ -1498,7 +1504,7 @@ class Tensor(object): ...@@ -1498,7 +1504,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.round(...)`_ : Compute the nearest integer of input. `torch.round(...)`_
""" """
...@@ -1514,7 +1520,7 @@ class Tensor(object): ...@@ -1514,7 +1520,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.round(...)`_ : Compute the nearest integer of input. `torch.round(...)`_
""" """
...@@ -1530,7 +1536,7 @@ class Tensor(object): ...@@ -1530,7 +1536,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.rsqrt(...)`_ : Compute the square root of input. `torch.rsqrt(...)`_
""" """
...@@ -1546,7 +1552,7 @@ class Tensor(object): ...@@ -1546,7 +1552,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.rsqrt(...)`_ : Compute the square root of input. `torch.rsqrt(...)`_
""" """
...@@ -1568,7 +1574,7 @@ class Tensor(object): ...@@ -1568,7 +1574,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.sign(...)`_ : Compute the sign indication of input. `torch.sign(...)`_
""" """
...@@ -1590,7 +1596,7 @@ class Tensor(object): ...@@ -1590,7 +1596,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.sign(...)`_ : Compute the sign indication of input. `torch.sign(...)`_
""" """
...@@ -1606,7 +1612,7 @@ class Tensor(object): ...@@ -1606,7 +1612,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.sin(...)`_ : Compute the sin of input. `torch.sin(...)`_
""" """
...@@ -1639,7 +1645,7 @@ class Tensor(object): ...@@ -1639,7 +1645,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.sqrt(...)`_ : Compute the square root of input. `torch.sqrt(...)`_
""" """
...@@ -1655,7 +1661,7 @@ class Tensor(object): ...@@ -1655,7 +1661,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.sqrt(...)`_ : Compute the square root of input. `torch.sqrt(...)`_
""" """
...@@ -1672,10 +1678,14 @@ class Tensor(object): ...@@ -1672,10 +1678,14 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.squeeze(...)`_
""" """
def squeeze_(self, dim=None): def squeeze_(self, dim=None):
"""Inplace version of ``Tensor.squeeze()``. """Remove the dimensions with size 1.
Parameters Parameters
---------- ----------
...@@ -1687,6 +1697,10 @@ class Tensor(object): ...@@ -1687,6 +1697,10 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The self. The self.
See Also
--------
`torch.squeeze(...)`_
""" """
def sum(self, dim=None, keepdim=False): def sum(self, dim=None, keepdim=False):
...@@ -1704,6 +1718,10 @@ class Tensor(object): ...@@ -1704,6 +1718,10 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.sum(...)`_
""" """
def sub(self, other): def sub(self, other):
...@@ -1723,7 +1741,7 @@ class Tensor(object): ...@@ -1723,7 +1741,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.sub(...)`_ : Compute the element-wise subtraction. `torch.sub(...)`_
""" """
...@@ -1744,7 +1762,7 @@ class Tensor(object): ...@@ -1744,7 +1762,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.sub(...)`_ : Compute the element-wise subtraction. `torch.sub(...)`_
""" """
...@@ -1769,7 +1787,7 @@ class Tensor(object): ...@@ -1769,7 +1787,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.topk(...)`_ : Return the top-K largest or smallest elements along the given dimension. `torch.topk(...)`_
""" """
...@@ -1822,10 +1840,14 @@ class Tensor(object): ...@@ -1822,10 +1840,14 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.unsqueeze(...)`_
""" """
def unsqueeze_(self, dim): def unsqueeze_(self, dim):
"""In-place version of ``Tensor.unsqueeze()``. """Insert the dimensions of size 1.
Parameters Parameters
---------- ----------
...@@ -1837,6 +1859,10 @@ class Tensor(object): ...@@ -1837,6 +1859,10 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The self. The self.
See Also
--------
`torch.unsqueeze(...)`_
""" """
def view(self, *shape): def view(self, *shape):
...@@ -1854,7 +1880,7 @@ class Tensor(object): ...@@ -1854,7 +1880,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.reshape(...)`_ : Change the shape of input. `torch.reshape(...)`_
""" """
return self.reshape(shape) return self.reshape(shape)
...@@ -1874,7 +1900,7 @@ class Tensor(object): ...@@ -1874,7 +1900,7 @@ class Tensor(object):
See Also See Also
-------- --------
`torch.reshape(...)`_ : Change the shape of input. `torch.reshape(...)`_
""" """
return self.reshape_(shape) return self.reshape_(shape)
...@@ -1918,6 +1944,10 @@ class Tensor(object): ...@@ -1918,6 +1944,10 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.where(...)`_
""" """
def zero_(self): def zero_(self):
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!