Commit 79a52211 by Ting PAN

Dismantle FP16 & TensorCore Support

1 parent bab9931e
Showing with 652 additions and 278 deletions
......@@ -133,6 +133,9 @@ class CUDAContext {
} else {
DeviceGuard gurad(gpu_id_);
CUBLAS_CHECK(cublasCreate_v2(&handle));
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
#endif
return handle;
}
}
......
......@@ -9,17 +9,17 @@
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
#ifndef DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ScaleOp : public Operator<Context> {
class AffineOp : public Operator<Context> {
public:
ScaleOp(const OperatorDef& op_def, Workspace* ws)
AffineOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {}
......@@ -30,14 +30,14 @@ class ScaleOp : public Operator<Context> {
protected:
TIndex axis, start_axis, num_axes;
TIndex inner_dim;
TIndex outer_dim, scale_dim, inner_dim;
Tensor* bias_multiplier;
};
template <class Context>
class ScaleGradientOp final : public Operator<Context> {
class AffineGradientOp final : public Operator<Context> {
public:
ScaleGradientOp(const OperatorDef& op_def, Workspace* ws)
AffineGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
......@@ -57,4 +57,4 @@ class ScaleGradientOp final : public Operator<Context> {
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
\ No newline at end of file
......@@ -33,7 +33,7 @@ class AccuracyOp final: public Operator<Context> {
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename Tx, typename Ty> void RunWithType();
protected:
TIndex top_k, axis, outer_dim, inner_dim, num_classes;
......
......@@ -9,8 +9,8 @@
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_
#define DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_
#ifndef DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
#define DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
#include "core/operator.h"
......@@ -62,4 +62,4 @@ class StopGradientOp final : public Operator<Context> {
} // namespace dragon
#endif // DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
\ No newline at end of file
......@@ -56,11 +56,13 @@ template <class Context>
class CuDNNConv2dOp : public Conv2dOp<Context> {
public:
CuDNNConv2dOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) {
: Conv2dOp<Context>(def, ws), enable_tensor_core(true) {
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else
cudnn_group = this->group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
......@@ -109,17 +111,20 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
size_t workspace_fwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
bool enable_tensor_core;
};
template <class Context>
class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
public:
CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dGradientOp<Context>(def, ws) {
: Conv2dGradientOp<Context>(def, ws), enable_tensor_core(true) {
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else
cudnn_group = this->group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
......@@ -168,6 +173,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
size_t workspace_bwd_filter_size, workspace_bwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
bool enable_tensor_core;
};
#endif // WITH_CUDNN
......
......@@ -60,11 +60,13 @@ template <class Context>
class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
public:
CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) {
: Conv2dTransposeOp<Context>(def, ws), enable_tensor_core(true) {
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else
cudnn_group = this->group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
......@@ -112,17 +114,20 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
size_t workspace_fwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
bool enable_tensor_core;
};
template <class Context>
class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context> {
public:
CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeGradientOp<Context>(def, ws) {
: Conv2dTransposeGradientOp<Context>(def, ws), enable_tensor_core(true) {
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else
cudnn_group = this->group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
......@@ -171,6 +176,7 @@ public:
size_t workspace_bwd_filter_size, workspace_bwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
bool enable_tensor_core;
};
#endif // WITH_CUDNN
......
......@@ -26,9 +26,9 @@ template<> inline int dragon_cast<int, float>(float val) {
return static_cast<int>(val);
}
template<> inline float dragon_cast<float, float>(float val) {
return val;
}
template<> inline float dragon_cast<float, float>(float val) { return val; }
template<> inline float16 dragon_cast<float16, float16>(float16 val) { return val; }
template<> inline float16 dragon_cast<float16, float>(float val) {
float16 ret;
......
......@@ -77,7 +77,7 @@ inline int GET_BLOCKS(const int N) {
#define __hdiv hdiv
#endif
inline int NUM_DEVICES() {
inline int CUDA_NUM_DEVICES() {
static int count = -1;
if (count < 0) {
auto err = cudaGetDeviceCount(&count);
......@@ -86,21 +86,47 @@ inline int NUM_DEVICES() {
return count;
}
inline int CURRENT_DEVICE() {
inline int CUDA_CURRENT_DEVICE() {
int gpu_id;
cudaGetDevice(&gpu_id);
return gpu_id;
}
inline int POINTER_DEVICE(const void* ptr) {
inline int CUDA_POINTER_DEVICE(const void* ptr) {
cudaPointerAttributes attr;
CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
return attr.device;
}
struct CUDADeviceProps {
CUDADeviceProps() : props(CUDA_NUM_DEVICES()) {
for (int i = 0; i < CUDA_NUM_DEVICES(); ++i)
CUDA_CHECK(cudaGetDeviceProperties(&props[i], i));
}
vector<cudaDeviceProp> props;
};
inline const cudaDeviceProp& GetDeviceProperty(const int device_id) {
static CUDADeviceProps props;
CHECK_LT(device_id, (int)props.props.size())
<< "Invalid device id: " << device_id
<< "\nDetected " << props.props.size() << " eligible cuda devices.";
return props.props[device_id];
}
inline bool TENSOR_CORE_AVAILABLE() {
#if CUDA_VERSION < 9000
return false;
#else
int device = CUDA_CURRENT_DEVICE();
auto& prop = GetDeviceProperty(device);
return prop.major >= 7;
#endif
}
class DeviceGuard {
public:
DeviceGuard(int newDevice) : previous_(CURRENT_DEVICE()) {
DeviceGuard(int newDevice) : previous_(CUDA_CURRENT_DEVICE()) {
if (previous_ != newDevice)
CUDA_CHECK(cudaSetDevice(newDevice));
}
......
......@@ -150,6 +150,28 @@ void Tanh(const int count, const T* x, T* y);
template <typename T, class Context>
void TanhGrad(const int count, const T* dy, const T* y, T* dx);
/******************** arithmetic.affine ********************/
template <typename T, class Context>
void Affine(const int count,
const int outer_dim,
const int scale_dim,
const int inner_dim,
const T* x,
const T* alpha,
const T* beta,
const T* beta_multiplier,
T* y);
template <typename T, class Context>
void AffineGrad(const int count,
const int outer_dim,
const int scale_dim,
const int inner_dim,
const T* dy,
const T* alpha,
T* dx);
/******************** arithmetic.bias_add ********************/
template <typename T, class Context>
......@@ -172,24 +194,6 @@ void Clip(const int count,
T* mask,
T* y);
/******************** arithmetic.scale ********************/
template <typename T, class Context>
void Scale(const int axis,
Tensor* x,
Tensor* gamma,
Tensor* beta,
Tensor* BMul,
Tensor* y);
template <typename T, class Context>
void ScaleGrad(const int axis, Tensor* dy, Tensor* gamma, Tensor* dx);
/******************** cast.float2half ********************/
template <typename T, class Context>
void Float2Half(const int count, const float* x, float16* y);
/******************** control_flow.compare ********************/
template <typename T, class Context>
......@@ -286,7 +290,7 @@ void SparseSoftmaxFocalLossGrad(const int count,
Tensor* ignore,
T* dx);
/******************** misc.dtype ********************/
/******************** misc.astype ********************/
template <typename Ta, typename Tb, class Context>
void TypeA2B(const int count, const Ta* a, Tb* b);
......
......@@ -88,7 +88,7 @@ std::string CreateGraph(const std::string& graph_file, const Device& device, Wor
// overwritten device options
DeviceOption* device_option = meta_graph.mutable_device_option();
device_option->set_device_type((DeviceType)device.device_type());
device_option->set_gpu_id(device.device_id());
device_option->set_device_id(device.device_id());
device_option->set_engine("CUDNN");
dragon::GraphBase* graph = ws->CreateGraph(meta_graph);
if (!graph) LOG(FATAL) << "Can not create the graph.";
......
......@@ -80,7 +80,8 @@ def FromShape(shape, dtype='float32', ctx=None, name=None):
The wrapper of ``TensorFromShapeCC``.
"""
tensor = Tensor(name)
if name is None: tensor = Tensor(name=name)
else: tensor = Tensor(_name=name)
if not isinstance(shape, (tuple, list)):
raise TypeError('The shape should be a tuple or list.')
if ctx is None: ctx = MakeDeviceOption(0, 0) # CPUContext
......@@ -134,7 +135,8 @@ def FromPyArray(array, name=None):
The wrapper of ``TensorFromPyArrayCC``.
"""
tensor = Tensor(name)
if name is None: tensor = Tensor(name=name)
else: tensor = Tensor(_name=name)
if not isinstance(array, np.ndarray):
raise TypeError('The given nd-array should be numpy.ndarray.')
TensorFromPyArrayCC(_stringify_tensor(tensor), array)
......@@ -258,15 +260,18 @@ def GetTensorInfo(tensor, stream=1):
The string info contains following fields:
stream #1: ``dtype``, ``from_numpy``, ``init``
``mem``, ``mem_at``, ``device_id``
stream #1: ``dtype``, ``from_numpy``, ``init``, ``mem``, ``mem_at``, ``device_id``
stream #2: ``shape``
stream #3: #1 + #2
Parameters
----------
tensor : Tensor or str
The input tensor.
stream : int
The stream id.
Returns
-------
......
......@@ -14,7 +14,6 @@ Common
operators/ndarray
operators/control_flow
operators/misc
operators/cast
operators/mpi
=================================== =====================================================================
......@@ -26,7 +25,6 @@ List Brief
`dragon.operators.ndarray`_ The ndarray operators.
`dragon.operators.control_flow`_ The control flow operators.
`dragon.operators.misc`_ The misc operators.
`dragon.operators.cast`_ The cast operators.
`dragon.operators.mpi`_ The MPI operators.
=================================== =====================================================================
......@@ -95,7 +93,6 @@ List Brief
.. _dragon.operators.ndarray: operators/ndarray.html
.. _dragon.operators.control_flow: operators/control_flow.html
.. _dragon.operators.misc: operators/misc.html
.. _dragon.operators.cast: operators/cast.html
.. _dragon.operators.mpi: operators/mpi.html
.. _dragon.operators.activation: operators/activation.html
.. _dragon.operators.vision: operators/vision.html
......
......@@ -101,7 +101,7 @@ List Brief
`Matmul`_ Matrix Multiplication.
`InnerProduct`_ InnerProduct Function.
`Eltwise`_ Eltwise Sum/Product Function.
`Scale`_ Scale Function.
`Affine`_ Calculate ``y = Ax + b`` along the given range of axes.
`GramMatrix`_ Calculate the gram matrix, introduced by `[Gatys et.al, 2016] <https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf>`_.
=============== ======================================================================
......@@ -244,7 +244,7 @@ List Brief
.. _Dot: operators/arithmetic.html#dragon.operators.arithmetic.Dot
.. _InnerProduct: operators/arithmetic.html#dragon.operators.arithmetic.InnerProduct
.. _Eltwise: operators/arithmetic.html#dragon.operators.arithmetic.Eltwise
.. _Scale: operators/arithmetic.html#dragon.operators.arithmetic.Scale
.. _Affine: operators/arithmetic.html#dragon.operators.arithmetic.Affine
.. _GramMatrix: operators/arithmetic.html#dragon.operators.arithmetic.GramMatrix
.. _BatchNorm: operators/norm.html#dragon.operators.norm.BatchNorm
......
......@@ -39,9 +39,9 @@ class DataBatch(object):
source : str
The path of database.
multiple_nodes: boolean
Whether to split data for multiple parallel nodes.
Whether to split data for multiple parallel nodes. Default is ``False``.
shuffle : boolean
Whether to shuffle the data.
Whether to shuffle the data. Default is ``False``.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
......@@ -132,7 +132,7 @@ class DataBatch(object):
part_idx = i
if self._readers[i]._multiple_nodes or \
self._readers[i]._use_shuffle:
self._readers[i]._use_shuffle:
num_parts *= group_size
part_idx += local_rank * self._num_readers
......
......@@ -36,11 +36,9 @@ class DataReader(Process):
source : str
The path of database.
multiple_nodes: boolean
Whether to split data for multiple parallel nodes.
Whether to split data for multiple parallel nodes. Default is ``False``.
shuffle : boolean
Whether to shuffle the data.
instance_chunk : boolean
Whether to limit each chunk with at most 1 instance.
Whether to shuffle the data. Default is ``False``.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
......@@ -51,7 +49,6 @@ class DataReader(Process):
self._source = kwargs.get('source', '')
self._multiple_nodes = kwargs.get('multiple_nodes', False)
self._use_shuffle = kwargs.get('shuffle', False)
self._use_instance_chunk = kwargs.get('instance_chunk', False)
self._num_chunks = kwargs.get('num_chunks', 2048)
self._chunk_size = kwargs.get('chunk_size', -1)
......@@ -172,19 +169,27 @@ class DataReader(Process):
self._db_zfill = int(self._db.get('zfill'))
self._epoch_size = int(self._db_size / self._num_parts + 1)
if self._use_instance_chunk:
self._chunk_size = 1
self._num_shuffle_parts = int(self._db_size / self._chunk_size / self._num_parts) + 1
if self._use_shuffle:
if self._chunk_size == 1:
# each chunk has at most 1 record [For Fully Shuffle]
self._num_shuffle_parts = int(self._db_size / self._chunk_size / self._num_parts) + 1
else:
if self._use_shuffle and self._chunk_size == -1:
# search a optimal chunk size by chunks [For Chunk Shuffle]
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20)))
min_chunk_size = 1
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2
self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._db_size / self._num_shuffle_parts / self._num_parts + 1)
else:
# search a optimal chunk size by chunks
if self._chunk_size == -1:
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20)))
min_chunk_size = 1
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2
self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._db_size / self._num_shuffle_parts / self._num_parts + 1)
# each chunk has at most K records [For Multiple Nodes]
# note that if ``shuffle`` and ``multiple_nodes`` are all ``False``,
# ``chunk_size`` and ``num_shuffle_parts`` are meaningless
self._chunk_size = int(self._db_size / self._num_parts) + 1
self._num_shuffle_parts = 1
self._perm = np.arange(self._num_shuffle_parts)
# init env
......
......@@ -428,8 +428,8 @@ def Sqrt(inputs, **kwargs):
return output
def Scale(inputs, axis=1, num_axes=1, **kwargs):
"""Scale Function.
def Affine(inputs, axis=1, num_axes=1, **kwargs):
"""Calculate ``y = Ax + b`` along the given range of axes.
The number of inputs vary from ``2`` to ``3`` (Without or With ``bias``).
......@@ -457,7 +457,7 @@ def Scale(inputs, axis=1, num_axes=1, **kwargs):
CheckInputs(inputs, 2, 3)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='Scale', **arguments)
output = Tensor.CreateOperator(nout=1, op_type='Affine', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
......
......@@ -90,7 +90,7 @@ Square = math.Square
Sqrt = math.Sqrt
InnerProduct = math.InnerProduct
Eltwise = math.Eltwise
Scale = math.Scale
Affine = math.Affine
GramMatrix = math.GramMatrix
# normalization
......
......@@ -494,7 +494,7 @@ class ScaleLayer(Layer):
def Setup(self, bottom):
super(ScaleLayer, self).Setup(bottom)
return ops.Scale(bottom + [blob['data'] for blob in self._blobs], **self._param)
return ops.Affine(bottom + [blob['data'] for blob in self._blobs], **self._param)
class BNLayer(Layer):
......@@ -606,11 +606,13 @@ class NormalizeLayer(Layer):
def __init__(self, LayerParameter):
super(NormalizeLayer, self).__init__(LayerParameter)
param = LayerParameter.normalize_param
self._l2norm_param = {'axis': 1,
'num_axes': -1 if param.across_spatial else 1,
'eps': param.eps}
self._scale_param = {'axis': 1,
'num_axes': 0 if param.channel_shared else 1}
self._l2norm_param = {
'axis': 1,
'num_axes': -1 if param.across_spatial else 1,
'eps': param.eps}
self._scale_param = {
'axis': 1,
'num_axes': 0 if param.channel_shared else 1}
scope = LayerParameter.name
scale = Tensor(scope + '/param:0')
if param.HasField('scale_filler'):
......@@ -622,8 +624,8 @@ class NormalizeLayer(Layer):
def Setup(self, bottom):
super(NormalizeLayer, self).Setup(bottom)
norm_out = [ops.L2Norm(bottom[0], **self._l2norm_param)]
scale_out = ops.Scale(norm_out + [blob['data'] for blob in self.scale_blobs],
**self._scale_param)
scale_out = ops.Affine(norm_out + [blob['data'] for blob in self.scale_blobs],
**self._scale_param)
return scale_out
......
......@@ -71,7 +71,7 @@ class DataLayer(Layer):
'min_random_scale': transform_param.min_random_scale,
'max_random_scale': transform_param.max_random_scale,
'shuffle': parallel_param.shuffle,
'node_step': parallel_param.node_step,
'multiple_nodes': parallel_param.multiple_nodes,
'partition': parallel_param.partition}
def Setup(self, bottom):
......
......@@ -112,6 +112,9 @@ class Module(object):
module.state_dict(destination, prefix + name + '.', to_numpy=to_numpy)
return destination
def _load_state_dict_key_mismatch(self, full_name, name, is_missing):
pass
def load_state_dict(self, state_dict, strict=True):
logger.info('Load the state dict from numpy arrays.')
def submodule_key_mismatch(full_name, is_missing):
......@@ -159,7 +162,7 @@ class Module(object):
', '.join('"{}"'.format(k) for k in unexpected))
if len(missing) > 0:
error_msg += 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected))
', '.join('"{}"'.format(k) for k in missing))
if len(error_msg) > 0:
raise KeyError(error_msg)
......@@ -210,9 +213,9 @@ class Module(object):
def __call__(self, *args, **kwargs):
with dg.name_scope(get_module_name(self)):
return self.forward(*args)
return self.forward(*args, **kwargs)
def forward(self, *inputs):
def forward(self, *inputs, **kwargs):
raise NotImplementedError('The base module can not be called.')
def name_scope(self, remove_separator=True):
......
......@@ -16,12 +16,13 @@ as it will be reused by the ``torch.ops``.
from dragon.vm.torch.module import Module
from dragon.vm.torch.tensor import Parameter
from .modules.conv import Conv2d
from .modules.conv import Conv2d, ConvTranspose2d
from .modules.pooling import MaxPool2d, AvgPool2d
from .modules.activation import ReLU, Softmax
from .modules.activation import ReLU, Sigmoid, Softmax
from .modules.linear import Linear
from .modules.loss import CrossEntropyLoss
from .modules.container import Container, Sequential, ModuleList
from .modules.batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d
from .modules.affine import Affine
from .modules.dropout import Dropout, Dropout2d, Dropout3d
from . import init
\ No newline at end of file
......@@ -35,6 +35,25 @@ class ReLU(Module):
return self.run(inputs, outputs)
class Sigmoid(Module):
def __init__(self, inplace=False):
super(Sigmoid, self).__init__()
self._inplace = inplace
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Sigmoid',
'n_inputs': 1, 'n_outputs': 1,
'arguments': {}
}
def forward(self, x):
inputs = [x]; self.unify_devices(inputs)
outputs = [x if self._inplace else self.register_output(x.dtype)]
return self.run(inputs, outputs)
class Softmax(Module):
def __init__(self, dim=None):
super(Softmax, self).__init__()
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.nn import Module, Parameter
from dragon.vm.torch.ops.creation import zeros, ones
class Affine(Module):
def __init__(self, num_features, bias=True, fix_weight=False, fix_bias=False):
super(Affine, self).__init__()
self.num_features = num_features
self.weight = Parameter(ones(num_features), requires_grad=not fix_weight)
if bias:
self.bias = Parameter(zeros(num_features), requires_grad=not fix_bias)
else:
self.bias = None
self.inputs = [self.weight, self.bias] if bias else [self.weight]
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Affine',
'n_inputs': 3 if self.bias else 2, 'n_outputs': 1,
'arguments': {
'axis': 1, # Data format: NCHW
'num_axes': 1,
}
}
def forward(self, input):
inputs = [input] + self.inputs
self.unify_devices(inputs)
outputs = [self.register_output(input.dtype)]
return self.run(inputs, outputs)
\ No newline at end of file
......@@ -20,8 +20,8 @@ from dragon.vm.torch.module import RunOperator
class _BatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(_BatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
......@@ -103,6 +103,10 @@ class _BatchNorm(Module):
self.unify_devices(inputs)
outputs = [self.register_output(input.dtype)]
phase = 'TRAIN' if input.requires_grad else 'TEST'
# Normalize the input by using batch stats ALWAYS
# Note that the update of moving average is meaningless(
# Because we can not remove it. Why? Ask nvidia and cuDNN -:)
if not self.track_running_stats: phase = 'TRAIN'
meta = ['PERSISTENT',] + self.make_meta_from_phase(phase)
return RunOperator(inputs, outputs, meta)
......
......@@ -52,7 +52,8 @@ class _ConvNd(Module):
def register_op(self):
self.op_meta = {
'op_type': 'Conv2d',
'op_type': 'Conv{}d{}'.format(len(self.kernel_size),
'Transpose' if self.transposed else ''),
'n_inputs': 3 if self.bias else 2, 'n_outputs': 1,
'arguments': {
'num_output': self.weight.shape[0],
......@@ -105,4 +106,22 @@ class Conv2d(_ConvNd):
inputs = [input, self.weight] + ([self.bias] if self.bias else [])
self.unify_devices(inputs)
outputs = [self.register_output(input.dtype)]
return self.run(inputs, outputs)
class ConvTranspose2d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(ConvTranspose2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, _pair(0), groups, bias)
def forward(self, input):
inputs = [input, self.weight] + ([self.bias] if self.bias else [])
self.unify_devices(inputs)
outputs = [self.register_output(input.dtype)]
return self.run(inputs, outputs)
\ No newline at end of file
......@@ -20,4 +20,8 @@ from .arithmetic import (
from .ndarray import (
sum, mean, argmin, argmax, max, topk, cat
)
from .vision import (
nn_resize, bilinear_resize, roi_pool, roi_align
)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.ops.modules.base import BaseModule
class Resize2d(BaseModule):
def __init__(self, key, ctx, **kwargs):
super(Resize2d, self).__init__(key, ctx, **kwargs)
self.op_type = kwargs.get('op_type', 'NNResize')
self.dsize = kwargs.get('dsize', None)
self.fx = kwargs.get('fx', None)
self.fy = kwargs.get('fy', None)
self.register_arguments()
self.register_op()
def register_arguments(self):
if self.dsize:
self.dsize = [self.register_argument('dsize[{}]'.format(i))
for i in range(2)]
def register_op(self):
self.op_meta = {
'op_type': self.op_type,
'n_inputs': 1, 'n_outputs': 1,
'arguments': {
'dsize_desc': [d for d in self.dsize] if self.dsize else None,
'fx': self.fx, 'fy': self.fy,
'data_format': 'NCHW',
}
}
def forward(self, input, dsize=None):
inputs = [input]; self.unify_devices(inputs)
outputs = [self.register_output(input.dtype)]
if dsize is not None:
for ix, d in enumerate(dsize):
self.set_argument_i(self.dsize[ix], d)
return self.run(inputs, outputs)
class RoIPool(BaseModule):
def __init__(self, key, ctx, **kwargs):
super(RoIPool, self).__init__(key, ctx, **kwargs)
self.pool_h = kwargs.get('pooled_h', 0)
self.pool_w = kwargs.get('pooled_w', 0)
self.spatial_scale = kwargs.get('spatial_scale', 1.0)
self.register_arguments()
self.register_op()
def register_arguments(self):
"""No arguments for roi-pool op."""
pass
def register_op(self):
self.op_meta = {
'op_type': 'ROIPooling',
'n_inputs': 2, 'n_outputs': 1,
'arguments': {
'pool_h': self.pool_h, 'pool_w': self.pool_w,
'spatial_scale': self.spatial_scale,
}
}
def forward(self, feature, rois, dsize=None):
inputs = [feature, rois]; self.unify_devices(inputs)
outputs = [self.register_output(feature.dtype)]
return self.run(inputs, outputs)
class RoIAlign(BaseModule):
def __init__(self, key, ctx, **kwargs):
super(RoIAlign, self).__init__(key, ctx, **kwargs)
self.pool_h = kwargs.get('pooled_h', 0)
self.pool_w = kwargs.get('pooled_w', 0)
self.spatial_scale = kwargs.get('spatial_scale', 1.0)
self.sampling_ratio = kwargs.get('sampling_ratio', 2)
self.register_arguments()
self.register_op()
def register_arguments(self):
"""No arguments for roi-pool op."""
pass
def register_op(self):
self.op_meta = {
'op_type': 'ROIAlign',
'n_inputs': 2, 'n_outputs': 1,
'arguments': {
'pool_h': self.pool_h, 'pool_w': self.pool_w,
'spatial_scale': self.spatial_scale,
'sampling_ratio': self.sampling_ratio,
}
}
def forward(self, feature, rois, dsize=None):
inputs = [feature, rois]; self.unify_devices(inputs)
outputs = [self.register_output(feature.dtype)]
return self.run(inputs, outputs)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.ops.primitive import MakeContext
from dragon.vm.torch.ops.factory import get_module
from dragon.vm.torch.ops.modules.vision import Resize2d
from dragon.vm.torch.ops.modules.vision import RoIPool, RoIAlign
def _resize_2d(input, op_type, dsize, fx, fy):
if dsize is None:
if fx < 0 or fy < 0:
raise ValueError('Set fx and fy if dsize is None.')
else:
if len(dsize) != 2:
raise ValueError('The dsize should be a list with 2 elements.')
if dsize is None and (fy == -1.0 or fx == -1.0):
raise RuntimeError('The dsize, fx/fy should be specified either.')
ctx = MakeContext(inputs=[input])
key = 'torch/ops/{}/{}:{}/dsize:{}/fx:{}/fy:{}'.format(
op_type.lower(), ctx[0].lower(), ctx[1], '2' if dsize else 'none', fx, fy)
module = get_module(Resize2d, key, ctx,
op_type=op_type, dsize=dsize, fx=fx, fy=fy)
return module.forward(input, dsize)
def nn_resize(input, dsize, fx=-1.0, fy=-1.0):
return _resize_2d(input, 'NNResize', dsize, fx, fy)
def bilinear_resize(input, dsize, fx=-1.0, fy=-1.0):
return _resize_2d(input, 'BilinearResize', dsize, fx, fy)
def roi_pool(feature, rois, pooled_h, pooled_w, spatial_scale):
ctx = MakeContext(inputs=[feature])
key = 'torch/ops/roi_pool/{}:{}/pool_h:{}/pool_w:{}/spatial_scale:{}'.format(
ctx[0].lower(), ctx[1], pooled_h, pooled_w, spatial_scale)
module = get_module(RoIPool, key, ctx, pooled_h=pooled_h,
pooled_w=pooled_w, spatial_scale=spatial_scale)
return module.forward(feature, rois)
def roi_align(feature, rois, pooled_h, pooled_w,
spatial_scale, sampling_ratio=2):
ctx = MakeContext(inputs=[feature])
key = 'torch/ops/roi_align/{}:{}/pool_h:{}/pool_w:{}/' \
'spatial_scale:{}/sampling_ratio:{}'.format(
ctx[0].lower(), ctx[1], pooled_h, pooled_w, spatial_scale, sampling_ratio)
module = get_module(RoIAlign, key, ctx, pooled_h=pooled_h,
pooled_w=pooled_w, spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
return module.forward(feature, rois)
\ No newline at end of file
......@@ -110,7 +110,7 @@ class Optimizer(object):
_update(p, g, op_type=self._update_type,
slot=group['slot'],
lr_mult=group.get('lr_mult', 1.0),
decay_mult=group.get('lr_mult', 1.0))
decay_mult=group.get('decay_mult', 1.0))
def zero_grad(self):
"""Set all gradients to zeros.
......
......@@ -144,6 +144,17 @@ class Tensor(object):
"Use .cpu() to move the tensor to host memory first.")
return tensor_utils.ToPyArray(self._dg_tensor)
def numpy_ex(self):
"""Create a numpy const nd-array sharing this tensor.
Returns
-------
numpy.ndarray
The numpy nd-array.
"""
return tensor_utils.ToPyArrayEx(self._dg_tensor)
def dragon(self):
"""Create a dragon tensor sharing this tensor.
......@@ -533,7 +544,7 @@ class Tensor(object):
Parameters
----------
args : tuple
args : tuple or int
The new size.
Returns
......@@ -606,6 +617,21 @@ class Tensor(object):
"""
self.fill_(0.)
def one_(self):
"""Fills self tensor with ones.
Parameters
----------
value : numerical type
Returns
-------
vm.torch.Tensor
The self.
"""
self.fill_(1.)
def uniform_(self, low=0, high=1):
"""Fill self tensor with the specified uniform distribution.
......
......@@ -34,7 +34,7 @@ class _DataLoaderIter(object):
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False,
partition=False, multiple_nodes=False, instance_chunk=False):
partition=False, multiple_nodes=False, num_chunks=2048, chunk_size=-1):
"""A MPI-Aware DataLoader. Forked from ``dragon.io``.
Parameters
......@@ -43,16 +43,16 @@ class DataLoader(object):
The dataset.
batch_size : int
The batch size. Divided by n mpi-nodes if ``partition`` is True.
instance_chunk : boolean
Whether to limit each chunk with at most 1 instance.
shuffle : boolean
Whether to shuffle the data.
partition : boolean
Whether to partition batch. Default is ``False``.
multiple_nodes: boolean
Whether to split data for multiple parallel nodes.
instance_chunk : boolean
Whether to limit each chunk with at most 1 instance.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
"""
self.dataset = dataset
......@@ -65,7 +65,8 @@ class DataLoader(object):
'source': dataset.database,
'multiple_nodes': multiple_nodes,
'shuffle': shuffle,
'instance_chunk': instance_chunk,
'num_chunks': num_chunks,
'chunk_size': chunk_size,
'batch_size': batch_size,
'partition': partition,
'transform': dataset.transform,
......
......@@ -39,11 +39,9 @@ class DataBatch(object):
source : str
The path of database.
multiple_nodes: boolean
Whether to split data for multiple parallel nodes.
Whether to split data for multiple parallel nodes. Default is ``False``.
shuffle : boolean
Whether to shuffle the data.
instance_chunk : boolean
Whether to limit each chunk with at most 1 instance.
Whether to shuffle the data. Default is ``False``.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
......
......@@ -36,11 +36,9 @@ class DataReader(Process):
source : str
The path of database.
multiple_nodes: boolean
Whether to split data for multiple parallel nodes.
Whether to split data for multiple parallel nodes. Default is ``False``.
shuffle : boolean
Whether to shuffle the data.
instance_chunk : boolean
Whether to limit each chunk with at most 1 instance.
Whether to shuffle the data. Default is ``False``.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
......@@ -51,7 +49,6 @@ class DataReader(Process):
self._source = kwargs.get('source', '')
self._multiple_nodes = kwargs.get('multiple_nodes', False)
self._use_shuffle = kwargs.get('shuffle', False)
self._use_instance_chunk = kwargs.get('instance_chunk', False)
self._num_chunks = kwargs.get('num_chunks', 2048)
self._chunk_size = kwargs.get('chunk_size', -1)
......@@ -172,19 +169,27 @@ class DataReader(Process):
self._db_zfill = int(self._db.get('zfill'))
self._epoch_size = int(self._db_size / self._num_parts + 1)
if self._use_instance_chunk:
self._chunk_size = 1
self._num_shuffle_parts = int(self._db_size / self._chunk_size / self._num_parts) + 1
if self._use_shuffle:
if self._chunk_size == 1:
# each chunk has at most 1 record [For Fully Shuffle]
self._num_shuffle_parts = int(self._db_size / self._chunk_size / self._num_parts) + 1
else:
if self._use_shuffle and self._chunk_size == -1:
# search a optimal chunk size by chunks [For Chunk Shuffle]
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20)))
min_chunk_size = 1
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2
self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._db_size / self._num_shuffle_parts / self._num_parts + 1)
else:
# search a optimal chunk size by chunks
if self._chunk_size == -1:
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20)))
min_chunk_size = 1
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2
self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._db_size / self._num_shuffle_parts / self._num_parts + 1)
# each chunk has at most K records [For Multiple Nodes]
# note that if ``shuffle`` and ``multiple_nodes`` are all ``False``,
# ``chunk_size`` and ``num_shuffle_parts`` are meaningless
self._chunk_size = int(self._db_size / self._num_parts) + 1
self._num_shuffle_parts = 1
self._perm = np.arange(self._num_shuffle_parts)
# init env
......
......@@ -13,7 +13,7 @@
#
# ------------------------------------------------------------
import torch
import dragon.vm.torch as torch
import hashlib
import os
......@@ -44,28 +44,6 @@ HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
def load_url(url, model_dir=None, map_location=None, progress=True):
r"""Loads the Torch serialized object at the given URL.
If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
The default value of `model_dir` is ``$TORCH_HOME/models`` where
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
overriden with the ``$TORCH_MODEL_ZOO`` environment variable.
Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
progress (bool, optional): whether or not to display a progress bar to stderr
Example:
>>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
"""
if model_dir is None:
torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
......
......@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules()
setup(name = 'dragon',
version='0.2.2.1',
version='0.2.2.2',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/seetaresearch/Dragon',
author='Ting Pan',
......
......@@ -12,7 +12,7 @@ CUDAObject CUDAContext::cuda_object_;
template<> void CPUContext::Memcpy<CPUContext, CUDAContext>(
size_t nbytes, void* dst, const void* src) {
#ifdef WITH_CUDA
CUDAContext ctx(POINTER_DEVICE(src));
CUDAContext ctx(CUDA_POINTER_DEVICE(src));
ctx.Memcpy<CPUContext, CUDAContext>(nbytes, dst, src);
#else
LOG(FATAL) << "CUDA was not compiled.";
......@@ -23,7 +23,7 @@ template<> void CPUContext::Memcpy<CPUContext, CUDAContext>(
template<> void CPUContext::Memcpy<CUDAContext, CPUContext>(
size_t nbytes, void* dst, const void* src) {
#ifdef WITH_CUDA
CUDAContext ctx(POINTER_DEVICE(dst));
CUDAContext ctx(CUDA_POINTER_DEVICE(dst));
ctx.Memcpy<CUDAContext, CPUContext>(nbytes, dst, src);
#else
LOG(FATAL) << "CUDA was not compiled.";
......
......@@ -128,8 +128,8 @@ MixedMemory::~MixedMemory() {
void MixedMemory::SwitchToDevice() {
if (cuda_ptr_) {
#ifdef WITH_CUDA
int ptr_device = POINTER_DEVICE(cuda_ptr_);
int cur_device = CURRENT_DEVICE();
int ptr_device = CUDA_POINTER_DEVICE(cuda_ptr_);
int cur_device = CUDA_CURRENT_DEVICE();
if (ptr_device != cur_device) state_ = SWITCHED;
#endif
}
......@@ -139,7 +139,7 @@ void MixedMemory::SwitchToCUDADevice(int device_id) {
#ifdef WITH_CUDA
DeviceGuard gurad(device_id);
if (cuda_ptr_) {
int ptr_device = POINTER_DEVICE(cuda_ptr_);
int ptr_device = CUDA_POINTER_DEVICE(cuda_ptr_);
if (ptr_device != device_id) state_ = SWITCHED;
}
ToCUDA();
......@@ -164,7 +164,7 @@ const Map<string, string> MixedMemory::info() const {
}
s2s["mem_at"] = _state_;
if (cpu_ptr_) s2s["CPU"] = "0";
if (cuda_ptr_) s2s["CUDA"] = dragon_cast<string, int>(POINTER_DEVICE(cuda_ptr_));
if (cuda_ptr_) s2s["CUDA"] = dragon_cast<string, int>(CUDA_POINTER_DEVICE(cuda_ptr_));
return s2s;
}
......
#include "operators/arithmetic/scale_op.h"
#include "operators/arithmetic/affine_op.h"
#include "core/workspace.h"
#include "utils/filler.h"
#include "utils/op_kernel.h"
......@@ -6,7 +6,7 @@
namespace dragon {
template <class Context> template <typename T>
void ScaleOp<Context>::RunWithType() {
void AffineOp<Context>::RunWithType() {
start_axis = axis;
if (start_axis < 0) start_axis += (int)Input(0).ndim();
if (num_axes == -1) num_axes = (int)Input(0).ndim() - start_axis;
......@@ -18,27 +18,29 @@ void ScaleOp<Context>::RunWithType() {
const vector<TIndex>::const_iterator& dim_start = Input(0).dims().begin() + start_axis;
const vector<TIndex>::const_iterator& dim_end = dim_start + num_axes;
vector<TIndex> param_dims(dim_start, dim_end);
TENSOR_FILL(Input(1), param_dims);
TENSOR_FILL(Input(1), param_dims);;
outer_dim = Input(0).count(0, start_axis);
inner_dim = Input(0).count(start_axis + num_axes);
scale_dim = Input(1).count();
if (InputSize() > 2) {
TENSOR_FILL(Input(2), param_dims);
inner_dim = Input(0).count(start_axis + num_axes);
INIT_MULTIPLIER(bias_multiplier, inner_dim);
}
if (InputSize() > 2) {
kernel::Scale<T, Context>(start_axis, &Input(0), &Input(1),
&Input(2), bias_multiplier,
Output(0));
} else {
kernel::Scale<T, Context>(start_axis, &Input(0), &Input(1),
nullptr, nullptr,
Output(0));
}
auto* Xdata = Input(0).template data<T, Context>();
auto* Adata = Input(1).template data<T, Context>();
auto* Bdata = InputSize() > 2 ? Input(2).template data<T, Context>() : nullptr;
auto* BMdata = InputSize() > 2 ? bias_multiplier->template data<T, Context>() : nullptr;
auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Affine<T, Context>(Output(0)->count(),
outer_dim, scale_dim, inner_dim,
Xdata, Adata, Bdata, BMdata,
Ydata);
}
template <class Context>
void ScaleOp<Context>::RunOnDevice() {
void AffineOp<Context>::RunOnDevice() {
Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) RunWithType<float>();
......@@ -46,14 +48,14 @@ void ScaleOp<Context>::RunOnDevice() {
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
}
DEPLOY_CPU(Scale);
DEPLOY_CPU(Affine);
#ifdef WITH_CUDA
DEPLOY_CUDA(Scale);
DEPLOY_CUDA(Affine);
#endif
OPERATOR_SCHEMA(Scale).NumInputs(2, 3).NumOutputs(1);
OPERATOR_SCHEMA(Affine).NumInputs(2, 3).NumOutputs(1);
template <class Context> template <typename T>
void ScaleGradientOp<Context>::BiasRunWithType() {
void AffineGradientOp<Context>::BiasRunWithType() {
Output(2)->ReshapeLike(Input(1));
INIT_MULTIPLIER(bias_multiplier, inner_dim);
auto* BMul_data = this->bias_multiplier->template data<T, Context>();
......@@ -71,7 +73,7 @@ void ScaleGradientOp<Context>::BiasRunWithType() {
}
template <class Context> template <typename T>
void ScaleGradientOp<Context>::ScaleRunWithType() {
void AffineGradientOp<Context>::ScaleRunWithType() {
Output(0)->ReshapeLike(Input(0));
Output(1)->ReshapeLike(Input(1));
INIT_MULTIPLIER(sum_multiplier, sum_dim);
......@@ -119,15 +121,20 @@ void ScaleGradientOp<Context>::ScaleRunWithType() {
}
template <class Context> template <typename T>
void ScaleGradientOp<Context>::RunWithType() {
void AffineGradientOp<Context>::RunWithType() {
Output(0)->ReshapeLike(Input(0));
kernel::ScaleGrad<T, Context>(start_axis,
&Input(-1), &Input(1), Output(0));
auto* dYdata = Input(-1).template data<T, Context>();
auto* Adata = Input(1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>();
kernel::AffineGrad<T, Context>(Output(0)->count(),
outer_dim, scale_dim, inner_dim,
dYdata, Adata, dXdata);
}
template <class Context>
void ScaleGradientOp<Context>::RunOnDevice() {
void AffineGradientOp<Context>::RunOnDevice() {
start_axis = axis;
if (start_axis < 0) start_axis += (int)Input(0).ndim();
if (num_axes == -1) num_axes = (int)Input(0).ndim() - start_axis;
......@@ -151,21 +158,21 @@ void ScaleGradientOp<Context>::RunOnDevice() {
}
}
DEPLOY_CPU(ScaleGradient);
DEPLOY_CPU(AffineGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(ScaleGradient);
DEPLOY_CUDA(AffineGradient);
#endif
OPERATOR_SCHEMA(ScaleGradient).NumInputs(3).NumOutputs(3);
OPERATOR_SCHEMA(AffineGradient).NumInputs(3).NumOutputs(3);
class GetScaleGradient final : public GradientMakerBase {
class GetAffineGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetScaleGradient);
GRADIENT_MAKER_CTOR(GetAffineGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0), GI(1), GI(2)});
}
};
REGISTER_GRADIENT(Scale, GetScaleGradient);
REGISTER_GRADIENT(Affine, GetAffineGradient);
} // namespace dragon
\ No newline at end of file
......@@ -3,17 +3,18 @@
#include "utils/math_functions.h"
namespace dragon {
template <class Context> template <typename T>
template <class Context> template <typename Tx, typename Ty>
void AccuracyOp<Context>::RunWithType() {
if (OutputSize() > 1) {
math::Set<T, CPUContext>(num_classes, 0,
Output(1)->template mutable_data<T, CPUContext>());
math::Set<float, CPUContext>(num_classes, 0,
Output(1)->template mutable_data<float, CPUContext>());
}
Map<int, int> num_per_class;
Map<int, TIndex> num_per_class;
T acc = 0, count = 0;
auto* Xdata = Input(0).template data<T, CPUContext>();
auto* labels = Input(1).template data<T, CPUContext>();
TIndex acc = 0, count = 0;
auto* Xdata = Input(0).template data<Tx, CPUContext>();
auto* labels = Input(1).template data<Ty, CPUContext>();
auto* ignores = ignore_labels.count() > 0 ?
ignore_labels.data<int, CPUContext>() : nullptr;
const TIndex dim = Input(0).count() / outer_dim;
......@@ -23,14 +24,14 @@ void AccuracyOp<Context>::RunWithType() {
for (int k = 0; k < ignore_labels.count(); k++)
if (label == ignores[k]) continue;
if (OutputSize() > 1) num_per_class[label]++;
vector<pair<T, int> > vec;
vector<pair<Tx, int> > vec;
for (int k = 0; k < num_classes; k++)
vec.push_back(std::make_pair(Xdata[i * dim + k * inner_dim + j], k));
std::partial_sort(vec.begin(), vec.begin() + top_k, vec.end(), std::greater<pair<T, int> >());
std::partial_sort(vec.begin(), vec.begin() + top_k, vec.end(), std::greater<pair<Tx, int> >());
for (int k = 0; k < top_k; k++) {
if (vec[k].second == label) {
if (OutputSize() > 1)
Output(1)->template mutable_data<T, CPUContext>()[label]++;
Output(1)->template mutable_data<float, CPUContext>()[label]++;
acc++;
break;
}
......@@ -39,11 +40,12 @@ void AccuracyOp<Context>::RunWithType() {
} // end inner_dim
} // end outer_dim
Output(0)->template mutable_data<T, CPUContext>()[0] = acc / count;
Output(0)->template mutable_data<float, CPUContext>()[0] = (float)acc / count;
if (OutputSize() > 1) {
auto* acc_per_class = Output(1)->template mutable_data<T, CPUContext>();
auto* acc_per_class = Output(1)->template mutable_data<float, CPUContext>();
for (int i = 0; i < num_classes; i++)
acc_per_class[i] = num_per_class[i] == 0 ? 0 : acc_per_class[i] / acc_per_class[i];
acc_per_class[i] = num_per_class[i] == 0 ?
0 : acc_per_class[i] / num_per_class[i];
}
}
......@@ -58,8 +60,11 @@ void AccuracyOp<Context>::RunOnDevice() {
Output(0)->Reshape(vector<TIndex>(1, 1));
if (OutputSize() > 1) Output(1)->Reshape(vector<TIndex>(1, num_classes));
if (XIsType(Input(0), float)) RunWithType<float>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
if (XIsType(Input(0), float)) {
if (XIsType(Input(1), float)) RunWithType<float, float>();
else if(XIsType(Input(1), int64_t)) RunWithType<float, int64_t>();
else LOG(FATAL) << DTypeHelper(Input(1), { "float32", "int64" });
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
}
DEPLOY_CPU(Accuracy);
......
......@@ -128,6 +128,8 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
if (enable_tensor_core)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
RunWithType<float>();
} else if (XIsType(Input(0), float16)) {
......@@ -148,6 +150,8 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
if (enable_tensor_core)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
RunWithType<float16>();
#endif // WITH_CUDA_FP16
......@@ -306,6 +310,8 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
if (enable_tensor_core)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
RunWithType<float>();
} else if (XIsType(Input(0), float16)) {
......@@ -325,6 +331,8 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
if (enable_tensor_core)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
RunWithType<float16>();
#endif // WITH_CUDA_FP16
......
......@@ -130,6 +130,8 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
if (enable_tensor_core)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
RunWithType<float>();
} else if (XIsType(Input(0), float16)) {
......@@ -150,6 +152,8 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
if (enable_tensor_core)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
RunWithType<float16>();
#endif // WITH_CUDA_FP16
......@@ -310,6 +314,8 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
if (enable_tensor_core)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
RunWithType<float>();
} else if (XIsType(Input(0), float16)) {
......@@ -330,6 +336,8 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
if (enable_tensor_core)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
RunWithType<float16>();
#endif // WITH_CUDA_FP16
......
......@@ -6,6 +6,7 @@
#include "utils/omp_alternative.h"
#include "utils/sse_alternative.h"
#include "utils/math_functions.h"
#include "utils/cast.h"
bool judge(int a, int b) { return unsigned(a) < unsigned(b); }
......@@ -359,6 +360,70 @@ template<> void TanhGrad<float, CPUContext>(const int count,
}
}
/******************** arithmetic.affine ********************/
template<> void Affine<float, CPUContext>(const int count,
const int outer_dim,
const int scale_dim,
const int inner_dim,
const float* x,
const float* alpha,
const float* beta,
const float* beta_multiplier,
float* y) {
// Ax
auto* Xdata = x; auto* Ydata = y;
for (int n = 0; n < outer_dim; ++n) {
for (int d = 0; d < scale_dim; ++d) {
math::Scale<float, CPUContext>(inner_dim, alpha[d], Xdata, Ydata);
Xdata += inner_dim;
Ydata += inner_dim;
}
}
// Pb
if (beta != nullptr && beta_multiplier != nullptr) {
int dim = scale_dim * inner_dim;
Ydata = y;
for (int n = 0; n < outer_dim; ++n) {
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
scale_dim, inner_dim, 1,
1.0,
beta, beta_multiplier,
1.0,
Ydata);
Ydata += dim;
}
}
}
template<> void Affine<float16, CPUContext>(const int count,
const int outer_dim,
const int scale_dim,
const int inner_dim,
const float16* x,
const float16* alpha,
const float16* beta,
const float16* beta_multiplier,
float16* y) {
LOG(FATAL) << "float16 is unsupported for CPUContext.";
}
template <> void AffineGrad<float, CPUContext>(const int count,
const int outer_dim,
const int scale_dim,
const int inner_dim,
const float* dy,
const float* alpha,
float* dx) {
auto* dYdata = dy; auto* dXdata = dx;
for (int n = 0; n < outer_dim; ++n) {
for (int d = 0; d < scale_dim; ++d) {
math::Scale<float, CPUContext>(inner_dim, alpha[d], dYdata, dXdata);
dYdata += inner_dim; dXdata += inner_dim;
}
}
}
/******************** arithmetic.bias_add ********************/
template<> void BiasAdd<float, CPUContext>(const int count,
......@@ -408,77 +473,6 @@ template <> void Clip<float, CPUContext>(const int count,
}
}
/******************** arithmetic.scale ********************/
template<> void Scale<float, CPUContext>(const int axis,
Tensor* x,
Tensor* gamma,
Tensor* beta,
Tensor* BMul,
Tensor* y) {
int outer_dim = x->count(0, axis);
int inner_dim = x->count(axis + gamma->ndim());
int scale_dim = gamma->count();
auto* Xdata = x->data<float, CPUContext>();
auto* Ydata = y->mutable_data<float, CPUContext>();
auto* Sdata = gamma->data<float, CPUContext>();
auto* Bdata = beta != nullptr ?
beta->data<float, CPUContext>() :
nullptr;
auto* BMul_data = BMul != nullptr ?
BMul->data<float, CPUContext>() :
nullptr;
for (int n = 0; n < outer_dim; ++n) {
for (int d = 0; d < scale_dim; ++d) {
const float factor = Sdata[d];
math::Scale<float, CPUContext>(inner_dim, factor, Xdata, Ydata);
Xdata += inner_dim;
Ydata += inner_dim;
}
}
if (Bdata != nullptr) {
int dim = scale_dim * inner_dim;
Ydata = y->mutable_data<float, CPUContext>();
for (int n = 0; n < outer_dim; ++n) {
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
scale_dim, inner_dim, 1,
1.0,
Bdata, BMul_data,
1.0,
Ydata);
Ydata += dim;
}
}
}
template<> void Scale<float16, CPUContext>(const int axis,
Tensor* x,
Tensor* gamma,
Tensor* beta,
Tensor* BMul,
Tensor* y) {
LOG(FATAL) << "float16 is unsupported for CPUContext.";
}
template <> void ScaleGrad<float, CPUContext>(const int axis,
Tensor* dy,
Tensor* gamma,
Tensor* dx) {
int outer_dim = dx->count(0, axis);
int inner_dim = dx->count(axis + gamma->ndim());
int scale_dim = gamma->count();
auto* dYdata = dy->data<float, CPUContext>();
auto* dXdata = dx->mutable_data<float, CPUContext>();
auto* Sdata = gamma->data<float, CPUContext>();
for (int n = 0; n < outer_dim; ++n) {
for (int d = 0; d < scale_dim; ++d) {
const float factor = Sdata[d];
math::Scale<float, CPUContext>(inner_dim, factor, dYdata, dXdata);
dYdata += inner_dim; dXdata += inner_dim;
}
}
}
/******************** control_flow.compare ********************/
template <> void Equal<float, CPUContext>(const int count,
......@@ -809,7 +803,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CPUContext>(const int count,
}
}
/******************** misc.dtype ********************/
/******************** misc.astype ********************/
template <typename Ta, typename Tb>
void _TypeA2B(const int count, const Ta* a, Tb* b) {
......@@ -819,6 +813,14 @@ void _TypeA2B(const int count, const Ta* a, Tb* b) {
for (int i = 0; i < count; ++i) b[i] = a[i];
}
template <typename Ta, typename Tb>
void _TypeA2B_v2(const int count, const Ta* a, Tb* b) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) b[i] = dragon_cast<Tb, Ta>(a[i]);
}
#define DEFINE_TYPE_A2B(type_a, type_b) \
template <> void TypeA2B<type_a, type_b, CPUContext>(const int count, \
const type_a* a, \
......@@ -826,6 +828,13 @@ void _TypeA2B(const int count, const Ta* a, Tb* b) {
_TypeA2B<type_a, type_b>(count, a, b); \
}
#define DEFINE_TYPE_A2B_V2(type_a, type_b) \
template <> void TypeA2B<type_a, type_b, CPUContext>(const int count, \
const type_a* a, \
type_b* b) { \
_TypeA2B_v2<type_a, type_b>(count, a, b); \
}
#define DEFINE_TYPE_DISABLE_FP16(type) \
template <> void TypeA2B<float16, type, CPUContext>(const int count, \
const float16* a, \
......@@ -845,13 +854,10 @@ void _TypeA2B(const int count, const Ta* a, Tb* b) {
DEFINE_TYPE_A2B(type_a, int64_t); \
DEFINE_TYPE_A2B(type_a, uint8_t);
template <> void TypeA2B<float16, float16, CPUContext>(const int count,
const float16* a,
float16* b) {
LOG(FATAL) << "float16 is unsupported for CPUContext.";
}
DEFINE_TYPE_A2ALL(float); DEFINE_TYPE_DISABLE_FP16(float);
DEFINE_TYPE_A2B_V2(float16, float);
DEFINE_TYPE_A2B_V2(float, float16);
DEFINE_TYPE_A2B_V2(float16, float16);
DEFINE_TYPE_A2ALL(float);
DEFINE_TYPE_A2ALL(double); DEFINE_TYPE_DISABLE_FP16(double);
DEFINE_TYPE_A2ALL(int); DEFINE_TYPE_DISABLE_FP16(int);
DEFINE_TYPE_A2ALL(int64_t); DEFINE_TYPE_DISABLE_FP16(int64_t);
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!