Commit bd84b7fd by Ting PAN

Add masked AssignOp

1 parent e90a8f1a
...@@ -168,6 +168,7 @@ List Brief ...@@ -168,6 +168,7 @@ List Brief
=============== ====================================================================== =============== ======================================================================
`Copy`_ Copy the *value* to *ref*. `Copy`_ Copy the *value* to *ref*.
`Assign`_ Assign the *value* to *ref*. `Assign`_ Assign the *value* to *ref*.
`MaskedAssign`_ Assign the *value* to *ref* where mask is *1*.
`Equal`_ *Equal* Comparing between A and B. `Equal`_ *Equal* Comparing between A and B.
`Less`_ *Less* Comparing between A and B. `Less`_ *Less* Comparing between A and B.
`LessEqual`_ *LessEqual* Comparing between A and B. `LessEqual`_ *LessEqual* Comparing between A and B.
...@@ -308,8 +309,9 @@ List Brief ...@@ -308,8 +309,9 @@ List Brief
.. _Arange: operators/array.html#dragon.operators.array.Arange .. _Arange: operators/array.html#dragon.operators.array.Arange
.. _Multinomial: operators/array.html#dragon.operators.array.Multinomial .. _Multinomial: operators/array.html#dragon.operators.array.Multinomial
.. _Copy: operators/control_flow.html#dAragon.operators.control_flow.Copy .. _Copy: operators/control_flow.html#dragon.operators.control_flow.Copy
.. _Assign: operators/control_flow.html#dAragon.operators.control_flow.Assign .. _Assign: operators/control_flow.html#dragon.operators.control_flow.Assign
.. _MaskedAssign: operators/control_flow.html#dragon.operators.control_flow.MaskedAssign
.. _Equal: operators/control_flow.html#dragon.operators.control_flow.Equal .. _Equal: operators/control_flow.html#dragon.operators.control_flow.Equal
.. _Less: operators/control_flow.html#dragon.operators.control_flow.Less .. _Less: operators/control_flow.html#dragon.operators.control_flow.Less
.. _LessEqual: operators/control_flow.html#dragon.operators.control_flow.LessEqual .. _LessEqual: operators/control_flow.html#dragon.operators.control_flow.LessEqual
......
...@@ -72,7 +72,7 @@ class CUDAObject { ...@@ -72,7 +72,7 @@ class CUDAObject {
if (streams.size() <= (unsigned)stream_id) if (streams.size() <= (unsigned)stream_id)
streams.resize(stream_id + 1, nullptr); streams.resize(stream_id + 1, nullptr);
if (!streams[stream_id]) { if (!streams[stream_id]) {
DeviceGuard guard(device_id); CUDADeviceGuard guard(device_id);
unsigned int flags = !stream_id ? unsigned int flags = !stream_id ?
cudaStreamDefault : cudaStreamDefault :
cudaStreamNonBlocking; cudaStreamNonBlocking;
...@@ -97,7 +97,7 @@ class CUDAObject { ...@@ -97,7 +97,7 @@ class CUDAObject {
if (handles.size() <= (unsigned)stream_id) if (handles.size() <= (unsigned)stream_id)
handles.resize(stream_id + 1, nullptr); handles.resize(stream_id + 1, nullptr);
if (!handles[stream_id]) { if (!handles[stream_id]) {
DeviceGuard guard(device_id); CUDADeviceGuard guard(device_id);
CUBLAS_CHECK(cublasCreate_v2(&handles[stream_id])); CUBLAS_CHECK(cublasCreate_v2(&handles[stream_id]));
CUBLAS_CHECK(cublasSetStream_v2( CUBLAS_CHECK(cublasSetStream_v2(
handles[stream_id], handles[stream_id],
...@@ -120,7 +120,7 @@ class CUDAObject { ...@@ -120,7 +120,7 @@ class CUDAObject {
if (handles.size() <= (unsigned)stream_id) if (handles.size() <= (unsigned)stream_id)
handles.resize(stream_id + 1, nullptr); handles.resize(stream_id + 1, nullptr);
if (!handles[stream_id]) { if (!handles[stream_id]) {
DeviceGuard guard(device_id); CUDADeviceGuard guard(device_id);
CUDNN_CHECK(cudnnCreate(&handles[stream_id])); CUDNN_CHECK(cudnnCreate(&handles[stream_id]));
CUDNN_CHECK(cudnnSetStream( CUDNN_CHECK(cudnnSetStream(
handles[stream_id], handles[stream_id],
...@@ -292,7 +292,7 @@ class CUDAContext { ...@@ -292,7 +292,7 @@ class CUDAContext {
/*! \brief Return the internal cuda random generator */ /*! \brief Return the internal cuda random generator */
curandGenerator_t& curand_generator() { curandGenerator_t& curand_generator() {
if (!curand_generator_) { if (!curand_generator_) {
DeviceGuard guard(device_id_); CUDADeviceGuard guard(device_id_);
CURAND_CHECK(curandCreateGenerator( CURAND_CHECK(curandCreateGenerator(
&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); &curand_generator_, CURAND_RNG_PSEUDO_DEFAULT));
CURAND_CHECK(curandSetPseudoRandomGeneratorSeed( CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_MASKED_ASSIGN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_MASKED_ASSIGN_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class MaskedAssignOp final : public Operator<Context> {
public:
MaskedAssignOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunImpl();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_CONTROL_FLOW_MASKED_ASSIGN_OP_H_
\ No newline at end of file
...@@ -135,19 +135,19 @@ struct CUDADeviceProps { ...@@ -135,19 +135,19 @@ struct CUDADeviceProps {
vector<cudaDeviceProp> props; vector<cudaDeviceProp> props;
}; };
inline const cudaDeviceProp& GetDeviceProperty( inline const cudaDeviceProp& GetCUDADeviceProp(
const int device_id) { int device_id) {
static CUDADeviceProps props; static CUDADeviceProps props;
CHECK_LT(device_id, (int)props.props.size()) CHECK_LT(device_id, (int)props.props.size())
<< "Invalid device id: " << device_id << "\nInvalid device id: " << device_id
<< "\nDetected " << props.props.size() << "\nDetected " << props.props.size()
<< " eligible cuda devices."; << " devices.";
return props.props[device_id]; return props.props[device_id];
} }
inline bool CUDA_TRUE_FP16_AVAILABLE() { inline bool CUDA_TRUE_FP16_AVAILABLE() {
int device = CUDA_GET_DEVICE(); int device = CUDA_GET_DEVICE();
auto& prop = GetDeviceProperty(device); auto& prop = GetCUDADeviceProp(device);
return prop.major >= 6; return prop.major >= 6;
} }
...@@ -156,21 +156,26 @@ inline bool TENSOR_CORE_AVAILABLE() { ...@@ -156,21 +156,26 @@ inline bool TENSOR_CORE_AVAILABLE() {
return false; return false;
#else #else
int device = CUDA_GET_DEVICE(); int device = CUDA_GET_DEVICE();
auto& prop = GetDeviceProperty(device); auto& prop = GetCUDADeviceProp(device);
return prop.major >= 7; return prop.major >= 7;
#endif #endif
} }
class DeviceGuard { class CUDADeviceGuard {
public: public:
DeviceGuard(int new_id) : prev_id(CUDA_GET_DEVICE()) { CUDADeviceGuard(int new_id)
if (prev_id != new_id) CUDA_CHECK(cudaSetDevice(new_id)); : prev_id_(CUDA_GET_DEVICE()) {
if (prev_id_ != new_id) {
CUDA_CHECK(cudaSetDevice(new_id));
}
} }
~DeviceGuard() { CUDA_CHECK(cudaSetDevice(prev_id)); } ~CUDADeviceGuard() {
CUDA_CHECK(cudaSetDevice(prev_id_));
}
private: private:
int prev_id; int prev_id_;
}; };
#else #else
......
...@@ -100,38 +100,38 @@ void CuDNNSetTensor3dDesc( ...@@ -100,38 +100,38 @@ void CuDNNSetTensor3dDesc(
template <typename T> template <typename T>
void CuDNNSetTensorDesc( void CuDNNSetTensorDesc(
cudnnTensorDescriptor_t* desc, cudnnTensorDescriptor_t* desc,
const vec64_t& dims); const vec64_t& dims);
template <typename T> template <typename T>
void CuDNNSetTensor4dDesc( void CuDNNSetTensor4dDesc(
cudnnTensorDescriptor_t* desc, cudnnTensorDescriptor_t* desc,
const string& data_format, const string& data_format,
const vec64_t& dims); const vec64_t& dims);
template <typename T> template <typename T>
void CuDNNSetTensor4dDescWithGroup( void CuDNNSetTensor4dDescWithGroup(
cudnnTensorDescriptor_t* desc, cudnnTensorDescriptor_t* desc,
const string& data_format, const string& data_format,
const vec64_t& dims, const vec64_t& dims,
const int64_t group); const int64_t group);
template <typename T> template <typename T>
void CuDNNSetTensor5dDesc( void CuDNNSetTensor5dDesc(
cudnnTensorDescriptor_t* desc, cudnnTensorDescriptor_t* desc,
const string& data_format, const string& data_format,
const vec64_t& dims); const vec64_t& dims);
template <typename T> template <typename T>
void CuDNNSetTensor3dDesc( void CuDNNSetTensor3dDesc(
cudnnTensorDescriptor_t* desc, cudnnTensorDescriptor_t* desc,
const string& data_format, const string& data_format,
const vec64_t& dims); const vec64_t& dims);
template <typename T> template <typename T>
void CuDNNSetTensorDesc( void CuDNNSetTensorDesc(
cudnnTensorDescriptor_t* desc, cudnnTensorDescriptor_t* desc,
const vec64_t& dims, const vec64_t& dims,
const vec64_t& strides); const vec64_t& strides);
} // namespace dragon } // namespace dragon
......
...@@ -657,6 +657,16 @@ void GreaterEqual( ...@@ -657,6 +657,16 @@ void GreaterEqual(
bool* y, bool* y,
Context* ctx); Context* ctx);
/*! control_flow.masked_assign */
template <typename T, class Context>
void MaskedAssign(
const int count,
const uint8_t* mask,
const T* x,
T* y,
Context* ctx);
/*! loss.l1_loss */ /*! loss.l1_loss */
template <typename T, class Context> template <typename T, class Context>
......
...@@ -488,9 +488,9 @@ class Tensor(object): ...@@ -488,9 +488,9 @@ class Tensor(object):
Parameters Parameters
---------- ----------
key : int or slice key : int, slice or Tensor
The indices. The indices.
value : Tensor, number or sequence value : number, sequence or Tensor
The value. The value.
Returns Returns
...@@ -498,11 +498,20 @@ class Tensor(object): ...@@ -498,11 +498,20 @@ class Tensor(object):
None None
""" """
starts, sizes = self._process_indices(key)
if not isinstance(value, Tensor): if not isinstance(value, Tensor):
value = self._from_constant(value) value = self._from_constant(value)
return self.CreateOperator('Assign', [value], if isinstance(key, Tensor):
existing_outputs=[self], starts=starts, sizes=sizes) return self.CreateOperator(
'MaskedAssign', [value, key],
existing_outputs=[self],
)
else:
starts, sizes = self._process_indices(key)
return self.CreateOperator(
'Assign', [value],
starts=starts, sizes=sizes,
existing_outputs=[self],
)
def _from_constant(self, value, name=None): def _from_constant(self, value, name=None):
if not isinstance(value, numpy.ndarray): if not isinstance(value, numpy.ndarray):
......
...@@ -75,10 +75,36 @@ def Assign(inputs, starts=None, sizes=None, **kwargs): ...@@ -75,10 +75,36 @@ def Assign(inputs, starts=None, sizes=None, **kwargs):
@OpSchema.ConvertConstantInputs() @OpSchema.ConvertConstantInputs()
@OpSchema.Inputs(2) @OpSchema.Inputs(2)
def MaskedAssign(inputs, mask, **kwargs):
"""Assign the ``value`` to ``ref`` where ``mask`` is *1*.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters
----------
inputs : sequence of Tensor
The ``ref`` and ``value`` respectively.
mask : Tensor
The mask, with the same size as ``ref``.
Returns
-------
Tensor
The ``ref``.
"""
arguments = ParseArgs(locals())
arguments['existing_outputs'] = [arguments['inputs'][0]]
arguments['inputs'] = [arguments['inputs'][1], mask]
return Tensor.CreateOperator('Assign', **arguments)
@OpSchema.ConvertConstantInputs()
@OpSchema.Inputs(2)
def Equal(inputs, to_uint8=False, **kwargs): def Equal(inputs, to_uint8=False, **kwargs):
"""``Equal`` comparing between A and B. """*Equal* comparing between A and B.
Set ``to_uint8`` if you expect the ``uint8`` results instead of ``bool``. Set ``to_uint8`` if you expect the *uint8* results instead of *bool*.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*) **Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...@@ -87,7 +113,7 @@ def Equal(inputs, to_uint8=False, **kwargs): ...@@ -87,7 +113,7 @@ def Equal(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor inputs : sequence of Tensor
The inputs, represent A and B respectively. The inputs, represent A and B respectively.
to_uint8 : bool to_uint8 : bool
``True`` to convert to ``uint8`` results. *True* to convert to *uint8* results.
Returns Returns
------- -------
...@@ -102,9 +128,9 @@ def Equal(inputs, to_uint8=False, **kwargs): ...@@ -102,9 +128,9 @@ def Equal(inputs, to_uint8=False, **kwargs):
@OpSchema.ConvertConstantInputs() @OpSchema.ConvertConstantInputs()
@OpSchema.Inputs(2) @OpSchema.Inputs(2)
def Less(inputs, to_uint8=False, **kwargs): def Less(inputs, to_uint8=False, **kwargs):
"""``Less`` comparing between A and B. """*Less* comparing between A and B.
Set ``to_uint8`` if you expect the ``uint8`` results instead of ``bool``. Set ``to_uint8`` if you expect the *uint8* results instead of *bool*.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*) **Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...@@ -113,7 +139,7 @@ def Less(inputs, to_uint8=False, **kwargs): ...@@ -113,7 +139,7 @@ def Less(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor inputs : sequence of Tensor
The inputs, represent A and B respectively. The inputs, represent A and B respectively.
to_uint8 : bool to_uint8 : bool
``True`` to convert to ``uint8`` results. *True* to convert to *uint8* results.
Returns Returns
------- -------
...@@ -128,9 +154,9 @@ def Less(inputs, to_uint8=False, **kwargs): ...@@ -128,9 +154,9 @@ def Less(inputs, to_uint8=False, **kwargs):
@OpSchema.ConvertConstantInputs() @OpSchema.ConvertConstantInputs()
@OpSchema.Inputs(2) @OpSchema.Inputs(2)
def LessEqual(inputs, to_uint8=False, **kwargs): def LessEqual(inputs, to_uint8=False, **kwargs):
"""``LessEqual`` comparing between A and B. """*LessEqual* comparing between A and B.
Set ``to_uint8`` if you expect the ``uint8`` results instead of ``bool``. Set ``to_uint8`` if you expect the *uint8* results instead of *bool*.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*) **Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...@@ -139,7 +165,7 @@ def LessEqual(inputs, to_uint8=False, **kwargs): ...@@ -139,7 +165,7 @@ def LessEqual(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor inputs : sequence of Tensor
The inputs, represent A and B respectively. The inputs, represent A and B respectively.
to_uint8 : bool to_uint8 : bool
``True`` to convert to ``uint8`` results. *True* to convert to *uint8* results.
Returns Returns
------- -------
...@@ -154,9 +180,9 @@ def LessEqual(inputs, to_uint8=False, **kwargs): ...@@ -154,9 +180,9 @@ def LessEqual(inputs, to_uint8=False, **kwargs):
@OpSchema.ConvertConstantInputs() @OpSchema.ConvertConstantInputs()
@OpSchema.Inputs(2) @OpSchema.Inputs(2)
def Greater(inputs, to_uint8=False, **kwargs): def Greater(inputs, to_uint8=False, **kwargs):
"""``Greater`` comparing between A and B. """*Greater* comparing between A and B.
Set ``to_uint8`` if you expect the ``uint8`` results instead of ``bool``. Set ``to_uint8`` if you expect the *uint8* results instead of *bool*.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*) **Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...@@ -165,7 +191,7 @@ def Greater(inputs, to_uint8=False, **kwargs): ...@@ -165,7 +191,7 @@ def Greater(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor inputs : sequence of Tensor
The inputs, represent A and B respectively. The inputs, represent A and B respectively.
to_uint8 : bool to_uint8 : bool
``True`` to convert to ``uint8`` results. *True* to convert to *uint8* results.
Returns Returns
------- -------
...@@ -180,9 +206,9 @@ def Greater(inputs, to_uint8=False, **kwargs): ...@@ -180,9 +206,9 @@ def Greater(inputs, to_uint8=False, **kwargs):
@OpSchema.ConvertConstantInputs() @OpSchema.ConvertConstantInputs()
@OpSchema.Inputs(2) @OpSchema.Inputs(2)
def GreaterEqual(inputs, to_uint8=False, **kwargs): def GreaterEqual(inputs, to_uint8=False, **kwargs):
"""``GreaterEqual`` comparing between A and B. """*GreaterEqual* comparing between A and B.
Set ``to_uint8`` if you expect the ``uint8`` results instead of ``bool``. Set ``to_uint8`` if you expect the *uint8* results instead of *bool*.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*) **Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...@@ -191,7 +217,7 @@ def GreaterEqual(inputs, to_uint8=False, **kwargs): ...@@ -191,7 +217,7 @@ def GreaterEqual(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor inputs : sequence of Tensor
The inputs, represent A and B respectively. The inputs, represent A and B respectively.
to_uint8 : bool to_uint8 : bool
``True`` to convert to ``uint8`` results. *True* to convert to *uint8* results.
Returns Returns
------- -------
......
...@@ -143,6 +143,7 @@ Multinomial = _array_ops.Multinomial ...@@ -143,6 +143,7 @@ Multinomial = _array_ops.Multinomial
# Control Flow # Control Flow
Copy = _control_flow_ops.Copy Copy = _control_flow_ops.Copy
Assign = _control_flow_ops.Assign Assign = _control_flow_ops.Assign
MaskedAssign = _control_flow_ops.MaskedAssign
Equal = _control_flow_ops.Equal Equal = _control_flow_ops.Equal
Less = _control_flow_ops.Less Less = _control_flow_ops.Less
LessEqual = _control_flow_ops.LessEqual LessEqual = _control_flow_ops.LessEqual
......
...@@ -70,8 +70,8 @@ class DataTransformer(multiprocessing.Process): ...@@ -70,8 +70,8 @@ class DataTransformer(multiprocessing.Process):
self._cutout_size = kwargs.get('cutout_size', 0) self._cutout_size = kwargs.get('cutout_size', 0)
self._mirror = kwargs.get('mirror', False) self._mirror = kwargs.get('mirror', False)
self._color_aug = kwargs.get('color_augmentation', False) self._color_aug = kwargs.get('color_augmentation', False)
self._min_random_scale = kwargs.get('min_random_scale', 1.0) self._min_rand_scale = kwargs.get('min_random_scale', 1.0)
self._max_random_scale = kwargs.get('max_random_scale', 1.0) self._max_rand_scale = kwargs.get('max_random_scale', 1.0)
self._force_color = kwargs.get('force_color', False) self._force_color = kwargs.get('force_color', False)
self._phase = kwargs.get('phase', 'TRAIN') self._phase = kwargs.get('phase', 'TRAIN')
self._random_seed = _cfg.GetRandomSeed() self._random_seed = _cfg.GetRandomSeed()
...@@ -102,12 +102,16 @@ class DataTransformer(multiprocessing.Process): ...@@ -102,12 +102,16 @@ class DataTransformer(multiprocessing.Process):
im = im.reshape((datum.height, datum.width, datum.channels)) im = im.reshape((datum.height, datum.width, datum.channels))
# Random scale # Random scale
random_scale = numpy.random.uniform() * ( rand_scale = numpy.random.uniform() * (
self._max_random_scale - self._min_random_scale) \ self._max_rand_scale - self._min_rand_scale
+ self._min_random_scale ) + self._min_rand_scale
if random_scale != 1.0: if rand_scale != 1.0:
im = cv2.resize(im, None, fx=random_scale, im = cv2.resize(
fy=random_scale, interpolation=cv2.INTER_LINEAR) im, None,
fx=rand_scale,
fy=rand_scale,
interpolation=cv2.INTER_LINEAR,
)
# Padding # Padding
if self._padding > 0: if self._padding > 0:
...@@ -149,7 +153,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -149,7 +153,7 @@ class DataTransformer(multiprocessing.Process):
# Gray Transformation # Gray Transformation
if self._force_color: if self._force_color:
if im.shape[2] == 1: if im.shape[2] == 1:
# duplicate to 3 channels # Duplicate to 3 channels
im = numpy.concatenate([im, im, im], axis=2) im = numpy.concatenate([im, im, im], axis=2)
# Color Augmentation # Color Augmentation
......
...@@ -338,11 +338,13 @@ class Module(object): ...@@ -338,11 +338,13 @@ class Module(object):
def run(self, inputs, outputs, auto_grad=True, callback=None): def run(self, inputs, outputs, auto_grad=True, callback=None):
if self._module_def is None: self._gen_module_def() if self._module_def is None: self._gen_module_def()
meta = (self.module_key, self._module_def)
return RunOperator( return RunOperator(
inputs, outputs, meta, inputs=inputs,
auto_grad=auto_grad, outputs=outputs,
callback_on_run=callback) meta=(self.module_key, self._module_def),
auto_grad=auto_grad,
callback_on_run=callback,
)
def train(self, mode=True): def train(self, mode=True):
self.training = mode self.training = mode
......
...@@ -17,7 +17,10 @@ from dragon.core import mpi ...@@ -17,7 +17,10 @@ from dragon.core import mpi
from dragon.vm.torch.tensor import Tensor, _LeafTensor, _Device from dragon.vm.torch.tensor import Tensor, _LeafTensor, _Device
from dragon.vm.torch.ops.primitive import MakeDevice, WrapScalar from dragon.vm.torch.ops.primitive import MakeDevice, WrapScalar
from dragon.vm.torch.ops.factory import get_module from dragon.vm.torch.ops.factory import get_module
from dragon.vm.torch.ops.modules.control_flow import Compare
from dragon.vm.torch.ops.modules.control_flow import (
Assign, MaskedAssign, Compare
)
from dragon.vm.torch.ops.modules.arithmetic import ( from dragon.vm.torch.ops.modules.arithmetic import (
Fundamental, Log, Exp, Sqrt, Fundamental, Log, Exp, Sqrt,
...@@ -32,9 +35,8 @@ from dragon.vm.torch.ops.modules.init import ( ...@@ -32,9 +35,8 @@ from dragon.vm.torch.ops.modules.init import (
from dragon.vm.torch.ops.modules.array import ( from dragon.vm.torch.ops.modules.array import (
Reshape, Squeeze, UnSqueeze, Permute, Reshape, Squeeze, UnSqueeze, Permute,
Indexing, Assigning, Indexing, IndexSelect,
Repeat, Concat, Stack, Repeat, Concat, Stack,
IndexSelect,
Reduce, ArgReduce, OneHot, Multinomial, Reduce, ArgReduce, OneHot, Multinomial,
) )
...@@ -48,8 +50,8 @@ from dragon.vm.torch.ops.modules.vision import ( ...@@ -48,8 +50,8 @@ from dragon.vm.torch.ops.modules.vision import (
__all__ = [ __all__ = [
'add', 'sub', 'mul', 'div',
'accumulate', 'accumulate',
'add', 'sub', 'mul', 'div',
'maximum', 'minimum', 'clamp', 'maximum', 'minimum', 'clamp',
'log', 'exp', 'sqrt', 'log', 'exp', 'sqrt',
'mm', 'xw_plus_b', 'mm', 'xw_plus_b',
...@@ -59,9 +61,12 @@ __all__ = [ ...@@ -59,9 +61,12 @@ __all__ = [
'gt', 'lt', 'eq', 'ge', 'le', 'gt', 'lt', 'eq', 'ge', 'le',
'cat', 'stack', 'narrow', 'cat', 'stack', 'narrow',
'index_select', 'index_select',
'one_hot', 'multinomial', 'rand', 'randn', 'one_hot', 'multinomial',
'zeros', 'zeros_like', 'ones', 'ones_like', 'rand', 'randn',
'nn_resize', 'bilinear_resize', 'roi_pool', 'roi_align', 'ones', 'ones_like',
'zeros', 'zeros_like',
'nn_resize', 'bilinear_resize',
'roi_pool', 'roi_align',
] ]
...@@ -409,52 +414,64 @@ def xw_plus_b(x, w, bias=None, transW=True, out=None): ...@@ -409,52 +414,64 @@ def xw_plus_b(x, w, bias=None, transW=True, out=None):
def _reshape(input, shape, shape_like=None): def _reshape(input, shape, shape_like=None):
if shape_like is not None: shape = shape_like.shape if shape_like is not None: shape = shape_like.shape
dev = MakeDevice(inputs=[input]); n_dim = len(shape) dev = MakeDevice(inputs=[input]); ndim = len(shape)
key = 'Reshape/{}/n_dim:{}'.format(dev, n_dim) key = 'Reshape/{}/ndim:{}'.format(dev, ndim)
module = get_module(Reshape, key, dev, n_dim=n_dim) module = get_module(Reshape, key, dev, ndim=ndim)
return module.forward(input, shape) return module.forward(input, shape)
def _permute(input, perm): def _permute(input, perm):
dev = MakeDevice(inputs=[input]); n_perm = len(perm) dev = MakeDevice(inputs=[input]); nperm = len(perm)
key = 'Permute/{}/n_perm:{}'.format(dev, n_perm) key = 'Permute/{}/nperm:{}'.format(dev, nperm)
module = get_module(Permute, key, dev, n_perm=n_perm) module = get_module(Permute, key, dev, nperm=nperm)
return module.forward(input, perm) return module.forward(input, perm)
def _repeat(input, times): def _repeat(input, times):
dev = MakeDevice(inputs=[input]); n_times = len(times) dev = MakeDevice(inputs=[input]); ntimes = len(times)
key = 'Repeat/{}/n_times:{}'.format(dev, n_times) key = 'Repeat/{}/ntimes:{}'.format(dev, ntimes)
module = get_module(Repeat, key, dev, n_times=n_times) module = get_module(Repeat, key, dev, ntimes=ntimes)
return module.forward(input, times) return module.forward(input, times)
def _fill(input, shape, value): def _fill(input, shape, value):
dev = MakeDevice(inputs=[input]); n_dim = len(shape) dev = MakeDevice(inputs=[input]); ndim = len(shape)
key = 'Fill/{}/dtype:{}/n_dim:{}/value:{}'.format( key = 'Fill/{}/dtype:{}/ndim:{}/value:{}' \
dev, input.dtype, n_dim, value) .format(dev, input.dtype, ndim, value)
module = get_module(Fill, key, dev, n_dim=n_dim, module = get_module(
value=value, dtype=input.dtype) Fill, key, dev,
ndim=ndim,
value=value,
dtype=input.dtype,
)
return module.forward(input, shape) return module.forward(input, shape)
def _uniform(input, shape, low, high): def _uniform(input, shape, low, high):
dev = MakeDevice(inputs=[input]); n_dim = len(shape) dev = MakeDevice(inputs=[input]); ndim = len(shape)
key = 'Uniform/{}/dtype:{}/n_dim:{}/low:{}/high:{}'.format( key = 'Uniform/{}/dtype:{}/ndim:{}/low:{}/high:{}'.format(
dev, input.dtype, n_dim, float(low), float(high)) dev, input.dtype, ndim, float(low), float(high))
module = get_module( module = get_module(
RandomUniform, key, dev, n_dim=n_dim, RandomUniform, key, dev,
low=low, high=high, dtype=input.dtype) ndim=ndim,
low=low,
high=high,
dtype=input.dtype,
)
return module.forward(input, shape) return module.forward(input, shape)
def _normal(input, shape, mean, std): def _normal(input, shape, mean, std):
dev = MakeDevice(inputs=[input]); n_dim = len(shape) dev = MakeDevice(inputs=[input]); ndim = len(shape)
key = 'Normal/{}/dtype:{}/n_dim:{}/mean:{}/std:{}'.format( key = 'Normal/{}/dtype:{}/ndim:{}/mean:{}/std:{}'.format(
dev, input.dtype, n_dim, float(mean), float(std)) dev, input.dtype, ndim, float(mean), float(std))
module = get_module( module = get_module(
RandomNormal, key, dev, n_dim=n_dim, RandomNormal, key, dev,
mean=mean, std=std, dtype=input.dtype) ndim=ndim,
mean=mean,
std=std,
dtype=input.dtype,
)
return module.forward(input, shape) return module.forward(input, shape)
...@@ -464,44 +481,62 @@ def _reduce(input, operation, dim=None, keepdim=False, out=None): ...@@ -464,44 +481,62 @@ def _reduce(input, operation, dim=None, keepdim=False, out=None):
key = '{}/{}/dim:{}/keepdim:{}'.format( key = '{}/{}/dim:{}/keepdim:{}'.format(
operation, dev, dim, int(keepdim)) operation, dev, dim, int(keepdim))
module = get_module( module = get_module(
Reduce, key, dev, operation=operation, Reduce, key, dev,
dim=dim, keepdim=keepdim) dim=dim,
keepdim=keepdim,
operation=operation,
)
return module.forward(input, out) return module.forward(input, out)
def _arg_reduce(input, operation, dim=None, keepdim=False, top_k=1, out=None): def _arg_reduce(input, operation, dim=None, keepdim=False, topk=1, out=None):
if dim is None: keepdim = False if dim is None: keepdim = False
dev = MakeDevice(inputs=[input]) dev = MakeDevice(inputs=[input])
key = '{}/{}/dim:{}/keepdim:{}/top_k:{}'.format( key = '{}/{}/dim:{}/keepdim:{}/topk:{}'.format(
operation, dev, dim, int(keepdim), top_k) operation, dev, dim, int(keepdim), topk)
module = get_module( module = get_module(
ArgReduce, key, dev, ArgReduce, key, dev,
operation=operation, axis=dim, axis=dim,
keepdim=keepdim, top_k=top_k) topk=topk,
keepdim=keepdim,
operation=operation,
)
return module.forward(input, out) return module.forward(input, out)
def _indexing(input, starts, sizes): def _index(input, starts, sizes):
n_starts, n_sizes = len(starts), len(sizes) nstarts, nsizes = len(starts), len(sizes)
dev = MakeDevice(inputs=[input]) dev = MakeDevice(inputs=[input])
key = 'Index/{}/n_starts:{}/n_sizes:{}'.format(dev, n_starts, n_sizes) key = 'Index/{}/nstarts:{}/nsizes:{}'.format(dev, nstarts, nsizes)
module = get_module(Indexing, key, dev, n_starts=n_starts, n_sizes=n_sizes) module = get_module(Indexing, key, dev, nstarts=nstarts, nsizes=nsizes)
return module.forward(input, starts, sizes) return module.forward(input, starts, sizes)
def _assigning(output, input, starts, sizes): def _assign(output, starts, sizes, input):
if not isinstance(input, Tensor): if not isinstance(input, Tensor):
if isinstance(input, (tuple, list)): if isinstance(input, (tuple, list)):
input = Tensor(input, dtype=output.dtype, device=output.device) input = Tensor(input, dtype=output.dtype, device=output.device)
else: else:
input = WrapScalar(input, output.dtype, output.device) input = WrapScalar(input, output.dtype, output.device)
n_starts, n_sizes = len(starts), len(sizes) nstarts, nsizes = len(starts), len(sizes)
dev = MakeDevice(inputs=[input]) dev = MakeDevice(inputs=[input])
key = 'Assign/{}/n_starts:{}/n_sizes:{}'.format(dev, n_starts, n_sizes) key = 'Assign/{}/nstarts:{}/nsizes:{}'.format(dev, nstarts, nsizes)
module = get_module(Assigning, key, dev, n_starts=n_starts, n_sizes=n_sizes) module = get_module(Assign, key, dev, nstarts=nstarts, nsizes=nsizes)
return module.forward(input, output, starts, sizes) return module.forward(input, output, starts, sizes)
def _masked_assign(output, mask, input):
if not isinstance(input, Tensor):
if isinstance(input, (tuple, list)):
input = Tensor(input, dtype=output.dtype, device=output.device)
else:
input = WrapScalar(input, output.dtype, output.device)
dev = MakeDevice(inputs=[input])
key = 'MaskedAssign/{}'.format(dev)
module = get_module(MaskedAssign, key, dev)
return module.forward(input, output, mask)
def _compare(input, other, operation, out=None): def _compare(input, other, operation, out=None):
if not isinstance(other, Tensor): if not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input.device) other = WrapScalar(other, input.dtype, input.device)
...@@ -927,7 +962,7 @@ def narrow(input, dimension, start, length): ...@@ -927,7 +962,7 @@ def narrow(input, dimension, start, length):
""" """
sizes = list(input.shape[:]); starts = [0] * len(sizes) sizes = list(input.shape[:]); starts = [0] * len(sizes)
starts[dimension], sizes[dimension] = start, length starts[dimension], sizes[dimension] = start, length
return _indexing(input, starts, sizes) return _index(input, starts, sizes)
def one_hot(input, depth): def one_hot(input, depth):
...@@ -1159,8 +1194,13 @@ def _update( ...@@ -1159,8 +1194,13 @@ def _update(
): ):
dev = MakeDevice(inputs=[param]) dev = MakeDevice(inputs=[param])
key = '{}/{}/{}/{}'.format(op_type, dev, slot, param.name) key = '{}/{}/{}/{}'.format(op_type, dev, slot, param.name)
module = get_module(Update, key, dev, op_type=op_type, module = get_module(
lr_mult=lr_mult, decay_mult=decay_mult, slot=slot) Update, key, dev,
op_type=op_type,
lr_mult=lr_mult,
decay_mult=decay_mult,
slot=slot,
)
return module.forward(param, grad) return module.forward(param, grad)
...@@ -1183,8 +1223,12 @@ def _resize_2d(input, op_type, dsize, fx, fy): ...@@ -1183,8 +1223,12 @@ def _resize_2d(input, op_type, dsize, fx, fy):
dev = MakeDevice(inputs=[input]) dev = MakeDevice(inputs=[input])
key = '{}/{}/dsize:{}/fx:{}/fy:{}'.format( key = '{}/{}/dsize:{}/fx:{}/fy:{}'.format(
op_type, dev, '2' if dsize else 'none', fx, fy) op_type, dev, '2' if dsize else 'none', fx, fy)
module = get_module(Resize2d, key, dev, module = get_module(
op_type=op_type, dsize=dsize, fx=fx, fy=fy) Resize2d, key, dev,
dsize=dsize,
fx=fx, fy=fy,
op_type=op_type,
)
return module.forward(input, dsize) return module.forward(input, dsize)
......
...@@ -27,8 +27,8 @@ class Indexing(BaseModule): ...@@ -27,8 +27,8 @@ class Indexing(BaseModule):
""" """
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Indexing, self).__init__(key, dev, **kwargs) super(Indexing, self).__init__(key, dev, **kwargs)
self.n_starts = kwargs.get('n_starts', 0) self.nstarts = kwargs.get('nstarts', 0)
self.n_sizes = kwargs.get('n_sizes', 0) self.nsizes = kwargs.get('nsizes', 0)
self.register_op() self.register_op()
def register_op(self): def register_op(self):
...@@ -37,61 +37,25 @@ class Indexing(BaseModule): ...@@ -37,61 +37,25 @@ class Indexing(BaseModule):
'arguments': { 'arguments': {
'starts_desc': [ 'starts_desc': [
'${{ANCHOR}}/starts[{}]'.format(n) '${{ANCHOR}}/starts[{}]'.format(n)
for n in range(self.n_starts)], for n in range(self.nstarts)],
'sizes_desc': [ 'sizes_desc': [
'${{ANCHOR}}/sizes[{}]'.format(n) '${{ANCHOR}}/sizes[{}]'.format(n)
for n in range(self.n_sizes)], for n in range(self.nsizes)],
}, },
} }
def update_arguments(self, A, starts, sizes): def update_args(self, A, starts, sizes):
for i, e in enumerate(starts): for i, e in enumerate(starts):
self.set_argument_i64('{}/starts[{}]'.format(A, i), e) self.set_arg_i64('{}/starts[{}]'.format(A, i), e)
self.set_argument_i64('{}/sizes[{}]'.format(A, i), sizes[i]) self.set_arg_i64('{}/sizes[{}]'.format(A, i), sizes[i])
def forward(self, x, starts, sizes): def forward(self, x, starts, sizes):
inputs = [x]; self.unify_devices(inputs) inputs = [x]; self.unify_devices(inputs)
outputs = [self.register_output()] outputs = [self.register_output()]
callback = lambda A: self.update_arguments(A, starts, sizes) callback = lambda A: self.update_args(A, starts, sizes)
return self.run(inputs, outputs, callback=callback) return self.run(inputs, outputs, callback=callback)
class Assigning(BaseModule):
"""This module imports the *AssignOp* from backend.
Arbitrary length of starts and sizes will be take.
"""
def __init__(self, key, dev, **kwargs):
super(Assigning, self).__init__(key, dev, **kwargs)
self.n_starts = kwargs.get('n_starts', 0)
self.n_sizes = kwargs.get('n_sizes', 0)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Assign',
'arguments': {
'starts_desc': [
'${{ANCHOR}}/starts[{}]'.format(n)
for n in range(self.n_starts)],
'sizes_desc': [
'${{ANCHOR}}/sizes[{}]'.format(n)
for n in range(self.n_sizes)],
},
}
def update_arguments(self, A, starts, sizes):
for i, e in enumerate(starts):
self.set_argument_i64('{}/starts[{}]'.format(A, i), e)
self.set_argument_i64('{}/sizes[{}]'.format(A, i), sizes[i])
def forward(self, x, y, starts, sizes):
self.unify_devices([x, y])
callback = lambda A: self.update_arguments(A, starts, sizes)
return self.run([x], [y], callback=callback, auto_grad=False)
class Concat(BaseModule): class Concat(BaseModule):
"""This module imports the *ConcatOp* from backend. """This module imports the *ConcatOp* from backend.
...@@ -200,18 +164,19 @@ class ArgReduce(BaseModule): ...@@ -200,18 +164,19 @@ class ArgReduce(BaseModule):
self.operation = kwargs.get('operation', 'ARGMAX') self.operation = kwargs.get('operation', 'ARGMAX')
self.axis = kwargs.get('axis', None) self.axis = kwargs.get('axis', None)
self.keepdim = kwargs.get('keepdim', True) self.keepdim = kwargs.get('keepdim', True)
self.top_k = kwargs.get('top_k', 1) self.topk = kwargs.get('topk', 1)
self.register_op() self.register_op()
def register_op(self): def register_op(self):
self.op_meta = { self.op_meta = {
'op_type': 'ArgReduce', 'op_type': 'ArgReduce',
'arguments': { 'arguments': {
'operation': self.operation if 'ARG' in self.operation \ 'operation': self.operation
if 'ARG' in self.operation \
else 'ARG' + self.operation, else 'ARG' + self.operation,
'axis': self.axis if self.axis else 2147483647, 'axis': self.axis if self.axis else 2147483647,
'keep_dims': self.keepdim, 'keep_dims': self.keepdim,
'top_k': self.top_k, 'top_k': self.topk,
}, },
} }
...@@ -241,7 +206,7 @@ class ArgReduce(BaseModule): ...@@ -241,7 +206,7 @@ class ArgReduce(BaseModule):
class Reshape(BaseModule): class Reshape(BaseModule):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Reshape, self).__init__(key, dev, **kwargs) super(Reshape, self).__init__(key, dev, **kwargs)
self.n_dim = kwargs.get('n_dim', 0) self.ndim = kwargs.get('ndim', 0)
self.register_op() self.register_op()
def register_op(self): def register_op(self):
...@@ -250,19 +215,19 @@ class Reshape(BaseModule): ...@@ -250,19 +215,19 @@ class Reshape(BaseModule):
'arguments': { 'arguments': {
'dims_desc': [ 'dims_desc': [
'${{ANCHOR}}/dims[{}]'.format(n) '${{ANCHOR}}/dims[{}]'.format(n)
for n in range(self.n_dim) for n in range(self.ndim)
], ],
}, },
} }
def update_arguments(self, A, shape): def update_args(self, A, shape):
for i, e in enumerate(shape): for i, e in enumerate(shape):
self.set_argument_i64('{}/dims[{}]'.format(A, i), e) self.set_arg_i64('{}/dims[{}]'.format(A, i), e)
def forward(self, x, shape): def forward(self, x, shape):
inputs = [x]; self.unify_devices(inputs) inputs = [x]; self.unify_devices(inputs)
outputs = [_ReferenceTensor(x)] outputs = [_ReferenceTensor(x)]
callback = lambda A: self.update_arguments(A, shape) callback = lambda A: self.update_args(A, shape)
return self.run(inputs, outputs, callback=callback) return self.run(inputs, outputs, callback=callback)
...@@ -275,7 +240,9 @@ class Squeeze(BaseModule): ...@@ -275,7 +240,9 @@ class Squeeze(BaseModule):
def register_op(self): def register_op(self):
self.op_meta = { self.op_meta = {
'op_type': 'Squeeze', 'op_type': 'Squeeze',
'arguments': {'axis': self.dim}, 'arguments': {
'axis': self.dim,
},
} }
def forward(self, x, out=None): def forward(self, x, out=None):
...@@ -293,7 +260,9 @@ class UnSqueeze(BaseModule): ...@@ -293,7 +260,9 @@ class UnSqueeze(BaseModule):
def register_op(self): def register_op(self):
self.op_meta = { self.op_meta = {
'op_type': 'ExpandDims', 'op_type': 'ExpandDims',
'arguments': {'axis': self.dim}, 'arguments': {
'axis': self.dim,
},
} }
def forward(self, x, out=None): def forward(self, x, out=None):
...@@ -305,7 +274,7 @@ class UnSqueeze(BaseModule): ...@@ -305,7 +274,7 @@ class UnSqueeze(BaseModule):
class Permute(BaseModule): class Permute(BaseModule):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Permute, self).__init__(key, dev, **kwargs) super(Permute, self).__init__(key, dev, **kwargs)
self.n_perm = kwargs.get('n_perm', 0) self.nperm = kwargs.get('nperm', 0)
self.register_op() self.register_op()
def register_op(self): def register_op(self):
...@@ -313,26 +282,26 @@ class Permute(BaseModule): ...@@ -313,26 +282,26 @@ class Permute(BaseModule):
'op_type': 'Transpose', 'op_type': 'Transpose',
'arguments': { 'arguments': {
'perm_desc': ['${{ANCHOR}}/perm[{}]'.format(n) 'perm_desc': ['${{ANCHOR}}/perm[{}]'.format(n)
for n in range(self.n_perm)], for n in range(self.nperm)],
}, },
} }
def update_arguments(self, A, perm): def update_args(self, A, perm):
if perm: if perm:
for i, e in enumerate(perm): for i, e in enumerate(perm):
self.set_argument_i64('{}/perm[{}]'.format(A, i), e) self.set_arg_i64('{}/perm[{}]'.format(A, i), e)
def forward(self, x, perm): def forward(self, x, perm):
inputs = [x]; self.unify_devices(inputs) inputs = [x]; self.unify_devices(inputs)
outputs = [self.register_output()] outputs = [self.register_output()]
callback = lambda A: self.update_arguments(A, perm) callback = lambda A: self.update_args(A, perm)
return self.run(inputs, outputs, callback=callback) return self.run(inputs, outputs, callback=callback)
class Repeat(BaseModule): class Repeat(BaseModule):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Repeat, self).__init__(key, dev, **kwargs) super(Repeat, self).__init__(key, dev, **kwargs)
self.n_times = kwargs.get('n_times', 0) self.ntimes = kwargs.get('ntimes', 0)
self.register_op() self.register_op()
def register_op(self): def register_op(self):
...@@ -341,19 +310,19 @@ class Repeat(BaseModule): ...@@ -341,19 +310,19 @@ class Repeat(BaseModule):
'arguments': { 'arguments': {
'multiples_desc': [ 'multiples_desc': [
'${{ANCHOR}}/multiples[{}]'.format(n) '${{ANCHOR}}/multiples[{}]'.format(n)
for n in range(self.n_times) for n in range(self.ntimes)
], ],
}, },
} }
def update_arguments(self, A, times): def update_args(self, A, times):
for i, d in enumerate(times): for i, d in enumerate(times):
self.set_argument_i64('{}/multiples[{}]'.format(A, i), d) self.set_arg_i64('{}/multiples[{}]'.format(A, i), d)
def forward(self, x, times): def forward(self, x, times):
inputs = [x]; self.unify_devices(inputs) inputs = [x]; self.unify_devices(inputs)
outputs = [self.register_output()] outputs = [self.register_output()]
callback = lambda A: self.update_arguments(A, times) callback = lambda A: self.update_args(A, times)
return self.run(inputs, outputs, callback=callback) return self.run(inputs, outputs, callback=callback)
...@@ -409,7 +378,6 @@ class Multinomial(BaseModule): ...@@ -409,7 +378,6 @@ class Multinomial(BaseModule):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Multinomial, self).__init__(key, dev, **kwargs) super(Multinomial, self).__init__(key, dev, **kwargs)
self.num_samples = kwargs.get('num_samples', 1) self.num_samples = kwargs.get('num_samples', 1)
self.normalize = kwargs.get('normalize', False)
self.register_op() self.register_op()
def register_op(self): def register_op(self):
...@@ -417,7 +385,7 @@ class Multinomial(BaseModule): ...@@ -417,7 +385,7 @@ class Multinomial(BaseModule):
'op_type': 'Multinomial', 'op_type': 'Multinomial',
'arguments': { 'arguments': {
'num_samples': self.num_samples, 'num_samples': self.num_samples,
'normalize': self.normalize, 'normalize': False,
}, },
} }
......
...@@ -14,9 +14,9 @@ from __future__ import division ...@@ -14,9 +14,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy import numpy
from dragon.core import proto_utils as _proto_utils
from dragon.core import workspace as _workspace from dragon.core import workspace as _workspace
from dragon.core import proto_utils as _proto_utils
from dragon.vm.torch.module import Module from dragon.vm.torch.module import Module
...@@ -25,10 +25,14 @@ class BaseModule(Module): ...@@ -25,10 +25,14 @@ class BaseModule(Module):
super(BaseModule, self).__init__() super(BaseModule, self).__init__()
self._module_key = key self._module_key = key
self._device = dev self._device = dev
self._args_dev = _proto_utils.\ self._arg_dev = _proto_utils \
GetDeviceOption('cpu').SerializeToString() .GetDeviceOption('cpu')\
.SerializeToString()
def set_argument_i64(self, name, value): def set_arg_i64(self, name, value):
_workspace.get_default_workspace()\ _workspace.get_default_workspace() \
.FeedTensor(name, numpy.array( .FeedTensor(
value, dtype=numpy.int64), self._args_dev) name,
\ No newline at end of file numpy.array(value, 'int64'),
self._arg_dev,
)
\ No newline at end of file
...@@ -46,4 +46,53 @@ class Compare(BaseModule): ...@@ -46,4 +46,53 @@ class Compare(BaseModule):
def forward(self, x1, x2, y): def forward(self, x1, x2, y):
inputs = [x1, x2]; self.unify_devices(inputs) inputs = [x1, x2]; self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()] outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs) return self.run(inputs, outputs)
\ No newline at end of file
class Assign(BaseModule):
"""This module imports the *AssignOp* from backend.
Arbitrary length of starts and sizes will be take.
"""
def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs)
self.nstarts = kwargs.get('nstarts', 0)
self.nsizes = kwargs.get('nsizes', 0)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Assign',
'arguments': {
'starts_desc': [
'${{ANCHOR}}/starts[{}]'.format(n)
for n in range(self.nstarts)],
'sizes_desc': [
'${{ANCHOR}}/sizes[{}]'.format(n)
for n in range(self.nsizes)],
},
}
def update_args(self, A, starts, sizes):
for i, e in enumerate(starts):
self.set_arg_i64('{}/starts[{}]'.format(A, i), e)
self.set_arg_i64('{}/sizes[{}]'.format(A, i), sizes[i])
def forward(self, x, y, starts, sizes):
self.unify_devices([x, y])
callback = lambda A: self.update_args(A, starts, sizes)
return self.run([x], [y], callback=callback, auto_grad=False)
class MaskedAssign(BaseModule):
def __init__(self, key, dev, **kwargs):
super(MaskedAssign, self).__init__(key, dev, **kwargs)
self.register_op()
def register_op(self):
self.op_meta = {'op_type': 'MaskedAssign', 'arguments': {}}
def forward(self, x, y, mask):
self.unify_devices([x, y])
return self.run([x, mask], [y])
\ No newline at end of file
...@@ -19,16 +19,16 @@ from dragon.vm.torch.ops.modules.base import BaseModule ...@@ -19,16 +19,16 @@ from dragon.vm.torch.ops.modules.base import BaseModule
class _InitModule(BaseModule): class _InitModule(BaseModule):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(_InitModule, self).__init__(key, dev, **kwargs) super(_InitModule, self).__init__(key, dev, **kwargs)
self.n_dim = kwargs.get('n_dim', 0) self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32') self.dtype = kwargs.get('dtype', 'float32')
def update_arguments(self, A, shape): def update_args(self, A, shape):
for i, e in enumerate(shape): for i, e in enumerate(shape):
self.set_argument_i64('{}/dims[{}]'.format(A, i), e) self.set_arg_i64('{}/dims[{}]'.format(A, i), e)
def forward(self, x, shape): def forward(self, x, shape):
outputs = [x]; self.unify_devices(outputs) outputs = [x]; self.unify_devices(outputs)
callback = lambda A: self.update_arguments(A, shape) callback = lambda A: self.update_args(A, shape)
return self.run([], outputs, callback=callback) return self.run([], outputs, callback=callback)
...@@ -46,7 +46,7 @@ class Fill(_InitModule): ...@@ -46,7 +46,7 @@ class Fill(_InitModule):
'value': float(self.value), 'value': float(self.value),
'dims_desc': [ 'dims_desc': [
'${{ANCHOR}}/dims[{}]'.format(n) '${{ANCHOR}}/dims[{}]'.format(n)
for n in range(self.n_dim) for n in range(self.ndim)
], ],
}, },
} }
...@@ -68,7 +68,7 @@ class RandomNormal(_InitModule): ...@@ -68,7 +68,7 @@ class RandomNormal(_InitModule):
'std': float(self.std), 'std': float(self.std),
'dims_desc': [ 'dims_desc': [
'${{ANCHOR}}/dims[{}]'.format(n) '${{ANCHOR}}/dims[{}]'.format(n)
for n in range(self.n_dim) for n in range(self.ndim)
], ],
}, },
} }
...@@ -90,7 +90,7 @@ class RandomUniform(_InitModule): ...@@ -90,7 +90,7 @@ class RandomUniform(_InitModule):
'high': float(self.high), 'high': float(self.high),
'dims_desc': [ 'dims_desc': [
'${{ANCHOR}}/dims[{}]'.format(n) '${{ANCHOR}}/dims[{}]'.format(n)
for n in range(self.n_dim) for n in range(self.ndim)
], ],
}, },
} }
\ No newline at end of file
...@@ -13,7 +13,7 @@ from __future__ import absolute_import ...@@ -13,7 +13,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import dragon.core.mpi as mpi from dragon.core import mpi as _mpi
from dragon.vm.torch.ops.modules.base import BaseModule from dragon.vm.torch.ops.modules.base import BaseModule
...@@ -50,11 +50,13 @@ class Collective(BaseModule): ...@@ -50,11 +50,13 @@ class Collective(BaseModule):
self.register_op() self.register_op()
def register_op(self): def register_op(self):
idx, group = mpi.AllowParallel() idx, group = _mpi.AllowParallel()
if idx == -1: if idx == -1:
raise RuntimeError('The mpi node({}) dost not in ' raise RuntimeError(
'parallel groups. \nSet it using mpi.Parallel([..]).'.format(mpi.Rank())) 'The mpi node({}) dost not in groups.\n'
mpi_comm, mpi_group = mpi.CreateGroup(root=group[0], incl=group) 'Set it using mpi.Parallel([..]).'.format(_mpi.Rank())
)
mpi_comm, mpi_group = _mpi.CreateGroup(root=group[0], incl=group)
self.op_meta = { self.op_meta = {
'op_type': 'CollectiveUpdate', 'op_type': 'CollectiveUpdate',
'arguments': { 'arguments': {
...@@ -78,7 +80,10 @@ class Accumulate(BaseModule): ...@@ -78,7 +80,10 @@ class Accumulate(BaseModule):
def register_op(self): def register_op(self):
self.op_meta = { self.op_meta = {
'op_type': 'Accumulate', 'op_type': 'Accumulate',
'arguments': {'alpha': 1., 'beta': 1.}, 'arguments': {
'alpha': 1.,
'beta': 1.,
},
} }
def forward(self, grads): def forward(self, grads):
......
...@@ -19,10 +19,10 @@ from dragon.vm.torch.ops.modules.base import BaseModule ...@@ -19,10 +19,10 @@ from dragon.vm.torch.ops.modules.base import BaseModule
class Resize2d(BaseModule): class Resize2d(BaseModule):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Resize2d, self).__init__(key, dev, **kwargs) super(Resize2d, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', 'NNResize')
self.dsize = kwargs.get('dsize', None) self.dsize = kwargs.get('dsize', None)
self.fx = kwargs.get('fx', None) self.fx = kwargs.get('fx', None)
self.fy = kwargs.get('fy', None) self.fy = kwargs.get('fy', None)
self.op_type = kwargs.get('op_type', 'NNResize')
self.register_op() self.register_op()
def register_op(self): def register_op(self):
...@@ -38,15 +38,15 @@ class Resize2d(BaseModule): ...@@ -38,15 +38,15 @@ class Resize2d(BaseModule):
}, },
} }
def update_arguments(self, A, dsize): def update_args(self, A, dsize):
if self.dsize: if self.dsize:
for i, e in enumerate(dsize): for i, e in enumerate(dsize):
self.set_argument_i64('{}/dsize[{}]'.format(A, i), e) self.set_arg_i64('{}/dsize[{}]'.format(A, i), e)
def forward(self, input, dsize=None): def forward(self, input, dsize=None):
inputs = [input]; self.unify_devices(inputs) inputs = [input]; self.unify_devices(inputs)
outputs = [self.register_output()] outputs = [self.register_output()]
callback = lambda A: self.update_arguments(A, dsize) callback = lambda A: self.update_args(A, dsize)
return self.run(inputs, outputs, callback=callback) return self.run(inputs, outputs, callback=callback)
...@@ -62,7 +62,8 @@ class RoIPool(BaseModule): ...@@ -62,7 +62,8 @@ class RoIPool(BaseModule):
self.op_meta = { self.op_meta = {
'op_type': 'ROIPool', 'op_type': 'ROIPool',
'arguments': { 'arguments': {
'pool_h': self.pool_h, 'pool_w': self.pool_w, 'pool_h': self.pool_h,
'pool_w': self.pool_w,
'spatial_scale': self.spatial_scale, 'spatial_scale': self.spatial_scale,
}, },
} }
...@@ -86,7 +87,8 @@ class RoIAlign(BaseModule): ...@@ -86,7 +87,8 @@ class RoIAlign(BaseModule):
self.op_meta = { self.op_meta = {
'op_type': 'ROIAlign', 'op_type': 'ROIAlign',
'arguments': { 'arguments': {
'pool_h': self.pool_h, 'pool_w': self.pool_w, 'pool_h': self.pool_h,
'pool_w': self.pool_w,
'spatial_scale': self.spatial_scale, 'spatial_scale': self.spatial_scale,
'sampling_ratio': self.sampling_ratio, 'sampling_ratio': self.sampling_ratio,
}, },
......
...@@ -23,9 +23,9 @@ from dragon.vm.torch.ops.builtin import ( ...@@ -23,9 +23,9 @@ from dragon.vm.torch.ops.builtin import (
_fundamental, _rfundamental, _fundamental, _rfundamental,
log, exp, sqrt, clamp, log, exp, sqrt, clamp,
_reshape, squeeze, unsqueeze, _reshape, squeeze, unsqueeze,
_permute, _repeat, _permute, _repeat, narrow,
_indexing, _assigning, _index, index_select,
narrow, index_select, _assign, _masked_assign,
mean, sum, max, min, mean, sum, max, min,
gt, lt, eq, ge, le, gt, lt, eq, ge, le,
) )
...@@ -41,6 +41,7 @@ def _type_to(input, dtype='float32', inplace=False): ...@@ -41,6 +41,7 @@ def _type_to(input, dtype='float32', inplace=False):
Tensor.fill_ = lambda self, value: _fill(self, self.shape, value) Tensor.fill_ = lambda self, value: _fill(self, self.shape, value)
Tensor.masked_fill_ = lambda *args, **kwargs: _masked_assign(*args, **kwargs)
Tensor.uniform_ = lambda self, low=0, high=1: _uniform(self, self.shape, low, high) Tensor.uniform_ = lambda self, low=0, high=1: _uniform(self, self.shape, low, high)
Tensor.normal_ = lambda self, mean=0, std=1: _normal(self, self.shape, mean, std) Tensor.normal_ = lambda self, mean=0, std=1: _normal(self, self.shape, mean, std)
Tensor.multinomial = lambda *args, **kwargs: multinomial(*args, **kwargs) Tensor.multinomial = lambda *args, **kwargs: multinomial(*args, **kwargs)
...@@ -85,8 +86,8 @@ Tensor.le = lambda *args, **kwargs: le(*args, **kwargs) ...@@ -85,8 +86,8 @@ Tensor.le = lambda *args, **kwargs: le(*args, **kwargs)
Tensor.eq = lambda *args, **kwargs: eq(*args, **kwargs) Tensor.eq = lambda *args, **kwargs: eq(*args, **kwargs)
Tensor.index_select = lambda *args, **kwargs: index_select(*args, **kwargs) Tensor.index_select = lambda *args, **kwargs: index_select(*args, **kwargs)
Tensor.narrow = lambda *args, **kwargs: narrow(*args, **kwargs) Tensor.narrow = lambda *args, **kwargs: narrow(*args, **kwargs)
Tensor._indexing = lambda *args, **kwargs: _indexing(*args, **kwargs) Tensor._index = lambda *args, **kwargs: _index(*args, **kwargs)
Tensor._assigning = lambda *args, **kwargs: _assigning(*args, **kwargs) Tensor._assign = lambda *args, **kwargs: _assign(*args, **kwargs)
Tensor.half = lambda self: _type_to(self, dtype='float16', inplace=False) Tensor.half = lambda self: _type_to(self, dtype='float16', inplace=False)
......
...@@ -533,16 +533,16 @@ class Tensor(object): ...@@ -533,16 +533,16 @@ class Tensor(object):
""" """
starts, sizes = self._process_indices(item) starts, sizes = self._process_indices(item)
return self._indexing(starts, sizes) return self._index(starts, sizes)
def __setitem__(self, key, value): def __setitem__(self, key, value):
"""Set the value at the specific indices. """Set the value at the specific indices.
Parameters Parameters
---------- ----------
key : int, slice key : int, slice or dragon.vm.torch.Tensor
The indices. The indices.
value : dragon.vm.torch.Tensor, number or sequence value : number, sequence or dragon.vm.torch.Tensor
The value. The value.
Returns Returns
...@@ -550,8 +550,11 @@ class Tensor(object): ...@@ -550,8 +550,11 @@ class Tensor(object):
None None
""" """
starts, sizes = self._process_indices(key) if isinstance(key, Tensor):
return self._assigning(value, starts, sizes) return self.masked_fill_(key, value)
else:
starts, sizes = self._process_indices(key)
return self._assign(starts, sizes, value)
def __hash__(self): def __hash__(self):
return id(self) return id(self)
...@@ -886,7 +889,7 @@ class Tensor(object): ...@@ -886,7 +889,7 @@ class Tensor(object):
return self return self
def fill_(self, value): def fill_(self, value):
"""Fills self tensor with the specified value. """Fill self with the given value.
Parameters Parameters
---------- ----------
...@@ -901,6 +904,24 @@ class Tensor(object): ...@@ -901,6 +904,24 @@ class Tensor(object):
""" """
raise NotImplementedError('Refer torch.ops.tensor.fill_') raise NotImplementedError('Refer torch.ops.tensor.fill_')
def masked_fill_(self, mask, value):
"""Fill self with the given value where ``mask`` is *1*.
Parameters
----------
mask : dragon.vm.torch.Tensor
The mask.
value : number
The value to fill.
Returns
-------
dragon.vm.torch.Tensor
The self.
"""
raise NotImplementedError('Refer torch.ops.tensor.masked_fill_')
def zero_(self): def zero_(self):
"""Fills self tensor with zeros. """Fills self tensor with zeros.
......
...@@ -123,7 +123,7 @@ void MixedMemory::SwitchToCUDADevice(int device_id) { ...@@ -123,7 +123,7 @@ void MixedMemory::SwitchToCUDADevice(int device_id) {
if (device_id != ptr_device_) { if (device_id != ptr_device_) {
// Move the memory to another device // Move the memory to another device
void* new_ptr_ = nullptr; void* new_ptr_ = nullptr;
DeviceGuard gurad(device_id); CUDADeviceGuard gurad(device_id);
new_ptr_ = CUDAContext::New(nbytes_); new_ptr_ = CUDAContext::New(nbytes_);
CUDAContext::MemcpyEx<CUDAContext, CUDAContext>( CUDAContext::MemcpyEx<CUDAContext, CUDAContext>(
nbytes_, new_ptr_, cuda_ptr_, ptr_device_); nbytes_, new_ptr_, cuda_ptr_, ptr_device_);
......
...@@ -4,24 +4,31 @@ namespace dragon { ...@@ -4,24 +4,31 @@ namespace dragon {
bool OpSchema::Verify(const OperatorDef& def) const { bool OpSchema::Verify(const OperatorDef& def) const {
if (ignore_verify_) return true; if (ignore_verify_) return true;
string indicator = "[" + def.name() + ", " + def.type() + "]\n"; auto header = "[" + def.name() + ", " + def.type() + "]\n";
if (def.input_size() < min_input_ || def.input_size() > max_input_) { if (def.input_size() < min_input_ ||
LOG(FATAL) << indicator << "Input size: " << def.input_size() def.input_size() > max_input_) {
<< " is not in range [min=" << min_input_ LOG(FATAL)
<< ", max=" << max_input_ << "]"; << header << "Input size: " << def.input_size()
<< " is not in range [min=" << min_input_
<< ", max=" << max_input_ << "]";
} }
if (def.output_size() < min_output_ || def.output_size() > max_output_) { if (def.output_size() < min_output_ ||
LOG(FATAL) << indicator << "Output size: " << def.output_size() def.output_size() > max_output_) {
<< " is not in range [min=" << min_output_ LOG(FATAL)
<< ", max=" << max_output_ << "]"; << header << "Output size: " << def.output_size()
<< " is not in range [min=" << min_output_
<< ", max=" << max_output_ << "]";
} }
for (int in = 0; in < def.input_size(); in++) { for (int i = 0; i < def.input_size(); ++i) {
if (def.input(in) == "NULL") continue; if (def.input(i) == "NULL") continue;
for (int out = 0; out < def.output_size(); out++) { for (int j = 0; j < def.output_size(); ++j) {
if (def.output(out) == "NULL") continue; if (def.output(j) == "NULL") continue;
if (def.input(in) == def.output(out) && (!CheckInplace(in, out))) if (def.input(i) == def.output(j) &&
LOG(FATAL) << indicator << "Input(" << in << ") and " !CheckInplace(i, j))
<< "Output(" << out << ") can not be set to inplace."; LOG(FATAL)
<< header << "Input(" << i
<< ") and Output(" << j << ") "
<< "can not be set to inplace.";
} }
} }
return true; return true;
......
...@@ -54,7 +54,7 @@ __global__ void _Assign( ...@@ -54,7 +54,7 @@ __global__ void _Assign(
const T* x, \ const T* x, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Assign<T> \ _Assign \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ << < CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >> >( \
count, \ count, \
......
#include "utils/op_kernel.h"
#include "utils/math_utils.h"
#include "utils/omp_alternative.h"
namespace dragon {
namespace kernel {
/* <T = ?, Device = CPU> */
template <typename T>
void _MaskedAssign(
const int count,
const uint8_t* mask,
const T* x,
T* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = mask[i] ? x[i] : y[i];
}
}
/* Kernel Launchers */
#define DEFINE_ASSIGN_KERNEL_LAUNCHER(T) \
template<> void MaskedAssign<T, CPUContext>( \
const int count, \
const uint8_t* mask, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_MaskedAssign(count, mask, x, y); \
}
DEFINE_ASSIGN_KERNEL_LAUNCHER(bool);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(uint8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int64_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float16);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float);
DEFINE_ASSIGN_KERNEL_LAUNCHER(double);
#undef DEFINE_ASSIGN_KERNEL_LAUNCHER
} // namespace kernel
} // namepsace dragon
\ No newline at end of file
#ifdef WITH_CUDA
#include "core/context_cuda.h"
#include "utils/op_kernel.h"
namespace dragon {
namespace kernel {
/* <T = ?, Device = CUDA> */
template<typename T>
__global__ void _MaskedAssign(
const int nthreads,
const uint8_t* mask,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = mask[i] ? x[i] : y[i];
}
}
/* Kernel Launchers */
#define DEFINE_ASSIGN_KERNEL_LAUNCHER(T) \
template<> void MaskedAssign<T, CUDAContext>( \
const int count, \
const uint8_t* mask, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_MaskedAssign \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \
count, mask, x, y \
); \
}
DEFINE_ASSIGN_KERNEL_LAUNCHER(bool);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(uint8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int64_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float16);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float);
DEFINE_ASSIGN_KERNEL_LAUNCHER(double);
#undef DEFINE_ASSIGN_KERNEL_LAUNCHER
} // namespace kernel
} // namepsace dragon
#endif // WITH_CUDA
\ No newline at end of file
...@@ -49,7 +49,7 @@ void AssignOp<Context>::RunImpl() { ...@@ -49,7 +49,7 @@ void AssignOp<Context>::RunImpl() {
} else if (X(0).count() == X_.count()) { } else if (X(0).count() == X_.count()) {
x = X(0).template data<T, Context>(); x = X(0).template data<T, Context>();
} else { } else {
LOG(FATAL) LOG(FATAL)
<< "Could not assign " << "Could not assign "
<< X(0).DimString() << X(0).DimString()
<< " to " << " to "
...@@ -146,7 +146,10 @@ DEPLOY_CUDA(Assign); ...@@ -146,7 +146,10 @@ DEPLOY_CUDA(Assign);
#endif #endif
OPERATOR_SCHEMA(Assign) OPERATOR_SCHEMA(Assign)
.NumInputs(1).NumOutputs(1); /* V */
.NumInputs(1)
/* X */
.NumOutputs(1);
NO_GRADIENT(Assign); NO_GRADIENT(Assign);
......
#include "core/workspace.h"
#include "utils/op_kernel.h"
#include "utils/math_utils.h"
#include "utils/math_functions.h"
#include "operators/control_flow/masked_assign_op.h"
namespace dragon {
template <class Context> template <typename T>
void MaskedAssignOp<Context>::RunImpl() {
const T* x = nullptr;
auto* mask = X(1).template raw_data<Context>();
auto* y = Y(0)->template mutable_data<T, Context>();
if (X(0).count() < Y(0)->count()) {
int rows, cols;
auto* scratch = ws()
->template data<T, Context>
({ Y(0)->count() })[0];
auto* rx = X(0).template data<T, Context>();
if (utils::IsRowwiseBroadcast(
Y(0)->dims(), X(0).dims(),
&rows, &cols)) {
math::BroadcastSet(
rows, cols, 0,
rx, scratch, ctx()
);
} else if (utils::IsColwiseBroadcast(
Y(0)->dims(), X(0).dims(),
&rows, &cols)) {
math::BroadcastSet(
rows, cols, 1,
rx, scratch, ctx()
);
} else {
LOG(FATAL)
<< "Could not broadcast "
<< X(0).DimString()
<< " to "
<< Y(0)->DimString();
}
x = scratch;
} else if (X(0).count() == Y(0)->count()) {
x = X(0).template data<T, Context>();
} else {
LOG(FATAL)
<< "Could not assign "
<< X(0).DimString()
<< " to "
<< Y(0)->DimString();
}
kernel::MaskedAssign(
Y(0)->count(),
(const uint8_t*)mask,
x, y, ctx()
);
}
template <class Context>
void MaskedAssignOp<Context>::RunOnDevice() {
CHECK_EQ(X(1).count(), Y(0)->count())
<< "\nSize of mask and input should be equal.";
CHECK(XIsType(X(1), bool) || XIsType(X(1), uint8_t))
<< "\nExcepted bool or uint8 mask.";
if (XIsType(X(0), bool)) {
RunImpl<bool>();
} else if (XIsType(X(0), int8_t)) {
RunImpl<int8_t>();
} else if (XIsType(X(0), uint8_t)) {
RunImpl<uint8_t>();
} else if (XIsType(X(0), int)) {
RunImpl<int>();
} else if (XIsType(X(0), int64_t)) {
RunImpl<int64_t>();
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else if (XIsType(X(0), float)) {
RunImpl<float>();
} else if (XIsType(X(0), double)) {
RunImpl<double>();
} else {
LOG(FATAL) << DTypeString(X(0), {
"bool", "int8", "uint8", "int32", "int64",
"float16", "float32", "float64",
});
}
}
DEPLOY_CPU(MaskedAssign);
#ifdef WITH_CUDA
DEPLOY_CUDA(MaskedAssign);
#endif
OPERATOR_SCHEMA(MaskedAssign)
/* V, M */
.NumInputs(2)
/* X */
.NumOutputs(1);
NO_GRADIENT(MaskedAssign);
} // namespace dragon
\ No newline at end of file
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!