Commit fb47d86f by Ting PAN

Merge internal commits

1 parent 3dfb6ea5
Showing with 1302 additions and 724 deletions
[submodule "third_party/cub"] [submodule "third_party/cub"]
path = third_party/cub path = third_party/cub
url = https://github.com/NVlabs/cub url = https://github.com/NVIDIA/cub
[submodule "third_party/eigen"] [submodule "third_party/eigen"]
path = third_party/eigen path = third_party/eigen
url = https://gitlab.com/libeigen/eigen url = https://gitlab.com/libeigen/eigen
......
...@@ -30,7 +30,6 @@ from dragon.vm.caffe.core.layers.common import Reshape ...@@ -30,7 +30,6 @@ from dragon.vm.caffe.core.layers.common import Reshape
from dragon.vm.caffe.core.layers.common import Scale from dragon.vm.caffe.core.layers.common import Scale
from dragon.vm.caffe.core.layers.common import Slice from dragon.vm.caffe.core.layers.common import Slice
from dragon.vm.caffe.core.layers.common import Softmax from dragon.vm.caffe.core.layers.common import Softmax
from dragon.vm.caffe.core.layers.common import StopGradient
from dragon.vm.caffe.core.layers.common import Tile from dragon.vm.caffe.core.layers.common import Tile
from dragon.vm.caffe.core.layers.data import Data from dragon.vm.caffe.core.layers.data import Data
from dragon.vm.caffe.core.layers.loss import EuclideanLoss from dragon.vm.caffe.core.layers.loss import EuclideanLoss
......
...@@ -14,6 +14,8 @@ from __future__ import absolute_import ...@@ -14,6 +14,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
from dragon.core.framework.tensor import Tensor from dragon.core.framework.tensor import Tensor
from dragon.core.ops import activation_ops from dragon.core.ops import activation_ops
from dragon.core.ops import array_ops from dragon.core.ops import array_ops
...@@ -86,7 +88,7 @@ class ArgMax(Layer): ...@@ -86,7 +88,7 @@ class ArgMax(Layer):
self.call_args = {'axis': param.axis, 'keepdims': True} self.call_args = {'axis': param.axis, 'keepdims': True}
def __call__(self, bottom): def __call__(self, bottom):
return array_ops.argmax(bottom, **self.call_args) return math_ops.argmax(bottom, **self.call_args)
class BatchNorm(Layer): class BatchNorm(Layer):
...@@ -514,7 +516,9 @@ class Reduction(Layer): ...@@ -514,7 +516,9 @@ class Reduction(Layer):
raise ValueError('The negative axis can only be -1.') raise ValueError('The negative axis can only be -1.')
self.scale = param.coeff self.scale = param.coeff
self.call_args = {'axis': [param.axis]} self.call_args = {'axis': [param.axis]}
self.reduction = {1: array_ops.sum, 4: array_ops.mean}[param.operation] self.reduction = {1: math_ops.sum,
2: functools.partial(math_ops.norm, ord=1),
4: math_ops.mean}[param.operation]
def __call__(self, bottom): def __call__(self, bottom):
top = self.reduction(bottom, **self.call_args) top = self.reduction(bottom, **self.call_args)
...@@ -633,14 +637,16 @@ class Slice(Layer): ...@@ -633,14 +637,16 @@ class Slice(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(Slice, self).__init__(layer_param) super(Slice, self).__init__(layer_param)
param = layer_param.slice_param param = layer_param.slice_param
self.call_args = { self.axis = param.axis
'axis': param.axis, self.slice_points = param.slice_point
'num_or_size_splits': len(self.top),
'slice_point': [e for e in param.slice_point],
}
def __call__(self, bottom): def __call__(self, bottom):
return array_ops.split(bottom, **self.call_args) stride, size_splits = 0, []
for point in self.slice_points:
size_splits.append(point - stride)
stride = point
size_splits.append(bottom.shape[self.axis] - stride)
return array_ops.split(bottom, size_splits, axis=self.axis)
class Softmax(Layer): class Softmax(Layer):
...@@ -669,28 +675,6 @@ class Softmax(Layer): ...@@ -669,28 +675,6 @@ class Softmax(Layer):
return activation_ops.softmax(bottom, **self.call_args) return activation_ops.softmax(bottom, **self.call_args)
class StopGradient(Layer):
"""Return the identity of input with truncated gradient-flow.
Examples:
```python
layer {
type: "StopGradient"
bottom: "res2c"
top: "res2c/frozen"
}
```
"""
def __init__(self, layer_param):
super(StopGradient, self).__init__(layer_param)
def __call__(self, bottom):
return framework_ops.stop_gradient(bottom)
class Tile(Layer): class Tile(Layer):
"""Repeat the input according to the given axis. """Repeat the input according to the given axis.
......
...@@ -96,9 +96,9 @@ class Data(Layer): ...@@ -96,9 +96,9 @@ class Data(Layer):
self.data_args = { self.data_args = {
'source': data_param.source, 'source': data_param.source,
'batch_size': data_param.batch_size, 'batch_size': data_param.batch_size,
'prefetch': data_param.prefetch, 'prefetch_depth': data_param.prefetch,
'shuffle': image_data_param.shuffle, 'shuffle': image_data_param.shuffle,
'phase': {0: 'TRAIN', 1: 'TEST'}[int(layer_param.phase)], 'training': {0: True, 1: False}[int(layer_param.phase)],
'crop_size': transform_param.crop_size, 'crop_size': transform_param.crop_size,
'mirror': transform_param.mirror, 'mirror': transform_param.mirror,
} }
......
...@@ -21,7 +21,7 @@ import numpy ...@@ -21,7 +21,7 @@ import numpy
from dragon.core.autograph import backprop from dragon.core.autograph import backprop
from dragon.core.autograph import context as eager_context from dragon.core.autograph import context as eager_context
from dragon.core.autograph import function_impl from dragon.core.autograph import function_lib
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.util import nest from dragon.core.util import nest
from dragon.core.util import serialization from dragon.core.util import serialization
...@@ -136,8 +136,6 @@ class Net(object): ...@@ -136,8 +136,6 @@ class Net(object):
blobs = [] blobs = []
for blob in layer.blobs: for blob in layer.blobs:
blobs.append(Blob(blob['data'], blob['diff'])) blobs.append(Blob(blob['data'], blob['diff']))
if 'decay_mult' in blob:
blobs[-1].decay_mult = blob['decay_mult']
self._param_dict[layer.name] = blobs self._param_dict[layer.name] = blobs
return self._param_dict return self._param_dict
...@@ -246,7 +244,7 @@ class Net(object): ...@@ -246,7 +244,7 @@ class Net(object):
layer_param.phase = phase_dict[self._phase] layer_param.phase = phase_dict[self._phase]
return True return True
@function_impl.function @function_lib.function
def _compute_outputs(self, **kwargs): def _compute_outputs(self, **kwargs):
"""Compute network outputs.""" """Compute network outputs."""
return [self.blobs[key].data for key in self.outputs] return [self.blobs[key].data for key in self.outputs]
......
...@@ -19,11 +19,10 @@ import time ...@@ -19,11 +19,10 @@ import time
import google.protobuf.text_format import google.protobuf.text_format
from dragon.core.autograph import function_impl from dragon.core.autograph import function_lib
from dragon.core.training.adam import Adam from dragon.core.training.adam import Adam
from dragon.core.training.rmsprop import RMSprop from dragon.core.training.rmsprop import RMSprop
from dragon.core.training.sgd import SGD from dragon.core.training.sgd import SGD
from dragon.core.training.sgd import Nesterov
from dragon.core.util import logging from dragon.core.util import logging
from dragon.vm.caffe.core.net import Net from dragon.vm.caffe.core.net import Net
from dragon.vm.caffe.core.proto import caffe_pb2 from dragon.vm.caffe.core.proto import caffe_pb2
...@@ -50,8 +49,7 @@ class Solver(object): ...@@ -50,8 +49,7 @@ class Solver(object):
if self._proto.iter_size > 1: if self._proto.iter_size > 1:
raise NotImplementedError('Gradient accumulation is not supported.') raise NotImplementedError('Gradient accumulation is not supported.')
self._optimizer_args = { self._optimizer_args = {
'scale': 1. / self._proto.iter_size, 'grad_scale': 1. / self._proto.iter_size,
'clip_norm': float(self._proto.clip_gradients),
'weight_decay': float(self._proto.weight_decay) 'weight_decay': float(self._proto.weight_decay)
if str(self._proto.regularization_type) == 'L2' else 0, if str(self._proto.regularization_type) == 'L2' else 0,
} }
...@@ -269,7 +267,7 @@ class Solver(object): ...@@ -269,7 +267,7 @@ class Solver(object):
self.base_lr = (self._proto.base_lr * self.base_lr = (self._proto.base_lr *
pow(1. - float(self.iter) / max_iter, power)) pow(1. - float(self.iter) / max_iter, power))
@function_impl.function @function_lib.function
def _apply_update(self): def _apply_update(self):
"""Apply the weights update.""" """Apply the weights update."""
grads_and_vars = [(blob.diff, blob.data) grads_and_vars = [(blob.diff, blob.data)
...@@ -342,7 +340,8 @@ class NesterovSolver(Solver): ...@@ -342,7 +340,8 @@ class NesterovSolver(Solver):
super(NesterovSolver, self).__init__(solver_file, is_root) super(NesterovSolver, self).__init__(solver_file, is_root)
self._optimizer_args['lr'] = self._proto.base_lr self._optimizer_args['lr'] = self._proto.base_lr
self._optimizer_args['momentum'] = self._proto.momentum self._optimizer_args['momentum'] = self._proto.momentum
self._optimizer = Nesterov(**self._optimizer_args) self._optimizer_args['nesterov'] = True
self._optimizer = SGD(**self._optimizer_args)
class RMSPropSolver(Solver): class RMSPropSolver(Solver):
......
...@@ -30,7 +30,7 @@ endif() ...@@ -30,7 +30,7 @@ endif()
set(CUDA_KNOWN_GPU_ARCHITECTURES "Kepler" "Maxwell") set(CUDA_KNOWN_GPU_ARCHITECTURES "Kepler" "Maxwell")
# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default) # This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0") set(CUDA_COMMON_GPU_ARCHITECTURES "5.0")
if(CUDA_VERSION VERSION_LESS "7.0") if(CUDA_VERSION VERSION_LESS "7.0")
set(CUDA_LIMIT_GPU_ARCHITECTURE "5.2") set(CUDA_LIMIT_GPU_ARCHITECTURE "5.2")
......
...@@ -20,6 +20,7 @@ import sys as _sys ...@@ -20,6 +20,7 @@ import sys as _sys
# Modules # Modules
from dragon.vm.dali._api import ops from dragon.vm.dali._api import ops
from dragon.vm.dali._api import types
# Classes # Classes
from dragon.vm.dali.core.framework.iterator import Iterator from dragon.vm.dali.core.framework.iterator import Iterator
...@@ -30,25 +31,6 @@ from dragon.vm.dali.core.framework.context import device ...@@ -30,25 +31,6 @@ from dragon.vm.dali.core.framework.context import device
from dragon.vm.dali.core.framework.context import get_device_type from dragon.vm.dali.core.framework.context import get_device_type
from dragon.vm.dali.core.framework.context import get_distributed_info from dragon.vm.dali.core.framework.context import get_distributed_info
# Enums
from dragon.vm.dali.core.framework.types import BOOL
from dragon.vm.dali.core.framework.types import BGR
from dragon.vm.dali.core.framework.types import FLOAT
from dragon.vm.dali.core.framework.types import FLOAT32
from dragon.vm.dali.core.framework.types import FLOAT64
from dragon.vm.dali.core.framework.types import INT8
from dragon.vm.dali.core.framework.types import INT32
from dragon.vm.dali.core.framework.types import INT64
from dragon.vm.dali.core.framework.types import INTERP_TRIANGULAR
from dragon.vm.dali.core.framework.types import NCHW
from dragon.vm.dali.core.framework.types import NHWC
from dragon.vm.dali.core.framework.types import RGB
from dragon.vm.dali.core.framework.types import STRING
from dragon.vm.dali.core.framework.types import UINT8
from dragon.vm.dali.core.framework.types import UINT16
from dragon.vm.dali.core.framework.types import UINT32
from dragon.vm.dali.core.framework.types import UINT64
# Attributes # Attributes
_API_MODULE = ops _API_MODULE = ops
_current_module = _sys.modules[__name__] _current_module = _sys.modules[__name__]
......
...@@ -20,13 +20,17 @@ from dragon.vm.dali.core.ops.decoder_ops import ImageDecoder ...@@ -20,13 +20,17 @@ from dragon.vm.dali.core.ops.decoder_ops import ImageDecoder
from dragon.vm.dali.core.ops.decoder_ops import ImageDecoderRandomCrop from dragon.vm.dali.core.ops.decoder_ops import ImageDecoderRandomCrop
from dragon.vm.dali.core.ops.generic_ops import Cast from dragon.vm.dali.core.ops.generic_ops import Cast
from dragon.vm.dali.core.ops.generic_ops import Erase from dragon.vm.dali.core.ops.generic_ops import Erase
from dragon.vm.dali.core.ops.generic_ops import Flip
from dragon.vm.dali.core.ops.generic_ops import Pad from dragon.vm.dali.core.ops.generic_ops import Pad
from dragon.vm.dali.core.ops.generic_ops import Reshape from dragon.vm.dali.core.ops.generic_ops import Reshape
from dragon.vm.dali.core.ops.generic_ops import Slice from dragon.vm.dali.core.ops.generic_ops import Slice
from dragon.vm.dali.core.ops.image_ops import Brightness from dragon.vm.dali.core.ops.image_ops import Brightness
from dragon.vm.dali.core.ops.image_ops import BrightnessContrast from dragon.vm.dali.core.ops.image_ops import BrightnessContrast
from dragon.vm.dali.core.ops.image_ops import Contrast from dragon.vm.dali.core.ops.image_ops import Contrast
from dragon.vm.dali.core.ops.image_ops import ColorSpaceConversion
from dragon.vm.dali.core.ops.image_ops import ColorTwist
from dragon.vm.dali.core.ops.image_ops import CropMirrorNormalize from dragon.vm.dali.core.ops.image_ops import CropMirrorNormalize
from dragon.vm.dali.core.ops.image_ops import GaussianBlur
from dragon.vm.dali.core.ops.image_ops import Hsv from dragon.vm.dali.core.ops.image_ops import Hsv
from dragon.vm.dali.core.ops.image_ops import Paste from dragon.vm.dali.core.ops.image_ops import Paste
from dragon.vm.dali.core.ops.image_ops import RandomBBoxCrop from dragon.vm.dali.core.ops.image_ops import RandomBBoxCrop
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
# Classes
from dragon.vm.dali.core.framework.types import Constant
from dragon.vm.dali.core.framework.types import ScalarConstant
# Enums
from dragon.vm.dali.core.framework.types import BOOL
from dragon.vm.dali.core.framework.types import BGR
from dragon.vm.dali.core.framework.types import FLOAT
from dragon.vm.dali.core.framework.types import FLOAT32
from dragon.vm.dali.core.framework.types import FLOAT64
from dragon.vm.dali.core.framework.types import INT8
from dragon.vm.dali.core.framework.types import INT32
from dragon.vm.dali.core.framework.types import INT64
from dragon.vm.dali.core.framework.types import INTERP_CUBIC
from dragon.vm.dali.core.framework.types import INTERP_GAUSSIAN
from dragon.vm.dali.core.framework.types import INTERP_LANCZOS3
from dragon.vm.dali.core.framework.types import INTERP_LINEAR
from dragon.vm.dali.core.framework.types import INTERP_NN
from dragon.vm.dali.core.framework.types import INTERP_TRIANGULAR
from dragon.vm.dali.core.framework.types import NCHW
from dragon.vm.dali.core.framework.types import NHWC
from dragon.vm.dali.core.framework.types import PIPELINE_API_BASIC
from dragon.vm.dali.core.framework.types import PIPELINE_API_ITERATOR
from dragon.vm.dali.core.framework.types import PIPELINE_API_SCHEDULED
from dragon.vm.dali.core.framework.types import RGB
from dragon.vm.dali.core.framework.types import STRING
from dragon.vm.dali.core.framework.types import UINT8
from dragon.vm.dali.core.framework.types import UINT16
from dragon.vm.dali.core.framework.types import UINT32
from dragon.vm.dali.core.framework.types import UINT64
__all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -47,9 +47,10 @@ class Iterator(object): ...@@ -47,9 +47,10 @@ class Iterator(object):
# Build pipeline and cache the first batch. # Build pipeline and cache the first batch.
with self._api_scope(): with self._api_scope():
self._pipe.build() self._pipe.build()
# Enforce the correct device of current process if self._pipe.device_id is not None:
# to initialize cuda handles instead of device 0. # Enforce the correct device of current process
cuda.set_device(self._pipe.device_id) # to initialize cuda handles instead of device 0.
cuda.set_device(self._pipe.device_id)
self._pipe.schedule_run() self._pipe.schedule_run()
self._copies = None self._copies = None
self._first_batch = None self._first_batch = None
...@@ -86,14 +87,13 @@ class Iterator(object): ...@@ -86,14 +87,13 @@ class Iterator(object):
if self._copies is None: if self._copies is None:
self._copies = [] self._copies = []
for tensor in tensors: for tensor in tensors:
self._copies.append( self._copies.append(self.new_tensor(
self.new_tensor( shape=tensor.shape(),
shape=tensor.shape(), dtype=str(types.np_dtype(tensor.dtype())),
dtype=str(types.np_dtype(tensor.dtype())), device=self.new_device(
device=self.new_device( device_type=('cuda' if isinstance(tensor, TensorGPU)
device_type=('cuda' if isinstance(tensor, TensorGPU) else 'cpu'),
else 'cpu'), device_index=self._pipe.device_id)))
device_index=self._pipe.device_id)))
# Transfer the data: DALI => Storage # Transfer the data: DALI => Storage
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
self._transfer_tensor(tensor, self._copies[i]) self._transfer_tensor(tensor, self._copies[i])
......
...@@ -53,13 +53,16 @@ try: ...@@ -53,13 +53,16 @@ try:
The number of workers to process external source. The number of workers to process external source.
""" """
device_id = context.get_device()['device_index'] device = context.get_device()
if device['device_type'] == 'cpu':
device['device_index'] = None
super(Pipeline, self).__init__( super(Pipeline, self).__init__(
batch_size=batch_size, batch_size=batch_size,
num_threads=num_threads, num_threads=num_threads,
device_id=device_id, device_id=device['device_index'],
seed=seed, seed=seed,
prefetch_queue_depth=prefetch_queue_depth, prefetch_queue_depth=prefetch_queue_depth,
py_num_workers=py_num_workers,
**kwargs **kwargs
) )
......
...@@ -17,6 +17,10 @@ import numpy ...@@ -17,6 +17,10 @@ import numpy
try: try:
from nvidia.dali import types as dali_types from nvidia.dali import types as dali_types
# ConstantWrapper
Constant = dali_types.Constant
ScalarConstant = dali_types.ScalarConstant
# DALIDataType # DALIDataType
BOOL = dali_types.BOOL BOOL = dali_types.BOOL
FLOAT = dali_types.FLOAT FLOAT = dali_types.FLOAT
...@@ -59,6 +63,10 @@ except ImportError: ...@@ -59,6 +63,10 @@ except ImportError:
dali_types = None dali_types = None
NO_DALI = -1 NO_DALI = -1
# ConstantWrapper
Constant = NO_DALI
ScalarConstant = NO_DALI
# DALIDataType # DALIDataType
BOOL = NO_DALI BOOL = NO_DALI
FLOAT = NO_DALI FLOAT = NO_DALI
......
...@@ -48,7 +48,7 @@ class ImageDecoder(object): ...@@ -48,7 +48,7 @@ class ImageDecoder(object):
Parameters Parameters
---------- ----------
output_type : {'BGR', 'RGB'}, optional output_type : str, optional, default='BGR'
The output color space. The output color space.
host_memory_padding : int, optional, default=8388608 host_memory_padding : int, optional, default=8388608
The number of bytes for host buffer. The number of bytes for host buffer.
...@@ -103,7 +103,7 @@ class ImageDecoderRandomCrop(object): ...@@ -103,7 +103,7 @@ class ImageDecoderRandomCrop(object):
Parameters Parameters
---------- ----------
output_type : {'BGR', 'RGB'}, optional output_type : str, optional, default='BGR'
The output color space. The output color space.
host_memory_padding : int, optional, default=8388608 host_memory_padding : int, optional, default=8388608
The number of bytes for host buffer. The number of bytes for host buffer.
......
...@@ -114,6 +114,46 @@ class Erase(object): ...@@ -114,6 +114,46 @@ class Erase(object):
) )
class Flip(object):
"""Flip input in selected dimensions.
Examples:
```python
flip_rng = dali.ops.CoinFlip(0.5)
flip = dali.ops.Flip()
y = flip(inputs['x'], horizontal=flip_rng())
```
"""
def __new__(cls, horizontal=None, vertical=None, depthwise=None, **kwargs):
"""Create a ``Flip`` operator.
Parameters
----------
horizontal : int, optional
Whether to apply the horizontal flip.
vertical : int, optional
Whether to apply the vertical flip.
depthwise : bool, optional, default=True
Whether to apply the depthwise flip.
Returns
-------
nvidia.dali.ops.Flip
The operator.
"""
return ops.Flip(
horizontal=horizontal,
vertical=vertical,
depthwise=depthwise,
device=context.get_device_type(),
**kwargs
)
class Pad(object): class Pad(object):
"""Pad input to have the same dimensions. """Pad input to have the same dimensions.
...@@ -245,6 +285,7 @@ class Slice(object): ...@@ -245,6 +285,7 @@ class Slice(object):
return ops.Slice( return ops.Slice(
axes=axes, axes=axes,
normalized_anchor=normalized_anchor, normalized_anchor=normalized_anchor,
normalized_shape=normalized_shape,
device=context.get_device_type(), device=context.get_device_type(),
**kwargs **kwargs
) )
...@@ -31,11 +31,9 @@ class Brightness(object): ...@@ -31,11 +31,9 @@ class Brightness(object):
Examples: Examples:
```python ```python
# Historical jitter range for brightness rng = dali.ops.Uniform((0.6, 1.4))
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
brightness = dali.ops.Brightness() brightness = dali.ops.Brightness()
y = brightness(inputs['x'], brightness=twist_rng()) y = brightness(inputs['x'], brightness=rng())
``` ```
""" """
...@@ -58,11 +56,9 @@ class BrightnessContrast(object): ...@@ -58,11 +56,9 @@ class BrightnessContrast(object):
Examples: Examples:
```python ```python
# Historical jitter range for brightness and contrast rng = dali.ops.Uniform((0.6, 1.4))
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
bc = dali.ops.BrightnessContrast() bc = dali.ops.BrightnessContrast()
y = bc(inputs['x'], brightness=twist_rng(), contrast=twist_rng()) y = bc(inputs['x'], brightness=rng(), contrast=rng())
``` ```
""" """
...@@ -79,17 +75,82 @@ class BrightnessContrast(object): ...@@ -79,17 +75,82 @@ class BrightnessContrast(object):
return ops.BrightnessContrast(device=context.get_device_type(), **kwargs) return ops.BrightnessContrast(device=context.get_device_type(), **kwargs)
class ColorSpaceConversion(object):
"""Convert the color space of image.
Examples:
```python
convert = dali.ops.ColorSpaceConversion('BGR', 'RGB')
y = convert(inputs['x'])
```
"""
def __new__(cls, image_type, output_type, **kwargs):
"""Create a ``ColorSpaceConversion`` operator.
Parameters
----------
image_type : str
The color space of input image.
output_type : str
The color space of output image.
Returns
-------
nvidia.dali.ops.ColorSpaceConversion
The operator.
"""
if isinstance(image_type, six.string_types):
image_type = getattr(types, image_type)
if isinstance(output_type, six.string_types):
output_type = getattr(types, output_type)
return ops.ColorSpaceConversion(
image_type=image_type,
output_type=output_type,
device=context.get_device_type(),
**kwargs
)
class ColorTwist(object):
"""Adjust the hue, saturation and brightness of image.
Examples:
```python
rng1 = dali.ops.Uniform((0.6, 1.4))
rng2 = dali.ops.Uniform((-36., 36.))
twist = dali.ops.ColorTwist()
y = twist(inputs['x'], brightness=rng1(), contrast=rng1(),
saturation=rng1(), hue=rng2())
```
"""
def __new__(cls, **kwargs):
"""Create a ``Brightness`` operator.
Returns
-------
nvidia.dali.ops.Brightness
The operator.
"""
return ops.ColorTwist(device=context.get_device_type(), **kwargs)
class Contrast(object): class Contrast(object):
"""Adjust the contrast of image. """Adjust the contrast of image.
Examples: Examples:
```python ```python
# Historical jitter range for contrast rng = dali.ops.Uniform((0.6, 1.4))
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
contrast = dali.ops.Contrast() contrast = dali.ops.Contrast()
y = contrast(inputs['x'], contrast=twist_rng()) y = contrast(inputs['x'], contrast=rng())
``` ```
""" """
...@@ -180,6 +241,48 @@ class CropMirrorNormalize(object): ...@@ -180,6 +241,48 @@ class CropMirrorNormalize(object):
) )
class GaussianBlur(object):
"""Apply gaussian blur to image.
Examples:
```python
sigma_rng = dali.ops.Uniform((0.1, 2.0))
blur = dali.ops.GaussianBlur()
y = blur(inputs['x'], sigma=sigma_rng())
```
"""
def __new__(cls, sigma=None, window_size=None, dtype=None, **kwargs):
"""Create a ``GaussianBlur`` operator.
Parameters
----------
sigma : Union[float, Sequence[float]], optional
The sigma value to gaussian kernel.
window_size : Union[int, Sequence[int]], optional
The window size to gaussian kernel.
dtype : str, optional
The output data type.
Returns
-------
nvidia.dali.ops.GaussianBlur
The operator.
"""
if isinstance(dtype, six.string_types):
dtype = getattr(types, dtype.upper())
return ops.GaussianBlur(
sigma=sigma,
window_size=window_size,
dtype=dtype,
device=context.get_device_type(),
**kwargs
)
class Hsv(object): class Hsv(object):
"""Adjust the hue and saturation. """Adjust the hue and saturation.
...@@ -419,7 +522,7 @@ class Resize(object): ...@@ -419,7 +522,7 @@ class Resize(object):
```python ```python
# Resize to a fixed area # Resize to a fixed area
resize1 = dali.ops.Resize(resize_x=300, resize_y=300) resize1 = dali.ops.Resize(size=300)
# Resize along the shorter side # Resize along the shorter side
resize2 = dali.ops.Resize(resize_shorter=600, max_size=1000) resize2 = dali.ops.Resize(resize_shorter=600, max_size=1000)
...@@ -432,8 +535,7 @@ class Resize(object): ...@@ -432,8 +535,7 @@ class Resize(object):
def __new__( def __new__(
cls, cls,
resize_x=None, size=None,
resize_y=None,
resize_shorter=None, resize_shorter=None,
resize_longer=None, resize_longer=None,
max_size=None, max_size=None,
...@@ -446,10 +548,8 @@ class Resize(object): ...@@ -446,10 +548,8 @@ class Resize(object):
Parameters Parameters
---------- ----------
resize_x : int, optional size : Union[int, Sequence[int]]
The output image width. The output image size.
resize_y : int, optional
The output image height.
resize_shorter : int, optional resize_shorter : int, optional
Resize along the shorter side and limited by ``max_size``. Resize along the shorter side and limited by ``max_size``.
resize_longer : int, optional resize_longer : int, optional
...@@ -476,8 +576,7 @@ class Resize(object): ...@@ -476,8 +576,7 @@ class Resize(object):
if isinstance(min_filter, six.string_types): if isinstance(min_filter, six.string_types):
min_filter = getattr(types, 'INTERP_' + min_filter.upper()) min_filter = getattr(types, 'INTERP_' + min_filter.upper())
return ops.Resize( return ops.Resize(
resize_x=resize_x, size=size,
resize_y=resize_y,
resize_shorter=resize_shorter, resize_shorter=resize_shorter,
resize_longer=resize_longer, resize_longer=resize_longer,
max_size=max_size, max_size=max_size,
......
...@@ -20,6 +20,9 @@ except ImportError: ...@@ -20,6 +20,9 @@ except ImportError:
from dragon.core.util import deprecation from dragon.core.util import deprecation
ops = deprecation.not_installed('nvidia.dali') ops = deprecation.not_installed('nvidia.dali')
from dragon.core.util import six
from dragon.vm.dali.core.framework import types
class CoinFlip(object): class CoinFlip(object):
"""Sample values from a bernoulli distribution. """Sample values from a bernoulli distribution.
...@@ -33,13 +36,15 @@ class CoinFlip(object): ...@@ -33,13 +36,15 @@ class CoinFlip(object):
""" """
def __new__(cls, probability=0.5, **kwargs): def __new__(cls, probability=0.5, dtype=None, **kwargs):
"""Create a ``CoinFlip`` operator. """Create a ``CoinFlip`` operator.
Parameters Parameters
---------- ----------
probability : float, optional, default=0.5 probability : float, optional, default=0.5
The probability to return 1. The probability to return 1.
dtype : str, optional
The output data type.
Returns Returns
------- -------
...@@ -47,7 +52,10 @@ class CoinFlip(object): ...@@ -47,7 +52,10 @@ class CoinFlip(object):
The operator. The operator.
""" """
return ops.random.CoinFlip(probability=probability, **kwargs) if isinstance(dtype, six.string_types):
dtype = getattr(types, dtype.upper())
return ops.random.CoinFlip(probability=probability,
dtype=dtype, **kwargs)
class Uniform(object): class Uniform(object):
...@@ -62,13 +70,15 @@ class Uniform(object): ...@@ -62,13 +70,15 @@ class Uniform(object):
""" """
def __new__(cls, range=(-1., 1.), **kwargs): def __new__(cls, range=(-1., 1.), dtype=None, **kwargs):
"""Create an ``Uniform`` operator. """Create an ``Uniform`` operator.
Parameters Parameters
---------- ----------
range : Tuple[float, float], optional range : Tuple[float, float], optional
The lower and upper bound of distribution. The lower and upper bound of distribution.
dtype : str, optional
The output data type.
Returns Returns
------- -------
...@@ -76,4 +86,6 @@ class Uniform(object): ...@@ -76,4 +86,6 @@ class Uniform(object):
The operator. The operator.
""" """
return ops.random.Uniform(range=range, **kwargs) if isinstance(dtype, six.string_types):
dtype = getattr(types, dtype.upper())
return ops.random.Uniform(range=range, dtype=dtype, **kwargs)
...@@ -11,10 +11,6 @@ Constructors ...@@ -11,10 +11,6 @@ Constructors
Public Functions Public Functions
---------------- ----------------
Buffer
######
.. doxygenfunction:: dragon::Operator::Buffer
DeriveFrom DeriveFrom
########## ##########
.. doxygenfunction:: dragon::Operator::DeriveFrom .. doxygenfunction:: dragon::Operator::DeriveFrom
...@@ -33,23 +29,31 @@ GetArgument ...@@ -33,23 +29,31 @@ GetArgument
Input Input
##### #####
.. doxygenfunction:: dragon::Operator::Input .. doxygenfunction:: dragon::Operator::Input(int index)
Input
#####
.. doxygenfunction:: dragon::Operator::Input(const string &name)
InputSize InputSize
######### #########
.. doxygenfunction:: dragon::Operator::InputSize .. doxygenfunction:: dragon::Operator::InputSize
Output
######
.. doxygenfunction:: dragon::Operator::Output(int i)
MessageForUnsupported MessageForUnsupported
##################### #####################
.. doxygenfunction:: dragon::Operator::MessageForUnsupported .. doxygenfunction:: dragon::Operator::MessageForUnsupported
Output Output
###### ######
.. doxygenfunction:: dragon::Operator::Output(int i, const vec32_t &inputs) .. doxygenfunction:: dragon::Operator::Output(int index)
Output
######
.. doxygenfunction:: dragon::Operator::Output(int index, const vector<int> &inputs_at)
Output
######
.. doxygenfunction:: dragon::Operator::Output(const string &name)
OutputSize OutputSize
########## ##########
...@@ -59,10 +63,6 @@ Run ...@@ -59,10 +63,6 @@ Run
### ###
.. doxygenfunction:: dragon::Operator::Run .. doxygenfunction:: dragon::Operator::Run
data_format
###########
.. doxygenfunction:: dragon::Operator::data_format
arg arg
### ###
.. doxygenfunction:: dragon::Operator::arg .. doxygenfunction:: dragon::Operator::arg
...@@ -71,18 +71,18 @@ args ...@@ -71,18 +71,18 @@ args
#### ####
.. doxygenfunction:: dragon::Operator::args .. doxygenfunction:: dragon::Operator::args
data_format
###########
.. doxygenfunction:: dragon::Operator::data_format
data_type
#########
.. doxygenfunction:: dragon::Operator::data_type
def def
### ###
.. doxygenfunction:: dragon::Operator::def .. doxygenfunction:: dragon::Operator::def
dtype
#####
.. doxygenfunction:: dragon::Operator::dtype
handle
######
.. doxygenfunction:: dragon::Operator::handle
name name
#### ####
.. doxygenfunction:: dragon::Operator::name .. doxygenfunction:: dragon::Operator::name
......
...@@ -57,11 +57,11 @@ UniqueName ...@@ -57,11 +57,11 @@ UniqueName
data data
#### ####
.. doxygenfunction:: dragon::Workspace::data(const vector<size_t> &segments, const string &name = "data:0") .. doxygenfunction:: dragon::Workspace::data(size_t size, const string &name = "BufferShared")
data data
#### ####
.. doxygenfunction:: dragon::Workspace::data(const vector<int64_t> &segments, const string &name = "data:0") .. doxygenfunction:: dragon::Workspace::data(int64_t size, const string &name = "BufferShared")
graphs graphs
###### ######
......
...@@ -147,7 +147,6 @@ vm.caffe.layers ...@@ -147,7 +147,6 @@ vm.caffe.layers
layers/SmoothL1Loss layers/SmoothL1Loss
layers/Softmax layers/Softmax
layers/SoftmaxWithLoss layers/SoftmaxWithLoss
layers/StopGradient
layers/TanH layers/TanH
layers/Tile layers/Tile
......
...@@ -34,7 +34,7 @@ extensions = [ ...@@ -34,7 +34,7 @@ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinxcontrib.katex', 'sphinxcontrib.katex',
'sphinx_seeta_theme.ext.viewcode', # 'sphinx_seeta_theme.ext.viewcode',
] ]
napoleon_use_rtype = False napoleon_use_rtype = False
......
...@@ -24,6 +24,12 @@ vm.dali.ops ...@@ -24,6 +24,12 @@ vm.dali.ops
`class CoinFlip <ops/CoinFlip.html>`_ `class CoinFlip <ops/CoinFlip.html>`_
: Sample values from a bernoulli distribution. : Sample values from a bernoulli distribution.
`class ColorSpaceConversion <ops/ColorSpaceConversion.html>`_
: Convert the color space of image.
`class ColorTwist <ops/ColorTwist.html>`_
: Adjust the hue, saturation and brightness of image.
`class Contrast <ops/Contrast.html>`_ `class Contrast <ops/Contrast.html>`_
: Adjust the contrast of image. : Adjust the contrast of image.
...@@ -36,6 +42,9 @@ vm.dali.ops ...@@ -36,6 +42,9 @@ vm.dali.ops
`class ExternalSource <ops/Cast.html>`_ `class ExternalSource <ops/Cast.html>`_
: Create a placeholder providing data from feeding. : Create a placeholder providing data from feeding.
`class GaussianBlur <ops/GaussianBlur.html>`_
: Apply gaussian blur to image.
`class Hsv <ops/Hsv.html>`_ `class Hsv <ops/Hsv.html>`_
: Adjust the hue and saturation of image. : Adjust the hue and saturation of image.
...@@ -93,10 +102,13 @@ vm.dali.ops ...@@ -93,10 +102,13 @@ vm.dali.ops
ops/BrightnessContrast ops/BrightnessContrast
ops/Cast ops/Cast
ops/CoinFlip ops/CoinFlip
ops/ColorSpaceConversion
ops/ColorTwist
ops/Contrast ops/Contrast
ops/CropMirrorNormalize ops/CropMirrorNormalize
ops/Erase ops/Erase
ops/ExternalSource ops/ExternalSource
ops/GaussianBlur
ops/Hsv ops/Hsv
ops/ImageDecoder ops/ImageDecoder
ops/ImageDecoderRandomCrop ops/ImageDecoderRandomCrop
......
ColorSpaceConversion
====================
.. autoclass:: dragon.vm.dali.ops.ColorSpaceConversion
__new__
--------
.. automethod:: dragon.vm.dali.ops.ColorSpaceConversion.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
ColorTwist
==========
.. autoclass:: dragon.vm.dali.ops.ColorTwist
__new__
--------
.. automethod:: dragon.vm.dali.ops.ColorTwist.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
GaussianBlur
============
.. autoclass:: dragon.vm.dali.ops.GaussianBlur
__new__
--------
.. automethod:: dragon.vm.dali.ops.GaussianBlur.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
...@@ -116,7 +116,7 @@ dragon ...@@ -116,7 +116,7 @@ dragon
: Return a tensor of ones with shape as the other. : Return a tensor of ones with shape as the other.
`one_hot(...) <dragon/one_hot.html>`_ `one_hot(...) <dragon/one_hot.html>`_
: Return the one-hot representation for input. : Return the one-hot representation of input.
`pad(...) <dragon/pad.html>`_ `pad(...) <dragon/pad.html>`_
: Pad the input according to the given sizes. : Pad the input according to the given sizes.
...@@ -184,6 +184,9 @@ dragon ...@@ -184,6 +184,9 @@ dragon
`triu(...) <dragon/triu.html>`_ `triu(...) <dragon/triu.html>`_
: Return the upper triangular part of input. : Return the upper triangular part of input.
`unstack(...) <dragon/unstack.html>`_
: Unpack input into chunks along the given axis.
`unique(...) <dragon/unique.html>`_ `unique(...) <dragon/unique.html>`_
: Return the unique elements of input. : Return the unique elements of input.
...@@ -257,6 +260,7 @@ dragon ...@@ -257,6 +260,7 @@ dragon
dragon/tril dragon/tril
dragon/triu dragon/triu
dragon/unique dragon/unique
dragon/unstack
dragon/variable_scope dragon/variable_scope
dragon/where dragon/where
dragon/zeros dragon/zeros
......
...@@ -15,9 +15,6 @@ dragon.cuda ...@@ -15,9 +15,6 @@ dragon.cuda
`current_device(...) <cuda/current_device.html>`_ `current_device(...) <cuda/current_device.html>`_
: Return the index of current selected device. : Return the index of current selected device.
`enable_cudnn(...) <cuda/enable_cudnn.html>`_
: Enable backend to use the cuDNN library.
`get_device_capability(...) <cuda/get_device_capability.html>`_ `get_device_capability(...) <cuda/get_device_capability.html>`_
: Return the capability of specified device. : Return the capability of specified device.
...@@ -27,6 +24,9 @@ dragon.cuda ...@@ -27,6 +24,9 @@ dragon.cuda
`memory_allocated(...) <cuda/memory_allocated.html>`_ `memory_allocated(...) <cuda/memory_allocated.html>`_
: Return the size of memory used by tensors in current workspace. : Return the size of memory used by tensors in current workspace.
`set_cudnn_flags(...) <cuda/set_cudnn_flags.html>`_
: Set the flags of cuDNN library.
`set_default_device(...) <cuda/set_default_device.html>`_ `set_default_device(...) <cuda/set_default_device.html>`_
: Set the default device. : Set the default device.
...@@ -41,10 +41,10 @@ dragon.cuda ...@@ -41,10 +41,10 @@ dragon.cuda
cuda/Stream cuda/Stream
cuda/current_device cuda/current_device
cuda/enable_cudnn
cuda/get_device_capability cuda/get_device_capability
cuda/is_available cuda/is_available
cuda/memory_allocated cuda/memory_allocated
cuda/set_cudnn_flags
cuda/set_default_device cuda/set_default_device
cuda/set_device cuda/set_device
cuda/synchronize cuda/synchronize
......
enable_cudnn set_cudnn_flags
============ ===============
.. autofunction:: dragon.cuda.enable_cudnn .. autofunction:: dragon.cuda.set_cudnn_flags
.. raw:: html .. raw:: html
......
...@@ -22,6 +22,10 @@ add_strings ...@@ -22,6 +22,10 @@ add_strings
########### ###########
.. automethod:: dragon.io.TFRecordExample.add_strings .. automethod:: dragon.io.TFRecordExample.add_strings
parse_from
##########
.. automethod:: dragon.io.TFRecordExample.parse_from
serialize_to serialize_to
############ ############
.. automethod:: dragon.io.TFRecordExample.serialize_to .. automethod:: dragon.io.TFRecordExample.serialize_to
......
...@@ -51,6 +51,9 @@ dragon.math ...@@ -51,6 +51,9 @@ dragon.math
`greater_equal(...) <math/greater_equal.html>`_ `greater_equal(...) <math/greater_equal.html>`_
: Compute the element-wise greater-equal comparison. : Compute the element-wise greater-equal comparison.
`is_finite(...) <math/is_finite.html>`_
: Check if the elements of input are finite.
`is_inf(...) <math/is_inf.html>`_ `is_inf(...) <math/is_inf.html>`_
: Check if the elements of input are infinite. : Check if the elements of input are infinite.
...@@ -105,6 +108,9 @@ dragon.math ...@@ -105,6 +108,9 @@ dragon.math
`negative(...) <math/negative.html>`_ `negative(...) <math/negative.html>`_
: Compute the element-wise negative. : Compute the element-wise negative.
`norm(...) <math/norm.html>`_
: Compute the norm value of elements along the given axis.
`not_equal(...) <math/not_equal.html>`_ `not_equal(...) <math/not_equal.html>`_
: Compute the element-wise not-equal comparison. : Compute the element-wise not-equal comparison.
...@@ -165,6 +171,7 @@ dragon.math ...@@ -165,6 +171,7 @@ dragon.math
math/gemm math/gemm
math/greater math/greater
math/greater_equal math/greater_equal
math/is_finite
math/is_inf math/is_inf
math/is_nan math/is_nan
math/less math/less
...@@ -183,6 +190,7 @@ dragon.math ...@@ -183,6 +190,7 @@ dragon.math
math/minimum math/minimum
math/mul math/mul
math/negative math/negative
math/norm
math/not_equal math/not_equal
math/pow math/pow
math/reciprocal math/reciprocal
......
is_finite
=========
.. autofunction:: dragon.math.is_finite
.. raw:: html
<style>
h1:before {
content: "dragon.math.";
color: #103d3e;
}
</style>
swish norm
===== ====
.. autofunction:: dragon.vm.tensorflow.nn.swish .. autofunction:: dragon.math.norm
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "tf.nn."; content: "dragon.math.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -14,16 +14,12 @@ dragon.optimizers ...@@ -14,16 +14,12 @@ dragon.optimizers
: The optimizer to apply AdamW algorithm. : The optimizer to apply AdamW algorithm.
`[Loshchilov & Hutter, 2017] <https://arxiv.org/abs/1711.05101>`_. `[Loshchilov & Hutter, 2017] <https://arxiv.org/abs/1711.05101>`_.
`class Nesterov <optimizers/Nesterov.html>`_
: The optimizer to apply NesterovSGD algorithm.
`[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_.
`class RMSProp <optimizers/RMSprop.html>`_ `class RMSProp <optimizers/RMSprop.html>`_
: The optimizer to apply RMSprop algorithm. : The optimizer to apply RMSprop algorithm.
`[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_. `[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_.
`class SGD <optimizers/SGD.html>`_ `class SGD <optimizers/SGD.html>`_
: The optimizer to apply MomentumSGD algorithm. : The optimizer to apply SGD algorithm.
`[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_. `[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_.
.. toctree:: .. toctree::
...@@ -31,7 +27,6 @@ dragon.optimizers ...@@ -31,7 +27,6 @@ dragon.optimizers
optimizers/Adam optimizers/Adam
optimizers/AdamW optimizers/AdamW
optimizers/Nesterov
optimizers/Optimizer optimizers/Optimizer
optimizers/RMSprop optimizers/RMSprop
optimizers/SGD optimizers/SGD
......
unstack
=======
.. autofunction:: dragon.unstack
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
...@@ -111,6 +111,8 @@ PyTorch ...@@ -111,6 +111,8 @@ PyTorch
* `torch <torch.html>`_ * `torch <torch.html>`_
* `torch.autograd <torch/autograd.html>`_ * `torch.autograd <torch/autograd.html>`_
* `torch.backends <torch/backends.html>`_
* `torch.cuda <torch/cuda.html>`_
* `torch.distributed <torch/distributed.html>`_ * `torch.distributed <torch/distributed.html>`_
* `torch.jit <torch/jit.html>`_ * `torch.jit <torch/jit.html>`_
* `torch.nn <torch/nn.html>`_ * `torch.nn <torch/nn.html>`_
...@@ -118,6 +120,7 @@ PyTorch ...@@ -118,6 +120,7 @@ PyTorch
* `torch.nn.init <torch/nn/init.html>`_ * `torch.nn.init <torch/nn/init.html>`_
* `torch.onnx <torch/onnx.html>`_ * `torch.onnx <torch/onnx.html>`_
* `torch.optim <torch/optim.html>`_ * `torch.optim <torch/optim.html>`_
* `torch.utils.checkpoint <torch/utils/checkpoint.html>`_
* `torch.utils.dlpack <torch/utils/dlpack.html>`_ * `torch.utils.dlpack <torch/utils/dlpack.html>`_
* `torchvision.ops <torchvision/ops.html>`_ * `torchvision.ops <torchvision/ops.html>`_
...@@ -274,6 +277,12 @@ Modules ...@@ -274,6 +277,12 @@ Modules
`Module vm.torch.autograd <torch/autograd.html>`_ `Module vm.torch.autograd <torch/autograd.html>`_
: Virtual API for ``torch.autograd`` namespace. : Virtual API for ``torch.autograd`` namespace.
`Module vm.torch.backends <torch/backends.html>`_
: Virtual API for ``torch.backends`` namespace.
`Module vm.torch.cuda <torch/cuda.html>`_
: Virtual API for ``torch.cuda`` namespace.
`Module vm.torch.distributed <torch/distributed.html>`_ `Module vm.torch.distributed <torch/distributed.html>`_
: Virtual API for ``torch.distributed`` namespace. : Virtual API for ``torch.distributed`` namespace.
...@@ -295,6 +304,9 @@ Modules ...@@ -295,6 +304,9 @@ Modules
`Module vm.torch.optim <torch/optim.html>`_ `Module vm.torch.optim <torch/optim.html>`_
: Virtual API for ``torch.optim`` namespace. : Virtual API for ``torch.optim`` namespace.
`Module vm.torch.utils.checkpoint <torch/utils/checkpoint.html>`_
: Virtual API for ``torch.utils.checkpoint`` namespace.
`Module vm.torch.utils.dlpack <torch/utils/dlpack.html>`_ `Module vm.torch.utils.dlpack <torch/utils/dlpack.html>`_
: Virtual API for ``torch.utils.dlpack`` namespace. : Virtual API for ``torch.utils.dlpack`` namespace.
...@@ -340,6 +352,8 @@ Modules ...@@ -340,6 +352,8 @@ Modules
tensorrt/onnx tensorrt/onnx
torch torch
torch/autograd torch/autograd
torch/backends
torch/cuda
torch/distributed torch/distributed
torch/jit torch/jit
torch/nn torch/nn
...@@ -347,5 +361,6 @@ Modules ...@@ -347,5 +361,6 @@ Modules
torch/nn/init torch/nn/init
torch/onnx torch/onnx
torch/optim torch/optim
torch/utils/checkpoint
torch/utils/dlpack torch/utils/dlpack
torchvision/ops torchvision/ops
...@@ -76,7 +76,7 @@ vm.tensorflow ...@@ -76,7 +76,7 @@ vm.tensorflow
: Return a tensor of ones with shape as the other. : Return a tensor of ones with shape as the other.
`one_hot(...) <tensorflow/one_hot.html>`_ `one_hot(...) <tensorflow/one_hot.html>`_
: Return the one-hot representation for input. : Return the one-hot representation of input.
`pad(...) <tensorflow/pad.html>`_ `pad(...) <tensorflow/pad.html>`_
: Pad the input according to the given sizes. : Pad the input according to the given sizes.
...@@ -120,6 +120,9 @@ vm.tensorflow ...@@ -120,6 +120,9 @@ vm.tensorflow
`unique_with_counts(...) <tensorflow/unique_with_counts.html>`_ `unique_with_counts(...) <tensorflow/unique_with_counts.html>`_
: Return the unique elements of input with counts. : Return the unique elements of input with counts.
`unstack(...) <tensorflow/unstack.html>`_
: Unpack input into chunks along the given axis.
`zeros(...) <tensorflow/zeros.html>`_ `zeros(...) <tensorflow/zeros.html>`_
: Return a tensor filled with zeros. : Return a tensor filled with zeros.
...@@ -166,6 +169,7 @@ vm.tensorflow ...@@ -166,6 +169,7 @@ vm.tensorflow
tensorflow/transpose tensorflow/transpose
tensorflow/unique tensorflow/unique
tensorflow/unique_with_counts tensorflow/unique_with_counts
tensorflow/unstack
tensorflow/zeros tensorflow/zeros
tensorflow/zeros_like tensorflow/zeros_like
......
...@@ -30,13 +30,6 @@ variables ...@@ -30,13 +30,6 @@ variables
######### #########
.. autoattribute:: dragon.vm.tensorflow.Module.variables .. autoattribute:: dragon.vm.tensorflow.Module.variables
Methods
-------
flatten
#######
.. automethod:: dragon.vm.tensorflow.Module.flatten
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -6,6 +6,9 @@ layers ...@@ -6,6 +6,9 @@ layers
Classes Classes
------- -------
`class Activation <layers/Activation.html>`_
: Activation layer.
`class Add <layers/Add.html>`_ `class Add <layers/Add.html>`_
: Layer to add a sequence of inputs. : Layer to add a sequence of inputs.
...@@ -82,6 +85,10 @@ layers ...@@ -82,6 +85,10 @@ layers
`class Layer <layers/Layer.html>`_ `class Layer <layers/Layer.html>`_
: The base class of layers. : The base class of layers.
`class LayerNormalization <layers/LayerNormalization.html>`_
: Layer normalization layer.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
`class LeakyReLU <layers/LeakyReLU.html>`_ `class LeakyReLU <layers/LeakyReLU.html>`_
: Layer to apply the leaky rectified linear unit. : Layer to apply the leaky rectified linear unit.
...@@ -144,6 +151,7 @@ layers ...@@ -144,6 +151,7 @@ layers
.. toctree:: .. toctree::
:hidden: :hidden:
layers/Activation
layers/Add layers/Add
layers/AveragePooling1D layers/AveragePooling1D
layers/AveragePooling2D layers/AveragePooling2D
...@@ -168,6 +176,7 @@ layers ...@@ -168,6 +176,7 @@ layers
layers/GlobalMaxPool2D layers/GlobalMaxPool2D
layers/GlobalMaxPool3D layers/GlobalMaxPool3D
layers/Layer layers/Layer
layers/LayerNormalization
layers/LeakyReLU layers/LeakyReLU
layers/Maximum layers/Maximum
layers/MaxPool1D layers/MaxPool1D
......
Nesterov Activation
======== ==========
.. autoclass:: dragon.optimizers.Nesterov .. autoclass:: dragon.vm.tensorflow.keras.layers.Activation
__init__ __init__
-------- --------
.. automethod:: dragon.optimizers.Nesterov.__init__ .. automethod:: dragon.vm.tensorflow.keras.layers.Activation.__init__
Methods
-------
apply_gradients
################
.. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex:
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.optimizers."; content: "tf.keras.layers.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -14,18 +14,10 @@ dtype ...@@ -14,18 +14,10 @@ dtype
##### #####
.. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.dtype .. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.dtype
layers
######
.. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.layers
non_trainable_weights non_trainable_weights
##################### #####################
.. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.non_trainable_weights .. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.non_trainable_weights
non_trainable_variables
#######################
.. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.non_trainable_variables
trainable trainable
######### #########
.. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.trainable .. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.trainable
...@@ -34,10 +26,6 @@ trainable_weights ...@@ -34,10 +26,6 @@ trainable_weights
################# #################
.. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.trainable_weights .. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.trainable_weights
trainable_variables
###################
.. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.trainable_variables
weights weights
####### #######
.. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.weights .. autoattribute:: dragon.vm.tensorflow.keras.layers.Layer.weights
......
LayerNormalization
==================
.. autoclass:: dragon.vm.tensorflow.keras.layers.LayerNormalization
__init__
--------
.. automethod:: dragon.vm.tensorflow.keras.layers.LayerNormalization.__init__
.. raw:: html
<style>
h1:before {
content: "tf.keras.layers.";
color: #103d3e;
}
</style>
...@@ -48,6 +48,9 @@ vm.tensorflow.math ...@@ -48,6 +48,9 @@ vm.tensorflow.math
`greater_equal(...) <math/greater_equal.html>`_ `greater_equal(...) <math/greater_equal.html>`_
: Compute the element-wise greater-equal comparison. : Compute the element-wise greater-equal comparison.
`is_finite(...) <math/is_finite.html>`_
: Check if the elements of input are finite.
`is_inf(...) <math/is_inf.html>`_ `is_inf(...) <math/is_inf.html>`_
: Check if the elements of input are infinite. : Check if the elements of input are infinite.
...@@ -140,6 +143,7 @@ vm.tensorflow.math ...@@ -140,6 +143,7 @@ vm.tensorflow.math
math/floor math/floor
math/greater math/greater
math/greater_equal math/greater_equal
math/is_finite
math/is_inf math/is_inf
math/is_nan math/is_nan
math/l2_normalize math/l2_normalize
......
is_finite
=========
.. autofunction:: dragon.vm.tensorflow.math.is_finite
.. raw:: html
<style>
h1:before {
content: "tf.math.";
color: #103d3e;
}
</style>
...@@ -18,10 +18,6 @@ vm.tensorflow.nn ...@@ -18,10 +18,6 @@ vm.tensorflow.nn
`avg_pool3d(...) <nn/avg_pool3d.html>`_ `avg_pool3d(...) <nn/avg_pool3d.html>`_
: Apply the 3d average pooling. : Apply the 3d average pooling.
`batch_normalization(...) <nn/batch_normalization.html>`_
: Apply the batch normalization.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`bias_add(...) <nn/bias_add.html>`_ `bias_add(...) <nn/bias_add.html>`_
: Add the bias across channels to input. : Add the bias across channels to input.
...@@ -64,6 +60,10 @@ vm.tensorflow.nn ...@@ -64,6 +60,10 @@ vm.tensorflow.nn
: Apply the exponential exponential linear unit to input. : Apply the exponential exponential linear unit to input.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_. `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`fused_batch_norm(...) <nn/fused_batch_norm.html>`_
: Apply the batch normalization.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`gelu(...) <nn/gelu.html>`_ `gelu(...) <nn/gelu.html>`_
: Apply the gaussian error linear unit. : Apply the gaussian error linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_. `[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
...@@ -121,10 +121,6 @@ vm.tensorflow.nn ...@@ -121,10 +121,6 @@ vm.tensorflow.nn
`sparse_softmax_cross_entropy_with_logits(...) <nn/sparse_softmax_cross_entropy_with_logits.html>`_ `sparse_softmax_cross_entropy_with_logits(...) <nn/sparse_softmax_cross_entropy_with_logits.html>`_
: Compute the softmax cross entropy with sparse labels. : Compute the softmax cross entropy with sparse labels.
`swish(...) <nn/swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -132,7 +128,6 @@ vm.tensorflow.nn ...@@ -132,7 +128,6 @@ vm.tensorflow.nn
nn/avg_pool1d nn/avg_pool1d
nn/avg_pool2d nn/avg_pool2d
nn/avg_pool3d nn/avg_pool3d
nn/batch_normalization
nn/bias_add nn/bias_add
nn/conv1d nn/conv1d
nn/conv1d_transpose nn/conv1d_transpose
...@@ -146,6 +141,7 @@ vm.tensorflow.nn ...@@ -146,6 +141,7 @@ vm.tensorflow.nn
nn/depth_to_space nn/depth_to_space
nn/dropout nn/dropout
nn/elu nn/elu
nn/fused_batch_norm
nn/gelu nn/gelu
nn/leaky_relu nn/leaky_relu
nn/local_response_normalization nn/local_response_normalization
...@@ -163,7 +159,6 @@ vm.tensorflow.nn ...@@ -163,7 +159,6 @@ vm.tensorflow.nn
nn/softmax_cross_entropy_with_logits nn/softmax_cross_entropy_with_logits
nn/space_to_depth nn/space_to_depth
nn/sparse_softmax_cross_entropy_with_logits nn/sparse_softmax_cross_entropy_with_logits
nn/swish
.. raw:: html .. raw:: html
......
batch_normalization fused_batch_norm
=================== ================
.. autofunction:: dragon.vm.tensorflow.nn.batch_normalization .. autofunction:: dragon.vm.tensorflow.nn.fused_batch_norm
.. raw:: html .. raw:: html
......
unstack
=======
.. autofunction:: dragon.vm.tensorflow.unstack
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
...@@ -4,7 +4,7 @@ vm.torch ...@@ -4,7 +4,7 @@ vm.torch
.. only:: html .. only:: html
Classes Classes
####### -------
`class device <torch/device.html>`_ `class device <torch/device.html>`_
: Represent the device spec. : Represent the device spec.
...@@ -28,7 +28,7 @@ vm.torch ...@@ -28,7 +28,7 @@ vm.torch
: A multi-dimensional array containing elements of a single data type. : A multi-dimensional array containing elements of a single data type.
Functions Functions
######### ---------
`abs(...) <torch/abs.html>`_ `abs(...) <torch/abs.html>`_
: Compute the absolute value of input. : Compute the absolute value of input.
...@@ -144,6 +144,9 @@ vm.torch ...@@ -144,6 +144,9 @@ vm.torch
`index_select(...) <torch/index_select.html>`_ `index_select(...) <torch/index_select.html>`_
: Select elements along the given dimension using index. : Select elements along the given dimension using index.
`isfinite(...) <torch/isfinite.html>`_
: Check if the elements of input are finite.
`isinf(...) <torch/isinf.html>`_ `isinf(...) <torch/isinf.html>`_
: Check if the elements of input are infinite. : Check if the elements of input are infinite.
...@@ -219,15 +222,15 @@ vm.torch ...@@ -219,15 +222,15 @@ vm.torch
`nonzero(...) <torch/nonzero.html>`_ `nonzero(...) <torch/nonzero.html>`_
: Return the index of non-zero elements. : Return the index of non-zero elements.
`norm(...) <torch/norm.html>`_
: Compute the min value of elements along the given dimension.
`ones(...) <torch/ones.html>`_ `ones(...) <torch/ones.html>`_
: Return a tensor filled with ones. : Return a tensor filled with ones.
`ones_like(...) <torch/ones_like.html>`_ `ones_like(...) <torch/ones_like.html>`_
: Return a tensor of ones with shape as the other. : Return a tensor of ones with shape as the other.
`one_hot(...) <torch/one_hot.html>`_
: Return the one-hot representation for input.
`permute(...) <torch/permute.html>`_ `permute(...) <torch/permute.html>`_
: Return a new tensor with the specific order of dimensions. : Return a new tensor with the specific order of dimensions.
...@@ -279,6 +282,9 @@ vm.torch ...@@ -279,6 +282,9 @@ vm.torch
`sqrt(...) <torch/sqrt.html>`_ `sqrt(...) <torch/sqrt.html>`_
: Compute the square root of input. : Compute the square root of input.
`square(...) <torch/square.html>`_
: Compute the square of input.
`squeeze(...) <torch/squeeze.html>`_ `squeeze(...) <torch/squeeze.html>`_
: Remove the dimensions of input with size 1. : Remove the dimensions of input with size 1.
...@@ -309,6 +315,9 @@ vm.torch ...@@ -309,6 +315,9 @@ vm.torch
`triu(...) <torch/triu.html>`_ `triu(...) <torch/triu.html>`_
: Return the upper triangular part of input. : Return the upper triangular part of input.
`unbind(...) <torch/unbind.html>`_
: Unpack input into chunks along the given dimension.
`unique(...) <torch/unique.html>`_ `unique(...) <torch/unique.html>`_
: Return the unique elements of input. : Return the unique elements of input.
...@@ -370,6 +379,7 @@ vm.torch ...@@ -370,6 +379,7 @@ vm.torch
torch/ge torch/ge
torch/gt torch/gt
torch/index_select torch/index_select
torch/isfinite
torch/isinf torch/isinf
torch/isnan torch/isnan
torch/le torch/le
...@@ -396,9 +406,9 @@ vm.torch ...@@ -396,9 +406,9 @@ vm.torch
torch/neg torch/neg
torch/no_grad torch/no_grad
torch/nonzero torch/nonzero
torch/norm
torch/ones torch/ones
torch/ones_like torch/ones_like
torch/one_hot
torch/permute torch/permute
torch/pow torch/pow
torch/rand torch/rand
...@@ -417,6 +427,7 @@ vm.torch ...@@ -417,6 +427,7 @@ vm.torch
torch/sort torch/sort
torch/split torch/split
torch/sqrt torch/sqrt
torch/square
torch/squeeze torch/squeeze
torch/stack torch/stack
torch/sub torch/sub
...@@ -427,6 +438,7 @@ vm.torch ...@@ -427,6 +438,7 @@ vm.torch
torch/transpose torch/transpose
torch/tril torch/tril
torch/triu torch/triu
torch/unbind
torch/unique torch/unique
torch/unsqueeze torch/unsqueeze
torch/where torch/where
......
...@@ -305,6 +305,10 @@ int\_ ...@@ -305,6 +305,10 @@ int\_
###### ######
.. automethod:: dragon.vm.torch.Tensor.int_ .. automethod:: dragon.vm.torch.Tensor.int_
isfinite
########
.. automethod:: dragon.vm.torch.Tensor.isfinite
isinf isinf
##### #####
.. automethod:: dragon.vm.torch.Tensor.isinf .. automethod:: dragon.vm.torch.Tensor.isinf
...@@ -461,6 +465,10 @@ nonzero ...@@ -461,6 +465,10 @@ nonzero
####### #######
.. automethod:: dragon.vm.torch.Tensor.nonzero .. automethod:: dragon.vm.torch.Tensor.nonzero
norm
####
.. automethod:: dragon.vm.torch.Tensor.norm
normal\_ normal\_
######## ########
.. automethod:: dragon.vm.torch.Tensor.normal_ .. automethod:: dragon.vm.torch.Tensor.normal_
...@@ -581,6 +589,10 @@ sqrt\_ ...@@ -581,6 +589,10 @@ sqrt\_
###### ######
.. automethod:: dragon.vm.torch.Tensor.sqrt_ .. automethod:: dragon.vm.torch.Tensor.sqrt_
square
######
.. automethod:: dragon.vm.torch.Tensor.square
squeeze squeeze
####### #######
.. automethod:: dragon.vm.torch.Tensor.squeeze .. automethod:: dragon.vm.torch.Tensor.squeeze
...@@ -641,6 +653,10 @@ type ...@@ -641,6 +653,10 @@ type
#### ####
.. automethod:: dragon.vm.torch.Tensor.type .. automethod:: dragon.vm.torch.Tensor.type
unbind
######
.. automethod:: dragon.vm.torch.Tensor.unbind
uniform\_ uniform\_
######### #########
.. automethod:: dragon.vm.torch.Tensor.uniform_ .. automethod:: dragon.vm.torch.Tensor.uniform_
...@@ -706,6 +722,7 @@ zero\_ ...@@ -706,6 +722,7 @@ zero\_
.. _torch.gather(...): gather.html .. _torch.gather(...): gather.html
.. _torch.ge(...): ge.html .. _torch.ge(...): ge.html
.. _torch.gt(...): gt.html .. _torch.gt(...): gt.html
.. _torch.isfinite(...): isfinite.html
.. _torch.isinf(...): isinf.html .. _torch.isinf(...): isinf.html
.. _torch.isnan(...): isnan.html .. _torch.isnan(...): isnan.html
.. _torch.le(...): le.html .. _torch.le(...): le.html
...@@ -726,6 +743,7 @@ zero\_ ...@@ -726,6 +743,7 @@ zero\_
.. _torch.ne(...): ne.html .. _torch.ne(...): ne.html
.. _torch.neg(...): neg.html .. _torch.neg(...): neg.html
.. _torch.nonzero(...): nonzero.html .. _torch.nonzero(...): nonzero.html
.. _torch.norm(...): norm.html
.. _torch.ones(...): ones.html .. _torch.ones(...): ones.html
.. _torch.pow(...): pow.html .. _torch.pow(...): pow.html
.. _torch.reciprocal(...): reciprocal.html .. _torch.reciprocal(...): reciprocal.html
...@@ -740,6 +758,7 @@ zero\_ ...@@ -740,6 +758,7 @@ zero\_
.. _torch.sort(...): sort.html .. _torch.sort(...): sort.html
.. _torch.split(...): split.html .. _torch.split(...): split.html
.. _torch.sqrt(...): sqrt.html .. _torch.sqrt(...): sqrt.html
.. _torch.square(...): square.html
.. _torch.squeeze(...): squeeze.html .. _torch.squeeze(...): squeeze.html
.. _torch.sub(...): sub.html .. _torch.sub(...): sub.html
.. _torch.sum(...): sum.html .. _torch.sum(...): sum.html
...@@ -748,6 +767,7 @@ zero\_ ...@@ -748,6 +767,7 @@ zero\_
.. _torch.transpose(...): transpose.html .. _torch.transpose(...): transpose.html
.. _torch.tril(...): tril.html .. _torch.tril(...): tril.html
.. _torch.triu(...): triu.html .. _torch.triu(...): triu.html
.. _torch.unbind(...): unbind.html
.. _torch.unique(...): unique.html .. _torch.unique(...): unique.html
.. _torch.unsqueeze(...): unsqueeze.html .. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html .. _torch.where(...): where.html
......
vm.torch.backends
=================
.. only:: html
Modules
-------
`Module cudnn <backends/cudnn.html>`_
: The cuDNN backend module.
.. toctree::
:hidden:
backends/cudnn
.. raw:: html
<style>
h1:before {
content: "Module: dragon.";
color: #103d3e;
}
</style>
cudnn
=====
Properties
----------
allow_tf32
##########
.. data:: dragon.vm.torch.backends.cudnn.allow_tf32
:annotation: = False
The flag that allows cuDNN TF32 math type or not.
benchmark
#########
.. data:: dragon.vm.torch.backends.cudnn.benchmark
:annotation: = False
The flag that benchmarks fastest cuDNN algorithms or not.
deterministic
#############
.. data:: dragon.vm.torch.backends.cudnn.deterministic
:annotation: = False
The flag that selects deterministic cuDNN algorithms or not.
enabled
#######
.. data:: dragon.vm.torch.backends.cudnn.enabled
:annotation: = True
The flag that uses cuDNN or not.
Functions
---------
is_available
############
.. automethod:: dragon.vm.torch.backends.cudnn.is_available
version
#######
.. automethod:: dragon.vm.torch.backends.cudnn.version
.. raw:: html
<style>
h1:before {
content: "torch.backends.";
color: #103d3e;
}
</style>
vm.torch.cuda
=============
.. only:: html
Functions
---------
`current_device(...) <cuda/current_device.html>`_
: Return the index of current selected device.
`get_device_capability(...) <cuda/get_device_capability.html>`_
: Return the capability of specified device.
`is_available(...) <cuda/is_available.html>`_
: Return a bool reporting if runtime is available.
`set_device(...) <cuda/set_device.html>`_
: Set the current device.
`synchronize(...) <cuda/synchronize.html>`_
: Synchronize all streams on a device.
.. toctree::
:hidden:
cuda/current_device
cuda/get_device_capability
cuda/is_available
cuda/set_device
cuda/synchronize
.. raw:: html
<style>
h1:before {
content: "Module: dragon.";
color: #103d3e;
}
</style>
current_device
==============
.. autofunction:: dragon.vm.torch.cuda.current_device
.. raw:: html
<style>
h1:before {
content: "torch.cuda.";
color: #103d3e;
}
</style>
get_device_capability
=====================
.. autofunction:: dragon.vm.torch.cuda.get_device_capability
.. raw:: html
<style>
h1:before {
content: "torch.cuda.";
color: #103d3e;
}
</style>
StopGradient is_available
============ ============
.. autoclass:: dragon.vm.caffe.core.layers.StopGradient .. autofunction:: dragon.vm.torch.cuda.is_available
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "torch.cuda.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
set_device
==========
.. autofunction:: dragon.vm.torch.cuda.set_device
.. raw:: html
<style>
h1:before {
content: "torch.cuda.";
color: #103d3e;
}
</style>
synchronize
===========
.. autofunction:: dragon.vm.torch.cuda.synchronize
.. raw:: html
<style>
h1:before {
content: "torch.cuda.";
color: #103d3e;
}
</style>
isfinite
========
.. autofunction:: dragon.vm.torch.isfinite
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
...@@ -82,11 +82,14 @@ vm.torch.nn ...@@ -82,11 +82,14 @@ vm.torch.nn
`class ConvTranspose3d <nn/ConvTranspose3d.html>`_ `class ConvTranspose3d <nn/ConvTranspose3d.html>`_
: Apply the 3d deconvolution. : Apply the 3d deconvolution.
`class CosineSimilarity <nn/CosineSimilarity.html>`_
: Compute the softmax cross entropy.
`class CrossEntropyLoss <nn/CrossEntropyLoss.html>`_ `class CrossEntropyLoss <nn/CrossEntropyLoss.html>`_
: Compute the softmax cross entropy with sparse labels. : Compute the softmax cross entropy.
`class CTCLoss <nn/CTCLoss.html>`_ `class CTCLoss <nn/CTCLoss.html>`_
: Compute the ctc loss with batched labels. : Compute the ctc loss.
`[Graves & Gomez, 2006] <http://www.cs.utoronto.ca/~graves/icml_2006.pdf>`_. `[Graves & Gomez, 2006] <http://www.cs.utoronto.ca/~graves/icml_2006.pdf>`_.
`class DepthwiseConv2d <nn/DepthwiseConv2d.html>`_ `class DepthwiseConv2d <nn/DepthwiseConv2d.html>`_
...@@ -192,7 +195,7 @@ vm.torch.nn ...@@ -192,7 +195,7 @@ vm.torch.nn
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_. `[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class NLLLoss <nn/NLLLoss.html>`_ `class NLLLoss <nn/NLLLoss.html>`_
: Compute the negative likelihood loss with sparse labels. : Compute the negative likelihood loss.
`class Parameter <nn/Parameter.html>`_ `class Parameter <nn/Parameter.html>`_
: A wrapped tensor considered to be a module parameter. : A wrapped tensor considered to be a module parameter.
...@@ -248,7 +251,7 @@ vm.torch.nn ...@@ -248,7 +251,7 @@ vm.torch.nn
: Apply the sigmoid function. : Apply the sigmoid function.
`class SigmoidFocalLoss <nn/SigmoidFocalLoss.html>`_ `class SigmoidFocalLoss <nn/SigmoidFocalLoss.html>`_
: Compute the sigmoid focal loss with sparse labels. : Compute the sigmoid focal loss.
`[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__. `[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__.
`class SiLU <nn/SiLU.html>`_ `class SiLU <nn/SiLU.html>`_
...@@ -327,6 +330,7 @@ vm.torch.nn ...@@ -327,6 +330,7 @@ vm.torch.nn
nn/ConvTranspose1d nn/ConvTranspose1d
nn/ConvTranspose2d nn/ConvTranspose2d
nn/ConvTranspose3d nn/ConvTranspose3d
nn/CosineSimilarity
nn/CrossEntropyLoss nn/CrossEntropyLoss
nn/CTCLoss nn/CTCLoss
nn/DepthwiseConv2d nn/DepthwiseConv2d
......
CosineSimilarity
================
.. autoclass:: dragon.vm.torch.nn.CosineSimilarity
__init__
--------
.. automethod:: dragon.vm.torch.nn.CosineSimilarity.__init__
.. _torch.nn.functional.cosine_similarity(...): functional/cosine_similarity.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
...@@ -9,6 +9,13 @@ __init__ ...@@ -9,6 +9,13 @@ __init__
.. _torch.nn.functional.sync_batch_norm(...): functional/sync_batch_norm.html .. _torch.nn.functional.sync_batch_norm(...): functional/sync_batch_norm.html
Methods
-------
convert_sync_batchnorm
######################
.. automethod:: dragon.vm.torch.nn.SyncBatchNorm.convert_sync_batchnorm
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -62,11 +62,14 @@ vm.torch.nn.functional ...@@ -62,11 +62,14 @@ vm.torch.nn.functional
`conv_transpose3d(...) <functional/conv_transpose3d.html>`_ `conv_transpose3d(...) <functional/conv_transpose3d.html>`_
: Apply the 3d deconvolution to input. : Apply the 3d deconvolution to input.
`cosine_similarity(...) <functional/cosine_similarity.html>`_
: Compute the cosine similarity between inputs.
`cross_entropy(...) <functional/cross_entropy.html>`_ `cross_entropy(...) <functional/cross_entropy.html>`_
: Compute the softmax cross entropy with sparse labels. : Compute the softmax cross entropy.
`ctc_loss(...) <functional/ctc_loss.html>`_ `ctc_loss(...) <functional/ctc_loss.html>`_
: Compute the ctc loss with batched labels. : Compute the ctc loss.
`[Graves & Gomez, 2006] <http://www.cs.utoronto.ca/~graves/icml_2006.pdf>`_. `[Graves & Gomez, 2006] <http://www.cs.utoronto.ca/~graves/icml_2006.pdf>`_.
`depthwise_conv2d(...) <functional/depthwise_conv2d.html>`_ `depthwise_conv2d(...) <functional/depthwise_conv2d.html>`_
...@@ -147,11 +150,14 @@ vm.torch.nn.functional ...@@ -147,11 +150,14 @@ vm.torch.nn.functional
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_. `[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`nll_loss(...) <functional/nll_loss.html>`_ `nll_loss(...) <functional/nll_loss.html>`_
: Compute the negative likelihood loss with sparse labels. : Compute the negative likelihood loss.
`normalize(...) <functional/normalize.html>`_ `normalize(...) <functional/normalize.html>`_
: Apply the :math:`L_{p}` normalization to the input. : Apply the :math:`L_{p}` normalization to the input.
`one_hot(...) <functional/one_hot.html>`_
: Return the one-hot representation of input.
`pad(...) <functional/pad.html>`_ `pad(...) <functional/pad.html>`_
: Pad the input according to the given sizes. : Pad the input according to the given sizes.
...@@ -174,14 +180,14 @@ vm.torch.nn.functional ...@@ -174,14 +180,14 @@ vm.torch.nn.functional
`[Krizhevsky, 2010] <http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf>`_. `[Krizhevsky, 2010] <http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf>`_.
`selu(...) <functional/selu.html>`_ `selu(...) <functional/selu.html>`_
: Compute the sigmoid focal loss with sparse labels. : Compute the sigmoid focal loss.
`[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__. `[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__.
`sigmoid(...) <functional/sigmoid.html>`_ `sigmoid(...) <functional/sigmoid.html>`_
: Apply the sigmoid function to input. : Apply the sigmoid function to input.
`sigmoid_focal_loss(...) <functional/sigmoid_focal_loss.html>`_ `sigmoid_focal_loss(...) <functional/sigmoid_focal_loss.html>`_
: Compute the sigmoid focal loss with sparse labels. : Compute the sigmoid focal loss.
`[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__. `[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__.
`silu(...) <functional/silu.html>`_ `silu(...) <functional/silu.html>`_
...@@ -235,6 +241,7 @@ vm.torch.nn.functional ...@@ -235,6 +241,7 @@ vm.torch.nn.functional
functional/conv_transpose1d functional/conv_transpose1d
functional/conv_transpose2d functional/conv_transpose2d
functional/conv_transpose3d functional/conv_transpose3d
functional/cosine_similarity
functional/cross_entropy functional/cross_entropy
functional/ctc_loss functional/ctc_loss
functional/depthwise_conv2d functional/depthwise_conv2d
...@@ -261,6 +268,7 @@ vm.torch.nn.functional ...@@ -261,6 +268,7 @@ vm.torch.nn.functional
functional/multi_head_attention_forward functional/multi_head_attention_forward
functional/nll_loss functional/nll_loss
functional/normalize functional/normalize
functional/one_hot
functional/pad functional/pad
functional/pixel_shuffle functional/pixel_shuffle
functional/pixel_unshuffle functional/pixel_unshuffle
......
cosine_similarity
=================
.. autofunction:: dragon.vm.torch.nn.functional.cosine_similarity
.. _torch.nn.CosineSimilarity(...): ../CosineSimilarity.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
one_hot one_hot
======= =======
.. autofunction:: dragon.vm.torch.one_hot .. autofunction:: dragon.vm.torch.nn.functional.one_hot
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "torch."; content: "torch.nn.functional.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
norm
====
.. autofunction:: dragon.vm.torch.norm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
...@@ -14,6 +14,10 @@ vm.torch.optim ...@@ -14,6 +14,10 @@ vm.torch.optim
: The optimizer to apply AdamW algorithm. : The optimizer to apply AdamW algorithm.
`[Loshchilov & Hutter, 2017] <https://arxiv.org/abs/1711.05101>`_. `[Loshchilov & Hutter, 2017] <https://arxiv.org/abs/1711.05101>`_.
`class LARS <optim/LARS.html>`_
: The optimizer to apply LARS algorithm.
`[You et.al, 2017] <https://arxiv.org/abs/1708.03888>`_.
`class Optimizer <optim/Optimizer.html>`_ `class Optimizer <optim/Optimizer.html>`_
: The base class of optimizers. : The base class of optimizers.
...@@ -29,6 +33,7 @@ vm.torch.optim ...@@ -29,6 +33,7 @@ vm.torch.optim
optim/Adam optim/Adam
optim/AdamW optim/AdamW
optim/LARS
optim/Optimizer optim/Optimizer
optim/RMSprop optim/RMSprop
optim/SGD optim/SGD
......
LARS
====
.. autoclass:: dragon.vm.torch.optim.LARS
__init__
--------
.. automethod:: dragon.vm.torch.optim.LARS.__init__
Methods
-------
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
:noindex:
step
####
.. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex:
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
:noindex:
.. raw:: html
<style>
h1:before {
content: "torch.optim.";
color: #103d3e;
}
</style>
square
======
.. autofunction:: dragon.vm.torch.square
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
unbind
======
.. autofunction:: dragon.vm.torch.unbind
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
vm.torch.utils.checkpoint
=========================
.. only:: html
Classes
-------
`class no_checkpoint <checkpoint/no_checkpoint.html>`_
: Context-manager to disable checkpointing.
Functions
---------
`checkpoint(...) <checkpoint/checkpoint.html>`_
: Call function and create a checkpoint.
`checkpoint_sequential(...) <checkpoint/checkpoint_sequential.html>`_
: Call functions and create segmental checkpoints.
.. toctree::
:hidden:
checkpoint/checkpoint
checkpoint/checkpoint_sequential
checkpoint/no_checkpoint
.. raw:: html
<style>
h1:before {
content: "Module: dragon.";
color: #103d3e;
}
</style>
checkpoint
==========
.. autofunction:: dragon.vm.torch.utils.checkpoint.checkpoint
.. raw:: html
<style>
h1:before {
content: "torch.utils.checkpoint.";
color: #103d3e;
}
</style>
checkpoint_sequential
=====================
.. autofunction:: dragon.vm.torch.utils.checkpoint.checkpoint_sequential
.. raw:: html
<style>
h1:before {
content: "torch.utils.checkpoint.";
color: #103d3e;
}
</style>
no_checkpoint
=============
.. autoclass:: dragon.vm.torch.utils.checkpoint.no_checkpoint
__init__
--------
.. automethod:: dragon.vm.torch.utils.checkpoint.no_checkpoint.__init__
.. raw:: html
<style>
h1:before {
content: "torch.utils.checkpoint.";
color: #103d3e;
}
</style>
...@@ -24,7 +24,7 @@ class Workspace; ...@@ -24,7 +24,7 @@ class Workspace;
*/ */
class DRAGON_API CPUContext { class DRAGON_API CPUContext {
public: public:
/*! \brief Default Constructor */ /*! \brief Constructor */
CPUContext() : random_seed_(3) {} CPUContext() : random_seed_(3) {}
/*! \brief Constructor with the random seed */ /*! \brief Constructor with the random seed */
......
...@@ -26,7 +26,7 @@ class Workspace; ...@@ -26,7 +26,7 @@ class Workspace;
class CUDAObjects { class CUDAObjects {
public: public:
/*! \brief Default Constructor */ /*! \brief Constructor */
CUDAObjects() { CUDAObjects() {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) { for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
cuda_streams_[i] = vector<cudaStream_t>(); cuda_streams_[i] = vector<cudaStream_t>();
...@@ -143,9 +143,6 @@ class CUDAObjects { ...@@ -143,9 +143,6 @@ class CUDAObjects {
#ifdef USE_CUDNN #ifdef USE_CUDNN
/*! \brief The cached cuDNN handles of each device */ /*! \brief The cached cuDNN handles of each device */
vector<cudnnHandle_t> cudnn_handles_[CUDA_MAX_DEVICES]; vector<cudnnHandle_t> cudnn_handles_[CUDA_MAX_DEVICES];
/*! \brief The disabled cuDNN operators */
Set<string> cudnn_disabled_ops_;
#endif #endif
#ifdef USE_NCCL #ifdef USE_NCCL
...@@ -153,15 +150,15 @@ class CUDAObjects { ...@@ -153,15 +150,15 @@ class CUDAObjects {
Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES]; Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES];
#endif #endif
/*! \brief The flag that allows cuDNN or not */ /*! \brief The flag that uses cuDNN or not */
bool cudnn_enabled_ = true; bool cudnn_enabled_ = true;
/*! \brief The flag that enforces deterministic cuDNN algorithms or not */
bool cudnn_deterministic_ = false;
/*! \brief The flag that benchmarks fastest cuDNN algorithms or not */ /*! \brief The flag that benchmarks fastest cuDNN algorithms or not */
bool cudnn_benchmark_ = false; bool cudnn_benchmark_ = false;
/*! \brief The flag that selects deterministic cuDNN algorithms or not */
bool cudnn_deterministic_ = false;
/*! \brief The flag that allows cuDNN TF32 math type or not */ /*! \brief The flag that allows cuDNN TF32 math type or not */
bool cudnn_allow_tf32_ = false; bool cudnn_allow_tf32_ = false;
...@@ -174,7 +171,7 @@ class CUDAObjects { ...@@ -174,7 +171,7 @@ class CUDAObjects {
*/ */
class DRAGON_API CUDAContext { class DRAGON_API CUDAContext {
public: public:
/*! \brief Default constructor */ /*! \brief Constructor */
CUDAContext() : device_id_(0), random_seed_(DEFAULT_RNG_SEED) {} CUDAContext() : device_id_(0), random_seed_(DEFAULT_RNG_SEED) {}
/*! \brief Constructor with the device index */ /*! \brief Constructor with the device index */
...@@ -366,7 +363,7 @@ class DRAGON_API CUDAContext { ...@@ -366,7 +363,7 @@ class DRAGON_API CUDAContext {
class DRAGON_API CUDAContext { class DRAGON_API CUDAContext {
public: public:
/*! \brief Default constructor */ /*! \brief Constructor */
explicit CUDAContext() { explicit CUDAContext() {
CUDA_NOT_COMPILED; CUDA_NOT_COMPILED;
} }
......
...@@ -31,9 +31,11 @@ class GradientMakerBase { ...@@ -31,9 +31,11 @@ class GradientMakerBase {
virtual bool CopyArguments() const { virtual bool CopyArguments() const {
return true; return true;
} }
virtual bool CopyDeviceOption() const { virtual bool CopyDeviceOption() const {
return true; return true;
} }
virtual bool CopyEngine() const { virtual bool CopyEngine() const {
return true; return true;
} }
...@@ -46,7 +48,7 @@ class GradientMakerBase { ...@@ -46,7 +48,7 @@ class GradientMakerBase {
if (arg.name() == "cache_key") cache_key = arg.s(); if (arg.name() == "cache_key") cache_key = arg.s();
} }
Argument new_arg; Argument new_arg;
new_arg.set_name("handle"); new_arg.set_name("name");
new_arg.set_s(def_.name()); new_arg.set_s(def_.name());
for (auto& grad_def : grad_defs_) { for (auto& grad_def : grad_defs_) {
if (CopyDeviceOption() && def_.has_device_option()) { if (CopyDeviceOption() && def_.has_device_option()) {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace dragon { namespace dragon {
GraphBase::GraphBase(const GraphDef& def, Workspace* ws) GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
: def_(def), ws_(ws), name_(def.name()), phase_("TEST") { : def_(def), workspace_(ws), name_(def.name()), phase_("TEST") {
// Collect arguments. // Collect arguments.
for (auto& arg : def_.arg()) { for (auto& arg : def_.arg()) {
CHECK_GT(arg.name().size(), 0); CHECK_GT(arg.name().size(), 0);
...@@ -19,7 +19,7 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws) ...@@ -19,7 +19,7 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
Set<string> outputs; Set<string> outputs;
for (const auto& op : def.op()) { for (const auto& op : def.op()) {
for (const auto& input : op.input()) for (const auto& input : op.input())
CHECK(outputs.count(input) || ws_->HasTensor(input)) CHECK(outputs.count(input) || workspace_->HasTensor(input))
<< "\nInput " << input << " is not in the graph."; << "\nInput " << input << " is not in the graph.";
for (const auto& output : op.output()) { for (const auto& output : op.output()) {
outputs.insert(output); outputs.insert(output);
...@@ -27,7 +27,7 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws) ...@@ -27,7 +27,7 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
} }
// Check outputs. // Check outputs.
for (const auto& output : def.output()) { for (const auto& output : def.output()) {
CHECK(outputs.count(output) || ws_->HasTensor(output)) CHECK(outputs.count(output) || workspace_->HasTensor(output))
<< "\nOutput " << output << " is not in the graph."; << "\nOutput " << output << " is not in the graph.";
} }
} }
...@@ -37,52 +37,38 @@ bool Graph::Create(const GraphDef& def) { ...@@ -37,52 +37,38 @@ bool Graph::Create(const GraphDef& def) {
bool has_device_option = def.has_device_option(); bool has_device_option = def.has_device_option();
for (int i = 0; i < def.op_size(); i++) { for (int i = 0; i < def.op_size(); i++) {
auto op_def(def.op(i)); auto op_def(def.op(i));
// Inherit device if not provided.
if (!op_def.has_device_option() && has_device_option) { if (!op_def.has_device_option() && has_device_option) {
op_def.mutable_device_option()->CopyFrom(def.device_option()); op_def.mutable_device_option()->CopyFrom(def.device_option());
} }
LOG(DEBUG) << "Create: " << op_def.name() << " [" << op_def.type() << "]"; LOG(DEBUG) << "Create: " << op_def.name() << " [" << op_def.type() << "]";
ops_.push_back(OperatorBase::New(op_def, ws_)); auto* op_ptr = OperatorBase::New(op_def, workspace_);
ops_.back()->set_output_aliases(output_aliases_); operators_.push_back(unique_ptr<OperatorBase>(op_ptr));
operators_.back()->set_outputs_from(outputs_from_);
} }
return true; return true;
} }
Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
// Apply the optimizations. // Apply optimizations.
GraphDef def_v2(def); GraphDef optimized_def(def);
GraphOptimizer optimizer(ws); GraphOptimizer optimizer;
Map<string, vec32_t> subgraph_indices; int opt_level = 1;
int opt = 1; if (args().count("optimization")) opt_level = arg("optimization").i();
if (args().count("optimization")) opt = arg("optimization").i(); if (opt_level >= 2) optimizer.PlanInplace(def, outputs_from_);
if (opt >= 2) optimizer.PlanInplace(def_v2, output_aliases_); if (opt_level >= 3) {
if (opt >= 3) {
if (phase() == "TRAIN") { if (phase() == "TRAIN") {
def_v2 = optimizer.PlanCheckpoint(def_v2, subgraph_indices);
if (args().count("grad_sources")) { if (args().count("grad_sources")) {
GradientTape tape(def_v2); GradientTape tape(def);
auto& grad_sources = args_["grad_sources"]->strings(); const auto& sources = arg("grad_sources").strings();
tape.Optimize({grad_sources.begin(), grad_sources.end()}); tape.Optimize({sources.begin(), sources.end()});
def_v2 = tape.def(); optimized_def = tape.def();
} }
} else { } else {
def_v2 = optimizer.EliminateIntermediates(def_v2); optimized_def = optimizer.EliminateIntermediates(def);
} }
} }
// Create graph. // Create graph.
Create(def_v2); Create(optimized_def);
// Create subgraphs.
if (subgraph_indices.size() > 0) {
Map<string, vector<OperatorBase*>> subgraph;
for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>();
for (auto op_idx : subgraph_indices[it.first])
subgraph[it.first].push_back(ops_[op_idx]);
}
for (auto* op : ops_) {
op->set_subgraph(subgraph);
}
}
} }
bool Graph::Run(int stream, const string& include, const string& exclude) { bool Graph::Run(int stream, const string& include, const string& exclude) {
...@@ -90,13 +76,14 @@ bool Graph::Run(int stream, const string& include, const string& exclude) { ...@@ -90,13 +76,14 @@ bool Graph::Run(int stream, const string& include, const string& exclude) {
if (!include.empty()) regex_incl.reset(new std::regex(include)); if (!include.empty()) regex_incl.reset(new std::regex(include));
if (!exclude.empty()) regex_excl.reset(new std::regex(exclude)); if (!exclude.empty()) regex_excl.reset(new std::regex(exclude));
LOG(DEBUG) << "Run: " << name(); LOG(DEBUG) << "Run: " << name();
for (auto* op : ops_) { for (size_t op_index = 0; op_index < operators_.size(); ++op_index) {
if (regex_incl && !regex_match(op->type(), *regex_incl)) continue; auto* op_ptr = operators_[op_index].get();
if (regex_excl && regex_match(op->type(), *regex_excl)) continue; if (regex_incl && !regex_match(op_ptr->type(), *regex_incl)) continue;
op->SwitchToPhase(phase()); if (regex_excl && regex_match(op_ptr->type(), *regex_excl)) continue;
LOG(DEBUG) << "Run: " << op->name(); op_ptr->SwitchToPhase(phase());
op->Run(stream); LOG(DEBUG) << "Run: " << op_ptr->name();
LOG(DEBUG) << "Finish: " << op->name(); op_ptr->Run(stream);
LOG(DEBUG) << "Finish: " << op_ptr->name();
} }
LOG(DEBUG) << "Finish: " << name(); LOG(DEBUG) << "Finish: " << name();
return true; return true;
...@@ -104,8 +91,7 @@ bool Graph::Run(int stream, const string& include, const string& exclude) { ...@@ -104,8 +91,7 @@ bool Graph::Run(int stream, const string& include, const string& exclude) {
GraphBase* GraphBase::New(const GraphDef& def, Workspace* ws) { GraphBase* GraphBase::New(const GraphDef& def, Workspace* ws) {
if (!def.has_type() || def.type().empty()) { if (!def.has_type() || def.type().empty()) {
// Sequential scheduler. return new Graph(def, ws); // Sequential scheduler.
return new Graph(def, ws);
} }
return GraphRegistry()->Create(def.type(), def, ws); return GraphRegistry()->Create(def.type(), def, ws);
} }
......
...@@ -25,7 +25,7 @@ class Workspace; ...@@ -25,7 +25,7 @@ class Workspace;
*/ */
class DRAGON_API GraphBase { class DRAGON_API GraphBase {
public: public:
/*! \brief Constructor with the def and workspace */ /*! \brief Constructor */
GraphBase(const GraphDef& def, Workspace* ws); GraphBase(const GraphDef& def, Workspace* ws);
/*! \brief Destructor */ /*! \brief Destructor */
...@@ -75,24 +75,27 @@ class DRAGON_API GraphBase { ...@@ -75,24 +75,27 @@ class DRAGON_API GraphBase {
/*! \brief Return the parent workspace */ /*! \brief Return the parent workspace */
Workspace* workspace() const { Workspace* workspace() const {
return ws_; return workspace_;
} }
protected: protected:
/*! \brief The name and executing phase */ /*! \brief The graph def */
string name_, phase_; GraphDef def_;
/*! \brief The defined arguments */ /*! \brief The optimized graph def */
Map<string, const Argument*> args_; GraphDef optimized_def_;
/*! \brief The parent workspace */ /*! \brief The parent workspace */
Workspace* ws_; Workspace* workspace_;
/*! \brief The graph definition */ /*! \brief The graph name */
GraphDef def_; string name_;
/*! \brief The optimized graph definition */ /*! \brief The executing phase */
GraphDef optimized_def_; string phase_;
/*! \brief The arguments */
Map<string, const Argument*> args_;
DISABLE_COPY_AND_ASSIGN(GraphBase); DISABLE_COPY_AND_ASSIGN(GraphBase);
}; };
...@@ -102,16 +105,9 @@ class DRAGON_API GraphBase { ...@@ -102,16 +105,9 @@ class DRAGON_API GraphBase {
*/ */
class Graph : public GraphBase { class Graph : public GraphBase {
public: public:
/*! \brief Constructor with the def and workspace */ /*! \brief Constructor */
Graph(const GraphDef& def, Workspace* ws); Graph(const GraphDef& def, Workspace* ws);
/*! \brief Destructor */
virtual ~Graph() {
for (auto* op : ops_) {
delete op;
}
}
/*! \brief Create graph in the workspace */ /*! \brief Create graph in the workspace */
bool Create(const GraphDef& def) override; bool Create(const GraphDef& def) override;
...@@ -122,11 +118,11 @@ class Graph : public GraphBase { ...@@ -122,11 +118,11 @@ class Graph : public GraphBase {
const string& exclude = "") override; const string& exclude = "") override;
protected: protected:
/*! \brief The created operators */ /*! \brief The operators */
vector<OperatorBase*> ops_; vector<unique_ptr<OperatorBase>> operators_;
/*! \brief The output aliases */ /*! \brief The output sourcing tensors */
Map<string, Set<string>> output_aliases_; Map<string, Set<string>> outputs_from_;
}; };
/* Macros */ /* Macros */
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include "dragon/core/operator_schema.h" #include "dragon/core/operator_schema.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#define GRAPH_TEMPORAL_OUTPUT_MAX_SIZE 2
namespace dragon { namespace dragon {
void GraphOptimizer::BuildDAG(const GraphDef& graph) { void GraphOptimizer::BuildDAG(const GraphDef& graph) {
...@@ -33,10 +31,10 @@ void GraphOptimizer::BuildDAG(const GraphDef& graph) { ...@@ -33,10 +31,10 @@ void GraphOptimizer::BuildDAG(const GraphDef& graph) {
void GraphOptimizer::PlanInplace( void GraphOptimizer::PlanInplace(
const GraphDef& graph, const GraphDef& graph,
Map<string, Set<string>>& output_aliases) { Map<string, Set<string>>& sources) {
// Initialization. // Initialization.
BuildDAG(graph); BuildDAG(graph);
// Generate aliases map to apply in-place. // Add source for outputs to apply in-place.
for (const auto& iter : inputs_count_) { for (const auto& iter : inputs_count_) {
if (iter.second > 1 || iter.first.empty()) continue; if (iter.second > 1 || iter.first.empty()) continue;
const auto& input = iter.first; const auto& input = iter.first;
...@@ -48,122 +46,29 @@ void GraphOptimizer::PlanInplace( ...@@ -48,122 +46,29 @@ void GraphOptimizer::PlanInplace(
if (op.input(i) != input) continue; if (op.input(i) != input) continue;
for (int j = 0; j < op.output_size(); ++j) { for (int j = 0; j < op.output_size(); ++j) {
if (!schema->CheckInplace(i, j)) continue; if (!schema->CheckInplace(i, j)) continue;
output_aliases[op.output(j)].insert(input); sources[op.output(j)].insert(input);
}
}
}
}
GraphDef GraphOptimizer::PlanCheckpoint(
const GraphDef& graph,
Map<string, vec32_t>& subgraph_indices) {
GraphDef graph_v2(graph);
Map<string, set<int>> op_indices;
Map<string, string> rename_map;
Map<string, int> versions;
// Check the mirror stage setting.
for (const auto& op : graph.op()) {
if (str::find(op.type(), "Gradient")) continue;
bool mirror_stage = false;
for (auto& arg : op.arg()) {
if (arg.name() == "mirror_stage") {
mirror_stage |= (bool)arg.i();
}
}
if (mirror_stage) {
// We only assume X(0) can be recomputed.
rename_map[op.input(0)] = "placeholder";
}
}
// Allocate the temporal buffers.
string v2_name, version_name;
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx);
auto* op_v2 = graph_v2.mutable_op(op_idx);
vector<string> used_buffers;
for (int i = 0; i < op.input_size(); ++i) {
const auto& it = rename_map.find(op.input(i));
if (it != rename_map.end() && it->second != "placeholder") {
*op_v2->mutable_input(i) = it->second;
used_buffers.emplace_back(it->second);
}
}
for (int i = 0; i < op.output_size(); ++i) {
bool inplace_flag = false;
for (const auto& in : op.input()) {
if (in == op.output(i)) inplace_flag = true;
}
if (rename_map.count(op.output(i))) {
if (inplace_flag && rename_map[op.output(i)] != "placeholder") {
*op_v2->mutable_output(i) = rename_map[op.output(i)];
continue;
}
for (int j = 0; j < GRAPH_TEMPORAL_OUTPUT_MAX_SIZE; ++j) {
v2_name = "shared/buffer/output:" + str::to(j);
for (const auto& buffer : used_buffers)
if (str::find(buffer, v2_name)) {
v2_name.clear();
}
if (!v2_name.empty()) {
used_buffers.emplace_back(v2_name);
break;
}
}
CHECK(!v2_name.empty()) << "\nNo enough buffers for outputs.";
ws_->CreateTensor(v2_name)->set_version(0);
version_name = "/ver:" + str::to(versions[v2_name]++);
*op_v2->mutable_output(i) = rename_map[op.output(i)] =
v2_name + version_name;
} }
} }
} }
// Determine the recomputing ops for temporal buffers
for (int i = 0; i < graph.op_size(); ++i) {
const auto &op = graph.op(i), &op_v2 = graph_v2.op(i);
set<int> recomputing_ops = {i};
for (int j = 0; j < op.input_size(); ++j) {
if (op.input(j) != op_v2.input(j)) {
for (auto op_idx : op_indices[op.input(j)]) {
recomputing_ops.insert(op_idx);
}
}
}
for (const auto& out : op.output()) {
for (auto op_idx : recomputing_ops) {
op_indices[out].insert(op_idx);
}
}
}
// Bind to the renamed tensors
for (const auto& it : rename_map) {
for (auto op_idx : op_indices[it.first]) {
subgraph_indices[it.second].push_back(op_idx);
}
}
// Done
return graph_v2;
} }
GraphDef GraphOptimizer::EliminateIntermediates(const GraphDef& graph) { GraphDef GraphOptimizer::EliminateIntermediates(const GraphDef& graph) {
Set<string> required_outputs; Set<string> graph_outputs;
Map<string, int> inputs_count; Map<string, int> inputs_count;
Map<string, string> outputs_to_buffers; Map<string, string> outputs_to;
static Set<string> skip_ops = {"Shape"}; static Set<string> noop_types = {"Shape"};
// Prepare pool. auto optimized_graph(graph);
// Initialize buffer pool.
int buffer_idx = 0; int buffer_idx = 0;
std::deque<string> pool; std::deque<string> buffer_pool;
auto get_buffer = [&]() mutable { auto get_buffer = [&]() mutable {
if (pool.empty()) { if (buffer_pool.empty()) {
return "shared/buffer/output:" + str::to(++buffer_idx); return "Buffer_" + str::to(++buffer_idx);
} else { } else {
auto buffer = pool.back(); auto buffer = buffer_pool.back();
pool.pop_back(); buffer_pool.pop_back();
return buffer; return buffer;
} }
}; };
...@@ -175,56 +80,58 @@ GraphDef GraphOptimizer::EliminateIntermediates(const GraphDef& graph) { ...@@ -175,56 +80,58 @@ GraphDef GraphOptimizer::EliminateIntermediates(const GraphDef& graph) {
} }
} }
// Initialize the required outputs before optimization. // Initialize graph outputs before optimization.
for (const auto& output : graph.output()) { for (const auto& output : graph.output()) {
required_outputs.insert(output); graph_outputs.insert(output);
} }
// Rewrite the inputs and outputs. // Rewrite inputs and outputs.
auto graph_v2(graph);
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx); auto* op = optimized_graph.mutable_op(op_idx);
if (op.input_size() == 0) continue; if (op->input_size() == 0) continue;
auto* op_v2 = graph_v2.mutable_op(op_idx); vector<int> outputs_at(op->output_size(), -1);
// Check output aliases. for (int i = 0; i < op->output_size(); ++i) {
vec32_t output_aliases(op.output_size(), -1); for (int j = 0; j < op->input_size(); ++j) {
for (int i = 0; i < op.output_size(); ++i) { if (op->output(i) != op->input(j)) continue;
for (int j = 0; j < op.input_size(); ++j) { outputs_at[i] = j;
if (op.output(i) != op.input(j)) continue;
output_aliases[i] = j;
break; break;
} }
} }
// Rewrite inputs. // Rewrite inputs.
vector<string> dead_buffers; vector<string> dead_buffers;
for (int i = 0; i < op.input_size(); ++i) { for (int i = 0; i < op->input_size(); ++i) {
const auto& input = op.input(i); const auto& input = op->input(i);
const auto& count_iter = inputs_count.find(input); const auto& count_iter = inputs_count.find(input);
count_iter->second--; count_iter->second--;
const auto& buffer_iter = outputs_to_buffers.find(input); const auto& buffer_iter = outputs_to.find(input);
if (buffer_iter == outputs_to_buffers.end()) continue; if (buffer_iter == outputs_to.end()) continue;
if (count_iter->second == 0) { if (count_iter->second == 0) {
dead_buffers.emplace_back(buffer_iter->second); dead_buffers.emplace_back(buffer_iter->second);
} }
op_v2->set_input(i, buffer_iter->second); op->set_input(i, buffer_iter->second);
} }
if (skip_ops.count(op.type())) continue; if (noop_types.count(op->type())) continue;
// Rewrite outputs. // Rewrite outputs.
for (int i = 0; i < op.output_size(); ++i) { for (int i = 0; i < op->output_size(); ++i) {
const auto& output = op.output(i); const auto& output = op->output(i);
if (output.empty() || required_outputs.count(output) > 0) continue; if (output.empty() || graph_outputs.count(output) > 0) continue;
if (output_aliases[i] >= 0) { if (inputs_count.count(output) == 0) {
op_v2->set_output(i, op_v2->input(output_aliases[i])); op->mutable_output(i)->clear();
continue;
}
if (outputs_at[i] >= 0) {
op->set_output(i, op->input(outputs_at[i]));
} else { } else {
*op_v2->mutable_output(i) = outputs_to_buffers[output] = get_buffer(); op->set_output(i, outputs_to[output] = get_buffer());
} }
} }
// Update pool. // Update buffer pool.
for (auto& buffer : dead_buffers) { for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer); buffer_pool.emplace_back(buffer);
} }
} }
return graph_v2;
return optimized_graph;
} }
} // namespace dragon } // namespace dragon
...@@ -29,29 +29,19 @@ class GraphOptimizer { ...@@ -29,29 +29,19 @@ class GraphOptimizer {
OperatorDef op_def; OperatorDef op_def;
}; };
/*! \brief Default constructor */ /*! \brief Constructor */
GraphOptimizer(Workspace* ws) : ws_(ws) {} GraphOptimizer() {}
/*! \brief Build the DAG */ /*! \brief Build the DAG */
void BuildDAG(const GraphDef& graph); void BuildDAG(const GraphDef& graph);
/*! \brief Plan the inplace for inputs */ /*! \brief Plan the in-place for outputs */
void PlanInplace( void PlanInplace(const GraphDef& graph, Map<string, Set<string>>& sources);
const GraphDef& graph,
Map<string, Set<string>>& output_aliases);
/*! \brief Plan the checkpoint for inputs */
GraphDef PlanCheckpoint(
const GraphDef& graph,
Map<string, vec32_t>& subgraph_indices);
/*! \brief Eliminate the intermediate outputs */ /*! \brief Eliminate the intermediate outputs */
GraphDef EliminateIntermediates(const GraphDef& graph); GraphDef EliminateIntermediates(const GraphDef& graph);
protected: protected:
/* \brief The graph workspace */
Workspace* ws_;
/* \brief The graph nodes */ /* \brief The graph nodes */
Map<string, Node> nodes_; Map<string, Node> nodes_;
......
...@@ -37,7 +37,7 @@ class DRAGON_API UnifiedMemory { ...@@ -37,7 +37,7 @@ class DRAGON_API UnifiedMemory {
SYNCED = 3, SYNCED = 3,
}; };
/*! \brief Default constructor */ /*! \brief Constructor */
UnifiedMemory() {} UnifiedMemory() {}
/*! \brief Constructor with the type meta and size */ /*! \brief Constructor with the type meta and size */
...@@ -49,23 +49,23 @@ class DRAGON_API UnifiedMemory { ...@@ -49,23 +49,23 @@ class DRAGON_API UnifiedMemory {
/*! \brief Switch to the given cuda device */ /*! \brief Switch to the given cuda device */
void SwitchToCUDADevice(int device); void SwitchToCUDADevice(int device);
/*! \brief Involve the state to cpu */ /*! \brief Set to the cpu state */
void ToCPU(size_t size = 0); void ToCPU(size_t size = 0);
/*! \brief Involve the state to cuda */ /*! \brief Set to the cuda state */
void ToCUDA(size_t size = 0); void ToCUDA(size_t size = 0);
/*! \brief Return the memory state */ /*! \brief Return the state */
State state() const { State state() const {
return state_; return state_;
} }
/*! \brief Return the total number of bytes */ /*! \brief Return the data size */
size_t size() const { size_t size() const {
return size_; return size_;
} }
/*! \brief Return the total number of bytes on given device */ /*! \brief Return the data size on given device */
size_t size(const string& device_type, int device_id) const { size_t size(const string& device_type, int device_id) const {
if (device_type == "cuda") { if (device_type == "cuda") {
if (own_cuda_ptr_ && cuda_ptr_ && device_id_ == device_id) { if (own_cuda_ptr_ && cuda_ptr_ && device_id_ == device_id) {
......
...@@ -5,72 +5,65 @@ namespace dragon { ...@@ -5,72 +5,65 @@ namespace dragon {
OperatorBase::OperatorBase(const OperatorDef& def, Workspace* ws) OperatorBase::OperatorBase(const OperatorDef& def, Workspace* ws)
: def_(def), : def_(def),
ws_(ws), workspace_(ws),
phase_("TRAIN"), phase_("TRAIN"),
handle_(def.name()), name_(def.name()),
dtype_("float32"), data_type_("float32"),
data_format_("NCHW") { data_format_("NCHW") {
// Scan the defined arguments // Set arguments.
for (auto& arg : def_.arg()) { for (auto& arg : def_.arg()) {
CHECK_GT(arg.name().size(), 0); CHECK_GT(arg.name().size(), 0);
CHECK_EQ(args_.count(arg.name()), 0); CHECK_EQ(args_.count(arg.name()), 0);
args_[arg.name()] = &arg; args_[arg.name()] = &arg;
if (arg.name() == "handle") { if (arg.name() == "name") {
handle_ = arg.s(); name_ = arg.s();
} else if (arg.name() == "dtype") { } else if (arg.name() == "dtype") {
dtype_ = arg.s(); data_type_ = arg.s();
} else if (arg.name() == "data_format") { } else if (arg.name() == "data_format") {
data_format_ = arg.s(); data_format_ = arg.s();
} }
} }
// Set inputs and outputs.
// Set the inputs and outputs
size_t version_pos;
for (const auto& input : def.input()) { for (const auto& input : def.input()) {
string name = input; inputs_.push_back(ws->GetTensor(input));
if ((version_pos = input.find("/ver:")) != string::npos) {
name = input.substr(0, version_pos);
}
inputs_.push_back(ws->GetTensor(name));
} }
for (const auto& output : def.output()) { for (const auto& output : def.output()) {
string name = output; outputs_.push_back(ws->CreateTensor(output));
if ((version_pos = output.find("/ver:")) != string::npos) {
name = output.substr(0, version_pos);
}
outputs_.push_back(ws->CreateTensor(name));
} }
} }
Tensor& OperatorBase::Input(int i) { Tensor& OperatorBase::Input(int index) {
CHECK_LT(i, (int)inputs_.size()); CHECK_LT(index, InputSize());
CHECK_GE(i, -(int)inputs_.size()); CHECK_GE(index, -InputSize());
if (i >= 0) return *inputs_[i]; if (index >= 0) return *inputs_[index];
return *inputs_[i + inputs_.size()]; return *inputs_[index + inputs_.size()];
} }
Tensor* OperatorBase::Output(int i) { Tensor& OperatorBase::Input(const string& name) {
CHECK_LT(i, (int)outputs_.size()); return *workspace_->GetTensor(name_ + "/" + name);
CHECK_GE(i, -(int)outputs_.size());
if (i >= 0) return outputs_[i]->MapFrom(nullptr);
return outputs_[i + outputs_.size()]->MapFrom(nullptr);
} }
Tensor* OperatorBase::Output(int i, const vec32_t& inputs) { Tensor* OperatorBase::Output(int index) {
auto* Y = Output(i); CHECK_LT(index, OutputSize());
if (i < output_aliases_.size()) { CHECK_GE(index, -OutputSize());
for (auto j : inputs) { if (index >= 0) return outputs_[index];
auto& X = Input(j); return outputs_[index + outputs_.size()];
if (output_aliases_[i].count(X.name())) { }
return Y->ReshapeLike(X)->MapFrom(&X);
} Tensor* OperatorBase::Output(int index, const vec32_t& inputs_at) {
auto* output = Output(index);
if (index >= outputs_from_.size()) return output;
for (auto input_index : inputs_at) {
auto& input = Input(input_index);
if (outputs_from_[index].count(input.name())) {
return output->ReshapeLike(input)->MapFrom(&input);
} }
} }
return Y->MapFrom(nullptr); return output;
} }
Tensor* OperatorBase::Buffer(const string& name) { Tensor* OperatorBase::Output(const string& name) {
return workspace()->CreateTensor(handle_ + "/" + name); return workspace_->CreateTensor(name_ + "/" + name);
} }
OperatorBase* OperatorBase::New(const OperatorDef& def, Workspace* ws) { OperatorBase* OperatorBase::New(const OperatorDef& def, Workspace* ws) {
...@@ -83,8 +76,7 @@ OperatorBase* OperatorBase::New(const OperatorDef& def, Workspace* ws) { ...@@ -83,8 +76,7 @@ OperatorBase* OperatorBase::New(const OperatorDef& def, Workspace* ws) {
case PROTO_CUDA: case PROTO_CUDA:
#ifdef USE_CUDNN #ifdef USE_CUDNN
if (CUDNNOperatorRegistry()->Has(op_type) && if (CUDNNOperatorRegistry()->Has(op_type) &&
CUDAContext::objects().cudnn_enabled_ && CUDAContext::objects().cudnn_enabled_) {
!CUDAContext::objects().cudnn_disabled_ops_.count(op_type)) {
return CUDNNOperatorRegistry()->Create(op_type, def, ws); return CUDNNOperatorRegistry()->Create(op_type, def, ws);
} }
#endif #endif
...@@ -96,58 +88,22 @@ OperatorBase* OperatorBase::New(const OperatorDef& def, Workspace* ws) { ...@@ -96,58 +88,22 @@ OperatorBase* OperatorBase::New(const OperatorDef& def, Workspace* ws) {
} }
OperatorBase* OperatorBase::DeriveFrom(const OperatorDef& def) { OperatorBase* OperatorBase::DeriveFrom(const OperatorDef& def) {
handle_ = def.name(); name_ = def.name();
if (def.arg().size() > 1) { if (def.arg().size() > 1) {
const auto& arg = *(def.arg().end() - 2); const auto& arg = *(def.arg().end() - 2);
if (arg.name() == "handle") handle_ = arg.s(); if (arg.name() == "name") name_ = arg.s();
} }
inputs_.resize(def.input_size()); inputs_.resize(def.input_size());
outputs_.resize(def.output_size()); outputs_.resize(def.output_size());
for (int i = 0; i < inputs_.size(); i++) { for (int i = 0; i < inputs_.size(); i++) {
inputs_[i] = workspace()->GetTensor(def.input(i)); inputs_[i] = workspace_->GetTensor(def.input(i));
} }
for (int i = 0; i < outputs_.size(); i++) { for (int i = 0; i < outputs_.size(); i++) {
outputs_[i] = workspace()->CreateTensor(def.output(i)); outputs_[i] = workspace_->CreateTensor(def.output(i));
} }
return this; return this;
} }
template <class Context>
void Operator<Context>::Prepare() {
for (int i = 0; i < InputSize(); ++i) {
auto& X = *inputs_[i];
if (X.version() >= 0) {
const auto& name = def().input(i);
auto ver_pos = name.find("/ver:");
auto version = std::atoi(name.substr(ver_pos + 5).c_str());
if (version == X.version()) continue;
LOG(DEBUG) << "Excepted version of Tensor(" + X.name() + ") "
<< "is " << version << ", got " << X.version()
<< ". Recompute.";
Tensor* flag = workspace()->GetTensor("flagged/recomp");
flag->mutable_data<bool, CPUContext>()[0] = true;
vector<OperatorBase*>& chain = subgraph()[name];
for (auto* op : chain) {
op->Run(ctx()->stream());
}
flag->mutable_data<bool, CPUContext>()[0] = false;
}
}
}
template <class Context>
void Operator<Context>::Release() {
for (int i = 0; i < OutputSize(); ++i) {
auto* Y = outputs_[i];
if (Y->version() >= 0) {
const auto& name = def().output(i);
auto ver_pos = name.find("/ver:");
auto version = std::atoi(name.substr(ver_pos + 5).c_str());
Y->set_version(version);
}
}
}
/* Operator Registry */ /* Operator Registry */
DEFINE_REGISTRY( DEFINE_REGISTRY(
......
...@@ -154,6 +154,7 @@ class DRAGON_API Tensor { ...@@ -154,6 +154,7 @@ class DRAGON_API Tensor {
offset_ = offset; offset_ = offset;
} }
} }
version_ = -1;
return this; return this;
} }
...@@ -208,7 +209,7 @@ class DRAGON_API Tensor { ...@@ -208,7 +209,7 @@ class DRAGON_API Tensor {
} }
/*! \brief Return the tensor version */ /*! \brief Return the tensor version */
int version() const { int64_t version() const {
return version_; return version_;
} }
...@@ -330,7 +331,7 @@ class DRAGON_API Tensor { ...@@ -330,7 +331,7 @@ class DRAGON_API Tensor {
template <class Context> template <class Context>
void raw_mutable_data(void** data_ptr) { void raw_mutable_data(void** data_ptr) {
auto* memory_ptr = memory(); auto* memory_ptr = memory();
if (!memory_ptr) { if (memory_ptr == nullptr) {
*data_ptr = nullptr; *data_ptr = nullptr;
} else { } else {
const auto context_type = TypeMeta::Id<Context>(); const auto context_type = TypeMeta::Id<Context>();
...@@ -349,11 +350,18 @@ class DRAGON_API Tensor { ...@@ -349,11 +350,18 @@ class DRAGON_API Tensor {
void* raw_mutable_data() { void* raw_mutable_data() {
CHECK_NE(meta_.id(), 0) << "\nTensor(" << name_ << "): unknown type, " CHECK_NE(meta_.id(), 0) << "\nTensor(" << name_ << "): unknown type, "
<< "or does not have a type."; << "or does not have a type.";
if (mapped_memory_ != nullptr) {
if (version_ == 0) {
MapFrom(nullptr);
} else {
version_ = 0;
}
}
void* data_ptr; void* data_ptr;
raw_mutable_data<Context>(&data_ptr); raw_mutable_data<Context>(&data_ptr);
if (data_ptr) return data_ptr; if (data_ptr) return data_ptr;
CHECK_GT(size_, 0) << "\nInvalid tensor size."; CHECK_GT(size_, 0) << "\nInvalid tensor size.";
capacity_ = size_ * meta_.itemsize(); capacity_ = ((size_ * meta_.itemsize() + 511) / 512) * 512;
memory_.reset(new UnifiedMemory(meta_, capacity_)); memory_.reset(new UnifiedMemory(meta_, capacity_));
raw_mutable_data<Context>(&data_ptr); raw_mutable_data<Context>(&data_ptr);
if (meta_.ctor()) meta_.ctor()(data_ptr, size_); if (meta_.ctor()) meta_.ctor()(data_ptr, size_);
...@@ -377,7 +385,7 @@ class DRAGON_API Tensor { ...@@ -377,7 +385,7 @@ class DRAGON_API Tensor {
} }
/*! \brief Set the tensor version */ /*! \brief Set the tensor version */
void set_version(int version) { void set_version(int64_t version) {
version_ = version; version_ = version;
} }
...@@ -416,7 +424,7 @@ class DRAGON_API Tensor { ...@@ -416,7 +424,7 @@ class DRAGON_API Tensor {
size_t offset_ = 0; size_t offset_ = 0;
/*! \brief The tensor version */ /*! \brief The tensor version */
int version_ = -1; int64_t version_ = -1;
/*! \brief The dimensions */ /*! \brief The dimensions */
vec64_t dims_; vec64_t dims_;
......
...@@ -80,6 +80,9 @@ using Integral = TypesBase<uint8_t, int8_t, int, int64_t>; ...@@ -80,6 +80,9 @@ using Integral = TypesBase<uint8_t, int8_t, int, int64_t>;
/*! \brief Floating types */ /*! \brief Floating types */
using Floating = TypesBase<float16, float, double>; using Floating = TypesBase<float16, float, double>;
/*! \brief Accumulated types */
using Accumulated = dtypes::TypesBase<int, int64_t, float16, float, double>;
/*! \brief Convert the type string to meta */ /*! \brief Convert the type string to meta */
inline const TypeMeta& to_meta(const std::string& type) { inline const TypeMeta& to_meta(const std::string& type) {
static TypeMeta unknown_type; static TypeMeta unknown_type;
......
...@@ -6,24 +6,21 @@ namespace dragon { ...@@ -6,24 +6,21 @@ namespace dragon {
Workspace::Workspace(const string& name) : name_(name) { Workspace::Workspace(const string& name) : name_(name) {
CreateTensor(""); // Empty placeholder CreateTensor(""); // Empty placeholder
CreateTensor("flagged/recomp")
->Reshape({})
->mutable_data<bool, CPUContext>()[0] = false;
} }
void Workspace::MergeFrom(Workspace* other) { void Workspace::MergeFrom(Workspace* other) {
if (other != nullptr) { if (other != nullptr) {
// Add the external tensors // Add the external tensors
for (const auto& it : other->tensor_map_) { for (const auto& it : other->tensors_) {
if (!it.first.empty() && !str::startswith(it.first, "/")) { if (!it.first.empty() && !str::startswith(it.first, "/")) {
external_tensor_map_[it.first] = it.second.get(); external_tensors_[it.first] = it.second.get();
} }
} }
// Recount the unique index to avoid duplicate names // Recount the unique index to avoid duplicate names
for (const auto& i : other->unique_index_map_) { for (const auto& i : other->scope_counters_) {
auto& index_map = unique_index_map_[i.first]; auto& counters = scope_counters_[i.first];
for (const auto& j : i.second) { for (const auto& j : i.second) {
index_map[j.first] = std::max(index_map[j.first], j.second); counters[j.first] = std::max(counters[j.first], j.second);
} }
} }
} }
...@@ -33,70 +30,66 @@ void Workspace::Clear() { ...@@ -33,70 +30,66 @@ void Workspace::Clear() {
// Following resources usually take large memory blob. // Following resources usually take large memory blob.
// It's necessary to clear them manually if workspace referenced // It's necessary to clear them manually if workspace referenced
// by the frontend GC circularly. // by the frontend GC circularly.
graph_map_.clear(); graphs_.clear();
operator_map_.clear(); operators_.clear();
for (const auto& it : tensor_map_) { for (const auto& iter : tensors_) {
// The tensor pointer may be referenced by the frontend. // The tensor pointer may be referenced by the frontend.
// Reset memory only to avoid the dangling pointer. // Reset memory only to avoid the dangling pointer.
it.second->Reset(); iter.second->Reset();
} }
// Reinitialize the tensor flags
GetTensor("flagged/recomp")
->Reshape({})
->mutable_data<bool, CPUContext>()[0] = false;
} }
Tensor* Workspace::TryGetTensor(const string& name, bool external) const { Tensor* Workspace::TryGetTensor(const string& name, bool external) const {
// Check the alias firstly // Check the alias.
const auto& alias_it = alias_map_.find(name); const auto& alias_iter = aliases_.find(name);
auto name_v2 = alias_it != alias_map_.end() ? alias_it->second : name; auto name_v2 = alias_iter != aliases_.end() ? alias_iter->second : name;
// Search this workspace // Search this workspace.
const auto& it = tensor_map_.find(name_v2); const auto& iter = tensors_.find(name_v2);
if (it != tensor_map_.end()) return it->second.get(); if (iter != tensors_.end()) return iter->second.get();
if (external) { if (external) {
// Search external workspaces // Search external workspaces.
const auto& it = external_tensor_map_.find(name_v2); const auto& iter = external_tensors_.find(name_v2);
if (it != external_tensor_map_.end()) return it->second; if (iter != external_tensors_.end()) return iter->second;
} }
return nullptr; return nullptr;
} }
Tensor* Workspace::CreateTensor(const string& name) { Tensor* Workspace::CreateTensor(const string& name) {
auto* tensor = TryGetTensor(name); auto* tensor_ptr = TryGetTensor(name);
// Create only if name not existed // Create only if name not existed.
if (tensor == nullptr) { if (tensor_ptr == nullptr) {
tensor = new Tensor(name); tensor_ptr = new Tensor(name);
tensor_map_[name] = unique_ptr<Tensor>(tensor); tensors_[name] = unique_ptr<Tensor>(tensor_ptr);
} }
return tensor; return tensor_ptr;
} }
Tensor* Workspace::GetTensor(const string& name, bool external) const { Tensor* Workspace::GetTensor(const string& name, bool external) const {
auto* tensor = TryGetTensor(name, external); auto* tensor_ptr = TryGetTensor(name, external);
CHECK(tensor) << "\nTensor(" << name << ") is not in current workspace."; CHECK(tensor_ptr) << "\nTensor '" << name << "' is not in workspace.";
return tensor; return tensor_ptr;
} }
void Workspace::RunOperator(const OperatorDef& def) { void Workspace::RunOperator(const OperatorDef& def) {
string cache_key; string cache_key;
OperatorBase* execute_op = nullptr; OperatorBase* op_ptr = nullptr;
if (!def.arg().empty()) { if (!def.arg().empty()) {
const auto& arg = *(def.arg().end() - 1); const auto& arg = *(def.arg().end() - 1);
if (arg.name() == "cache_key") cache_key = arg.s(); if (arg.name() == "cache_key") cache_key = arg.s();
} }
if (cache_key.empty()) { if (cache_key.empty()) {
execute_op = OperatorBase::New(def, this); op_ptr = OperatorBase::New(def, this);
execute_op->Run(); op_ptr->Run();
delete execute_op; delete op_ptr;
} else { } else {
const auto& iter = operator_map_.find(cache_key); const auto& iter = operators_.find(cache_key);
if (iter == operator_map_.end()) { if (iter == operators_.end()) {
execute_op = OperatorBase::New(def, this); op_ptr = OperatorBase::New(def, this);
operator_map_[cache_key] = unique_ptr<OperatorBase>(execute_op); operators_[cache_key] = unique_ptr<OperatorBase>(op_ptr);
} else { } else {
execute_op = iter->second.get(); op_ptr = iter->second.get();
} }
execute_op->DeriveFrom(def)->Run(); op_ptr->DeriveFrom(def)->Run();
} }
} }
...@@ -105,19 +98,19 @@ GraphBase* Workspace::CreateGraph(const GraphDef& def) { ...@@ -105,19 +98,19 @@ GraphBase* Workspace::CreateGraph(const GraphDef& def) {
GraphDef def_v2(def); // Copy to set an unique name GraphDef def_v2(def); // Copy to set an unique name
def_v2.set_name(UniqueName(def.name(), "", "Graph", false)); def_v2.set_name(UniqueName(def.name(), "", "Graph", false));
LOG(DEBUG) << "Create: " << def_v2.name(); LOG(DEBUG) << "Create: " << def_v2.name();
auto* graph = GraphBase::New(def_v2, this); auto* graph_ptr = GraphBase::New(def_v2, this);
graph_map_[def_v2.name()] = unique_ptr<GraphBase>(graph); graphs_[def_v2.name()] = unique_ptr<GraphBase>(graph_ptr);
return graph; return graph_ptr;
} }
void Workspace::RunGraph( void Workspace::RunGraph(
const string& name, const string& name,
const string& include, const string& include,
const string& exclude, const string& exclude,
const int stream) { int stream) {
CHECK(graph_map_.count(name)) CHECK(graphs_.count(name))
<< "\nGraph " << name << " is not in current workspace."; << "\nGraph " << name << " is not in current workspace.";
graph_map_[name]->Run(stream, include, exclude); graphs_[name]->Run(stream, include, exclude);
} }
string Workspace::UniqueName( string Workspace::UniqueName(
...@@ -125,22 +118,22 @@ string Workspace::UniqueName( ...@@ -125,22 +118,22 @@ string Workspace::UniqueName(
const string& suffix, const string& suffix,
const string& scope, const string& scope,
bool zero_based) { bool zero_based) {
auto& index_map = unique_index_map_[scope]; auto& counters = scope_counters_[scope];
auto required_name = name + suffix; auto target_name = name + suffix;
auto index = index_map[required_name]++; auto index = counters[target_name]++;
if (index > 0) return name + "_" + str::to(index) + suffix; if (index > 0) return name + "_" + str::to(index) + suffix;
if (zero_based) return required_name; if (zero_based) return target_name;
return name + "_" + str::to(index_map[required_name]++) + suffix; return name + "_" + str::to(counters[target_name]++) + suffix;
} }
vector<string> Workspace::tensors(bool external) const { vector<string> Workspace::tensors(bool external) const {
vector<string> names; vector<string> names;
for (const auto& it : tensor_map_) { for (const auto& iter : tensors_) {
names.emplace_back(it.first); names.emplace_back(iter.first);
} }
if (external) { if (external) {
for (const auto& it : external_tensor_map_) { for (const auto& iter : external_tensors_) {
names.emplace_back(it.first); names.emplace_back(iter.first);
} }
} }
return names; return names;
...@@ -148,8 +141,8 @@ vector<string> Workspace::tensors(bool external) const { ...@@ -148,8 +141,8 @@ vector<string> Workspace::tensors(bool external) const {
vector<string> Workspace::graphs() const { vector<string> Workspace::graphs() const {
vector<string> names; vector<string> names;
for (const auto& it : graph_map_) { for (const auto& iter : graphs_) {
names.emplace_back(it.first); names.emplace_back(iter.first);
} }
return names; return names;
} }
......
...@@ -40,7 +40,7 @@ class DRAGON_API Workspace { ...@@ -40,7 +40,7 @@ class DRAGON_API Workspace {
/* \brief Set an alias for the target */ /* \brief Set an alias for the target */
void SetAlias(const string& target, const string& alias) { void SetAlias(const string& target, const string& alias) {
alias_map_[alias] = target; aliases_[alias] = target;
} }
/*! \brief Return whether tensor is existing */ /*! \brief Return whether tensor is existing */
...@@ -68,73 +68,54 @@ class DRAGON_API Workspace { ...@@ -68,73 +68,54 @@ class DRAGON_API Workspace {
const string& name, const string& name,
const string& include = "", const string& include = "",
const string& exclude = "", const string& exclude = "",
const int stream = 0); int stream = 0);
/*! \brief Return the workspace name */ /*! \brief Return the workspace name */
const string& name() { const string& name() {
return name_; return name_;
} }
/*! \brief Return the name of cached tensors */ /*! \brief Return the name of created tensors */
vector<string> tensors(bool external = true) const; vector<string> tensors(bool external = true) const;
/*! \brief Return the name of cached graphs */ /*! \brief Return the name of created graphs */
vector<string> graphs() const; vector<string> graphs() const;
/*! \brief Return a group of the shared raw data */ /*! \brief Return a shared raw data */
template <class Context> template <class Context>
vector<void*> data( void* data(size_t size, const string& name = "BufferShared") {
const vector<size_t>& segments, size = size > size_t(0) ? size : size_t(1);
const string& name = "data:0") { auto* tensor = CreateTensor(name)->Reshape({int64_t(size)});
vector<void*> group(segments.size()); return (void*)tensor->template mutable_data<uint8_t, Context>();
group[0] = CreateTensor("shared/buffer/" + name)
->Reshape({(int64_t)std::accumulate(
segments.begin(), segments.end(), size_t(0))})
->template mutable_data<uint8_t, Context>();
for (int i = 1; i < segments.size(); ++i) {
group[i] = (uint8_t*)group[i - 1] + segments[i - 1];
}
return group;
} }
/*! \brief Return a group of shared typed data */ /*! \brief Return a shared typed data */
template <typename T, class Context> template <typename T, class Context>
vector<T*> data( T* data(int64_t size, const string& name = "BufferShared") {
const vector<int64_t>& segments, return (T*)data<Context>(size_t(size) * sizeof(T), name);
const string& name = "data:0") {
vector<T*> group(segments.size());
vector<size_t> segments_v2;
for (const auto size : segments) {
segments_v2.push_back(size * sizeof(T));
}
auto group_v2 = data<Context>(segments_v2, name);
for (int i = 0; i < segments.size(); ++i) {
group[i] = (T*)group_v2[i];
}
return group;
} }
private: private:
/*! \brief The workspace name */ /*! \brief The workspace name */
string name_; string name_;
/*! \brief The unique indices */ /*! \brief The scope counters */
Map<string, Map<string, int64_t>> unique_index_map_; Map<string, Map<string, int64_t>> scope_counters_;
/*! \brief The created aliases */ /*! \brief The aliases */
Map<string, string> alias_map_; Map<string, string> aliases_;
/*! \brief The created tensors */ /*! \brief The tensors */
Map<string, unique_ptr<Tensor>> tensor_map_; Map<string, unique_ptr<Tensor>> tensors_;
/*! \brief The external tensors */ /*! \brief The external tensors */
Map<string, Tensor*> external_tensor_map_; Map<string, Tensor*> external_tensors_;
/*! \brief The created operators */ /*! \brief The operators */
Map<string, unique_ptr<OperatorBase>> operator_map_; Map<string, unique_ptr<OperatorBase>> operators_;
/*! \brief The created graphs */ /*! \brief The graphs */
Map<string, unique_ptr<GraphBase>> graph_map_; Map<string, unique_ptr<GraphBase>> graphs_;
DISABLE_COPY_AND_ASSIGN(Workspace); DISABLE_COPY_AND_ASSIGN(Workspace);
}; };
......
...@@ -12,15 +12,19 @@ if (NOT BUILD_RUNTIME) ...@@ -12,15 +12,19 @@ if (NOT BUILD_RUNTIME)
add_subdirectory(training) add_subdirectory(training)
endif() endif()
# ---[ Merge CUDA kernels to speed up compiling # ---[ Merge CUDA kernels
if (USE_CUDA) if (USE_CUDA)
set(_gen_file ${CMAKE_CURRENT_BINARY_DIR}/../codegen/op_kernels.cu) if (MSVC)
file(WRITE ${_gen_file} "") set(_gen_file ${CMAKE_CURRENT_BINARY_DIR}/../kernels/op_kernels.cu)
foreach(_file ${KERNEL_CUDA_SOURCES}) file(WRITE ${_gen_file} "")
file(STRINGS ${_file} tmp NEWLINE_CONSUME) foreach(_file ${KERNEL_CUDA_SOURCES})
file(APPEND ${_gen_file} ${tmp} "\n") file(STRINGS ${_file} tmp NEWLINE_CONSUME)
endforeach() file(APPEND ${_gen_file} ${tmp} "\n")
set(MODULE_CUDA_SOURCES ${MODULE_CUDA_SOURCES} ${_gen_file}) endforeach()
set(MODULE_CUDA_SOURCES ${MODULE_CUDA_SOURCES} ${_gen_file})
else()
set(MODULE_CUDA_SOURCES ${MODULE_CUDA_SOURCES} ${KERNEL_CUDA_SOURCES})
endif()
endif() endif()
# ---[ Submit to the parent scope # ---[ Submit to the parent scope
......
...@@ -15,8 +15,7 @@ void _Assign( ...@@ -15,8 +15,7 @@ void _Assign(
const int64_t* starts, const int64_t* starts,
const T* x, const T* x,
T* y) { T* y) {
const auto N = const auto N = math::utils::Prod(num_dims, x_dims);
std::accumulate(x_dims, x_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0); vec64_t index(num_dims, 0);
int yi; int yi;
for (int xi = 0; xi < N; ++xi) { for (int xi = 0; xi < N; ++xi) {
......
...@@ -46,8 +46,7 @@ __global__ void _Assign( ...@@ -46,8 +46,7 @@ __global__ void _Assign(
CUDAContext* ctx) { \ CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \ CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_dims, Y_strides, X_starts; \ SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_dims, Y_strides, X_starts; \
const auto N = std::accumulate( \ const auto N = math::utils::Prod(num_dims, x_dims); \
x_dims, x_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \ for (int i = 0; i < num_dims; ++i) { \
X_dims.data[i] = x_dims[i]; \ X_dims.data[i] = x_dims[i]; \
Y_strides.data[i] = y_strides[i]; \ Y_strides.data[i] = y_strides[i]; \
......
...@@ -17,8 +17,7 @@ void _ChannelNormalize( ...@@ -17,8 +17,7 @@ void _ChannelNormalize(
const float* mean, const float* mean,
const float* std, const float* std,
OutputT* y) { OutputT* y) {
const auto N = const auto N = math::utils::Prod(num_dims, y_dims);
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t idx(num_dims, 0); vec64_t idx(num_dims, 0);
int64_t xi, wi; int64_t xi, wi;
for (int yi = 0; yi < N; ++yi) { for (int yi = 0; yi < N; ++yi) {
......
...@@ -38,32 +38,31 @@ __global__ void _ChannelNormalize( ...@@ -38,32 +38,31 @@ __global__ void _ChannelNormalize(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void ChannelNormalize<InputT, OutputT, CUDAContext>( \ void ChannelNormalize<InputT, OutputT, CUDAContext>( \
const int axis, \ const int axis, \
const int num_dims, \ const int num_dims, \
const int64_t* x_strides, \ const int64_t* x_strides, \
const int64_t* y_dims, \ const int64_t* y_dims, \
const InputT* x, \ const InputT* x, \
const float* mean, \ const float* mean, \
const float* std, \ const float* std, \
OutputT* y, \ OutputT* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \ CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims; \ SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims; \
const auto N = std::accumulate( \ const auto N = math::utils::Prod(num_dims, y_dims); \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \ for (int i = 0; i < num_dims; ++i) { \
for (int i = 0; i < num_dims; ++i) { \ X_strides.data[i] = x_strides[i]; \
X_strides.data[i] = x_strides[i]; \ Y_dims.data[i] = y_dims[i]; \
Y_dims.data[i] = y_dims[i]; \ } \
} \ _ChannelNormalize<<< \
_ChannelNormalize<<< \ CUDA_BLOCKS(N), \
CUDA_BLOCKS(N), \ CUDA_THREADS, \
CUDA_THREADS, \ 0, \
0, \ ctx->cuda_stream()>>>( \
ctx->cuda_stream()>>>( \ N, axis, num_dims, X_strides, Y_dims, x, mean, std, y); \
N, axis, num_dims, X_strides, Y_dims, x, mean, std, y); \
} }
DEFINE_KERNEL_LAUNCHER(uint8_t, float16); DEFINE_KERNEL_LAUNCHER(uint8_t, float16);
......
...@@ -60,8 +60,7 @@ void _GatherElements( ...@@ -60,8 +60,7 @@ void _GatherElements(
const int64_t* index, const int64_t* index,
const T* x, const T* x,
T* y) { T* y) {
const auto N = const auto N = math::utils::Prod(num_dims, y_dims);
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t dim_index(num_dims, 0); vec64_t dim_index(num_dims, 0);
for (int yi = 0; yi < N; ++yi) { for (int yi = 0; yi < N; ++yi) {
int64_t xi = 0; int64_t xi = 0;
......
...@@ -28,7 +28,7 @@ __global__ void _Gather( ...@@ -28,7 +28,7 @@ __global__ void _Gather(
} }
} }
template <typename T> template <typename T, typename AccT>
__global__ void _GatherGrad( __global__ void _GatherGrad(
const int NxKxS, const int NxKxS,
const int S, const int S,
...@@ -36,14 +36,14 @@ __global__ void _GatherGrad( ...@@ -36,14 +36,14 @@ __global__ void _GatherGrad(
const int K, const int K,
const int64_t* index, const int64_t* index,
const T* dy, const T* dy,
float* dx) { AccT* dx) {
CUDA_1D_KERNEL_LOOP(yi, NxKxS) { CUDA_1D_KERNEL_LOOP(yi, NxKxS) {
const int j = yi % S; const int j = yi % S;
const int i = yi / S / K; const int i = yi / S / K;
int pos = __ldg(index + yi / S % K); int pos = __ldg(index + yi / S % K);
pos = (pos >= 0 ? pos : pos + C); pos = (pos >= 0 ? pos : pos + C);
math::utils::AtomicAdd( math::utils::AtomicAdd(
dx + (i * C + pos) * S + j, convert::To<float>(dy[yi])); dx + (i * C + pos) * S + j, convert::To<AccT>(dy[yi]));
} }
} }
...@@ -51,7 +51,6 @@ template <typename T, int D> ...@@ -51,7 +51,6 @@ template <typename T, int D>
__global__ void _GatherElements( __global__ void _GatherElements(
const int N, const int N,
const int axis, const int axis,
const int num_dims,
const SimpleArray<int, D> X_strides, const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims, const SimpleArray<int, D> Y_dims,
const int64_t* index, const int64_t* index,
...@@ -59,7 +58,8 @@ __global__ void _GatherElements( ...@@ -59,7 +58,8 @@ __global__ void _GatherElements(
T* y) { T* y) {
CUDA_1D_KERNEL_LOOP(yi, N) { CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi; int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) { #pragma unroll
for (int d = D - 1; d >= 0; --d) {
int r; int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r); FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
xi += (d == axis ? index[yi] : r) * X_strides.data[d]; xi += (d == axis ? index[yi] : r) * X_strides.data[d];
...@@ -68,6 +68,25 @@ __global__ void _GatherElements( ...@@ -68,6 +68,25 @@ __global__ void _GatherElements(
} }
} }
template <typename T, int D>
void _GatherElementsImpl(
const int axis,
const int64_t* x_strides,
const int64_t* y_dims,
const int64_t* index,
const T* x,
T* y,
CUDAContext* ctx) {
const auto N = math::utils::Prod(D, y_dims);
SimpleArray<int, D> X_strides, Y_dims;
for (int i = 0; i < D; ++i) {
X_strides.data[i] = x_strides[i];
Y_dims.data[i] = y_dims[i];
}
_GatherElements<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, axis, X_strides, Y_dims, index, x, y);
}
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
...@@ -107,28 +126,29 @@ DEFINE_KERNEL_LAUNCHER(GatherGrad, float, float); // GatherGrad ...@@ -107,28 +126,29 @@ DEFINE_KERNEL_LAUNCHER(GatherGrad, float, float); // GatherGrad
DEFINE_KERNEL_LAUNCHER(GatherGrad, double, float); // GatherGrad DEFINE_KERNEL_LAUNCHER(GatherGrad, double, float); // GatherGrad
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T) \ #define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void name<T, CUDAContext>( \ void name<T, CUDAContext>( \
const int axis, \ const int axis, \
const int num_dims, \ const int num_dims, \
const int64_t* x_strides, \ const int64_t* x_strides, \
const int64_t* y_dims, \ const int64_t* y_dims, \
const int64_t* index, \ const int64_t* index, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \ CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides; \ DISPATCH_FUNC_BY_VALUE_WITH_TYPE_1( \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> Y_dims; \ _GatherElementsImpl, \
const auto N = std::accumulate( \ T, \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \ num_dims, \
for (int i = 0; i < num_dims; ++i) { \ axis, \
X_strides.data[i] = x_strides[i]; \ x_strides, \
Y_dims.data[i] = y_dims[i]; \ y_dims, \
} \ index, \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ x, \
N, axis, num_dims, X_strides, Y_dims, index, x, y); \ y, \
ctx); \
} }
DEFINE_KERNEL_LAUNCHER(GatherElements, bool); DEFINE_KERNEL_LAUNCHER(GatherElements, bool);
......
...@@ -17,8 +17,7 @@ void _ConstPad( ...@@ -17,8 +17,7 @@ void _ConstPad(
const T value, const T value,
const T* x, const T* x,
T* y) { T* y) {
const auto N = const auto N = math::utils::Prod(num_dims, y_dims);
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0); vec64_t index(num_dims, 0);
int64_t xi, d, r; int64_t xi, d, r;
for (int yi = 0; yi < N; ++yi) { for (int yi = 0; yi < N; ++yi) {
...@@ -42,8 +41,7 @@ void _ReflectPad( ...@@ -42,8 +41,7 @@ void _ReflectPad(
const int64_t* pads, const int64_t* pads,
const T* x, const T* x,
T* y) { T* y) {
const auto N = const auto N = math::utils::Prod(num_dims, y_dims);
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0); vec64_t index(num_dims, 0);
int64_t xi, d, r; int64_t xi, d, r;
for (int yi = 0; yi < N; ++yi) { for (int yi = 0; yi < N; ++yi) {
...@@ -68,8 +66,7 @@ void _EdgePad( ...@@ -68,8 +66,7 @@ void _EdgePad(
const int64_t* pads, const int64_t* pads,
const T* x, const T* x,
T* y) { T* y) {
const auto N = const auto N = math::utils::Prod(num_dims, y_dims);
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0); vec64_t index(num_dims, 0);
int64_t xi, d, r; int64_t xi, d, r;
for (int yi = 0; yi < N; ++yi) { for (int yi = 0; yi < N; ++yi) {
......
...@@ -92,8 +92,7 @@ void _PadImpl( ...@@ -92,8 +92,7 @@ void _PadImpl(
T* y, T* y,
CUDAContext* ctx) { CUDAContext* ctx) {
SimpleArray<int, D> X_dims, X_strides, Y_dims, X_pads; SimpleArray<int, D> X_dims, X_strides, Y_dims, X_pads;
const auto N = const auto N = math::utils::Prod(D, y_dims);
std::accumulate(y_dims, y_dims + D, 1, std::multiplies<int64_t>());
for (int i = 0; i < D; ++i) { for (int i = 0; i < D; ++i) {
X_dims.data[i] = x_dims[i]; X_dims.data[i] = x_dims[i];
X_strides.data[i] = x_strides[i]; X_strides.data[i] = x_strides[i];
......
...@@ -15,8 +15,7 @@ void _Reverse( ...@@ -15,8 +15,7 @@ void _Reverse(
const int64_t* y_dims, const int64_t* y_dims,
const T* x, const T* x,
T* y) { T* y) {
const auto N = const auto N = math::utils::Prod(num_dims, y_dims);
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0); vec64_t index(num_dims, 0);
int64_t xi; int64_t xi;
for (int yi = 0; yi < N; ++yi) { for (int yi = 0; yi < N; ++yi) {
......
...@@ -47,8 +47,7 @@ __global__ void _Reverse( ...@@ -47,8 +47,7 @@ __global__ void _Reverse(
CUDA_TENSOR_DIMS_CHECK(num_dims); \ CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<uint8_t, CUDA_TENSOR_MAX_DIMS> X_flips; \ SimpleArray<uint8_t, CUDA_TENSOR_MAX_DIMS> X_flips; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims; \ SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims; \
const auto N = std::accumulate( \ const auto N = math::utils::Prod(num_dims, y_dims); \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \ for (int i = 0; i < num_dims; ++i) { \
X_flips.data[i] = x_flips[i]; \ X_flips.data[i] = x_flips[i]; \
X_strides.data[i] = x_strides[i]; \ X_strides.data[i] = x_strides[i]; \
......
...@@ -15,8 +15,7 @@ void _Roll( ...@@ -15,8 +15,7 @@ void _Roll(
const int64_t* y_dims, const int64_t* y_dims,
const T* x, const T* x,
T* y) { T* y) {
const auto N = const auto N = math::utils::Prod(num_dims, y_dims);
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0); vec64_t index(num_dims, 0);
for (int yi = 0; yi < N; ++yi) { for (int yi = 0; yi < N; ++yi) {
int64_t xi = 0, r; int64_t xi = 0, r;
......
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -40,8 +41,7 @@ void _RollImpl( ...@@ -40,8 +41,7 @@ void _RollImpl(
T* y, T* y,
CUDAContext* ctx) { CUDAContext* ctx) {
SimpleArray<int, D> X_shifts, X_strides, Y_dims; SimpleArray<int, D> X_shifts, X_strides, Y_dims;
const auto N = const auto N = math::utils::Prod(D, y_dims);
std::accumulate(y_dims, y_dims + D, 1, std::multiplies<int64_t>());
for (int i = 0; i < D; ++i) { for (int i = 0; i < D; ++i) {
X_shifts.data[i] = x_shifts[i]; X_shifts.data[i] = x_shifts[i];
X_strides.data[i] = x_strides[i]; X_strides.data[i] = x_strides[i];
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!