Commit fdf26ef2 by Ting PAN

Use local workspace for Context

Summary:
This commit uses local(thread or stream) workspace for Context,
which provides a more elegant way to dispatch kernels requiring scratch.
Besides, TF32 math type is provided as a cuDNN option for Ampere device.
1 parent 1dd8aeef
Showing with 1813 additions and 1654 deletions
...@@ -9,7 +9,7 @@ dragon/core ...@@ -9,7 +9,7 @@ dragon/core
`class CPUContext <core/CPUContext.html>`_ `class CPUContext <core/CPUContext.html>`_
: The cpu device context. : The cpu device context.
`class CUDAContext <core/CPUContext.html>`_ `class CUDAContext <core/CUDAContext.html>`_
: The cuda device context. : The cuda device context.
`class Graph <core/Graph.html>`_ `class Graph <core/Graph.html>`_
......
...@@ -69,6 +69,10 @@ stream ...@@ -69,6 +69,10 @@ stream
###### ######
.. doxygenfunction:: dragon::CPUContext::stream .. doxygenfunction:: dragon::CPUContext::stream
workspace
#########
.. doxygenfunction:: dragon::CPUContext::workspace
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -97,6 +97,14 @@ stream ...@@ -97,6 +97,14 @@ stream
###### ######
.. doxygenfunction:: dragon::CUDAContext::stream .. doxygenfunction:: dragon::CUDAContext::stream
workspace
#########
.. doxygenfunction:: dragon::CUDAContext::workspace()
workspace
#########
.. doxygenfunction:: dragon::CUDAContext::workspace(int device, int stream)
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -43,9 +43,9 @@ phase ...@@ -43,9 +43,9 @@ phase
##### #####
.. doxygenfunction:: dragon::Graph::phase .. doxygenfunction:: dragon::Graph::phase
ws workspace
## #########
.. doxygenfunction:: dragon::Graph::ws .. doxygenfunction:: dragon::Graph::workspace
.. raw:: html .. raw:: html
......
...@@ -95,9 +95,9 @@ phase ...@@ -95,9 +95,9 @@ phase
##### #####
.. doxygenfunction:: dragon::Operator::phase .. doxygenfunction:: dragon::Operator::phase
ws workspace
## #########
.. doxygenfunction:: dragon::Operator::ws .. doxygenfunction:: dragon::Operator::workspace
.. raw:: html .. raw:: html
......
...@@ -30,6 +30,9 @@ dragon ...@@ -30,6 +30,9 @@ dragon
`cast(...) <dragon/cast.html>`_ `cast(...) <dragon/cast.html>`_
: Cast the data type of input. : Cast the data type of input.
`channel_affine(...) <dragon/channel_affine.html>`_
: Apply affine transformation along the channels.
`channel_normalize(...) <dragon/channel_normalize.html>`_ `channel_normalize(...) <dragon/channel_normalize.html>`_
: Normalize channels with mean and standard deviation. : Normalize channels with mean and standard deviation.
...@@ -171,6 +174,7 @@ dragon ...@@ -171,6 +174,7 @@ dragon
dragon/assign dragon/assign
dragon/broadcast_to dragon/broadcast_to
dragon/cast dragon/cast
dragon/channel_affine
dragon/channel_normalize dragon/channel_normalize
dragon/channel_shuffle dragon/channel_shuffle
dragon/concat dragon/concat
......
affine channel_affine
====== ==============
.. autofunction:: dragon.math.affine .. autofunction:: dragon.channel_affine
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.math."; content: "dragon.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -12,9 +12,6 @@ dragon.math ...@@ -12,9 +12,6 @@ dragon.math
`add(...) <math/add.html>`_ `add(...) <math/add.html>`_
: Compute the element-wise addition. : Compute the element-wise addition.
`affine(...) <math/affine.html>`_
: Compute the affine transformation along the given axes.
`argmax(...) <math/argmax.html>`_ `argmax(...) <math/argmax.html>`_
: Compute the index of maximum elements along the given axis. : Compute the index of maximum elements along the given axis.
...@@ -149,7 +146,6 @@ dragon.math ...@@ -149,7 +146,6 @@ dragon.math
math/abs math/abs
math/add math/add
math/affine
math/argmax math/argmax
math/argmin math/argmin
math/axpby math/axpby
......
...@@ -60,6 +60,9 @@ vm.torch ...@@ -60,6 +60,9 @@ vm.torch
`ceil(...) <torch/ceil.html>`_ `ceil(...) <torch/ceil.html>`_
: Compute the smallest integer not less than input. : Compute the smallest integer not less than input.
`channel_affine(...) <torch/channel_affine.html>`_
: Apply affine transformation along the channels.
`channel_normalize(...) <torch/channel_normalize.html>`_ `channel_normalize(...) <torch/channel_normalize.html>`_
: Normalize channels with mean and standard deviation. : Normalize channels with mean and standard deviation.
...@@ -263,6 +266,7 @@ vm.torch ...@@ -263,6 +266,7 @@ vm.torch
torch/bitwise_xor torch/bitwise_xor
torch/cat torch/cat
torch/ceil torch/ceil
torch/channel_affine
torch/channel_normalize torch/channel_normalize
torch/channel_shuffle torch/channel_shuffle
torch/chunk torch/chunk
......
affine channel_affine
====== ==============
.. autofunction:: dragon.vm.torch.nn.functional.affine .. autofunction:: dragon.vm.torch.channel_affine
.. _torch.nn.Affine(...): ../Affine.html
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "torch.nn.functional."; content: "torch.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -6,8 +6,8 @@ vm.torch.nn ...@@ -6,8 +6,8 @@ vm.torch.nn
Classes Classes
------- -------
`class Affine <nn/Affine.html>`_ `class AffineChannel <nn/AffineChannel.html>`_
: Apply the affine transformation over input. : Apply affine transformation along the channels.
`class AvgPool2d <nn/AvgPool2d.html>`_ `class AvgPool2d <nn/AvgPool2d.html>`_
: Apply the 2d average pooling. : Apply the 2d average pooling.
...@@ -197,7 +197,7 @@ vm.torch.nn ...@@ -197,7 +197,7 @@ vm.torch.nn
.. toctree:: .. toctree::
:hidden: :hidden:
nn/Affine nn/AffineChannel
nn/AvgPool2d nn/AvgPool2d
nn/BatchNorm1d nn/BatchNorm1d
nn/BatchNorm2d nn/BatchNorm2d
......
Affine AffineChannel
====== =============
.. autoclass:: dragon.vm.torch.nn.Affine .. autoclass:: dragon.vm.torch.nn.AffineChannel
__init__ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.Affine.__init__ .. automethod:: dragon.vm.torch.nn.AffineChannel.__init__
.. _torch.nn.functional.affine(...): functional/affine.html .. _torch.channel_affine(...): ../channel_affine.html
.. raw:: html .. raw:: html
......
...@@ -6,9 +6,6 @@ vm.torch.nn.functional ...@@ -6,9 +6,6 @@ vm.torch.nn.functional
Functions Functions
--------- ---------
`affine(...) <functional/affine.html>`_
: Apply the affine transformation to input.
`avg_pool2d(...) <functional/avg_pool2d.html>`_ `avg_pool2d(...) <functional/avg_pool2d.html>`_
: Apply the 2d average pooling to input. : Apply the 2d average pooling to input.
...@@ -132,7 +129,6 @@ vm.torch.nn.functional ...@@ -132,7 +129,6 @@ vm.torch.nn.functional
.. toctree:: .. toctree::
:hidden: :hidden:
functional/affine
functional/avg_pool2d functional/avg_pool2d
functional/batch_norm functional/batch_norm
functional/binary_cross_entropy_with_logits functional/binary_cross_entropy_with_logits
......
#include "context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/core/workspace.h"
namespace dragon { namespace dragon {
Workspace* CPUContext::workspace() {
static thread_local Workspace workspace("");
return &workspace;
}
#ifdef USE_CUDA #ifdef USE_CUDA
CUDAObjects::~CUDAObjects() {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
#ifdef USE_NCCL
for (auto& comm_iter : nccl_comms_[i]) {
if (comm_iter.second) {
NCCL_CHECK(ncclCommDestroy(comm_iter.second));
}
}
#endif
#ifdef USE_CUDNN
for (auto& handle : cudnn_handles_[i]) {
/*!
* Temporarily disable the handle destroying,
* to avoid the segmentation fault in CUDNN v8.
*
* if (handle) CUDNN_CHECK(cudnnDestroy(handle));
*/
}
#endif
for (auto& handle : cublas_handles_[i]) {
if (handle) CUBLAS_CHECK(cublasDestroy(handle));
}
for (int j = 0; j < cuda_streams_[i].size(); j++) {
auto& stream = cuda_streams_[i][j];
/*!
* Do not check the stream destroying,
* error code 29 (driver shutting down) is inevitable.
*/
if (stream) cudaStreamDestroy(stream);
}
for (auto& workspace : cuda_workspaces_[i]) {
if (workspace) delete workspace;
}
}
}
Workspace* CUDAObjects::workspace(int device_id, int stream_id) {
auto& workspaces = cuda_workspaces_[device_id];
if (workspaces.size() <= (unsigned)stream_id) {
workspaces.resize(stream_id + 1, nullptr);
}
if (!workspaces[stream_id]) {
workspaces[stream_id] = new Workspace("");
}
return workspaces[stream_id];
}
std::mutex& CUDAContext::mutex() { std::mutex& CUDAContext::mutex() {
static std::mutex m; static std::mutex m;
return m; return m;
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
namespace dragon { namespace dragon {
class Workspace;
/*! /*!
* \brief The cpu device context. * \brief The cpu device context.
*/ */
...@@ -94,6 +96,9 @@ class DRAGON_API CPUContext { ...@@ -94,6 +96,9 @@ class DRAGON_API CPUContext {
/*! \brief Wait for the dispatched computation to complete */ /*! \brief Wait for the dispatched computation to complete */
void FinishDeviceComputation() {} void FinishDeviceComputation() {}
/*! \brief Return the current workspace */
Workspace* workspace();
/*! \brief Return the device index */ /*! \brief Return the device index */
int device() const { int device() const {
return 0; return 0;
......
...@@ -22,12 +22,15 @@ namespace dragon { ...@@ -22,12 +22,15 @@ namespace dragon {
#ifdef USE_CUDA #ifdef USE_CUDA
class Workspace;
class CUDAObjects { class CUDAObjects {
public: public:
/*! \brief Default Constructor */ /*! \brief Default Constructor */
CUDAObjects() { CUDAObjects() {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) { for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
cuda_streams_[i] = vector<cudaStream_t>(); cuda_streams_[i] = vector<cudaStream_t>();
cuda_workspaces_[i] = vector<Workspace*>();
cublas_handles_[i] = vector<cublasHandle_t>(); cublas_handles_[i] = vector<cublasHandle_t>();
#ifdef USE_CUDNN #ifdef USE_CUDNN
cudnn_handles_[i] = vector<cudnnHandle_t>(); cudnn_handles_[i] = vector<cudnnHandle_t>();
...@@ -39,38 +42,7 @@ class CUDAObjects { ...@@ -39,38 +42,7 @@ class CUDAObjects {
} }
/*! \brief Destructor */ /*! \brief Destructor */
~CUDAObjects() { ~CUDAObjects();
for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
#ifdef USE_NCCL
for (auto& comm_iter : nccl_comms_[i]) {
if (comm_iter.second) {
NCCL_CHECK(ncclCommDestroy(comm_iter.second));
}
}
#endif
#ifdef USE_CUDNN
for (auto& handle : cudnn_handles_[i]) {
/*!
* Temporarily disable the handle destroying,
* to avoid the segmentation fault in CUDNN v8.
*
* if (handle) CUDNN_CHECK(cudnnDestroy(handle));
*/
}
#endif
for (auto& handle : cublas_handles_[i]) {
if (handle) CUBLAS_CHECK(cublasDestroy(handle));
}
for (int j = 0; j < cuda_streams_[i].size(); j++) {
auto& stream = cuda_streams_[i][j];
/*!
* Do not check the stream destroying,
* error code 29 (driver shutting down) is inevitable.
*/
if (stream) cudaStreamDestroy(stream);
}
}
}
/*! \brief Return the specified cublas handle */ /*! \brief Return the specified cublas handle */
cublasHandle_t cublas_handle(int device_id, int stream_id) { cublasHandle_t cublas_handle(int device_id, int stream_id) {
...@@ -142,8 +114,9 @@ class CUDAObjects { ...@@ -142,8 +114,9 @@ class CUDAObjects {
/*! \brief Return the specified cuda stream */ /*! \brief Return the specified cuda stream */
cudaStream_t stream(int device_id, int stream_id) { cudaStream_t stream(int device_id, int stream_id) {
auto& streams = cuda_streams_[device_id]; auto& streams = cuda_streams_[device_id];
if (streams.size() <= (unsigned)stream_id) if (streams.size() <= (unsigned)stream_id) {
streams.resize(stream_id + 1, nullptr); streams.resize(stream_id + 1, nullptr);
}
if (!streams[stream_id]) { if (!streams[stream_id]) {
CUDADeviceGuard guard(device_id); CUDADeviceGuard guard(device_id);
unsigned int flags = unsigned int flags =
...@@ -153,19 +126,37 @@ class CUDAObjects { ...@@ -153,19 +126,37 @@ class CUDAObjects {
return streams[stream_id]; return streams[stream_id];
} }
/*! \brief Return the workspace for specified cuda stream */
Workspace* workspace(int device_id, int stream_id);
/*! \brief The cached CUDA streams of each device */
vector<cudaStream_t> cuda_streams_[CUDA_MAX_DEVICES]; vector<cudaStream_t> cuda_streams_[CUDA_MAX_DEVICES];
/*! \brief The cached CUDA workspaces of each device */
vector<Workspace*> cuda_workspaces_[CUDA_MAX_DEVICES];
/*! \brief The cached cuBLAS handles of each device */
vector<cublasHandle_t> cublas_handles_[CUDA_MAX_DEVICES]; vector<cublasHandle_t> cublas_handles_[CUDA_MAX_DEVICES];
#ifdef USE_CUDNN #ifdef USE_CUDNN
/*! \brief The cached cuDNN handles of each device */
vector<cudnnHandle_t> cudnn_handles_[CUDA_MAX_DEVICES]; vector<cudnnHandle_t> cudnn_handles_[CUDA_MAX_DEVICES];
#endif #endif
#ifdef USE_NCCL #ifdef USE_NCCL
/*! \brief The cached NCCL comms of each device */
Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES]; Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES];
#endif #endif
/*! \brief The flag that alllows cuDNN or not */
bool cudnn_enabled_ = true; bool cudnn_enabled_ = true;
/*! \brief The flag that allows cuDNN benchmark or not */
bool cudnn_benchmark_ = false; bool cudnn_benchmark_ = false;
/*! \brief The flag thats allow cuDNN TF32 math type or not */
bool cudnn_allow_tf32_ = false;
private: private:
DISABLE_COPY_AND_ASSIGN(CUDAObjects); DISABLE_COPY_AND_ASSIGN(CUDAObjects);
}; };
...@@ -190,11 +181,19 @@ class DRAGON_API CUDAContext { ...@@ -190,11 +181,19 @@ class DRAGON_API CUDAContext {
CHECK_EQ(option.device_type(), PROTO_CUDA); CHECK_EQ(option.device_type(), PROTO_CUDA);
} }
/*! \brief Allocate a block of memory */ /*! \brief Allocate a block of device memory */
static void* New(size_t size) { static void* New(size_t size) {
void* data; void* data;
cudaMalloc(&data, size); cudaMalloc(&data, size);
CHECK(data) << "\nAllocate cuda memory with " << size << " bytes failed."; CHECK(data) << "\nAllocate device memory with " << size << " bytes failed.";
return data;
}
/*! \brief Allocate a block of host memory */
static void* NewHost(size_t size) {
void* data;
cudaMallocHost(&data, size);
CHECK(data) << "\nAllocate host memory with " << size << " bytes failed.";
return data; return data;
} }
...@@ -237,11 +236,16 @@ class DRAGON_API CUDAContext { ...@@ -237,11 +236,16 @@ class DRAGON_API CUDAContext {
CHECK_EQ(err, cudaSuccess) << "\nCUDA Error: " << cudaGetErrorString(err); CHECK_EQ(err, cudaSuccess) << "\nCUDA Error: " << cudaGetErrorString(err);
} }
/*! \brief Deallocate a memory block */ /*! \brief Deallocate a device memory block */
static void Delete(void* ptr) { static void Delete(void* ptr) {
cudaFree(ptr); cudaFree(ptr);
} }
/*! \brief Deallocate a host memory block */
static void DeleteHost(void* ptr) {
cudaFreeHost(ptr);
}
/*! \brief Switch to the device in current thread */ /*! \brief Switch to the device in current thread */
void SwitchToDevice() { void SwitchToDevice() {
SwitchToDevice(0); SwitchToDevice(0);
...@@ -265,9 +269,19 @@ class DRAGON_API CUDAContext { ...@@ -265,9 +269,19 @@ class DRAGON_API CUDAContext {
SynchronizeStream(cuda_stream()); SynchronizeStream(cuda_stream());
} }
/*! \brief Return the cuda stream */ /*! \brief Return the current workspace */
Workspace* workspace() {
return objects().workspace(device_id_, stream_id_);
}
/*! \brief Return the specified workspace */
Workspace* workspace(int device, int stream) {
return objects().workspace(device, stream);
}
/*! \brief Return the current cuda stream */
cudaStream_t cuda_stream() { cudaStream_t cuda_stream() {
return cuda_stream(device_id_, stream_id_); return objects().stream(device_id_, stream_id_);
} }
/*! \brief Return the specified cuda stream */ /*! \brief Return the specified cuda stream */
...@@ -359,12 +373,18 @@ class DRAGON_API CUDAContext { ...@@ -359,12 +373,18 @@ class DRAGON_API CUDAContext {
CUDA_NOT_COMPILED; CUDA_NOT_COMPILED;
} }
/*! \brief Allocate a block of memory */ /*! \brief Allocate a block of device memory */
static void* New(size_t nbytes) { static void* New(size_t nbytes) {
CUDA_NOT_COMPILED; CUDA_NOT_COMPILED;
return nullptr; return nullptr;
} }
/*! \brief Allocate a block of host memory */
static void* NewHost(size_t nbytes) {
CUDA_NOT_COMPILED;
return nullptr;
}
/*! \brief Set a memory block to the given value */ /*! \brief Set a memory block to the given value */
static void Memset(size_t nbytes, void* ptr, int value = 0) { static void Memset(size_t nbytes, void* ptr, int value = 0) {
CUDA_NOT_COMPILED; CUDA_NOT_COMPILED;
...@@ -387,11 +407,16 @@ class DRAGON_API CUDAContext { ...@@ -387,11 +407,16 @@ class DRAGON_API CUDAContext {
CUDA_NOT_COMPILED; CUDA_NOT_COMPILED;
} }
/*! \brief Deallocate a memory block */ /*! \brief Deallocate a device memory block */
static void Delete(void* ptr) { static void Delete(void* ptr) {
CUDA_NOT_COMPILED; CUDA_NOT_COMPILED;
} }
/*! \brief Deallocate a host memory block */
static void DeleteHost(void* ptr) {
CUDA_NOT_COMPILED;
}
/*! \brief Copy the memory asynchronously */ /*! \brief Copy the memory asynchronously */
template <class DestContext, class SrcContext> template <class DestContext, class SrcContext>
void MemcpyAsync(size_t nbytes, void* dest, const void* src) { void MemcpyAsync(size_t nbytes, void* dest, const void* src) {
......
...@@ -69,7 +69,7 @@ class DRAGON_API GraphBase { ...@@ -69,7 +69,7 @@ class DRAGON_API GraphBase {
} }
/*! \brief Return the parent workspace */ /*! \brief Return the parent workspace */
Workspace* ws() const { Workspace* workspace() const {
return ws_; return ws_;
} }
......
...@@ -147,7 +147,7 @@ class DRAGON_API UnifiedMemory { ...@@ -147,7 +147,7 @@ class DRAGON_API UnifiedMemory {
/*! \brief Set to use an external block of cpu data */ /*! \brief Set to use an external block of cpu data */
void set_cpu_data(void* cpu_ptr, size_t size); void set_cpu_data(void* cpu_ptr, size_t size);
/*! \brief Set to use an extenral block of cuda data */ /*! \brief Set to use an external block of cuda data */
void set_cuda_data(void* cuda_ptr, size_t size, int device); void set_cuda_data(void* cuda_ptr, size_t size, int device);
private: private:
......
...@@ -71,7 +71,7 @@ Tensor* OperatorBase::Output(int i, const vec32_t& inputs) { ...@@ -71,7 +71,7 @@ Tensor* OperatorBase::Output(int i, const vec32_t& inputs) {
} }
Tensor* OperatorBase::Buffer(const string& name) { Tensor* OperatorBase::Buffer(const string& name) {
return ws()->CreateTensor("/share/buffer/" + handle_ + "/" + name); return workspace()->CreateTensor("/share/buffer/" + handle_ + "/" + name);
} }
string OperatorBase::MessageForUnsupported( string OperatorBase::MessageForUnsupported(
...@@ -94,10 +94,10 @@ OperatorBase* OperatorBase::UpdateFrom(const OperatorDef& def) { ...@@ -94,10 +94,10 @@ OperatorBase* OperatorBase::UpdateFrom(const OperatorDef& def) {
inputs_.resize(def.input_size()); inputs_.resize(def.input_size());
outputs_.resize(def.output_size()); outputs_.resize(def.output_size());
for (int i = 0; i < inputs_.size(); i++) { for (int i = 0; i < inputs_.size(); i++) {
inputs_[i] = ws()->GetTensor(def.input(i)); inputs_[i] = workspace()->GetTensor(def.input(i));
} }
for (int i = 0; i < outputs_.size(); i++) { for (int i = 0; i < outputs_.size(); i++) {
outputs_[i] = ws()->CreateTensor(def.output(i)); outputs_[i] = workspace()->CreateTensor(def.output(i));
} }
return this; return this;
} }
...@@ -113,7 +113,7 @@ void Operator<Context>::Prepare() { ...@@ -113,7 +113,7 @@ void Operator<Context>::Prepare() {
LOG(DEBUG) << "Excepted version of Tensor(" + Input(i).name() + ") " LOG(DEBUG) << "Excepted version of Tensor(" + Input(i).name() + ") "
<< "is " << version << ", got " << Input(i).version() << "is " << version << ", got " << Input(i).version()
<< ". Recompute."; << ". Recompute.";
Tensor* flag = ws()->GetTensor("/share/flag/recomputing"); Tensor* flag = workspace()->GetTensor("/share/flag/recomputing");
flag->mutable_data<bool, CPUContext>()[0] = true; flag->mutable_data<bool, CPUContext>()[0] = true;
vector<OperatorBase*>& chain = subgraph()[name]; vector<OperatorBase*>& chain = subgraph()[name];
for (auto* op : chain) { for (auto* op : chain) {
......
...@@ -139,7 +139,7 @@ class DRAGON_API OperatorBase { ...@@ -139,7 +139,7 @@ class DRAGON_API OperatorBase {
} }
/*! \brief Return the parent workspace */ /*! \brief Return the parent workspace */
Workspace* ws() const { Workspace* workspace() const {
return ws_; return ws_;
} }
...@@ -219,7 +219,7 @@ class DRAGON_API Operator : public OperatorBase { ...@@ -219,7 +219,7 @@ class DRAGON_API Operator : public OperatorBase {
ctx()->SwitchToDevice(stream); ctx()->SwitchToDevice(stream);
SwitchToDevice(); SwitchToDevice();
RunOnDevice(); RunOnDevice();
if (do_sync_ || stream > 0) { if (do_sync_) {
ctx()->FinishDeviceComputation(); ctx()->FinishDeviceComputation();
} }
Release(); Release();
...@@ -262,7 +262,7 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*); ...@@ -262,7 +262,7 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
using OperatorBase::data_format; \ using OperatorBase::data_format; \
using OperatorBase::handle; \ using OperatorBase::handle; \
using OperatorBase::def; \ using OperatorBase::def; \
using OperatorBase::ws using OperatorBase::workspace
#define USE_OPERATOR_FUNCTIONS \ #define USE_OPERATOR_FUNCTIONS \
USE_OPERATOR_BASE_FUNCTIONS; \ USE_OPERATOR_BASE_FUNCTIONS; \
...@@ -274,7 +274,7 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*); ...@@ -274,7 +274,7 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
->set_meta(Input(i).meta())) ->set_meta(Input(i).meta()))
#define RESTORE_INPUT_SPEC(i) \ #define RESTORE_INPUT_SPEC(i) \
*(ws()->GetTensor( \ *(workspace()->GetTensor( \
"/share/buffer/" + handle() + "/X_spec:" + std::to_string(i))) "/share/buffer/" + handle() + "/X_spec:" + std::to_string(i)))
/* Dispatchers */ /* Dispatchers */
...@@ -341,7 +341,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType); ...@@ -341,7 +341,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
#define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \ #define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \
if (tensor.count() == 0) { \ if (tensor.count() == 0) { \
auto* filler_info = ws()->GetFillerInfo(tensor.name()); \ auto* filler_info = workspace()->GetFillerInfo(tensor.name()); \
CHECK(filler_info) << "\nTensor(" << tensor.name() << ") is empty.\n" \ CHECK(filler_info) << "\nTensor(" << tensor.name() << ") is empty.\n" \
<< "May be specify a filler for it?"; \ << "May be specify a filler for it?"; \
tensor.Reshape(shape); \ tensor.Reshape(shape); \
...@@ -362,7 +362,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType); ...@@ -362,7 +362,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
#define TENSOR_FILL(tensor, shape) \ #define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \ if (tensor.count() == 0) { \
auto* filler_info = ws()->GetFillerInfo(tensor.name()); \ auto* filler_info = workspace()->GetFillerInfo(tensor.name()); \
CHECK(filler_info) << "\nTensor(" << tensor.name() << ") is empty.\n" \ CHECK(filler_info) << "\nTensor(" << tensor.name() << ") is empty.\n" \
<< "May be specify a filler for it?"; \ << "May be specify a filler for it?"; \
tensor.Reshape(shape); \ tensor.Reshape(shape); \
...@@ -413,7 +413,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType); ...@@ -413,7 +413,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
template <class Context> \ template <class Context> \
type classname<Context>::arg() { \ type classname<Context>::arg() { \
if (arg##_desc_.empty()) return arg##_; \ if (arg##_desc_.empty()) return arg##_; \
auto* arg##_tensor = ws()->GetTensor( \ auto* arg##_tensor = workspace()->GetTensor( \
str::replace_first(arg##_desc_, "${HANDLE}", handle())); \ str::replace_first(arg##_desc_, "${HANDLE}", handle())); \
CHECK_EQ(arg##_tensor->count(), 1) \ CHECK_EQ(arg##_tensor->count(), 1) \
<< "\nThe argument <" << #arg << "> should be a scalar."; \ << "\nThe argument <" << #arg << "> should be a scalar."; \
...@@ -423,35 +423,35 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType); ...@@ -423,35 +423,35 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
return arg##_tensor->template data<type, CPUContext>()[0]; \ return arg##_tensor->template data<type, CPUContext>()[0]; \
} }
#define DEFINE_OP_REPEATED_ARG_WITH_DESC(type, classname, arg) \ #define DEFINE_OP_REPEATED_ARG_WITH_DESC(type, classname, arg) \
template <class Context> \ template <class Context> \
type classname<Context>::arg(int i, int* num) { \ type classname<Context>::arg(int i, int* num) { \
const type* data; \ const type* data; \
string desc; \ string desc; \
if (!arg##_desc_.empty()) { \ if (!arg##_desc_.empty()) { \
desc = arg##_desc_; \ desc = arg##_desc_; \
} else if (!arg##_descs_.empty()) { \ } else if (!arg##_descs_.empty()) { \
desc = arg##_descs_[i]; \ desc = arg##_descs_[i]; \
} \ } \
if (!desc.empty()) { \ if (!desc.empty()) { \
auto* arg##_tensor = \ auto* arg##_tensor = workspace()->GetTensor( \
ws()->GetTensor(str::replace_first(desc, "${HANDLE}", handle())); \ str::replace_first(desc, "${HANDLE}", handle())); \
CHECK(arg##_tensor->template IsType<type>()) \ CHECK(arg##_tensor->template IsType<type>()) \
<< "\nThe type of argument <" << #arg << "> should be " \ << "\nThe type of argument <" << #arg << "> should be " \
<< types::to_string<type>() << "."; \ << types::to_string<type>() << "."; \
data = arg##_tensor->template data<type, CPUContext>(); \ data = arg##_tensor->template data<type, CPUContext>(); \
if (num != nullptr) { \ if (num != nullptr) { \
*num = arg##_desc_.empty() ? (int)arg##_descs_.size() \ *num = arg##_desc_.empty() ? (int)arg##_descs_.size() \
: (int)arg##_tensor->size(); \ : (int)arg##_tensor->size(); \
} \ } \
} else { \ } else { \
data = arg##_.data(); \ data = arg##_.data(); \
if (num != nullptr) { \ if (num != nullptr) { \
*num = (int)arg##_.size(); \ *num = (int)arg##_.size(); \
} \ } \
} \ } \
if (num != nullptr && (*num) == 0) return type(0); \ if (num != nullptr && (*num) == 0) return type(0); \
return arg##_descs_.empty() ? data[i] : data[0]; \ return arg##_descs_.empty() ? data[i] : data[0]; \
} }
#define CANONICALIZE_AXIS_WITH_TENSOR_AND_OFFSET(tensor, offset) \ #define CANONICALIZE_AXIS_WITH_TENSOR_AND_OFFSET(tensor, offset) \
......
...@@ -89,9 +89,9 @@ class DRAGON_API Workspace { ...@@ -89,9 +89,9 @@ class DRAGON_API Workspace {
template <class Context> template <class Context>
vector<void*> data(const vector<size_t>& segments) { vector<void*> data(const vector<size_t>& segments) {
vector<void*> group(segments.size()); vector<void*> group(segments.size());
auto total_bytes = std::accumulate(segments.begin(), segments.end(), 0);
group[0] = CreateTensor("/share/data") group[0] = CreateTensor("/share/data")
->Reshape({(int64_t)total_bytes}) ->Reshape({(int64_t)std::accumulate(
segments.begin(), segments.end(), size_t(0))})
->template mutable_data<uint8_t, Context>(); ->template mutable_data<uint8_t, Context>();
for (int i = 1; i < segments.size(); ++i) { for (int i = 1; i < segments.size(); ++i) {
group[i] = (uint8_t*)group[i - 1] + segments[i - 1]; group[i] = (uint8_t*)group[i - 1] + segments[i - 1];
......
...@@ -8,7 +8,7 @@ namespace kernel { ...@@ -8,7 +8,7 @@ namespace kernel {
namespace { namespace {
template <typename T> template <typename T>
void _Affine( void _ChannelAffine(
const int outer_dim, const int outer_dim,
const int axis_dim, const int axis_dim,
const T* x, const T* x,
...@@ -29,7 +29,7 @@ void _Affine( ...@@ -29,7 +29,7 @@ void _Affine(
} }
template <typename T> template <typename T>
void _Affine( void _ChannelAffine(
const int outer_dim, const int outer_dim,
const int axis_dim, const int axis_dim,
const int inner_dim, const int inner_dim,
...@@ -57,7 +57,7 @@ void _Affine( ...@@ -57,7 +57,7 @@ void _Affine(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> template <>
void Affine<float16, CPUContext>( void ChannelAffine<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int axis_dim, const int axis_dim,
const int inner_dim, const int inner_dim,
...@@ -69,22 +69,22 @@ void Affine<float16, CPUContext>( ...@@ -69,22 +69,22 @@ void Affine<float16, CPUContext>(
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void Affine<T, CPUContext>( \ void ChannelAffine<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \ const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const T* x, \ const T* x, \
const T* w, \ const T* w, \
const T* b, \ const T* b, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
if (inner_dim == 1) { \ if (inner_dim == 1) { \
_Affine(outer_dim, axis_dim, x, w, b, y); \ _ChannelAffine(outer_dim, axis_dim, x, w, b, y); \
} else { \ } else { \
_Affine(outer_dim, axis_dim, inner_dim, x, w, b, y); \ _ChannelAffine(outer_dim, axis_dim, inner_dim, x, w, b, y); \
} \ } \
} }
DEFINE_KERNEL_LAUNCHER(int8_t); DEFINE_KERNEL_LAUNCHER(int8_t);
...@@ -93,7 +93,6 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -93,7 +93,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void _ChannelAffine(
const int nthreads,
const int axis_dim,
const int inner_dim,
const T* x,
const T* w,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = x[i] * __ldg(w + (i / inner_dim) % axis_dim);
#else
y[i] = x[i] * w[(i / inner_dim) % axis_dim];
#endif
}
}
template <>
__global__ void _ChannelAffine<half>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const half* x,
const half* w,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(x[i], __ldg(w + (i / inner_dim) % axis_dim));
#elif __CUDA_ARCH__ >= 350
y[i] = __float2half(
__half2float(x[i]) *
__half2float(__ldg(w + (i / inner_dim) % axis_dim)));
#else
y[i] = __float2half(
__half2float(x[i]) * __half2float(w[(i / inner_dim) % axis_dim]));
#endif
}
}
template <typename T>
__global__ void _ChannelAffine(
const int nthreads,
const int axis_dim,
const int inner_dim,
const T* x,
const T* w,
const T* b,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int wi = (i / inner_dim) % axis_dim;
#if __CUDA_ARCH__ >= 350
y[i] = x[i] * __ldg(w + wi) + __ldg(b + wi);
#else
y[i] = x[i] * w[wi] + b[wi];
#endif
}
}
template <>
__global__ void _ChannelAffine<half>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const half* x,
const half* w,
const half* b,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int wi = (i / inner_dim) % axis_dim;
#if __CUDA_ARCH__ >= 530
y[i] = __hfma(x[i], __ldg(w + wi), __ldg(b + wi));
#elif __CUDA_ARCH__ >= 350
y[i] = __float2half(fmaf(
__half2float(x[i]),
__half2float(__ldg(w + wi)),
__half2float(__ldg(b + wi))));
#else
y[i] = __float2half(
fmaf(__half2float(x[i]), __half2float(w[wi]), __half2float(b[wi])));
#endif
}
}
template <>
__global__ void _ChannelAffine<float>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const float* x,
const float* w,
const float* b,
float* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int wi = (i / inner_dim) % axis_dim;
#if __CUDA_ARCH__ >= 350
y[i] = fmaf(x[i], __ldg(w + wi), __ldg(b + wi));
#else
y[i] = fmaf(x[i], w[wi], b[wi]);
#endif
}
}
template <>
__global__ void _ChannelAffine<double>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const double* x,
const double* w,
const double* b,
double* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int wi = (i / inner_dim) % axis_dim;
#if __CUDA_ARCH__ >= 350
y[i] = fma(x[i], __ldg(w + wi), __ldg(b + wi));
#else
y[i] = fma(x[i], w[wi], b[wi]);
#endif
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void ChannelAffine<float16, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* x,
const float16* w,
const float16* b,
float16* y,
CUDAContext* ctx) {
const int nthreads = outer_dim * axis_dim * inner_dim;
if (b != nullptr) {
_ChannelAffine<<<
CUDA_BLOCKS(nthreads),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
nthreads,
axis_dim,
inner_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(w),
reinterpret_cast<const half*>(b),
reinterpret_cast<half*>(y));
} else {
_ChannelAffine<<<
CUDA_BLOCKS(nthreads),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
nthreads,
axis_dim,
inner_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(w),
reinterpret_cast<half*>(y));
}
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelAffine<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const T* x, \
const T* w, \
const T* b, \
T* y, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \
if (b != nullptr) { \
_ChannelAffine<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(nthreads, axis_dim, inner_dim, x, w, b, y); \
} else { \
_ChannelAffine<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(nthreads, axis_dim, inner_dim, x, w, y); \
} \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
...@@ -12,18 +12,12 @@ void _Flagged( ...@@ -12,18 +12,12 @@ void _Flagged(
const int count, const int count,
const uint8_t* mask, const uint8_t* mask,
IndexType* index, IndexType* index,
int* num_selected, int* num_selected) {
void* scratch, IndexType* offset_index = index;
size_t& scratch_size) { for (int i = 0; i < count; ++i) {
if (scratch_size <= 0) { if (mask[i]) *(offset_index++) = i;
scratch_size = size_t(1);
} else {
IndexType* offset_index = index;
for (int i = 0; i < count; ++i) {
if (mask[i]) *(offset_index++) = i;
}
num_selected[0] = std::distance(index, offset_index);
} }
num_selected[0] = std::distance(index, offset_index);
} }
template <typename IndexType, typename CoordType> template <typename IndexType, typename CoordType>
...@@ -45,17 +39,15 @@ void _UnravelIndex( ...@@ -45,17 +39,15 @@ void _UnravelIndex(
} // namespace } // namespace
#define DEFINE_KERNEL_LAUNCHER(IndexType) \ #define DEFINE_KERNEL_LAUNCHER(IndexType) \
template <> \ template <> \
void Flagged<IndexType, CPUContext>( \ void Flagged<IndexType, CPUContext>( \
const int count, \ const int count, \
const uint8_t* mask, \ const uint8_t* mask, \
IndexType* index, \ IndexType* index, \
int* num_selected, \ int* num_selected, \
void* scratch, \ CPUContext* ctx) { \
size_t& scratch_size, \ _Flagged(count, mask, index, num_selected); \
CPUContext* ctx) { \
_Flagged(count, mask, index, num_selected, scratch, scratch_size); \
} }
DEFINE_KERNEL_LAUNCHER(int); DEFINE_KERNEL_LAUNCHER(int);
......
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/device/common_cub.h" #include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
...@@ -31,48 +32,44 @@ __global__ void _UnravelIndex( ...@@ -31,48 +32,44 @@ __global__ void _UnravelIndex(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(IndexType) \ #define DEFINE_KERNEL_LAUNCHER(IndexType) \
template <> \ template <> \
void Flagged<IndexType, CUDAContext>( \ void Flagged<IndexType, CUDAContext>( \
const int count, \ const int count, \
const uint8_t* mask, \ const uint8_t* mask, \
IndexType* index, \ IndexType* index, \
int* num_selected, \ int* num_selected, \
void* scratch, \ CUDAContext* ctx) { \
size_t& scratch_size, \ IndexType num_selected_host; \
CUDAContext* ctx) { \ auto* num_selected_dev = index + count; \
cub::CountingInputIterator<int> itr(0); \ size_t ws_nbytes = 0; \
if (scratch_size <= 0) { \ cub::CountingInputIterator<int> itr(0); \
cub::DeviceSelect::Flagged( \ cub::DeviceSelect::Flagged( \
scratch, \ nullptr, \
scratch_size, \ ws_nbytes, \
itr, \ itr, \
mask, \ mask, \
index, \ index, \
static_cast<int64_t*>(nullptr), \ static_cast<int64_t*>(nullptr), \
count, \ count, \
ctx->cuda_stream()); \ ctx->cuda_stream()); \
} else { \ cub::DeviceSelect::Flagged( \
auto* num_selected_dev = index + count; \ ctx->workspace()->template data<CUDAContext>({ws_nbytes})[0], \
cub::DeviceSelect::Flagged( \ ws_nbytes, \
scratch, \ itr, \
scratch_size, \ mask, \
itr, \ index, \
mask, \ num_selected_dev, \
index, \ count, \
num_selected_dev, \ ctx->cuda_stream()); \
count, \ CUDA_CHECK(cudaMemcpyAsync( \
ctx->cuda_stream()); \ &num_selected_host, \
IndexType num_selected_host; \ num_selected_dev, \
CUDA_CHECK(cudaMemcpyAsync( \ sizeof(IndexType), \
&num_selected_host, \ cudaMemcpyDefault, \
num_selected_dev, \ ctx->cuda_stream())); \
sizeof(IndexType), \ ctx->FinishDeviceComputation(); \
cudaMemcpyDefault, \ num_selected[0] = num_selected_host; \
ctx->cuda_stream())); \
ctx->FinishDeviceComputation(); \
num_selected[0] = num_selected_host; \
} \
} }
DEFINE_KERNEL_LAUNCHER(int); DEFINE_KERNEL_LAUNCHER(int);
......
...@@ -23,17 +23,42 @@ void _BroadcastLossGrad( ...@@ -23,17 +23,42 @@ void _BroadcastLossGrad(
} }
} }
} // namespace
template <>
void ReduceLoss<float16, CPUContext>(
const int count,
const int num_masks,
const float normalizer,
const float16* x,
const float16* mask,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template <>
void ReduceLossGrad<float16, CPUContext>(
const int count,
const int num_masks,
const float normalizer,
const float16* dy,
const float16* mask,
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template <> template <>
void _BroadcastLossGrad<float16>( void BroadcastLossGrad<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int axis_dim, const int axis_dim,
const int inner_dim, const int inner_dim,
const float16* dy, const float16* dy,
float16* dx) { float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} // BroadcastLossGrad }
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
...@@ -42,11 +67,11 @@ void _BroadcastLossGrad<float16>( ...@@ -42,11 +67,11 @@ void _BroadcastLossGrad<float16>(
const int num_masks, \ const int num_masks, \
const float normalizer, \ const float normalizer, \
const T* x, \ const T* x, \
const int* mask, \ const T* mask, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
float inv_scale = std::max( \ float inv_scale = std::max( \
1e-5F, \ 1.f, \
num_masks > 0 && normalizer < 0.f \ num_masks > 0 && normalizer < 0.f \
? (float)math::Sum(num_masks, 1.f, mask, ctx) \ ? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \ : normalizer); \
...@@ -60,11 +85,11 @@ void _BroadcastLossGrad<float16>( ...@@ -60,11 +85,11 @@ void _BroadcastLossGrad<float16>(
const int num_masks, \ const int num_masks, \
const float normalizer, \ const float normalizer, \
const T* dy, \ const T* dy, \
const int* mask, \ const T* mask, \
T* dx, \ T* dx, \
CPUContext* ctx) { \ CPUContext* ctx) { \
float inv_scale = std::max( \ float inv_scale = std::max( \
1e-5F, \ 0.5f, \
num_masks > 0 && normalizer < 0.f \ num_masks > 0 && normalizer < 0.f \
? (float)math::Sum(num_masks, 1.f, mask, ctx) \ ? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \ : normalizer); \
...@@ -81,11 +106,9 @@ void _BroadcastLossGrad<float16>( ...@@ -81,11 +106,9 @@ void _BroadcastLossGrad<float16>(
_BroadcastLossGrad(outer_dim, axis_dim, inner_dim, dy, dx); \ _BroadcastLossGrad(outer_dim, axis_dim, inner_dim, dy, dx); \
} }
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_cub.h" #include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -12,84 +13,14 @@ namespace { ...@@ -12,84 +13,14 @@ namespace {
template <typename T> template <typename T>
__global__ void __global__ void
_ReduceLoss(const int nthreads, const T scale, const T* x, T* y) {
__shared__ typename BlockReduce<T>::TempStorage storage;
T val = T(0);
CUDA_2D_KERNEL_LOOP2(i, nthreads) {
val += x[i];
}
val = BlockReduce<T>(storage).Sum(val);
if (threadIdx.x == 0) {
y[0] = val * scale;
}
}
__global__ void
_ReduceLoss(const int nthreads, const float scale, const half* x, half* y) {
__shared__ typename BlockReduce<float>::TempStorage storage;
float val = 0.f;
CUDA_2D_KERNEL_LOOP2(i, nthreads) {
val += __half2float(x[i]);
}
val = BlockReduce<float>(storage).Sum(val);
if (threadIdx.x == 0) {
y[0] = __float2half(val * scale);
}
}
template <typename T>
__global__ void
_ReduceLossWithMask(const int nthreads, const T* x, const int* mask, T* y) {
__shared__ union {
typename BlockReduce<T>::TempStorage loss;
typename BlockReduce<int>::TempStorage mask;
} storage;
T val = T(0);
int num_valids = 0;
CUDA_2D_KERNEL_LOOP2(i, nthreads) {
val += x[i];
num_valids += mask[i];
}
val = BlockReduce<T>(storage.loss).Sum(val);
num_valids = BlockReduce<int>(storage.mask).Sum(num_valids);
if (threadIdx.x == 0) {
y[0] = val / (T)max(1, num_valids);
}
}
template <>
__global__ void _ReduceLossWithMask<half>(
const int nthreads,
const half* x,
const int* mask,
half* y) {
__shared__ union {
typename BlockReduce<float>::TempStorage loss;
typename BlockReduce<int>::TempStorage mask;
} storage;
float val = 0.f;
int num_valids = 0;
CUDA_2D_KERNEL_LOOP2(i, nthreads) {
val += __half2float(x[i]);
num_valids += mask[i];
}
val = BlockReduce<float>(storage.loss).Sum(val);
num_valids = BlockReduce<int>(storage.mask).Sum(num_valids);
if (threadIdx.x == 0) {
y[0] = __float2half(val / (float)max(1, num_valids));
}
}
template <typename T>
__global__ void
_ReduceLossGrad(const int nthreads, const T scale, const T* dy, T* dx) { _ReduceLossGrad(const int nthreads, const T scale, const T* dy, T* dx) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
const T val = __ldg(dy) * scale; const T alpha = __ldg(dy) * scale;
#else #else
const T val = dy[0] * scale; const T alpha = dy[0] * scale;
#endif #endif
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] *= val; dx[i] *= alpha;
} }
} }
...@@ -99,54 +30,43 @@ __global__ void _ReduceLossGrad( ...@@ -99,54 +30,43 @@ __global__ void _ReduceLossGrad(
const half* dy, const half* dy,
half* dx) { half* dx) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
const float val = __half2float(__ldg(dy)) * scale; const float alpha = __half2float(__ldg(dy)) * scale;
#else #else
const float val = __half2float(dy[0]) * scale; const float alpha = __half2float(dy[0]) * scale;
#endif #endif
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = __float2half(__half2float(dx[i]) * val); dx[i] = __float2half(__half2float(dx[i]) * alpha);
}
}
__global__ void _ReduceMask(const int num_masks, int* mask) {
__shared__ typename BlockReduce<int>::TempStorage storage;
int num_valids = 0;
CUDA_2D_KERNEL_LOOP2(i, num_masks) {
num_valids += mask[i];
} }
num_valids = BlockReduce<int>(storage).Sum(num_valids);
if (threadIdx.x == 0) mask[0] = max(num_valids, 1);
} }
template <typename T> template <typename T>
__global__ void _ReduceLossGradWithMask( __global__ void
const int nthreads, _ReduceLossGrad(const int nthreads, const T* normalizer, const T* dy, T* dx) {
const T* dy,
const int* mask,
T* dx) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
const T val = __ldg(dy) / (T)__ldg(mask); const T alpha = __ldg(dy) / max(__ldg(normalizer), T(1));
#else #else
const T val = dy[0] / (T)mask[0]; const T alpha = dy[0] / max(normalizer[0], T(1));
#endif #endif
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] *= val; dx[i] *= alpha;
} }
} }
template <> template <>
__global__ void _ReduceLossGradWithMask<half>( __global__ void _ReduceLossGrad<half>(
const int nthreads, const int nthreads,
const half* normalizer,
const half* dy, const half* dy,
const int* mask,
half* dx) { half* dx) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
const float val = __half2float(__ldg(dy)) / (float)__ldg(mask); const float alpha =
__half2float(__ldg(dy)) / max(__half2float(__ldg(normalizer)), 1.f);
#else #else
const float val = __half2float(dy[0]) / (float)mask[0]; const float alpha =
__half2float(dy[0]) / max(__half2float(normalizer[0]), 1.f);
#endif #endif
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = __float2half(__half2float(dx[i]) * val); dx[i] = __float2half(__half2float(dx[i]) * alpha);
} }
} }
...@@ -190,49 +110,25 @@ __global__ void _BroadcastLossGrad<half>( ...@@ -190,49 +110,25 @@ __global__ void _BroadcastLossGrad<half>(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> template <>
void ReduceLoss<float16, CUDAContext>(
const int count,
const int num_masks,
const float normalizer,
const float16* x,
const int* mask,
float16* y,
CUDAContext* ctx) {
if (num_masks > 0 && normalizer < 0.f) {
_ReduceLossWithMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
num_masks,
reinterpret_cast<const half*>(x),
mask,
reinterpret_cast<half*>(y));
} else {
_ReduceLoss<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
1.f / std::max(1e-5F, normalizer),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
}
template <>
void ReduceLossGrad<float16, CUDAContext>( void ReduceLossGrad<float16, CUDAContext>(
const int count, const int count,
const int num_masks, const int num_masks,
const float normalizer, const float normalizer,
const float16* dy, const float16* dy,
const int* mask, const float16* mask,
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
if (num_masks > 0 && normalizer < 0.f) { if (num_masks > 0 && normalizer < 0.f) {
_ReduceMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( auto* normalizer_v2 = const_cast<float16*>(mask + num_masks);
num_masks, const_cast<int*>(mask)); math::Sum(num_masks, 1.f, mask, normalizer_v2, ctx);
_ReduceLossGradWithMask<<< _ReduceLossGrad<<<
CUDA_BLOCKS(count), CUDA_BLOCKS(count),
CUDA_THREADS, CUDA_THREADS,
0, 0,
ctx->cuda_stream()>>>( ctx->cuda_stream()>>>(
count, count,
reinterpret_cast<const half*>(normalizer_v2),
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
mask,
reinterpret_cast<half*>(dx)); reinterpret_cast<half*>(dx));
} else { } else {
_ReduceLossGrad<<< _ReduceLossGrad<<<
...@@ -241,11 +137,11 @@ void ReduceLossGrad<float16, CUDAContext>( ...@@ -241,11 +137,11 @@ void ReduceLossGrad<float16, CUDAContext>(
0, 0,
ctx->cuda_stream()>>>( ctx->cuda_stream()>>>(
count, count,
1.f / std::max(1e-5F, normalizer), 1.f / std::max(0.5f, normalizer),
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx)); reinterpret_cast<half*>(dx));
} }
} // ReduceLossGrad }
template <> template <>
void BroadcastLossGrad<float16, CUDAContext>( void BroadcastLossGrad<float16, CUDAContext>(
...@@ -266,71 +162,73 @@ void BroadcastLossGrad<float16, CUDAContext>( ...@@ -266,71 +162,73 @@ void BroadcastLossGrad<float16, CUDAContext>(
inner_dim, inner_dim,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx)); reinterpret_cast<half*>(dx));
} // BroadcastLossGrad }
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ReduceLoss<T, CUDAContext>( \
const int count, \
const int num_masks, \
const float normalizer, \
const T* x, \
const int* mask, \
T* y, \
CUDAContext* ctx) { \
if (num_masks > 0 && normalizer < 0.f) { \
_ReduceLossWithMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
num_masks, x, mask, y); \
} else { \
_ReduceLoss<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, T(1) / (T)std::max(1e-5F, normalizer), x, y); \
} \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void ReduceLossGrad<T, CUDAContext>( \
const int count, \
const int num_masks, \
const float normalizer, \
const T* dy, \
const int* mask, \
T* dx, \
CUDAContext* ctx) { \
if (num_masks > 0 && normalizer < 0.f) { \
_ReduceMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
num_masks, const_cast<int*>(mask)); \
_ReduceLossGradWithMask<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(count, dy, mask, dx); \
} else { \
_ReduceLossGrad<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
count, T(1) / (T)std::max(1e-5F, normalizer), dy, dx); \
} \
} \
template <> \
void BroadcastLossGrad<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
auto nthreads = outer_dim * axis_dim * inner_dim; \
_BroadcastLossGrad<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, axis_dim * inner_dim, inner_dim, dy, dx); \
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ReduceLoss<T, CUDAContext>( \
const int count, \
const int num_masks, \
const float normalizer, \
const T* x, \
const T* mask, \
T* y, \
CUDAContext* ctx) { \
if (num_masks > 0 && normalizer < 0.f) { \
auto* normalizer_v2 = const_cast<T*>(mask + num_masks); \
math::Sum(num_masks, 1.f, mask, normalizer_v2, ctx); \
math::Sum(count, 1.f, x, y, ctx); \
math::Div(1, y, normalizer_v2, y, ctx); \
} else { \
math::Sum(count, 1.f / std::max(1.f, normalizer), x, y, ctx); \
} \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void ReduceLossGrad<T, CUDAContext>( \
const int count, \
const int num_masks, \
const float normalizer, \
const T* dy, \
const T* mask, \
T* dx, \
CUDAContext* ctx) { \
if (num_masks > 0 && normalizer < 0.f) { \
auto* normalizer_v2 = const_cast<T*>(mask + num_masks); \
math::Sum(num_masks, 1.f, mask, normalizer_v2, ctx); \
_ReduceLossGrad<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(count, normalizer_v2, dy, dx); \
} else { \
_ReduceLossGrad<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
count, T(1.f / std::max(0.5f, normalizer)), dy, dx); \
} \
} \
template <> \
void BroadcastLossGrad<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
auto nthreads = outer_dim * axis_dim * inner_dim; \
_BroadcastLossGrad<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, axis_dim * inner_dim, inner_dim, dy, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
......
...@@ -16,17 +16,17 @@ void _NLLLoss( ...@@ -16,17 +16,17 @@ void _NLLLoss(
const LogitType* log_prob, const LogitType* log_prob,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask) { LogitType* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
int count = dims[0] * dims[1], k; int count = dims[0] * dims[1], k;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int label = (int)target[i]; const int label = (int)target[i];
if (label == ignore_index) { if (label == ignore_index) {
loss[i] = mask[i] = 0; loss[i] = mask[i] = LogitType(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
loss[i] = -log_prob[k], mask[i] = 1; loss[i] = -log_prob[k], mask[i] = LogitType(1);
} }
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data()); utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
...@@ -41,17 +41,17 @@ void _NLLLossGrad( ...@@ -41,17 +41,17 @@ void _NLLLossGrad(
const LogitType* log_prob, const LogitType* log_prob,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dx,
int* mask) { LogitType* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
int count = dims[0] * dims[1], k; int count = dims[0] * dims[1], k;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int label = (int)target[i]; const int label = (int)target[i];
if (label == ignore_index) { if (label == ignore_index) {
mask[i] = 0; mask[i] = LogitType(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
dx[k] = LogitType(-1), mask[i] = 1; dx[k] = LogitType(-1), mask[i] = LogitType(1);
} }
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data()); utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
...@@ -71,7 +71,7 @@ void _NLLLossGrad( ...@@ -71,7 +71,7 @@ void _NLLLossGrad(
const LogitType* log_prob, \ const LogitType* log_prob, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
int* mask, \ LogitType* mask, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
......
...@@ -18,16 +18,16 @@ __global__ void _NLLLoss( ...@@ -18,16 +18,16 @@ __global__ void _NLLLoss(
const LogitType* log_prob, const LogitType* log_prob,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask) { LogitType* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int label = target[i * inner_dim + j]; const int label = target[i * inner_dim + j];
if (label == ignore_index) { if (label == ignore_index) {
loss[yi] = mask[yi] = 0; loss[yi] = mask[yi] = LogitType(0);
} else { } else {
loss[yi] = -log_prob[(i * axis_dim + label) * inner_dim + j]; loss[yi] = -log_prob[(i * axis_dim + label) * inner_dim + j];
mask[yi] = 1; mask[yi] = LogitType(1);
} }
} }
} }
...@@ -41,16 +41,16 @@ __global__ void _NLLLossGrad( ...@@ -41,16 +41,16 @@ __global__ void _NLLLossGrad(
const LogitType* log_prob, const LogitType* log_prob,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dx,
int* mask) { LogitType* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int label = target[i * inner_dim + j]; const int label = target[i * inner_dim + j];
if (label == ignore_index) { if (label == ignore_index) {
mask[yi] = 0; mask[yi] = LogitType(0);
} else { } else {
dx[(i * axis_dim + label) * inner_dim + j] = LogitType(-1); dx[(i * axis_dim + label) * inner_dim + j] = LogitType(-1);
mask[yi] = 1; mask[yi] = LogitType(1);
} }
} }
} }
...@@ -69,7 +69,7 @@ __global__ void _NLLLossGrad( ...@@ -69,7 +69,7 @@ __global__ void _NLLLossGrad(
const LogitType* log_prob, \ const LogitType* log_prob, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
int* mask, \ LogitType* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
auto nthreads = outer_dim * inner_dim; \ auto nthreads = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
......
...@@ -13,19 +13,19 @@ void _SigmoidCrossEntropy( ...@@ -13,19 +13,19 @@ void _SigmoidCrossEntropy(
const T* logit, const T* logit,
const T* target, const T* target,
T* loss, T* loss,
int* mask) { T* mask) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count)) #pragma omp parallel for num_threads(OMP_THREADS(count))
#endif #endif
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
if (target[i] < 0) { if (target[i] < 0) {
loss[i] = mask[i] = 0; loss[i] = mask[i] = T(0);
} else { } else {
loss[i] = loss[i] =
std::log( std::log(
T(1) + std::exp(logit[i] - T(2) * logit[i] * (logit[i] >= 0))) + T(1) + std::exp(logit[i] - T(2) * logit[i] * (logit[i] >= 0))) +
logit[i] * ((logit[i] >= 0) - target[i]); logit[i] * ((logit[i] >= 0) - target[i]);
mask[i] = 1; mask[i] = T(1);
} }
} }
} }
...@@ -36,16 +36,16 @@ void _SigmoidCrossEntropyGrad( ...@@ -36,16 +36,16 @@ void _SigmoidCrossEntropyGrad(
const T* logit, const T* logit,
const T* target, const T* target,
T* dx, T* dx,
int* mask) { T* mask) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count)) #pragma omp parallel for num_threads(OMP_THREADS(count))
#endif #endif
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
if (target[i] < 0) { if (target[i] < 0) {
dx[i] = mask[i] = 0; dx[i] = mask[i] = T(0);
} else { } else {
dx[i] = T(1) / (T(1) + std::exp(-logit[i])) - target[i]; dx[i] = T(1) / (T(1) + std::exp(-logit[i])) - target[i];
mask[i] = 1; mask[i] = T(1);
} }
} }
} }
...@@ -61,7 +61,7 @@ void _SigmoidCrossEntropyGrad( ...@@ -61,7 +61,7 @@ void _SigmoidCrossEntropyGrad(
const T* logit, \ const T* logit, \
const T* target, \ const T* target, \
T* loss, \ T* loss, \
int* mask, \ T* mask, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name(count, logit, target, loss, mask); \ _##name(count, logit, target, loss, mask); \
} }
......
...@@ -15,14 +15,14 @@ __global__ void _SigmoidCrossEntropy( ...@@ -15,14 +15,14 @@ __global__ void _SigmoidCrossEntropy(
const T* logit, const T* logit,
const T* target, const T* target,
T* loss, T* loss,
int* mask) { T* mask) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
if (target[i] < 0) { if (target[i] < 0) {
loss[i] = mask[i] = 0; loss[i] = mask[i] = T(0);
} else { } else {
loss[i] = log(T(1) + exp(logit[i] - T(2) * logit[i] * (logit[i] >= 0))) + loss[i] = log(T(1) + exp(logit[i] - T(2) * logit[i] * (logit[i] >= 0))) +
logit[i] * ((logit[i] >= 0) - target[i]); logit[i] * ((logit[i] >= 0) - target[i]);
mask[i] = 1; mask[i] = T(1);
} }
} }
} }
...@@ -33,13 +33,13 @@ __global__ void _SigmoidCrossEntropyGrad( ...@@ -33,13 +33,13 @@ __global__ void _SigmoidCrossEntropyGrad(
const T* logit, const T* logit,
const T* target, const T* target,
T* dx, T* dx,
int* mask) { T* mask) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
if (target[i] < 0) { if (target[i] < 0) {
dx[i] = mask[i] = 0; dx[i] = mask[i] = T(0);
} else { } else {
dx[i] = T(1) / (T(1) + exp(-logit[i])) - target[i]; dx[i] = T(1) / (T(1) + exp(-logit[i])) - target[i];
mask[i] = 1; mask[i] = T(1);
} }
} }
} }
...@@ -55,7 +55,7 @@ __global__ void _SigmoidCrossEntropyGrad( ...@@ -55,7 +55,7 @@ __global__ void _SigmoidCrossEntropyGrad(
const T* logit, \ const T* logit, \
const T* target, \ const T* target, \
T* loss, \ T* loss, \
int* mask, \ T* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##name<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, logit, target, loss, mask); \ count, logit, target, loss, mask); \
......
...@@ -19,7 +19,7 @@ void _SigmoidFocalLoss( ...@@ -19,7 +19,7 @@ void _SigmoidFocalLoss(
const LogitType* logit, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask) { LogitType* mask) {
std::array<int, 3> idx = {0, 0, 0}; std::array<int, 3> idx = {0, 0, 0};
std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim}; std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim};
const int count = dims[0] * dims[1] * dims[2]; const int count = dims[0] * dims[1] * dims[2];
...@@ -64,7 +64,7 @@ void _SigmoidFocalLossGrad( ...@@ -64,7 +64,7 @@ void _SigmoidFocalLossGrad(
const LogitType* logit, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dx,
int* mask) { LogitType* mask) {
std::array<int, 3> idx = {0, 0, 0}; std::array<int, 3> idx = {0, 0, 0};
std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim}; std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim};
const int count = dims[0] * dims[1] * dims[2]; const int count = dims[0] * dims[1] * dims[2];
...@@ -117,7 +117,7 @@ void _SigmoidFocalLossGrad( ...@@ -117,7 +117,7 @@ void _SigmoidFocalLossGrad(
const LogitType* logit, \ const LogitType* logit, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
int* mask, \ LogitType* mask, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
......
...@@ -21,7 +21,7 @@ __global__ void _SigmoidFocalLoss( ...@@ -21,7 +21,7 @@ __global__ void _SigmoidFocalLoss(
const LogitType* logit, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask) { LogitType* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int k = (yi / inner_dim) % axis_dim; const int k = (yi / inner_dim) % axis_dim;
...@@ -62,7 +62,7 @@ __global__ void _SigmoidFocalLossGrad( ...@@ -62,7 +62,7 @@ __global__ void _SigmoidFocalLossGrad(
const LogitType* logit, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dx,
int* mask) { LogitType* mask) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) { CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int j = xi % inner_dim; const int j = xi % inner_dim;
const int k = (xi / inner_dim) % axis_dim; const int k = (xi / inner_dim) % axis_dim;
...@@ -111,7 +111,7 @@ __global__ void _SigmoidFocalLossGrad( ...@@ -111,7 +111,7 @@ __global__ void _SigmoidFocalLossGrad(
const LogitType* logit, \ const LogitType* logit, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
int* mask, \ LogitType* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \ const int nthreads = outer_dim * axis_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
......
...@@ -16,18 +16,18 @@ void _SparseSoftmaxCrossEntropy( ...@@ -16,18 +16,18 @@ void _SparseSoftmaxCrossEntropy(
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask) { LogitType* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
int count = dims[0] * dims[1], k; int count = dims[0] * dims[1], k;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int label = (int)target[i]; const int label = (int)target[i];
if (label == ignore_index) { if (label == ignore_index) {
loss[i] = mask[i] = 0; loss[i] = mask[i] = LogitType(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
loss[i] = -std::log(std::max(prob[k], LogitType(FLT_MIN))); loss[i] = -std::log(std::max(prob[k], LogitType(FLT_MIN)));
mask[i] = 1; mask[i] = LogitType(1);
} }
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data()); utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
...@@ -42,7 +42,7 @@ void _SparseSoftmaxCrossEntropyGrad( ...@@ -42,7 +42,7 @@ void _SparseSoftmaxCrossEntropyGrad(
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dx,
int* mask) { LogitType* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
int count = dims[0] * dims[1], k; int count = dims[0] * dims[1], k;
...@@ -54,11 +54,11 @@ void _SparseSoftmaxCrossEntropyGrad( ...@@ -54,11 +54,11 @@ void _SparseSoftmaxCrossEntropyGrad(
(*offset_dx) = LogitType(0); (*offset_dx) = LogitType(0);
offset_dx += inner_dim; offset_dx += inner_dim;
} }
mask[i] = 0; mask[i] = LogitType(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
dx[k] -= LogitType(1); dx[k] -= LogitType(1);
mask[i] = 1; mask[i] = LogitType(1);
} }
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data()); utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
...@@ -78,7 +78,7 @@ void _SparseSoftmaxCrossEntropyGrad( ...@@ -78,7 +78,7 @@ void _SparseSoftmaxCrossEntropyGrad(
const LogitType* prob, \ const LogitType* prob, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
int* mask, \ LogitType* mask, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
......
...@@ -18,17 +18,17 @@ __global__ void _SparseSoftmaxCrossEntropy( ...@@ -18,17 +18,17 @@ __global__ void _SparseSoftmaxCrossEntropy(
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask) { LogitType* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int label = target[i * inner_dim + j]; const int label = target[i * inner_dim + j];
if (label == ignore_index) { if (label == ignore_index) {
loss[yi] = mask[yi] = 0; loss[yi] = mask[yi] = LogitType(0);
} else { } else {
loss[yi] = -log(max( loss[yi] = -log(max(
prob[(i * axis_dim + label) * inner_dim + j], LogitType(FLT_MIN))); prob[(i * axis_dim + label) * inner_dim + j], LogitType(FLT_MIN)));
mask[yi] = 1; mask[yi] = LogitType(1);
} }
} }
} }
...@@ -42,7 +42,7 @@ __global__ void _SparseSoftmaxCrossEntropyGrad( ...@@ -42,7 +42,7 @@ __global__ void _SparseSoftmaxCrossEntropyGrad(
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dx,
int* mask) { LogitType* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
const int j = yi % inner_dim; const int j = yi % inner_dim;
...@@ -53,10 +53,10 @@ __global__ void _SparseSoftmaxCrossEntropyGrad( ...@@ -53,10 +53,10 @@ __global__ void _SparseSoftmaxCrossEntropyGrad(
(*offset_dx) = LogitType(0); (*offset_dx) = LogitType(0);
offset_dx += inner_dim; offset_dx += inner_dim;
} }
mask[yi] = 0; mask[yi] = LogitType(0);
} else { } else {
dx[(i * axis_dim + label) * inner_dim + j] -= LogitType(1); dx[(i * axis_dim + label) * inner_dim + j] -= LogitType(1);
mask[yi] = 1; mask[yi] = LogitType(1);
} }
} }
} }
...@@ -75,7 +75,7 @@ __global__ void _SparseSoftmaxCrossEntropyGrad( ...@@ -75,7 +75,7 @@ __global__ void _SparseSoftmaxCrossEntropyGrad(
const LogitType* prob, \ const LogitType* prob, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
int* mask, \ LogitType* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * inner_dim; \ const int nthreads = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void _Affine(
const int nthreads,
const int axis_dim,
const int inner_dim,
const T* x,
const T* w,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __ldg(w + (i / inner_dim) % axis_dim) * x[i];
#else
y[i] = w[(i / inner_dim) % axis_dim] * x[i];
#endif
}
}
template <>
__global__ void _Affine<half>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const half* x,
const half* w,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(x[i], __ldg(w + (i / inner_dim) % axis_dim));
#endif
}
}
template <typename T>
__global__ void _Affine(
const int nthreads,
const int axis_dim,
const int inner_dim,
const T* x,
const T* w,
const T* b,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int wi = (i / inner_dim) % axis_dim;
#if __CUDA_ARCH__ >= 350
y[i] = __ldg(w + wi) * x[i] + __ldg(b + wi);
#else
y[i] = w[wi] * x[i] + b[wi];
#endif
}
}
template <>
__global__ void _Affine<half>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const half* x,
const half* w,
const half* b,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const int wi = (i / inner_dim) % axis_dim;
y[i] = __hadd(__hmul(x[i], __ldg(w + wi)), __ldg(b + wi));
#endif
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void Affine<float16, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* x,
const float16* w,
const float16* b,
float16* y,
CUDAContext* ctx) {
const int nthreads = outer_dim * axis_dim * inner_dim;
if (b != nullptr) {
_Affine<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
axis_dim,
inner_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(w),
reinterpret_cast<const half*>(b),
reinterpret_cast<half*>(y));
} else {
_Affine<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
axis_dim,
inner_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(w),
reinterpret_cast<half*>(y));
}
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Affine<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const T* x, \
const T* w, \
const T* b, \
T* y, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \
if (b != nullptr) { \
_Affine<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, axis_dim, inner_dim, x, w, b, y); \
} else { \
_Affine<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, axis_dim, inner_dim, x, w, y); \
} \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
...@@ -9,7 +9,11 @@ namespace dragon { ...@@ -9,7 +9,11 @@ namespace dragon {
namespace kernel { namespace kernel {
#if __CUDA_ARCH__ >= 350
#define L(x, i) __ldg(x + i) #define L(x, i) __ldg(x + i)
#else
#define L(x, i) x[i]
#endif
namespace { namespace {
...@@ -30,13 +34,8 @@ __global__ void _BatchNormExpectation( ...@@ -30,13 +34,8 @@ __global__ void _BatchNormExpectation(
CUDA_2D_KERNEL_LOOP2(j, outer_dim) { CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i; : j * C + i;
#if __CUDA_ARCH__ >= 350 ex_val += L(x, xi);
ex_val += __ldg(x + xi); ex2_val += utils::math::Square(L(x, xi));
ex2_val += __ldg(x + xi) * __ldg(x + xi);
#else
ex_val += x[xi];
ex2_val += x[xi] * x[xi];
#endif
} }
ex_val = BlockReduce<Tp>(ex_storage).Reduce(ex_val, cub::Sum()); ex_val = BlockReduce<Tp>(ex_storage).Reduce(ex_val, cub::Sum());
ex2_val = BlockReduce<Tp>(ex2_storage).Reduce(ex2_val, cub::Sum()); ex2_val = BlockReduce<Tp>(ex2_storage).Reduce(ex2_val, cub::Sum());
...@@ -67,13 +66,8 @@ __global__ void _BatchNormInternalGrad( ...@@ -67,13 +66,8 @@ __global__ void _BatchNormInternalGrad(
CUDA_2D_KERNEL_LOOP2(j, outer_dim) { CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i; : j * C + i;
#if __CUDA_ARCH__ >= 350
dg_val += L(dy, xi) * (L(x, xi) - L(mu, i)) * L(rsig, i); dg_val += L(dy, xi) * (L(x, xi) - L(mu, i)) * L(rsig, i);
db_val += L(dy, xi); db_val += L(dy, xi);
#else
dg_val += dy[xi] * (x[xi] - mu[i]) * rsig[i];
db_val += dy[xi];
#endif
} }
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum()); dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum()); db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
...@@ -101,15 +95,9 @@ __global__ void _BatchNormTrainingGrad( ...@@ -101,15 +95,9 @@ __global__ void _BatchNormTrainingGrad(
const Tp denom = Tp(1) / Tp(N * S); const Tp denom = Tp(1) / Tp(N * S);
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C; const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
#if __CUDA_ARCH__ >= 350
const Tp x_norm = (L(x, i) - L(mu, pi)) * L(rsig, pi); const Tp x_norm = (L(x, i) - L(mu, pi)) * L(rsig, pi);
dx[i] = L(gamma, pi) * L(rsig, pi) * dx[i] = L(gamma, pi) * L(rsig, pi) *
(L(dy, i) - (x_norm * L(dgamma, pi) + L(dbeta, pi)) * denom); (L(dy, i) - fma(x_norm, L(dgamma, pi), L(dbeta, pi)) * denom);
#else
const Tp x_norm = (x[i] - mu[pi]) * rsig[pi];
dx[i] = gamma[pi] * rsig[pi] *
(dy[i] - (x_norm * dgamma[pi] + dbeta[pi]) * denom);
#endif
} }
} }
...@@ -132,13 +120,8 @@ __global__ void _BatchNormWGrad( ...@@ -132,13 +120,8 @@ __global__ void _BatchNormWGrad(
CUDA_2D_KERNEL_LOOP2(j, outer_dim) { CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i; : j * C + i;
#if __CUDA_ARCH__ >= 350
dg_val += L(dy, xi) * (L(x, xi) - L(mu, i)) * L(rsig, i); dg_val += L(dy, xi) * (L(x, xi) - L(mu, i)) * L(rsig, i);
db_val += L(dy, xi); db_val += L(dy, xi);
#else
dg_val += dy[xi] * (x[xi] - mu[i]) * rsig[i];
db_val += dy[xi];
#endif
} }
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum()); dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum()); db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
...@@ -160,11 +143,7 @@ __global__ void _BatchNormInferenceGrad( ...@@ -160,11 +143,7 @@ __global__ void _BatchNormInferenceGrad(
Tx* dx) { Tx* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C; const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
#if __CUDA_ARCH__ >= 350
dx[i] = L(gamma, pi) * L(dy, i) * L(rsig, pi); dx[i] = L(gamma, pi) * L(dy, i) * L(rsig, pi);
#else
dx[i] = gamma[pi] * dy[i] * rsig[pi];
#endif
} }
} }
......
...@@ -9,8 +9,13 @@ namespace dragon { ...@@ -9,8 +9,13 @@ namespace dragon {
namespace kernel { namespace kernel {
#if __CUDA_ARCH__ >= 350
#define L(x, i) __ldg(x + i) #define L(x, i) __ldg(x + i)
#define LF(x, i) __half2float(__ldg(x + i)) #define LF(x, i) __half2float(__ldg(x + i))
#else
#define L(x, i) x[i]
#define LF(x, i) __half2float(x[i])
#endif
namespace { namespace {
...@@ -28,25 +33,14 @@ __global__ void _GroupNormFusedParams( ...@@ -28,25 +33,14 @@ __global__ void _GroupNormFusedParams(
const int outer_dim = N * G; const int outer_dim = N * G;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) { CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const int g = i % G; const int g = i % G;
#if __CUDA_ARCH__ >= 350
const T mu_val = L(mu, i); const T mu_val = L(mu, i);
const T rsig_val = L(rsig, i); const T rsig_val = L(rsig, i);
#else
const T mu_val = mu[i];
const T rsig_val = rsig[i];
#endif
CUDA_2D_KERNEL_LOOP2(j, D) { CUDA_2D_KERNEL_LOOP2(j, D) {
const int wi = i * D + j; const int wi = i * D + j;
const int gi = g * D + j; const int gi = g * D + j;
#if __CUDA_ARCH__ >= 350
const T w = L(gamma, gi) * rsig_val; const T w = L(gamma, gi) * rsig_val;
scale[wi] = w; scale[wi] = w;
bias[wi] = L(beta, gi) - w * mu_val; bias[wi] = fma(-w, mu_val, L(beta, gi));
#else
const T w = gamma[gi] * rsig_val;
scale[wi] = w;
bias[wi] = beta[gi] - w * mu_val;
#endif
} }
} }
} }
...@@ -62,20 +56,11 @@ __global__ void _GroupNormForwardNCHW( ...@@ -62,20 +56,11 @@ __global__ void _GroupNormForwardNCHW(
Tx* y) { Tx* y) {
const int outer_dim = N * C; const int outer_dim = N * C;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) { CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
#if __CUDA_ARCH__ >= 350
const Tp w = L(scale, i); const Tp w = L(scale, i);
const Tp b = L(bias, i); const Tp b = L(bias, i);
#else
const Tp w = scale[i];
const Tp b = bias[i];
#endif
CUDA_2D_KERNEL_LOOP2(j, S) { CUDA_2D_KERNEL_LOOP2(j, S) {
const int xi = i * S + j; const int xi = i * S + j;
#if __CUDA_ARCH__ >= 350 y[xi] = fma(L(x, xi), w, b);
y[xi] = L(x, xi) * w + b;
#else
y[xi] = x[xi] * w + b;
#endif
} }
} }
} }
...@@ -89,17 +74,15 @@ __global__ void _GroupNormForwardNCHW<half, float>( ...@@ -89,17 +74,15 @@ __global__ void _GroupNormForwardNCHW<half, float>(
const float* scale, const float* scale,
const float* bias, const float* bias,
half* y) { half* y) {
#if __CUDA_ARCH__ >= 530
const int outer_dim = N * C; const int outer_dim = N * C;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) { CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const float w = L(scale, i); const float w = L(scale, i);
const float b = L(bias, i); const float b = L(bias, i);
CUDA_2D_KERNEL_LOOP2(j, S) { CUDA_2D_KERNEL_LOOP2(j, S) {
const int xi = i * S + j; const int xi = i * S + j;
y[xi] = __float2half(LF(x, xi) * w + b); y[xi] = __float2half(fmaf(LF(x, xi), w, b));
} }
} }
#endif
} }
template <typename Tx, typename Tp> template <typename Tx, typename Tp>
...@@ -117,11 +100,7 @@ __global__ void _GroupNormForwardNHWC( ...@@ -117,11 +100,7 @@ __global__ void _GroupNormForwardNHWC(
CUDA_2D_KERNEL_LOOP2(j, C) { CUDA_2D_KERNEL_LOOP2(j, C) {
const int xi = i * C + j; const int xi = i * C + j;
const int wi = n * C + j; const int wi = n * C + j;
#if __CUDA_ARCH__ >= 350 y[xi] = fma(L(x, xi), L(scale, wi), L(bias, wi));
y[xi] = L(x, xi) * L(scale, wi) + L(bias, wi);
#else
y[xi] = x[xi] * scale[wi] + bias[wi];
#endif
} }
} }
} }
...@@ -135,17 +114,15 @@ __global__ void _GroupNormForwardNHWC<half, float>( ...@@ -135,17 +114,15 @@ __global__ void _GroupNormForwardNHWC<half, float>(
const float* scale, const float* scale,
const float* bias, const float* bias,
half* y) { half* y) {
#if __CUDA_ARCH__ >= 530
const int outer_dim = N * S; const int outer_dim = N * S;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) { CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const int n = i / S; const int n = i / S;
CUDA_2D_KERNEL_LOOP2(j, C) { CUDA_2D_KERNEL_LOOP2(j, C) {
const int xi = i * C + j; const int xi = i * C + j;
const int wi = n * C + j; const int wi = n * C + j;
y[xi] = __float2half(LF(x, xi) * L(scale, wi) + L(bias, wi)); y[xi] = __float2half(fmaf(LF(x, xi), L(scale, wi), L(bias, wi)));
} }
} }
#endif
} }
template <typename Tx, typename Tp, StorageOrder kOrder> template <typename Tx, typename Tp, StorageOrder kOrder>
...@@ -172,13 +149,8 @@ __global__ void _GroupNormWGrad( ...@@ -172,13 +149,8 @@ __global__ void _GroupNormWGrad(
? (n * outer_dim + i) * S + j % S ? (n * outer_dim + i) * S + j % S
: j * outer_dim + i; : j * outer_dim + i;
const int mi = n * G + i / D; const int mi = n * G + i / D;
#if __CUDA_ARCH__ >= 350
dg_val += L(dy, xi) * (L(x, xi) - L(mu, mi)) * L(rsig, mi); dg_val += L(dy, xi) * (L(x, xi) - L(mu, mi)) * L(rsig, mi);
db_val += L(dy, xi); db_val += L(dy, xi);
#else
dg_val += dy[xi] * (x[xi] - mu[mi]) * rsig[mi];
db_val += dy[xi];
#endif
} }
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum()); dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum()); db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
...@@ -201,7 +173,6 @@ __global__ void _GroupNormWGradHalf( ...@@ -201,7 +173,6 @@ __global__ void _GroupNormWGradHalf(
const half* dy, const half* dy,
float* dgamma, float* dgamma,
float* dbeta) { float* dbeta) {
#if __CUDA_ARCH__ >= 530
const int outer_dim = G * D; const int outer_dim = G * D;
const int inner_dim = N * S; const int inner_dim = N * S;
__shared__ typename BlockReduce<float>::TempStorage dg_storage; __shared__ typename BlockReduce<float>::TempStorage dg_storage;
...@@ -224,7 +195,6 @@ __global__ void _GroupNormWGradHalf( ...@@ -224,7 +195,6 @@ __global__ void _GroupNormWGradHalf(
dbeta[i] = db_val; dbeta[i] = db_val;
} }
} }
#endif
} }
template <typename Tx, typename Tp, StorageOrder kOrder> template <typename Tx, typename Tp, StorageOrder kOrder>
...@@ -249,13 +219,8 @@ __global__ void _GroupNormInternalGrad( ...@@ -249,13 +219,8 @@ __global__ void _GroupNormInternalGrad(
const int xi = kOrder == StorageOrder::NCHW const int xi = kOrder == StorageOrder::NCHW
? i * inner_dim + j ? i * inner_dim + j
: (i / G * S + j % S) * G * D + gi; : (i / G * S + j % S) * G * D + gi;
#if __CUDA_ARCH__ >= 350
ds_val += L(gamma, gi) * L(dy, xi) * L(x, xi); ds_val += L(gamma, gi) * L(dy, xi) * L(x, xi);
db_val += L(gamma, gi) * L(dy, xi); db_val += L(gamma, gi) * L(dy, xi);
#else
ds_val += gamma[gi] * dy[xi] * x[xi];
db_val += gamma[gi] * dy[xi];
#endif
} }
ds_val = BlockReduce<Tp>(ds_storage).Reduce(ds_val, cub::Sum()); ds_val = BlockReduce<Tp>(ds_storage).Reduce(ds_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum()); db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
...@@ -277,7 +242,6 @@ __global__ void _GroupNormInternalGradHalf( ...@@ -277,7 +242,6 @@ __global__ void _GroupNormInternalGradHalf(
const half* dy, const half* dy,
float* ds, float* ds,
float* db) { float* db) {
#if __CUDA_ARCH__ >= 530
const int outer_dim = N * G; const int outer_dim = N * G;
const int inner_dim = D * S; const int inner_dim = D * S;
__shared__ typename BlockReduce<float>::TempStorage ds_storage; __shared__ typename BlockReduce<float>::TempStorage ds_storage;
...@@ -299,7 +263,6 @@ __global__ void _GroupNormInternalGradHalf( ...@@ -299,7 +263,6 @@ __global__ void _GroupNormInternalGradHalf(
db[i] = db_val; db[i] = db_val;
} }
} }
#endif
} }
template <typename Tx, typename Tp, StorageOrder kOrder> template <typename Tx, typename Tp, StorageOrder kOrder>
...@@ -322,17 +285,10 @@ __global__ void _GroupNormGrad( ...@@ -322,17 +285,10 @@ __global__ void _GroupNormGrad(
const int mi = kOrder == StorageOrder::NCHW ? i / (D * S) const int mi = kOrder == StorageOrder::NCHW ? i / (D * S)
: i / (C * S) * G + (i / D % G); : i / (C * S) * G + (i / D % G);
const int gi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C; const int gi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
#if __CUDA_ARCH__ >= 350 const Tp u = fma(L(db, mi), L(mu, mi), -L(ds, mi)) * (L(x, i) - L(mu, mi)) *
const Tp u = (L(db, mi) * L(mu, mi) - L(ds, mi)) * (L(x, i) - L(mu, mi)) *
utils::math::Cube(L(rsig, mi)); utils::math::Cube(L(rsig, mi));
const Tp v = L(db, mi) * L(rsig, mi); const Tp v = L(db, mi) * L(rsig, mi);
dx[i] = L(gamma, gi) * L(dy, i) * L(rsig, mi) + (u - v) * denom; dx[i] = L(gamma, gi) * L(dy, i) * L(rsig, mi) + (u - v) * denom;
#else
const Tp u = (db[mi] * mu[mi] - ds[mi]) * (x[i] - mu[mi]) *
utils::math::Cube(rsig[mi]);
const Tp v = db[mi] * rsig[mi];
dx[i] = gamma[gi] * dy[i] * rsig[mi] + (u - v) * denom;
#endif
} }
} }
...@@ -350,20 +306,18 @@ __global__ void _GroupNormGradHalf( ...@@ -350,20 +306,18 @@ __global__ void _GroupNormGradHalf(
const float* db, const float* db,
const half* dy, const half* dy,
half* dx) { half* dx) {
#if __CUDA_ARCH__ >= 530
const int C = G * D; const int C = G * D;
const float denom = 1.f / float(D * S); const float denom = 1.f / float(D * S);
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int mi = kOrder == StorageOrder::NCHW ? i / (D * S) const int mi = kOrder == StorageOrder::NCHW ? i / (D * S)
: i / (C * S) * G + (i / D % G); : i / (C * S) * G + (i / D % G);
const int gi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C; const int gi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
const float u = (L(db, mi) * L(mu, mi) - L(ds, mi)) * const float u = fmaf(L(db, mi), L(mu, mi), -L(ds, mi)) *
(LF(x, i) - L(mu, mi)) * utils::math::Cube(L(rsig, mi)); (LF(x, i) - L(mu, mi)) * utils::math::Cube(L(rsig, mi));
const float v = L(db, mi) * L(rsig, mi); const float v = L(db, mi) * L(rsig, mi);
dx[i] = dx[i] =
__float2half(L(gamma, gi) * LF(dy, i) * L(rsig, mi) + (u - v) * denom); __float2half(L(gamma, gi) * LF(dy, i) * L(rsig, mi) + (u - v) * denom);
} }
#endif
} }
} // namespace } // namespace
......
#include "dragon/utils/cast.h"
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
template <>
void MixedPrecL2Penalty<float16, CPUContext>(
const int count,
const float alpha,
const float16* x,
float* dx,
CPUContext* ctx) {
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
dx[i] += (cast::to<float>(x[i]) * alpha);
}
}
template <>
void MixedPrecUpdate<float16, CPUContext>(
const int count,
const float* dx,
float16* x,
CPUContext* ctx) {
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
x[i] = cast::to<float16>(cast::to<float>(x[i]) - dx[i]);
}
}
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
__global__ void _MixedPrecL2Penalty(
const int nthreads,
const float alpha,
const half* x,
float* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] += __half2float(x[i]) * alpha;
}
}
__global__ void _MixedPrecUpdate(const int nthreads, const float* dx, half* x) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
x[i] = __float2half(__half2float(x[i]) - dx[i]);
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void MixedPrecL2Penalty<float16, CUDAContext>(
const int count,
const float alpha,
const float16* x,
float* dx,
CUDAContext* ctx) {
_MixedPrecL2Penalty<<<
CUDA_BLOCKS(count),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(count, alpha, reinterpret_cast<const half*>(x), dx);
}
template <>
void MixedPrecUpdate<float16, CUDAContext>(
const int count,
const float* dx,
float16* x,
CUDAContext* ctx) {
_MixedPrecUpdate<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, dx, reinterpret_cast<half*>(x));
}
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
...@@ -11,11 +11,19 @@ namespace { ...@@ -11,11 +11,19 @@ namespace {
template <typename T> template <typename T>
__global__ void __global__ void
_NesterovUpdate(const int nthreads, const T lr, const T momentum, T* g, T* m) { _NesterovUpdate(const int nthreads, const T lr, const T momentum, T* g, T* m);
template <>
__global__ void _NesterovUpdate<float>(
const int nthreads,
const float lr,
const float momentum,
float* g,
float* m) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
T mi = m[i]; float mi = m[i];
T mi_new = m[i] = momentum * mi + lr * g[i]; float mi_new = m[i] = momentum * mi + lr * g[i];
g[i] = (1 + momentum) * mi_new - momentum * mi; g[i] = fmaf(momentum, mi_new - mi, mi_new);
} }
} }
......
...@@ -94,11 +94,12 @@ void RegisterModule(py::module& m) { ...@@ -94,11 +94,12 @@ void RegisterModule(py::module& m) {
}); });
/*! \brief Activate the CuDNN engine */ /*! \brief Activate the CuDNN engine */
m.def("cudaEnableDNN", [](bool enabled, bool benchmark) { m.def("cudaEnableDNN", [](bool enabled, bool benchmark, bool allow_tf32) {
#ifdef USE_CUDA #ifdef USE_CUDA
auto& cuda_objects = CUDAContext::objects(); auto& cuda_objects = CUDAContext::objects();
cuda_objects.cudnn_enabled_ = enabled; cuda_objects.cudnn_enabled_ = enabled;
cuda_objects.cudnn_benchmark_ = benchmark; cuda_objects.cudnn_benchmark_ = benchmark;
cuda_objects.cudnn_allow_tf32_ = allow_tf32;
#endif #endif
}); });
......
...@@ -40,7 +40,7 @@ void DropBlock2dOp<Context>::DoRunWithType() { ...@@ -40,7 +40,7 @@ void DropBlock2dOp<Context>::DoRunWithType() {
auto* scale = Buffer("scale") auto* scale = Buffer("scale")
->Reshape({}) ->Reshape({})
->template mutable_data<float, CPUContext>(); ->template mutable_data<float, CPUContext>();
auto scratches = ws()->template data<Context>({ auto scratches = ctx()->workspace()->template data<Context>({
X.dim(0) * seed_h * seed_w * sizeof(uint32_t), // seed points X.dim(0) * seed_h * seed_w * sizeof(uint32_t), // seed points
X.count() * sizeof(int), // int32 mask for seed growing X.count() * sizeof(int), // int32 mask for seed growing
}); });
...@@ -61,7 +61,7 @@ void DropBlock2dOp<Context>::DoRunWithType() { ...@@ -61,7 +61,7 @@ void DropBlock2dOp<Context>::DoRunWithType() {
(int*)scratches[1], (int*)scratches[1],
ctx()); ctx());
// Convert to uint8 mask // Convert to uint8 mask
kernel::Cast(X.count(), (int*)scratches[1], mask, ctx()); math::Cast(X.count(), (int*)scratches[1], mask, ctx());
// Count the number of zeros to compute scale factor // Count the number of zeros to compute scale factor
float normalizer = math::Sum(X.count(), 1.f, (int*)scratches[1], ctx()); float normalizer = math::Sum(X.count(), 1.f, (int*)scratches[1], ctx());
scale[0] = (float)X.count() / std::max(normalizer, 1.f); scale[0] = (float)X.count() / std::max(normalizer, 1.f);
......
...@@ -20,7 +20,7 @@ void DropoutOp<Context>::DoRunWithType() { ...@@ -20,7 +20,7 @@ void DropoutOp<Context>::DoRunWithType() {
X.template data<T, Context>(), X.template data<T, Context>(),
Buffer("mask")->template mutable_data<uint8_t, Context>(), Buffer("mask")->template mutable_data<uint8_t, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ws()->template data<uint32_t, Context>({X.count()})[0], ctx()->workspace()->template data<uint32_t, Context>({X.count()})[0],
ctx()); ctx());
} else { } else {
LOG(FATAL) << "Unknown Phase: " << phase(); LOG(FATAL) << "Unknown Phase: " << phase();
......
...@@ -22,7 +22,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() { ...@@ -22,7 +22,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
CUDNN_CHECK( CUDNN_CHECK(
cudnnDropoutGetStatesSize(ctx()->cudnn_handle(), &states_size)); cudnnDropoutGetStatesSize(ctx()->cudnn_handle(), &states_size));
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
auto* X_states = ws()->CreateTensor( auto* X_states = workspace()->CreateTensor(
"/share/cudnn/dropout:" + str::to(rng_seed_) + "/states"); "/share/cudnn/dropout:" + str::to(rng_seed_) + "/states");
if (X_states->count() > 0) { if (X_states->count() > 0) {
CUDNN_CHECK(cudnnRestoreDropoutDescriptor( CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
...@@ -80,7 +80,7 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() { ...@@ -80,7 +80,7 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
CUDNN_CHECK( CUDNN_CHECK(
cudnnDropoutGetStatesSize(ctx()->cudnn_handle(), &states_size)); cudnnDropoutGetStatesSize(ctx()->cudnn_handle(), &states_size));
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
auto* X_states = ws()->CreateTensor( auto* X_states = workspace()->CreateTensor(
"/share/cudnn/dropout:" + str::to(rng_seed_) + "/states"); "/share/cudnn/dropout:" + str::to(rng_seed_) + "/states");
if (X_states->count() > 0) { if (X_states->count() > 0) {
CUDNN_CHECK(cudnnRestoreDropoutDescriptor( CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
......
#include "dragon/operators/array/cast_op.h" #include "dragon/operators/array/cast_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
#define ELIGIBLE_TENSOR_TYPES \ #define ELIGIBLE_TENSOR_TYPES \
{ "bool", "int8", "uint8", "int32", "int64", "float16", "float32", "float64" } { "bool", "int8", "uint8", "int32", "int64", "float16", "float32", "float64" }
#define DISPATCH_TYPE_TO(InputType, OutputType) \ #define DISPATCH_TYPE_TO(InputType, OutputType) \
if (dtype() == types::to_string<OutputType>()) { \ if (dtype() == types::to_string<OutputType>()) { \
if (InputSize() != 0) { \ if (InputSize() != 0) { \
Output(0)->ReshapeLike(Input(0)); \ Output(0)->ReshapeLike(Input(0)); \
auto* x = Input(0).template data<InputType, Context>(); \ auto* x = Input(0).template data<InputType, Context>(); \
auto* y = Output(0)->template mutable_data<OutputType, Context>(); \ auto* y = Output(0)->template mutable_data<OutputType, Context>(); \
kernel::Cast(Input(0).count(), x, y, ctx()); \ math::Cast(Input(0).count(), x, y, ctx()); \
} else { \ } else { \
auto n = Output(0)->count(); \ auto n = Output(0)->count(); \
auto* x = Output(0)->template data<InputType, Context>(); \ auto* x = Output(0)->template data<InputType, Context>(); \
auto* scratch = ws()->template data<OutputType, Context>({n})[0]; \ auto* scratch = \
kernel::Cast(n, x, scratch, ctx()); \ ctx()->workspace()->template data<OutputType, Context>({n})[0]; \
ctx()->FinishDeviceComputation(); \ math::Cast(n, x, scratch, ctx()); \
auto* y = Output(0)->template mutable_data<OutputType, Context>(); \ ctx()->FinishDeviceComputation(); \
math::Copy(n, scratch, y, ctx()); \ auto* y = Output(0)->template mutable_data<OutputType, Context>(); \
} \ math::Copy(n, scratch, y, ctx()); \
return; \ } \
return; \
} }
#define DISPATCH_TYPE_TO_ALL(InputType) \ #define DISPATCH_TYPE_TO_ALL(InputType) \
......
#include "dragon/operators/math/affine_op.h" #include "dragon/operators/array/channel_affine_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
...@@ -19,7 +19,7 @@ namespace dragon { ...@@ -19,7 +19,7 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void AffineOp<Context>::DoRunWithType() { void ChannelAffineOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0, {0}); auto &X = Input(0), &W = Input(1), *Y = Output(0, {0});
CANONICALIZE_AXES_WITH_TENSOR(X); CANONICALIZE_AXES_WITH_TENSOR(X);
...@@ -37,7 +37,7 @@ void AffineOp<Context>::DoRunWithType() { ...@@ -37,7 +37,7 @@ void AffineOp<Context>::DoRunWithType() {
<< ", got " << Input(2).DimString() << "."; << ", got " << Input(2).DimString() << ".";
} }
kernel::Affine( kernel::ChannelAffine(
X.count(0, axis), X.count(0, axis),
X.count(axis, axis + num_axes), X.count(axis, axis + num_axes),
X.count(axis + num_axes), X.count(axis + num_axes),
...@@ -49,21 +49,22 @@ void AffineOp<Context>::DoRunWithType() { ...@@ -49,21 +49,22 @@ void AffineOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void AffineOp<Context>::RunOnDevice() { void ChannelAffineOp<Context>::RunOnDevice() {
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0)); DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
} }
template <class Context> template <class Context>
template <typename T> template <typename T>
void AffineGradientOp<Context>::DoRunWithType() { void ChannelAffineGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2); auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
CANONICALIZE_AXES_WITH_TENSOR(X); CANONICALIZE_AXES_WITH_TENSOR(X);
// Reduce parameters for weight and bias // Reduce parameters for weight and bias
vec32_t dims = {(int)X.count(0, axis), vec32_t dims = {
(int)X.count(axis, axis + num_axes), (int)X.count(0, axis),
(int)X.count(axis + num_axes)}; (int)X.count(axis, axis + num_axes),
(int)X.count(axis + num_axes)};
vec32_t axes = {0, 2}; vec32_t axes = {0, 2};
// dW = dY * X // dW = dY * X
...@@ -79,7 +80,8 @@ void AffineGradientOp<Context>::DoRunWithType() { ...@@ -79,7 +80,8 @@ void AffineGradientOp<Context>::DoRunWithType() {
dW->ReshapeLike(W)->template mutable_data<T, Context>(), dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx()); ctx());
} else { } else {
T* scratch = ws()->template data<T, Context>({X.count()})[0]; T* scratch =
ctx()->workspace()->template data<T, Context>({X.count()})[0];
math::Mul( math::Mul(
X.count(), X.count(),
dY.template data<T, Context>(), dY.template data<T, Context>(),
...@@ -118,7 +120,7 @@ void AffineGradientOp<Context>::DoRunWithType() { ...@@ -118,7 +120,7 @@ void AffineGradientOp<Context>::DoRunWithType() {
// dX = dY * W // dX = dY * W
if (dX->has_name()) { if (dX->has_name()) {
Output(0)->ReshapeLike(Input(-1)); Output(0)->ReshapeLike(Input(-1));
kernel::Affine( kernel::ChannelAffine(
X.count(0, axis), X.count(0, axis),
X.count(axis, axis + num_axes), X.count(axis, axis + num_axes),
X.count(axis + num_axes), X.count(axis + num_axes),
...@@ -131,21 +133,21 @@ void AffineGradientOp<Context>::DoRunWithType() { ...@@ -131,21 +133,21 @@ void AffineGradientOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void AffineGradientOp<Context>::RunOnDevice() { void ChannelAffineGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
} }
DEPLOY_CPU_OPERATOR(Affine); DEPLOY_CPU_OPERATOR(ChannelAffine);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Affine); DEPLOY_CUDA_OPERATOR(ChannelAffine);
#endif #endif
DEPLOY_CPU_OPERATOR(AffineGradient); DEPLOY_CPU_OPERATOR(ChannelAffineGradient);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(AffineGradient); DEPLOY_CUDA_OPERATOR(ChannelAffineGradient);
#endif #endif
OPERATOR_SCHEMA(Affine) OPERATOR_SCHEMA(ChannelAffine)
/* X, W, B */ /* X, W, B */
.NumInputs(2, 3) .NumInputs(2, 3)
/* Y */ /* Y */
...@@ -153,7 +155,7 @@ OPERATOR_SCHEMA(Affine) ...@@ -153,7 +155,7 @@ OPERATOR_SCHEMA(Affine)
/* X => Y */ /* X => Y */
.AllowInplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(AffineGradient) OPERATOR_SCHEMA(ChannelAffineGradient)
/* X, W, dY */ /* X, W, dY */
.NumInputs(3) .NumInputs(3)
/* dX, dW, dB */ /* dX, dW, dB */
...@@ -177,7 +179,7 @@ class GradientMaker final : public GradientMakerBase { ...@@ -177,7 +179,7 @@ class GradientMaker final : public GradientMakerBase {
} // namespace } // namespace
REGISTER_GRADIENT(Affine, GradientMaker); REGISTER_GRADIENT(ChannelAffine, GradientMaker);
#undef CANONICALIZE_AXES_WITH_TENSOR #undef CANONICALIZE_AXES_WITH_TENSOR
......
...@@ -10,17 +10,17 @@ ...@@ -10,17 +10,17 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_MATH_AFFINE_OP_H_ #ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_
#define DRAGON_OPERATORS_MATH_AFFINE_OP_H_ #define DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_
#include "dragon/core/operator.h" #include "dragon/core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class AffineOp final : public Operator<Context> { class ChannelAffineOp final : public Operator<Context> {
public: public:
SIMPLE_CTOR_DTOR(AffineOp); SIMPLE_CTOR_DTOR(ChannelAffineOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -30,9 +30,9 @@ class AffineOp final : public Operator<Context> { ...@@ -30,9 +30,9 @@ class AffineOp final : public Operator<Context> {
}; };
template <class Context> template <class Context>
class AffineGradientOp final : public Operator<Context> { class ChannelAffineGradientOp final : public Operator<Context> {
public: public:
SIMPLE_CTOR_DTOR(AffineGradientOp); SIMPLE_CTOR_DTOR(ChannelAffineGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -43,4 +43,4 @@ class AffineGradientOp final : public Operator<Context> { ...@@ -43,4 +43,4 @@ class AffineGradientOp final : public Operator<Context> {
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_MATH_AFFINE_OP_H_ #endif // DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_
...@@ -18,17 +18,6 @@ void MaskedSelectOp<Context>::DoRunWithType() { ...@@ -18,17 +18,6 @@ void MaskedSelectOp<Context>::DoRunWithType() {
STORE_INPUT_SPEC(0); STORE_INPUT_SPEC(0);
auto* X_index = Buffer("X_index")->Reshape({X.count() + 1}); auto* X_index = Buffer("X_index")->Reshape({X.count() + 1});
// Determine the scratch requirement
size_t scratch_size = 0;
kernel::Flagged(
X.count(),
(const uint8_t*)X_mask.template raw_data<Context>(),
X_index->template mutable_data<int, Context>(),
nullptr,
nullptr,
scratch_size,
ctx());
// Select the index of values matching the criteria // Select the index of values matching the criteria
// The first ``num_selected`` indices are valid // The first ``num_selected`` indices are valid
int num_selected; int num_selected;
...@@ -37,8 +26,6 @@ void MaskedSelectOp<Context>::DoRunWithType() { ...@@ -37,8 +26,6 @@ void MaskedSelectOp<Context>::DoRunWithType() {
(const uint8_t*)X_mask.template raw_data<Context>(), (const uint8_t*)X_mask.template raw_data<Context>(),
X_index->template mutable_data<int, Context>(), X_index->template mutable_data<int, Context>(),
&num_selected, &num_selected,
ws()->template data<Context>({scratch_size})[0],
scratch_size,
ctx()); ctx());
// Select the values according to the flat indices // Select the values according to the flat indices
......
...@@ -19,17 +19,6 @@ void NonZeroOp<Context>::DoRunWithType() { ...@@ -19,17 +19,6 @@ void NonZeroOp<Context>::DoRunWithType() {
(bool*)X_mask->template mutable_data<uint8_t, Context>(), (bool*)X_mask->template mutable_data<uint8_t, Context>(),
ctx()); ctx());
// Determine the scratch requirement
size_t scratch_size = 0;
kernel::Flagged(
X.count(),
X_mask->template mutable_data<uint8_t, Context>(),
X_index->template mutable_data<int, Context>(),
nullptr,
nullptr,
scratch_size,
ctx());
// Select the index of values matching the criteria // Select the index of values matching the criteria
// The first ``num_selected`` indices are valid // The first ``num_selected`` indices are valid
int num_selected; int num_selected;
...@@ -38,8 +27,6 @@ void NonZeroOp<Context>::DoRunWithType() { ...@@ -38,8 +27,6 @@ void NonZeroOp<Context>::DoRunWithType() {
X_mask->template mutable_data<uint8_t, Context>(), X_mask->template mutable_data<uint8_t, Context>(),
X_index->template mutable_data<int, Context>(), X_index->template mutable_data<int, Context>(),
&num_selected, &num_selected,
ws()->template data<Context>({scratch_size})[0],
scratch_size,
ctx()); ctx());
// Convert the flat indices into n-dimension coordinates // Convert the flat indices into n-dimension coordinates
......
...@@ -11,7 +11,7 @@ void PermutationOp<Context>::DoRunWithType() { ...@@ -11,7 +11,7 @@ void PermutationOp<Context>::DoRunWithType() {
kernel::Permutation( kernel::Permutation(
Y->count(), Y->count(),
Y->template mutable_data<T, Context>(), Y->template mutable_data<T, Context>(),
ws()->template data<uint32_t, Context>({Y->count()})[0], ctx()->workspace()->template data<uint32_t, Context>({Y->count()})[0],
ctx()); ctx());
} }
......
...@@ -39,6 +39,7 @@ void ReduceMaxOp<Context>::DoRunWithType() { ...@@ -39,6 +39,7 @@ void ReduceMaxOp<Context>::DoRunWithType() {
X_dims.data(), X_dims.data(),
reduce_axes.size(), reduce_axes.size(),
reduce_axes.data(), reduce_axes.data(),
1.f,
X.template data<T, Context>(), X.template data<T, Context>(),
Y->Reshape(Y_shape)->template mutable_data<T, Context>(), Y->Reshape(Y_shape)->template mutable_data<T, Context>(),
ctx()); ctx());
......
...@@ -55,7 +55,7 @@ void ReduceMeanOp<Context>::DoRunWithType() { ...@@ -55,7 +55,7 @@ void ReduceMeanOp<Context>::DoRunWithType() {
template <class Context> template <class Context>
void ReduceMeanOp<Context>::RunOnDevice() { void ReduceMeanOp<Context>::RunOnDevice() {
STORE_INPUT_SPEC(0); STORE_INPUT_SPEC(0);
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
} }
template <class Context> template <class Context>
......
...@@ -39,6 +39,7 @@ void ReduceMinOp<Context>::DoRunWithType() { ...@@ -39,6 +39,7 @@ void ReduceMinOp<Context>::DoRunWithType() {
X_dims.data(), X_dims.data(),
reduce_axes.size(), reduce_axes.size(),
reduce_axes.data(), reduce_axes.data(),
1.f,
X.template data<T, Context>(), X.template data<T, Context>(),
Y->Reshape(Y_shape)->template mutable_data<T, Context>(), Y->Reshape(Y_shape)->template mutable_data<T, Context>(),
ctx()); ctx());
......
...@@ -42,12 +42,12 @@ void TileGradientOp<Context>::DoRunWithType() { ...@@ -42,12 +42,12 @@ void TileGradientOp<Context>::DoRunWithType() {
const T* dy; const T* dy;
T* dx; T* dx;
if (src_ == &nav_) { if (src_ == &nav_) {
dy = ws()->template data<T, Context>({src_->count()})[0]; dy = ctx()->workspace()->template data<T, Context>({src_->count()})[0];
} else { } else {
dy = src_->template data<T, Context>(); dy = src_->template data<T, Context>();
} }
if (dest_ == &nav_) { if (dest_ == &nav_) {
dx = ws()->template data<T, Context>({dest_->count()})[0]; dx = ctx()->workspace()->template data<T, Context>({dest_->count()})[0];
} else { } else {
dx = dest_->template mutable_data<T, Context>(); dx = dest_->template mutable_data<T, Context>();
} }
......
...@@ -66,7 +66,7 @@ void WhereGradientOp<Context>::DoRunWithType() { ...@@ -66,7 +66,7 @@ void WhereGradientOp<Context>::DoRunWithType() {
} }
if (scratch_size > 0) { if (scratch_size > 0) {
scratch = ws()->template data<T, Context>({scratch_size})[0]; scratch = ctx()->workspace()->template data<T, Context>({scratch_size})[0];
zeros = scratch + (scratch_size - 1); zeros = scratch + (scratch_size - 1);
math::Set(1, cast::to<T>(0.f), zeros, ctx()); math::Set(1, cast::to<T>(0.f), zeros, ctx());
} }
......
...@@ -49,8 +49,8 @@ void AssignOp<Context>::DoRunWithType() { ...@@ -49,8 +49,8 @@ void AssignOp<Context>::DoRunWithType() {
<< Tensor::DimString(X_dims); << Tensor::DimString(X_dims);
utils::math::ComputeBinaryBroadcastDims(X.dims(), X_dims, dims1, dims2); utils::math::ComputeBinaryBroadcastDims(X.dims(), X_dims, dims1, dims2);
if (dims1 != dims2) { if (dims1 != dims2) {
auto* scratch = auto* scratch = ctx()->workspace()->template data<T, Context>(
ws()->template data<T, Context>({X_broadcast.count()})[0]; {X_broadcast.count()})[0];
math::Set( math::Set(
X.ndim(), X.ndim(),
X.dims().data(), X.dims().data(),
......
...@@ -27,7 +27,7 @@ void CollectiveOp<Context>::AllReduceMPI() { ...@@ -27,7 +27,7 @@ void CollectiveOp<Context>::AllReduceMPI() {
auto from = (comm_rank_ - 1 + comm_size_) % comm_size_; auto from = (comm_rank_ - 1 + comm_size_) % comm_size_;
auto* data = src_tensor_->template mutable_data<T, Context>(); auto* data = src_tensor_->template mutable_data<T, Context>();
auto* scratch = ws()->template data<T, Context>({sizes[0]})[0]; auto* scratch = ctx()->workspace()->template data<T, Context>({sizes[0]})[0];
// Scatter-Reduce // Scatter-Reduce
MPI_Request recv_req; MPI_Request recv_req;
...@@ -129,25 +129,10 @@ void CollectiveOp<Context>::RunOnDevice() { ...@@ -129,25 +129,10 @@ void CollectiveOp<Context>::RunOnDevice() {
// Otherwise, data corruption will happen through GPUDirect(UVA) // Otherwise, data corruption will happen through GPUDirect(UVA)
// during executing collectives asynchronously. // during executing collectives asynchronously.
ctx()->FinishDeviceComputation(); ctx()->FinishDeviceComputation();
#ifdef USE_NCCL
#if NCCL_VERSION_MIN(2, 2, 0)
if (enable_nccl_ && InputSize() <= 2048) {
this->nccl_comm(); // Ensure the comm created
NCCL_CHECK(ncclGroupStart());
}
#endif
#endif
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); i++) {
src_tensor_ = &Input(i); src_tensor_ = &Input(i);
DispatchHelper<NumericalTensorTypes>::Call(this, *src_tensor_); DispatchHelper<NumericalTensorTypes>::Call(this, *src_tensor_);
} }
#ifdef USE_NCCL
#if NCCL_VERSION_MIN(2, 2, 0)
if (enable_nccl_ && InputSize() <= 2048) {
NCCL_CHECK(ncclGroupEnd());
}
#endif
#endif
src_tensor_ = nullptr; src_tensor_ = nullptr;
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); i++) {
dest_tensor_ = &Input(i); dest_tensor_ = &Input(i);
......
...@@ -52,7 +52,8 @@ void CuDNNCTCLossOp<Context>::DoRunWithType() { ...@@ -52,7 +52,8 @@ void CuDNNCTCLossOp<Context>::DoRunWithType() {
ctc_desc_, ctc_desc_,
&workspace_size_)); &workspace_size_));
auto* scratch = (uint8_t*)ws()->template data<Context>({workspace_size_})[0]; auto* scratch = (uint8_t*)ctx()->workspace()->template data<Context>(
{workspace_size_})[0];
auto* g = Buffer("grad") auto* g = Buffer("grad")
->ReshapeLike(Input(0)) ->ReshapeLike(Input(0))
......
...@@ -18,7 +18,7 @@ void L1LossOp<Context>::DoRunWithType() { ...@@ -18,7 +18,7 @@ void L1LossOp<Context>::DoRunWithType() {
} }
// Allocate a temporal error buffer // Allocate a temporal error buffer
auto* x_error = ws()->template data<T, Context>({X.count()})[0]; auto* x_error = ctx()->workspace()->template data<T, Context>({X.count()})[0];
// Compute the error of inputs // Compute the error of inputs
if (InputSize() > 1) { if (InputSize() > 1) {
...@@ -55,7 +55,7 @@ void L1LossOp<Context>::DoRunWithType() { ...@@ -55,7 +55,7 @@ void L1LossOp<Context>::DoRunWithType() {
0, 0,
normalizer, normalizer,
x_error, x_error,
nullptr, (T*)nullptr,
Y->Reshape({})->template mutable_data<T, Context>(), Y->Reshape({})->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -99,7 +99,8 @@ void L1LossGradientOp<Context>::DoRunWithType() { ...@@ -99,7 +99,8 @@ void L1LossGradientOp<Context>::DoRunWithType() {
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer *= dX->count(); normalizer *= dX->count();
} }
kernel::ReduceLossGrad(dX->count(), 0, normalizer, dy, nullptr, dx, ctx()); kernel::ReduceLossGrad(
dX->count(), 0, normalizer, dy, (T*)nullptr, dx, ctx());
} }
// Gradient w.r.t. the second input // Gradient w.r.t. the second input
......
...@@ -18,7 +18,7 @@ void L2LossOp<Context>::DoRunWithType() { ...@@ -18,7 +18,7 @@ void L2LossOp<Context>::DoRunWithType() {
} }
// Allocate a temporal error buffer // Allocate a temporal error buffer
auto* x_error = ws()->template data<T, Context>({X.count()})[0]; auto* x_error = ctx()->workspace()->template data<T, Context>({X.count()})[0];
// Compute the error of inputs // Compute the error of inputs
if (InputSize() > 1) { if (InputSize() > 1) {
...@@ -55,7 +55,7 @@ void L2LossOp<Context>::DoRunWithType() { ...@@ -55,7 +55,7 @@ void L2LossOp<Context>::DoRunWithType() {
0, 0,
normalizer, normalizer,
x_error, x_error,
nullptr, (T*)nullptr,
Y->Reshape({})->template mutable_data<T, Context>(), Y->Reshape({})->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -98,7 +98,7 @@ void L2LossGradientOp<Context>::DoRunWithType() { ...@@ -98,7 +98,7 @@ void L2LossGradientOp<Context>::DoRunWithType() {
normalizer *= dX->count(); normalizer *= dX->count();
} }
kernel::ReduceLossGrad( kernel::ReduceLossGrad(
dX->count(), 0, float(normalizer) * 0.5f, dy, nullptr, dx, ctx()); dX->count(), 0, float(normalizer) * 0.5f, dy, (T*)nullptr, dx, ctx());
} }
// Gradient w.r.t. the second input // Gradient w.r.t. the second input
......
...@@ -18,12 +18,12 @@ void NLLLossOp<Context>::DoRunWithType() { ...@@ -18,12 +18,12 @@ void NLLLossOp<Context>::DoRunWithType() {
CHECK_EQ(num_preds, Input(1).count()) CHECK_EQ(num_preds, Input(1).count())
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
auto scratches = ws()->template data<Context>({ auto scratches = ctx()->workspace()->template data<Context>({
num_preds * sizeof(LogitType), // loss (size_t)num_preds * sizeof(LogitType), // loss
num_preds * sizeof(int), // mask (size_t)num_preds * sizeof(LogitType) + sizeof(LogitType), // mask
}); });
auto* loss = static_cast<LogitType*>(scratches[0]); auto* loss = static_cast<LogitType*>(scratches[0]);
auto* mask = static_cast<int*>(scratches[1]); auto* mask = static_cast<LogitType*>(scratches[1]);
kernel::NLLLoss( kernel::NLLLoss(
outer_dim, outer_dim,
...@@ -101,9 +101,10 @@ void NLLLossGradientOp<Context>::DoRunWithType() { ...@@ -101,9 +101,10 @@ void NLLLossGradientOp<Context>::DoRunWithType() {
auto inner_dim = dX->count(axis + 1); auto inner_dim = dX->count(axis + 1);
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
auto* mask = ws()->template data<int, Context>({num_preds})[0];
auto* dy = dY.template data<LogitType, Context>(); auto* dy = dY.template data<LogitType, Context>();
auto* dx = dX->template mutable_data<LogitType, Context>(); auto* dx = dX->template mutable_data<LogitType, Context>();
auto* mask =
ctx()->workspace()->template data<LogitType, Context>({num_preds + 1})[0];
math::Set(dX->count(), cast::to<LogitType>(0.f), dx, ctx()); math::Set(dX->count(), cast::to<LogitType>(0.f), dx, ctx());
kernel::NLLLossGrad( kernel::NLLLossGrad(
......
...@@ -13,12 +13,12 @@ void SigmoidCrossEntropyOp<Context>::DoRunWithType() { ...@@ -13,12 +13,12 @@ void SigmoidCrossEntropyOp<Context>::DoRunWithType() {
CHECK_EQ(X.count(), Input(1).count()) CHECK_EQ(X.count(), Input(1).count())
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
auto scratches = ws()->template data<Context>({ auto scratches = ctx()->workspace()->template data<Context>({
X.count() * sizeof(T), // loss X.size() * sizeof(T), // loss
X.count() * sizeof(int), // mask X.size() * sizeof(T) + sizeof(T), // mask
}); });
auto* loss = static_cast<T*>(scratches[0]); auto* loss = static_cast<T*>(scratches[0]);
auto* mask = static_cast<int*>(scratches[1]); auto* mask = static_cast<T*>(scratches[1]);
kernel::SigmoidCrossEntropy( kernel::SigmoidCrossEntropy(
X.count(), X.count(),
...@@ -64,9 +64,10 @@ template <typename T> ...@@ -64,9 +64,10 @@ template <typename T>
void SigmoidCrossEntropyGradientOp<Context>::DoRunWithType() { void SigmoidCrossEntropyGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(-1), *dX = Output(0); auto &X = Input(0), &dY = Input(-1), *dX = Output(0);
auto* mask = ws()->template data<int, Context>({dX->count()})[0];
auto* dy = dY.template data<T, Context>(); auto* dy = dY.template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
auto* mask =
ctx()->workspace()->template data<T, Context>({dX->count() + 1})[0];
kernel::SigmoidCrossEntropyGrad( kernel::SigmoidCrossEntropyGrad(
dX->count(), dX->count(),
......
...@@ -17,12 +17,12 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() { ...@@ -17,12 +17,12 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() {
CHECK_EQ(outer_dim * inner_dim, Input(1).count()) CHECK_EQ(outer_dim * inner_dim, Input(1).count())
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
auto scratches = ws()->template data<Context>({ auto scratches = ctx()->workspace()->template data<Context>({
X.count() * sizeof(LogitType), // loss X.size() * sizeof(LogitType), // loss
X.count() * sizeof(int), // mask X.size() * sizeof(LogitType) + sizeof(LogitType), // mask
}); });
auto* loss = static_cast<LogitType*>(scratches[0]); auto* loss = static_cast<LogitType*>(scratches[0]);
auto* mask = static_cast<int*>(scratches[1]); auto* mask = static_cast<LogitType*>(scratches[1]);
kernel::SigmoidFocalLoss( kernel::SigmoidFocalLoss(
outer_dim, outer_dim,
...@@ -100,9 +100,10 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() { ...@@ -100,9 +100,10 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() {
auto outer_dim = dX->count(0, axis); auto outer_dim = dX->count(0, axis);
auto inner_dim = dX->count(axis + 1); auto inner_dim = dX->count(axis + 1);
auto* mask = ws()->template data<int, Context>({dX->count()})[0];
auto* dy = dY.template data<LogitType, Context>(); auto* dy = dY.template data<LogitType, Context>();
auto* dx = dX->template mutable_data<LogitType, Context>(); auto* dx = dX->template mutable_data<LogitType, Context>();
auto* mask = ctx()->workspace()->template data<LogitType, Context>(
{dX->count() + 1})[0];
kernel::SigmoidFocalLossGrad( kernel::SigmoidFocalLossGrad(
outer_dim, outer_dim,
......
...@@ -18,7 +18,7 @@ void SmoothL1LossOp<Context>::DoRunWithType() { ...@@ -18,7 +18,7 @@ void SmoothL1LossOp<Context>::DoRunWithType() {
} }
// Allocate a temporal error buffer // Allocate a temporal error buffer
auto* x_error = ws()->template data<T, Context>({X.count()})[0]; auto* x_error = ctx()->workspace()->template data<T, Context>({X.count()})[0];
// Compute the error of inputs // Compute the error of inputs
if (InputSize() > 1) { if (InputSize() > 1) {
...@@ -55,7 +55,7 @@ void SmoothL1LossOp<Context>::DoRunWithType() { ...@@ -55,7 +55,7 @@ void SmoothL1LossOp<Context>::DoRunWithType() {
0, 0,
normalizer, normalizer,
x_error, x_error,
nullptr, (T*)nullptr,
Y->Reshape({})->template mutable_data<T, Context>(), Y->Reshape({})->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -99,7 +99,8 @@ void SmoothL1LossGradientOp<Context>::DoRunWithType() { ...@@ -99,7 +99,8 @@ void SmoothL1LossGradientOp<Context>::DoRunWithType() {
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer *= dX->count(); normalizer *= dX->count();
} }
kernel::ReduceLossGrad(dX->count(), 0, normalizer, dy, nullptr, dx, ctx()); kernel::ReduceLossGrad(
dX->count(), 0, normalizer, dy, (T*)nullptr, dx, ctx());
} }
// Gradient w.r.t. the second input // Gradient w.r.t. the second input
......
...@@ -19,7 +19,7 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -19,7 +19,7 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() {
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
Buffer("prob")->ReshapeLike(X); Buffer("prob")->ReshapeLike(X);
auto* loss = ws()->template data<T, Context>({X.count()})[0]; auto* loss = ctx()->workspace()->template data<T, Context>({X.count()})[0];
auto* prob = Buffer("prob")->template mutable_data<T, Context>(); auto* prob = Buffer("prob")->template mutable_data<T, Context>();
kernel::Softmax( kernel::Softmax(
...@@ -59,7 +59,7 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -59,7 +59,7 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() {
0, 0,
normalizer, normalizer,
loss, loss,
nullptr, (T*)nullptr,
Y->Reshape({})->template mutable_data<T, Context>(), Y->Reshape({})->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -98,7 +98,8 @@ void SoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -98,7 +98,8 @@ void SoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = num_preds; normalizer = num_preds;
} }
kernel::ReduceLossGrad(dX->count(), 0, normalizer, dy, nullptr, dx, ctx()); kernel::ReduceLossGrad(
dX->count(), 0, normalizer, dy, (T*)nullptr, dx, ctx());
} }
} }
......
...@@ -20,12 +20,12 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -20,12 +20,12 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
auto* X_prob = Buffer("prob")->ReshapeLike(X); auto* X_prob = Buffer("prob")->ReshapeLike(X);
auto* prob = X_prob->template mutable_data<LogitType, Context>(); auto* prob = X_prob->template mutable_data<LogitType, Context>();
auto scratches = ws()->template data<Context>({ auto scratches = ctx()->workspace()->template data<Context>({
num_preds * sizeof(LogitType), // loss (size_t)num_preds * sizeof(LogitType), // loss
num_preds * sizeof(int), // mask (size_t)num_preds * sizeof(LogitType) + sizeof(LogitType), // mask
}); });
auto* loss = static_cast<LogitType*>(scratches[0]); auto* loss = static_cast<LogitType*>(scratches[0]);
auto* mask = static_cast<int*>(scratches[1]); auto* mask = static_cast<LogitType*>(scratches[1]);
kernel::Softmax( kernel::Softmax(
outer_dim, outer_dim,
...@@ -111,9 +111,10 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -111,9 +111,10 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
auto* prob = Buffer("prob")->template data<LogitType, Context>(); auto* prob = Buffer("prob")->template data<LogitType, Context>();
auto* mask = ws()->template data<int, Context>({num_preds})[0];
auto* dy = Input(-1).template data<LogitType, Context>(); auto* dy = Input(-1).template data<LogitType, Context>();
auto* dx = Output(0)->template mutable_data<LogitType, Context>(); auto* dx = Output(0)->template mutable_data<LogitType, Context>();
auto* mask =
ctx()->workspace()->template data<LogitType, Context>({num_preds + 1})[0];
math::Copy(dX->count(), prob, dx, ctx()); math::Copy(dX->count(), prob, dx, ctx());
......
...@@ -83,7 +83,7 @@ void DivGradientOp<Context>::DoRunWithType() { ...@@ -83,7 +83,7 @@ void DivGradientOp<Context>::DoRunWithType() {
ctx()); ctx());
} }
} else { } else {
scratch = ws()->template data<T, Context>({dY.count()})[0]; scratch = ctx()->workspace()->template data<T, Context>({dY.count()})[0];
if (B_broadcast_axes.empty()) { if (B_broadcast_axes.empty()) {
math::Div( math::Div(
B_ref.count(), B_ref.count(),
...@@ -136,7 +136,8 @@ void DivGradientOp<Context>::DoRunWithType() { ...@@ -136,7 +136,8 @@ void DivGradientOp<Context>::DoRunWithType() {
} }
} else { } else {
if (scratch == nullptr) { if (scratch == nullptr) {
scratch = ws()->template data<T, Context>({dY.count()})[0]; scratch =
ctx()->workspace()->template data<T, Context>({dY.count()})[0];
} }
if (A_broadcast_axes.empty()) { if (A_broadcast_axes.empty()) {
math::Mul( math::Mul(
......
...@@ -21,7 +21,7 @@ void MaximumGradientOp<Context>::DoRunWithType() { ...@@ -21,7 +21,7 @@ void MaximumGradientOp<Context>::DoRunWithType() {
T* scratch = nullptr; T* scratch = nullptr;
if (dA->has_name()) { if (dA->has_name()) {
auto scratches = ws()->template data<Context>( auto scratches = ctx()->workspace()->template data<Context>(
{dY.size() * sizeof(T), dY.size() * sizeof(bool)}); {dY.size() * sizeof(T), dY.size() * sizeof(bool)});
mask = (bool*)scratches[1], scratch = (T*)scratches[0]; mask = (bool*)scratches[1], scratch = (T*)scratches[0];
if (A_broadcast_axes.empty()) { if (A_broadcast_axes.empty()) {
...@@ -43,7 +43,7 @@ void MaximumGradientOp<Context>::DoRunWithType() { ...@@ -43,7 +43,7 @@ void MaximumGradientOp<Context>::DoRunWithType() {
mask, mask,
ctx()); ctx());
} }
kernel::Cast(dY.count(), mask, scratch, ctx()); math::Cast(dY.count(), mask, scratch, ctx());
math::Mul( math::Mul(
dY.count(), dY.count(),
dY.template data<T, Context>(), dY.template data<T, Context>(),
...@@ -60,7 +60,7 @@ void MaximumGradientOp<Context>::DoRunWithType() { ...@@ -60,7 +60,7 @@ void MaximumGradientOp<Context>::DoRunWithType() {
B.template data<T, Context>(), B.template data<T, Context>(),
mask, mask,
ctx()); ctx());
kernel::Cast(dY.count(), mask, scratch, ctx()); math::Cast(dY.count(), mask, scratch, ctx());
math::Mul( math::Mul(
dY.count(), dY.template data<T, Context>(), scratch, scratch, ctx()); dY.count(), dY.template data<T, Context>(), scratch, scratch, ctx());
math::ReduceSum( math::ReduceSum(
...@@ -77,7 +77,7 @@ void MaximumGradientOp<Context>::DoRunWithType() { ...@@ -77,7 +77,7 @@ void MaximumGradientOp<Context>::DoRunWithType() {
if (dB->has_name()) { if (dB->has_name()) {
if (mask == nullptr) { if (mask == nullptr) {
auto scratches = ws()->template data<Context>( auto scratches = ctx()->workspace()->template data<Context>(
{dY.size() * sizeof(T), dY.size() * sizeof(bool)}); {dY.size() * sizeof(T), dY.size() * sizeof(bool)});
mask = (bool*)scratches[1], scratch = (T*)scratches[0]; mask = (bool*)scratches[1], scratch = (T*)scratches[0];
} }
...@@ -100,7 +100,7 @@ void MaximumGradientOp<Context>::DoRunWithType() { ...@@ -100,7 +100,7 @@ void MaximumGradientOp<Context>::DoRunWithType() {
mask, mask,
ctx()); ctx());
} }
kernel::Cast(dY.count(), mask, scratch, ctx()); math::Cast(dY.count(), mask, scratch, ctx());
math::Mul( math::Mul(
dY.count(), dY.count(),
dY.template data<T, Context>(), dY.template data<T, Context>(),
...@@ -117,7 +117,7 @@ void MaximumGradientOp<Context>::DoRunWithType() { ...@@ -117,7 +117,7 @@ void MaximumGradientOp<Context>::DoRunWithType() {
B.template data<T, Context>(), B.template data<T, Context>(),
mask, mask,
ctx()); ctx());
kernel::Cast(dY.count(), mask, scratch, ctx()); math::Cast(dY.count(), mask, scratch, ctx());
math::Mul( math::Mul(
dY.count(), dY.template data<T, Context>(), scratch, scratch, ctx()); dY.count(), dY.template data<T, Context>(), scratch, scratch, ctx());
math::ReduceSum( math::ReduceSum(
......
...@@ -21,7 +21,7 @@ void MinimumGradientOp<Context>::DoRunWithType() { ...@@ -21,7 +21,7 @@ void MinimumGradientOp<Context>::DoRunWithType() {
T* scratch = nullptr; T* scratch = nullptr;
if (dA->has_name()) { if (dA->has_name()) {
auto scratches = ws()->template data<Context>( auto scratches = ctx()->workspace()->template data<Context>(
{dY.size() * sizeof(T), dY.size() * sizeof(bool)}); {dY.size() * sizeof(T), dY.size() * sizeof(bool)});
mask = (bool*)scratches[1], scratch = (T*)scratches[0]; mask = (bool*)scratches[1], scratch = (T*)scratches[0];
if (A_broadcast_axes.empty()) { if (A_broadcast_axes.empty()) {
...@@ -43,7 +43,7 @@ void MinimumGradientOp<Context>::DoRunWithType() { ...@@ -43,7 +43,7 @@ void MinimumGradientOp<Context>::DoRunWithType() {
mask, mask,
ctx()); ctx());
} }
kernel::Cast(dY.count(), mask, scratch, ctx()); math::Cast(dY.count(), mask, scratch, ctx());
math::Mul( math::Mul(
dY.count(), dY.count(),
dY.template data<T, Context>(), dY.template data<T, Context>(),
...@@ -60,7 +60,7 @@ void MinimumGradientOp<Context>::DoRunWithType() { ...@@ -60,7 +60,7 @@ void MinimumGradientOp<Context>::DoRunWithType() {
B.template data<T, Context>(), B.template data<T, Context>(),
mask, mask,
ctx()); ctx());
kernel::Cast(dY.count(), mask, scratch, ctx()); math::Cast(dY.count(), mask, scratch, ctx());
math::Mul( math::Mul(
dY.count(), dY.template data<T, Context>(), scratch, scratch, ctx()); dY.count(), dY.template data<T, Context>(), scratch, scratch, ctx());
math::ReduceSum( math::ReduceSum(
...@@ -77,7 +77,7 @@ void MinimumGradientOp<Context>::DoRunWithType() { ...@@ -77,7 +77,7 @@ void MinimumGradientOp<Context>::DoRunWithType() {
if (dB->has_name()) { if (dB->has_name()) {
if (mask == nullptr) { if (mask == nullptr) {
auto scratches = ws()->template data<Context>( auto scratches = ctx()->workspace()->template data<Context>(
{dY.size() * sizeof(T), dY.size() * sizeof(bool)}); {dY.size() * sizeof(T), dY.size() * sizeof(bool)});
mask = (bool*)scratches[1], scratch = (T*)scratches[0]; mask = (bool*)scratches[1], scratch = (T*)scratches[0];
} }
...@@ -100,7 +100,7 @@ void MinimumGradientOp<Context>::DoRunWithType() { ...@@ -100,7 +100,7 @@ void MinimumGradientOp<Context>::DoRunWithType() {
mask, mask,
ctx()); ctx());
} }
kernel::Cast(dY.count(), mask, scratch, ctx()); math::Cast(dY.count(), mask, scratch, ctx());
math::Mul( math::Mul(
dY.count(), dY.count(),
dY.template data<T, Context>(), dY.template data<T, Context>(),
...@@ -117,7 +117,7 @@ void MinimumGradientOp<Context>::DoRunWithType() { ...@@ -117,7 +117,7 @@ void MinimumGradientOp<Context>::DoRunWithType() {
B.template data<T, Context>(), B.template data<T, Context>(),
mask, mask,
ctx()); ctx());
kernel::Cast(dY.count(), mask, scratch, ctx()); math::Cast(dY.count(), mask, scratch, ctx());
math::Mul( math::Mul(
dY.count(), dY.template data<T, Context>(), scratch, scratch, ctx()); dY.count(), dY.template data<T, Context>(), scratch, scratch, ctx());
math::ReduceSum( math::ReduceSum(
......
...@@ -33,7 +33,7 @@ void MomentsOp<Context>::DoRunWithType() { ...@@ -33,7 +33,7 @@ void MomentsOp<Context>::DoRunWithType() {
} }
if (X.count() == 1) { if (X.count() == 1) {
kernel::Cast( math::Cast(
1, 1,
X.template data<Tx, Context>(), X.template data<Tx, Context>(),
Y1->Reshape(Y_shape)->template mutable_data<Ty, Context>(), Y1->Reshape(Y_shape)->template mutable_data<Ty, Context>(),
......
...@@ -83,7 +83,7 @@ void MulGradientOp<Context>::DoRunWithType() { ...@@ -83,7 +83,7 @@ void MulGradientOp<Context>::DoRunWithType() {
ctx()); ctx());
} }
} else { } else {
scratch = ws()->template data<T, Context>({dY.count()})[0]; scratch = ctx()->workspace()->template data<T, Context>({dY.count()})[0];
if (B_broadcast_axes.empty()) { if (B_broadcast_axes.empty()) {
math::Mul( math::Mul(
B_ref.count(), B_ref.count(),
...@@ -136,7 +136,8 @@ void MulGradientOp<Context>::DoRunWithType() { ...@@ -136,7 +136,8 @@ void MulGradientOp<Context>::DoRunWithType() {
} }
} else { } else {
if (scratch == nullptr) { if (scratch == nullptr) {
scratch = ws()->template data<T, Context>({dY.count()})[0]; scratch =
ctx()->workspace()->template data<T, Context>({dY.count()})[0];
} }
if (A_broadcast_axes.empty()) { if (A_broadcast_axes.empty()) {
math::Mul( math::Mul(
......
...@@ -33,7 +33,8 @@ void PowGradientOp<Context>::DoRunWithType() { ...@@ -33,7 +33,8 @@ void PowGradientOp<Context>::DoRunWithType() {
dB->template mutable_data<T, Context>(), dB->template mutable_data<T, Context>(),
ctx()); ctx());
} else { } else {
scratch = ws()->template data<T, Context>({dY.count()})[0]; scratch =
ctx()->workspace()->template data<T, Context>({dY.count()})[0];
math::Log(A.count(), A.template data<T, Context>(), scratch, ctx()); math::Log(A.count(), A.template data<T, Context>(), scratch, ctx());
math::Mul( math::Mul(
A.ndim(), A.ndim(),
...@@ -53,13 +54,14 @@ void PowGradientOp<Context>::DoRunWithType() { ...@@ -53,13 +54,14 @@ void PowGradientOp<Context>::DoRunWithType() {
ctx()); ctx());
} else { } else {
if (A_broadcast_axes.empty()) { if (A_broadcast_axes.empty()) {
scratch = ws()->template data<T, Context>({dY.count()})[0]; scratch =
ctx()->workspace()->template data<T, Context>({dY.count()})[0];
math::Log(A.count(), A.template data<T, Context>(), scratch, ctx()); math::Log(A.count(), A.template data<T, Context>(), scratch, ctx());
math::Mul( math::Mul(
Y.count(), scratch, Y.template data<T, Context>(), scratch, ctx()); Y.count(), scratch, Y.template data<T, Context>(), scratch, ctx());
} else { } else {
auto scratches = auto scratches = ctx()->workspace()->template data<T, Context>(
ws()->template data<T, Context>({dY.count(), A.count()}); {dY.count(), A.count()});
scratch = scratches[0]; scratch = scratches[0];
math::Log( math::Log(
A.count(), A.template data<T, Context>(), scratches[1], ctx()); A.count(), A.template data<T, Context>(), scratches[1], ctx());
...@@ -127,7 +129,8 @@ void PowGradientOp<Context>::DoRunWithType() { ...@@ -127,7 +129,8 @@ void PowGradientOp<Context>::DoRunWithType() {
ctx()); ctx());
} else { } else {
if (scratch == nullptr) { if (scratch == nullptr) {
scratch = ws()->template data<T, Context>({dY.count()})[0]; scratch =
ctx()->workspace()->template data<T, Context>({dY.count()})[0];
} }
math::Div( math::Div(
Y.ndim(), Y.ndim(),
......
...@@ -56,9 +56,9 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -56,9 +56,9 @@ void BatchNormOp<Context>::TrainingImpl() {
// Compute affine transformation // Compute affine transformation
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
kernel::Affine(N_, C_, S_, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
kernel::Affine(N_ * S_, C_, 1, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx());
} }
} }
...@@ -91,9 +91,9 @@ void BatchNormOp<Context>::InferenceImpl() { ...@@ -91,9 +91,9 @@ void BatchNormOp<Context>::InferenceImpl() {
// Compute affine transformation // Compute affine transformation
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
kernel::Affine(N_, C_, S_, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
kernel::Affine(N_ * S_, C_, 1, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx());
} }
} }
...@@ -102,7 +102,7 @@ void BatchNormOp<Context>::RunOnDevice() { ...@@ -102,7 +102,7 @@ void BatchNormOp<Context>::RunOnDevice() {
DetermineBaseArguments(); DetermineBaseArguments();
// Get the recomputing flag // Get the recomputing flag
auto* flag = ws()->GetTensor("/share/flag/recomputing"); auto* flag = workspace()->GetTensor("/share/flag/recomputing");
is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0; is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0;
// Dispatch the training or inference impl // Dispatch the training or inference impl
......
...@@ -73,7 +73,7 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() { ...@@ -73,7 +73,7 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() {
DetermineBaseArguments(); DetermineBaseArguments();
// Get the recomputing flag // Get the recomputing flag
auto* flag = ws()->GetTensor("/share/flag/recomputing"); auto* flag = workspace()->GetTensor("/share/flag/recomputing");
is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0; is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0;
// Dispatch the training or inference impl // Dispatch the training or inference impl
......
...@@ -88,9 +88,9 @@ void SyncBatchNormOp<Context>::TrainingImpl() { ...@@ -88,9 +88,9 @@ void SyncBatchNormOp<Context>::TrainingImpl() {
// Compute affine transformation // Compute affine transformation
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
kernel::Affine(N_, C_, S_, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
kernel::Affine(N_ * S_, C_, 1, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx());
} }
} }
...@@ -99,7 +99,7 @@ void SyncBatchNormOp<Context>::RunOnDevice() { ...@@ -99,7 +99,7 @@ void SyncBatchNormOp<Context>::RunOnDevice() {
DetermineBaseArguments(); DetermineBaseArguments();
// Get the recomputing flag // Get the recomputing flag
auto* flag = ws()->GetTensor("/share/flag/recomputing"); auto* flag = workspace()->GetTensor("/share/flag/recomputing");
is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0; is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0;
// Dispatch the training or inference impl // Dispatch the training or inference impl
......
...@@ -11,6 +11,7 @@ template <typename T> ...@@ -11,6 +11,7 @@ template <typename T>
void CuDNNRecurrentOpBase<Context>::ResetDesc() { void CuDNNRecurrentOpBase<Context>::ResetDesc() {
input_dims_ = Input(0).dims(); input_dims_ = Input(0).dims();
seq_length_ = Input(0).dim(0); seq_length_ = Input(0).dim(0);
auto input_type = TypeMeta::Id<T>();
auto batch_size = Input(0).dim(1); auto batch_size = Input(0).dim(1);
auto x_dim = Input(0).dim(2); auto x_dim = Input(0).dim(2);
auto ndirections = bidirectional_ ? 2 : 1; auto ndirections = bidirectional_ ? 2 : 1;
...@@ -24,7 +25,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -24,7 +25,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
CUDNN_CHECK( CUDNN_CHECK(
cudnnDropoutGetStatesSize(ctx()->cudnn_handle(), &states_size_)); cudnnDropoutGetStatesSize(ctx()->cudnn_handle(), &states_size_));
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
auto* states_tensor = ws()->CreateTensor( auto* states_tensor = workspace()->CreateTensor(
"/share/cudnn/dropout:" + str::to(rng_seed_) + "/states"); "/share/cudnn/dropout:" + str::to(rng_seed_) + "/states");
if (states_tensor->count() > 0) { if (states_tensor->count() > 0) {
auto* states = states_tensor->template mutable_data<uint8_t, Context>(); auto* states = states_tensor->template mutable_data<uint8_t, Context>();
...@@ -53,6 +54,13 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -53,6 +54,13 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
} }
// Setup RNN // Setup RNN
if (input_type == TypeMeta::Id<float16>()) {
compute_type_ = CUDNN_DATA_FLOAT;
} else if (input_type == TypeMeta::Id<float>()) {
compute_type_ = CUDNN_DATA_FLOAT;
} else if (input_type == TypeMeta::Id<double>()) {
compute_type_ = CUDNN_DATA_DOUBLE;
}
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetRNNDescriptor_v6( CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
...@@ -64,7 +72,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -64,7 +72,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
rnn_direction_, rnn_direction_,
rnn_mode_, rnn_mode_,
CUDNN_RNN_ALGO_STANDARD, CUDNN_RNN_ALGO_STANDARD,
CuDNNType<T>::type)); compute_type_));
#else #else
CUDNN_CHECK(cudnnSetRNNDescriptor( CUDNN_CHECK(cudnnSetRNNDescriptor(
rnn_desc_, rnn_desc_,
...@@ -74,7 +82,25 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -74,7 +82,25 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
rnn_input_mode_, rnn_input_mode_,
rnn_direction_, rnn_direction_,
rnn_mode_, rnn_mode_,
CuDNNType<T>::type)); compute_type_));
#endif
// Setup TensorCore
#if CUDNN_VERSION_MIN(7, 0, 0)
if (enable_tensor_core_ > 0) {
cudnnMathType_t math_type;
if (input_type == TypeMeta::Id<float16>()) {
math_type = CUDNN_TENSOR_OP_MATH;
} else {
math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) {
math_type = CUDNN_FMA_MATH;
}
#endif
}
CUDNN_CHECK(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
}
#endif #endif
// Setup X and Y // Setup X and Y
...@@ -151,7 +177,8 @@ void CuDNNRecurrentOp<Context>::DoRunWithType() { ...@@ -151,7 +177,8 @@ void CuDNNRecurrentOp<Context>::DoRunWithType() {
return Output(i)->template mutable_data<T, Context>(); return Output(i)->template mutable_data<T, Context>();
}; };
auto* scratch = ws()->template data<Context>({workspace_size_})[0]; auto* scratch =
ctx()->workspace()->template data<Context>({workspace_size_})[0];
if (phase() == "TRAIN") { if (phase() == "TRAIN") {
CUDNN_CHECK(cudnnGetRNNTrainingReserveSize( CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(
...@@ -235,7 +262,8 @@ void CuDNNRecurrentGradientOp<Context>::DoRunWithType() { ...@@ -235,7 +262,8 @@ void CuDNNRecurrentGradientOp<Context>::DoRunWithType() {
return Output(i)->template mutable_data<T, Context>(); return Output(i)->template mutable_data<T, Context>();
}; };
auto* scratch = ws()->template data<Context>({workspace_size_})[0]; auto* scratch =
ctx()->workspace()->template data<Context>({workspace_size_})[0];
// Check the ReserveSpace // Check the ReserveSpace
CUDNN_CHECK(cudnnGetRNNTrainingReserveSize( CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(
......
...@@ -57,7 +57,8 @@ class CuDNNRecurrentOpBase : public Operator<Context> { ...@@ -57,7 +57,8 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
hidden_size_(OP_SINGLE_ARG(int64_t, "hidden_size", 0)), hidden_size_(OP_SINGLE_ARG(int64_t, "hidden_size", 0)),
bidirectional_(OP_SINGLE_ARG(int64_t, "bidirectional", 0)), bidirectional_(OP_SINGLE_ARG(int64_t, "bidirectional", 0)),
dropout_ratio_(OP_SINGLE_ARG(float, "dropout_ratio", 1.f)), dropout_ratio_(OP_SINGLE_ARG(float, "dropout_ratio", 1.f)),
rng_seed_(def.device_option().random_seed()) { rng_seed_(def.device_option().random_seed()),
enable_tensor_core_(TENSOR_CORE_AVAILABLE() ? 1 : 0) {
// Determine the rnn direction // Determine the rnn direction
rnn_direction_ = rnn_direction_ =
bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
...@@ -111,11 +112,13 @@ class CuDNNRecurrentOpBase : public Operator<Context> { ...@@ -111,11 +112,13 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
public: public:
float dropout_ratio_; float dropout_ratio_;
unsigned long long rng_seed_; unsigned long long rng_seed_;
int64_t enable_tensor_core_;
int64_t bidirectional_, states_initialized_; int64_t bidirectional_, states_initialized_;
int64_t seq_length_, hidden_size_, num_layers_; int64_t seq_length_, hidden_size_, num_layers_;
vec64_t input_dims_, output_dims_, hidden_dims_; vec64_t input_dims_, output_dims_, hidden_dims_;
size_t workspace_size_, reserve_size_, states_size_; size_t workspace_size_, reserve_size_, states_size_;
cudnnDataType_t compute_type_;
cudnnRNNMode_t rnn_mode_; cudnnRNNMode_t rnn_mode_;
cudnnRNNDescriptor_t rnn_desc_; cudnnRNNDescriptor_t rnn_desc_;
cudnnDirectionMode_t rnn_direction_; cudnnDirectionMode_t rnn_direction_;
......
...@@ -12,8 +12,9 @@ Tensor* UpdateOpBase<Context>::Slot(const string& name) { ...@@ -12,8 +12,9 @@ Tensor* UpdateOpBase<Context>::Slot(const string& name) {
template <class Context> template <class Context>
float UpdateOpBase<Context>::Parameter(const string& name) const { float UpdateOpBase<Context>::Parameter(const string& name) const {
auto* P = ws()->GetTensor("/share/hyper/" + handle() + "/" + name); return workspace()
return P->template mutable_data<float, CPUContext>()[0]; ->GetTensor("/share/hyper/" + handle() + "/" + name)
->template mutable_data<float, CPUContext>()[0];
} }
template <class Context> template <class Context>
...@@ -36,42 +37,25 @@ void UpdateOpBase<Context>::AdjustGradient(Tensor* dX, Tensor* X) { ...@@ -36,42 +37,25 @@ void UpdateOpBase<Context>::AdjustGradient(Tensor* dX, Tensor* X) {
} }
// Penalty // Penalty
auto weight_decay = Parameter("weight_decay"); auto weight_decay = Parameter("weight_decay");
if (weight_decay > 0.f) { if (weight_decay > 0.f && decay_mult_ > 0.f) {
if (X->template IsType<float16>()) { math::Axpy(
kernel::MixedPrecL2Penalty( X->count(),
X->count(), weight_decay * decay_mult_,
weight_decay * decay_mult_, X->template data<T, Context>(),
X->template data<float16, Context>(), dX->template mutable_data<T, Context>(),
dX->template mutable_data<float, Context>(), ctx());
ctx());
} else {
math::Axpy(
X->count(),
weight_decay * decay_mult_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
ctx());
}
} }
} }
template <class Context> template <class Context>
template <typename T> template <typename T>
void UpdateOpBase<Context>::ApplyUpdate(Tensor* dX, Tensor* X) { void UpdateOpBase<Context>::ApplyUpdate(Tensor* dX, Tensor* X) {
if (X->template IsType<float16>()) { math::Sub(
kernel::MixedPrecUpdate( X->count(),
X->count(), X->template data<T, Context>(),
dX->template data<float, Context>(), dX->template data<T, Context>(),
X->template mutable_data<float16, Context>(), X->template mutable_data<T, Context>(),
ctx()); ctx());
} else {
math::Sub(
X->count(),
X->template data<T, Context>(),
dX->template data<T, Context>(),
X->template mutable_data<T, Context>(),
ctx());
}
} }
template <class Context> template <class Context>
...@@ -90,15 +74,28 @@ void UpdateOpBase<Context>::RunOnDevice() { ...@@ -90,15 +74,28 @@ void UpdateOpBase<Context>::RunOnDevice() {
ComputeUpdate(&dX); ComputeUpdate(&dX);
ApplyUpdate<float>(&dX, X); ApplyUpdate<float>(&dX, X);
} else if (dX.template IsType<float16>()) { } else if (dX.template IsType<float16>()) {
auto* dX_cast = ws()->CreateTensor(dX.name() + "[float32]"); auto* X_master = workspace()->CreateTensor(X->name() + "[float32]");
kernel::Cast( auto* dX_copy = ctx()->workspace()->CreateTensor("/share/data");
if (X_master->count() != X->count()) {
math::Cast(
X->count(),
X->template data<float16, Context>(),
X_master->ReshapeLike(*X)->template mutable_data<float, Context>(),
ctx());
}
math::Cast(
dX.count(), dX.count(),
dX.template data<float16, Context>(), dX.template data<float16, Context>(),
dX_cast->ReshapeLike(dX)->template mutable_data<float, Context>(), dX_copy->ReshapeLike(dX)->template mutable_data<float, Context>(),
ctx());
AdjustGradient<float>(dX_copy, X_master);
ComputeUpdate(dX_copy);
ApplyUpdate<float>(dX_copy, X_master);
math::Cast(
X->count(),
X_master->template data<float, Context>(),
X->template mutable_data<float16, Context>(),
ctx()); ctx());
AdjustGradient<float>(dX_cast, X);
ComputeUpdate(dX_cast);
ApplyUpdate<float>(dX_cast, X);
} else { } else {
LOG(FATAL) << MessageForUnsupported( LOG(FATAL) << MessageForUnsupported(
types::to_string(dX.meta()), {"float16", "float32"}); types::to_string(dX.meta()), {"float16", "float32"});
......
...@@ -41,8 +41,19 @@ void CuDNNConv2dOp<Context>::SetConvDesc() { ...@@ -41,8 +41,19 @@ void CuDNNConv2dOp<Context>::SetConvDesc() {
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_)); CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
if (enable_tensor_core_) { if (enable_tensor_core_ > 0) {
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); cudnnMathType_t math_type;
if (input_type == TypeMeta::Id<float16>()) {
math_type = CUDNN_TENSOR_OP_MATH;
} else {
math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) {
math_type = CUDNN_FMA_MATH;
}
#endif
}
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, math_type));
} }
#endif #endif
} }
...@@ -148,8 +159,8 @@ void CuDNNConv2dOp<Context>::DoRunWithType() { ...@@ -148,8 +159,8 @@ void CuDNNConv2dOp<Context>::DoRunWithType() {
// Find the appropriate algorithm if necessary // Find the appropriate algorithm if necessary
if (exhaustive_search_) { if (exhaustive_search_) {
scratch = scratch = ctx()->workspace()->template data<Context>(
ws()->template data<Context>({CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0]; {CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0];
auto algo = algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() { auto algo = algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() {
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS;
...@@ -188,7 +199,7 @@ void CuDNNConv2dOp<Context>::DoRunWithType() { ...@@ -188,7 +199,7 @@ void CuDNNConv2dOp<Context>::DoRunWithType() {
// Alloc the memory for workspace data // Alloc the memory for workspace data
if (cudnn_ws_nbytes_ > 0) { if (cudnn_ws_nbytes_ > 0) {
scratch = ws()->template data<Context>({cudnn_ws_nbytes_})[0]; scratch = ctx()->workspace()->template data<Context>({cudnn_ws_nbytes_})[0];
} }
for (int g = 0; g < cudnn_group_; g++) { for (int g = 0; g < cudnn_group_; g++) {
...@@ -279,8 +290,19 @@ void CuDNNConv2dGradientOp<Context>::SetConvDesc() { ...@@ -279,8 +290,19 @@ void CuDNNConv2dGradientOp<Context>::SetConvDesc() {
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_)); CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
if (enable_tensor_core_) { if (enable_tensor_core_ > 0) {
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); cudnnMathType_t math_type;
if (input_type == TypeMeta::Id<float16>()) {
math_type = CUDNN_TENSOR_OP_MATH;
} else {
math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) {
math_type = CUDNN_FMA_MATH;
}
#endif
}
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, math_type));
} }
#endif #endif
} }
...@@ -418,8 +440,8 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() { ...@@ -418,8 +440,8 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
// Find the appropriate algorithm if necessary // Find the appropriate algorithm if necessary
if (dW->has_name() && exhaustive_search_filter_) { if (dW->has_name() && exhaustive_search_filter_) {
scratch = scratch = ctx()->workspace()->template data<Context>(
ws()->template data<Context>({CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0]; {CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0];
x = X.template data<T, Context>(); x = X.template data<T, Context>();
dw = dW->template mutable_data<T, Context>(); dw = dW->template mutable_data<T, Context>();
auto algo = auto algo =
...@@ -448,8 +470,8 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() { ...@@ -448,8 +470,8 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
} }
if (dX->has_name() && exhaustive_search_data_) { if (dX->has_name() && exhaustive_search_data_) {
scratch = scratch = ctx()->workspace()->template data<Context>(
ws()->template data<Context>({CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0]; {CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0];
w = W.template data<T, Context>(); w = W.template data<T, Context>();
dx = dX->template mutable_data<T, Context>(); dx = dX->template mutable_data<T, Context>();
auto algo = data_algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() { auto algo = data_algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() {
...@@ -500,7 +522,7 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() { ...@@ -500,7 +522,7 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
// Alloc the memory for workspace data // Alloc the memory for workspace data
if (cudnn_ws_nbytes_ > 0) { if (cudnn_ws_nbytes_ > 0) {
scratch = ws()->template data<Context>({cudnn_ws_nbytes_})[0]; scratch = ctx()->workspace()->template data<Context>({cudnn_ws_nbytes_})[0];
} }
if (Output(2)->has_name()) { if (Output(2)->has_name()) {
......
...@@ -41,8 +41,19 @@ void CuDNNConvTranspose2dOp<Context>::SetConvDesc() { ...@@ -41,8 +41,19 @@ void CuDNNConvTranspose2dOp<Context>::SetConvDesc() {
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_)); CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
if (enable_tensor_core_) { if (enable_tensor_core_ > 0) {
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); cudnnMathType_t math_type;
if (input_type == TypeMeta::Id<float16>()) {
math_type = CUDNN_TENSOR_OP_MATH;
} else {
math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) {
math_type = CUDNN_FMA_MATH;
}
#endif
}
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, math_type));
} }
#endif #endif
} }
...@@ -146,8 +157,8 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() { ...@@ -146,8 +157,8 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() {
// Find the appropriate algorithm if necessary // Find the appropriate algorithm if necessary
if (exhaustive_search_) { if (exhaustive_search_) {
scratch = scratch = ctx()->workspace()->template data<Context>(
ws()->template data<Context>({CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0]; {CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0];
auto algo = algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() { auto algo = algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() {
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS;
...@@ -186,7 +197,7 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() { ...@@ -186,7 +197,7 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() {
// Alloc the memory for workspace data // Alloc the memory for workspace data
if (cudnn_ws_nbytes_ > 0) { if (cudnn_ws_nbytes_ > 0) {
scratch = ws()->template data<Context>({cudnn_ws_nbytes_})[0]; scratch = ctx()->workspace()->template data<Context>({cudnn_ws_nbytes_})[0];
} }
for (int g = 0; g < cudnn_group_; g++) { for (int g = 0; g < cudnn_group_; g++) {
...@@ -277,8 +288,19 @@ void CuDNNConvTranspose2dGradientOp<Context>::SetConvDesc() { ...@@ -277,8 +288,19 @@ void CuDNNConvTranspose2dGradientOp<Context>::SetConvDesc() {
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_)); CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
if (enable_tensor_core_) { if (enable_tensor_core_ > 0) {
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); cudnnMathType_t math_type;
if (input_type == TypeMeta::Id<float16>()) {
math_type = CUDNN_TENSOR_OP_MATH;
} else {
math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) {
math_type = CUDNN_FMA_MATH;
}
#endif
}
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, math_type));
} }
#endif #endif
} }
...@@ -413,8 +435,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -413,8 +435,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
// Find the appropriate algorithm if necessary // Find the appropriate algorithm if necessary
if (dW->has_name() && exhaustive_search_filter_) { if (dW->has_name() && exhaustive_search_filter_) {
scratch = scratch = ctx()->workspace()->template data<Context>(
ws()->template data<Context>({CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0]; {CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0];
x = X.template data<T, Context>(); x = X.template data<T, Context>();
dw = dW->template mutable_data<T, Context>(); dw = dW->template mutable_data<T, Context>();
auto algo = auto algo =
...@@ -443,8 +465,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -443,8 +465,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
} }
if (dX->has_name() && exhaustive_search_data_) { if (dX->has_name() && exhaustive_search_data_) {
scratch = scratch = ctx()->workspace()->template data<Context>(
ws()->template data<Context>({CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0]; {CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0];
w = W.template data<T, Context>(); w = W.template data<T, Context>();
dx = dX->template mutable_data<T, Context>(); dx = dX->template mutable_data<T, Context>();
auto algo = data_algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() { auto algo = data_algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() {
...@@ -495,7 +517,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -495,7 +517,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
// Alloc the memory for workspace data // Alloc the memory for workspace data
if (cudnn_ws_nbytes_ > 0) { if (cudnn_ws_nbytes_ > 0) {
scratch = ws()->template data<Context>({cudnn_ws_nbytes_})[0]; scratch = ctx()->workspace()->template data<Context>({cudnn_ws_nbytes_})[0];
} }
if (Output(2)->has_name()) { if (Output(2)->has_name()) {
......
...@@ -79,10 +79,11 @@ class CuDNNConv2dOp final : public Conv2dOp<Context> { ...@@ -79,10 +79,11 @@ class CuDNNConv2dOp final : public Conv2dOp<Context> {
CuDNNCreateTensorDesc(&output2b_desc_); CuDNNCreateTensorDesc(&output2b_desc_);
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_)); CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
if (data_format() == "NCHW") if (data_format() == "NCHW") {
format_ = CUDNN_TENSOR_NCHW; format_ = CUDNN_TENSOR_NCHW;
else if (data_format() == "NHWC") } else if (data_format() == "NHWC") {
format_ = CUDNN_TENSOR_NHWC; format_ = CUDNN_TENSOR_NHWC;
}
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS; USE_CONVOLUTION_FUNCTIONS;
...@@ -140,10 +141,11 @@ class CuDNNConv2dGradientOp final : public Conv2dGradientOp<Context> { ...@@ -140,10 +141,11 @@ class CuDNNConv2dGradientOp final : public Conv2dGradientOp<Context> {
CuDNNCreateTensorDesc(&input2b_desc_); CuDNNCreateTensorDesc(&input2b_desc_);
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_)); CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
if (data_format() == "NCHW") if (data_format() == "NCHW") {
format_ = CUDNN_TENSOR_NCHW; format_ = CUDNN_TENSOR_NCHW;
else if (data_format() == "NHWC") } else if (data_format() == "NHWC") {
format_ = CUDNN_TENSOR_NHWC; format_ = CUDNN_TENSOR_NHWC;
}
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS; USE_CONVOLUTION_FUNCTIONS;
......
...@@ -77,7 +77,8 @@ template <typename T> ...@@ -77,7 +77,8 @@ template <typename T>
void ConvOpBase<Context>::Wx(const T* x, const T* w, T* y, bool skip) { void ConvOpBase<Context>::Wx(const T* x, const T* w, T* y, bool skip) {
auto* col = x; auto* col = x;
if (!is_1x1_) { if (!is_1x1_) {
auto* scratch = ws()->template data<T, Context>({col_dim_})[0]; auto* scratch =
ctx()->workspace()->template data<T, Context>({col_dim_})[0];
if (!skip) Im2Col(x, scratch); if (!skip) Im2Col(x, scratch);
col = scratch; col = scratch;
} }
...@@ -127,7 +128,9 @@ void ConvOpBase<Context>::Pb(const T* bias, T* y) { ...@@ -127,7 +128,9 @@ void ConvOpBase<Context>::Pb(const T* bias, T* y) {
template <class Context> template <class Context>
template <typename T> template <typename T>
void ConvOpBase<Context>::Dx(const T* dy, const T* w, T* dx) { void ConvOpBase<Context>::Dx(const T* dy, const T* w, T* dx) {
auto* col = is_1x1_ ? dx : ws()->template data<T, Context>({col_dim_})[0]; auto* col = is_1x1_
? dx
: ctx()->workspace()->template data<T, Context>({col_dim_})[0];
for (int g = 0; g < group_; g++) { for (int g = 0; g < group_; g++) {
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
math::Gemm( math::Gemm(
...@@ -165,7 +168,8 @@ template <typename T> ...@@ -165,7 +168,8 @@ template <typename T>
void ConvOpBase<Context>::Dw(const T* dy, const T* x, T* dw, bool accum) { void ConvOpBase<Context>::Dw(const T* dy, const T* x, T* dw, bool accum) {
auto* col = x; auto* col = x;
if (!is_1x1_) { if (!is_1x1_) {
auto* scratch = ws()->template data<T, Context>({col_dim_})[0]; auto* scratch =
ctx()->workspace()->template data<T, Context>({col_dim_})[0];
Im2Col(x, scratch); Im2Col(x, scratch);
col = scratch; col = scratch;
} }
......
...@@ -142,10 +142,11 @@ class CuDNNConvTranspose2dGradientOp final ...@@ -142,10 +142,11 @@ class CuDNNConvTranspose2dGradientOp final
CuDNNCreateTensorDesc(&input2b_desc_); CuDNNCreateTensorDesc(&input2b_desc_);
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_)); CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
if (data_format() == "NCHW") if (data_format() == "NCHW") {
format_ = CUDNN_TENSOR_NCHW; format_ = CUDNN_TENSOR_NCHW;
else if (data_format() == "NHWC") } else if (data_format() == "NHWC") {
format_ = CUDNN_TENSOR_NHWC; format_ = CUDNN_TENSOR_NHWC;
}
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS; USE_CONVOLUTION_FUNCTIONS;
......
#include "dragon/operators/vision/resize_op.h" #include "dragon/operators/vision/resize_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -175,7 +176,8 @@ template <typename T> ...@@ -175,7 +176,8 @@ template <typename T>
void ResizeGradientOp<Context>::DoRunWithTypeAndCast() { void ResizeGradientOp<Context>::DoRunWithTypeAndCast() {
auto* dy = Input(0).template data<T, Context>(); auto* dy = Input(0).template data<T, Context>();
auto* dx = Output(0)->template mutable_data<T, Context>(); auto* dx = Output(0)->template mutable_data<T, Context>();
auto* scratch = ws()->template data<float, Context>({Output(0)->count()})[0]; auto* scratch = ctx()->workspace()->template data<float, Context>(
{Output(0)->count()})[0];
if (mode_ == "NEAREST") { if (mode_ == "NEAREST") {
NearestImpl(dy, scratch); NearestImpl(dy, scratch);
} else if (mode_ == "LINEAR") { } else if (mode_ == "LINEAR") {
...@@ -183,7 +185,7 @@ void ResizeGradientOp<Context>::DoRunWithTypeAndCast() { ...@@ -183,7 +185,7 @@ void ResizeGradientOp<Context>::DoRunWithTypeAndCast() {
} else { } else {
LOG(FATAL) << "Unknown interpolation mode: " << mode_; LOG(FATAL) << "Unknown interpolation mode: " << mode_;
} }
kernel::Cast(Output(0)->count(), scratch, dx, ctx()); math::Cast(Output(0)->count(), scratch, dx, ctx());
} }
template <class Context> template <class Context>
......
...@@ -67,7 +67,8 @@ void RoiAlignGradientOp<Context>::DoRunWithTypeAndCast() { ...@@ -67,7 +67,8 @@ void RoiAlignGradientOp<Context>::DoRunWithTypeAndCast() {
auto &RoI = Input(0), &dY = Input(1); auto &RoI = Input(0), &dY = Input(1);
auto* dX = Output(0)->ReshapeLike(RESTORE_INPUT_SPEC(0)); auto* dX = Output(0)->ReshapeLike(RESTORE_INPUT_SPEC(0));
auto* scratch = ws()->template data<float, Context>({dX->count()})[0]; auto* scratch =
ctx()->workspace()->template data<float, Context>({dX->count()})[0];
math::Set(dX->count(), 0.f, scratch, ctx()); math::Set(dX->count(), 0.f, scratch, ctx());
kernel::RoiAlignGrad( kernel::RoiAlignGrad(
dX->dim(1), dX->dim(1),
...@@ -82,7 +83,7 @@ void RoiAlignGradientOp<Context>::DoRunWithTypeAndCast() { ...@@ -82,7 +83,7 @@ void RoiAlignGradientOp<Context>::DoRunWithTypeAndCast() {
RoI.template data<float, Context>(), RoI.template data<float, Context>(),
scratch, scratch,
ctx()); ctx());
kernel::Cast( math::Cast(
dX->count(), scratch, dX->template mutable_data<T, Context>(), ctx()); dX->count(), scratch, dX->template mutable_data<T, Context>(), ctx());
} }
......
...@@ -68,7 +68,8 @@ void RoiPoolGradientOp<Context>::DoRunWithTypeAndCast() { ...@@ -68,7 +68,8 @@ void RoiPoolGradientOp<Context>::DoRunWithTypeAndCast() {
auto &RoI = Input(0), &dY = Input(1); auto &RoI = Input(0), &dY = Input(1);
auto* dX = Output(0)->ReshapeLike(RESTORE_INPUT_SPEC(0)); auto* dX = Output(0)->ReshapeLike(RESTORE_INPUT_SPEC(0));
auto* scratch = ws()->template data<float, Context>({dX->count()})[0]; auto* scratch =
ctx()->workspace()->template data<float, Context>({dX->count()})[0];
math::Set(dX->count(), 0.f, scratch, ctx()); math::Set(dX->count(), 0.f, scratch, ctx());
kernel::RoiPoolGrad( kernel::RoiPoolGrad(
...@@ -85,7 +86,7 @@ void RoiPoolGradientOp<Context>::DoRunWithTypeAndCast() { ...@@ -85,7 +86,7 @@ void RoiPoolGradientOp<Context>::DoRunWithTypeAndCast() {
scratch, scratch,
ctx()); ctx());
kernel::Cast( math::Cast(
dX->count(), scratch, dX->template mutable_data<T, Context>(), ctx()); dX->count(), scratch, dX->template mutable_data<T, Context>(), ctx());
} }
......
...@@ -56,6 +56,7 @@ from dragon.core.ops import tensorbind_eager as _ ...@@ -56,6 +56,7 @@ from dragon.core.ops import tensorbind_eager as _
from dragon.core.ops import tensorbind_symbol as _ from dragon.core.ops import tensorbind_symbol as _
from dragon.core.ops.array_ops import broadcast_to from dragon.core.ops.array_ops import broadcast_to
from dragon.core.ops.array_ops import cast from dragon.core.ops.array_ops import cast
from dragon.core.ops.array_ops import channel_affine
from dragon.core.ops.array_ops import channel_normalize from dragon.core.ops.array_ops import channel_normalize
from dragon.core.ops.array_ops import channel_shuffle from dragon.core.ops.array_ops import channel_shuffle
from dragon.core.ops.array_ops import concat from dragon.core.ops.array_ops import concat
......
...@@ -26,7 +26,6 @@ from dragon.core.ops.array_ops import sum ...@@ -26,7 +26,6 @@ from dragon.core.ops.array_ops import sum
from dragon.core.ops.array_ops import top_k from dragon.core.ops.array_ops import top_k
from dragon.core.ops.math_ops import abs from dragon.core.ops.math_ops import abs
from dragon.core.ops.math_ops import add from dragon.core.ops.math_ops import add
from dragon.core.ops.math_ops import affine
from dragon.core.ops.math_ops import axpby from dragon.core.ops.math_ops import axpby
from dragon.core.ops.math_ops import ceil from dragon.core.ops.math_ops import ceil
from dragon.core.ops.math_ops import clip from dragon.core.ops.math_ops import clip
......
...@@ -62,7 +62,7 @@ def current_device(): ...@@ -62,7 +62,7 @@ def current_device():
return backend.cudaGetDevice() return backend.cudaGetDevice()
def enable_cudnn(enabled=True, benchmark=False): def enable_cudnn(enabled=True, benchmark=False, allow_tf32=False):
"""Enable backend to use the cuDNN library. """Enable backend to use the cuDNN library.
Parameters Parameters
...@@ -71,9 +71,11 @@ def enable_cudnn(enabled=True, benchmark=False): ...@@ -71,9 +71,11 @@ def enable_cudnn(enabled=True, benchmark=False):
Use cuDNN library or not. Use cuDNN library or not.
benchmark : bool, optional, default=False benchmark : bool, optional, default=False
Select algorithms according to the benchmark or not. Select algorithms according to the benchmark or not.
allow_tf32 : bool, optional, default=False
Allow TF32 Tensor core operation or not.
""" """
return backend.cudaEnableDNN(enabled, benchmark) return backend.cudaEnableDNN(enabled, benchmark, allow_tf32)
def get_device_capability(device_index=None): def get_device_capability(device_index=None):
......
...@@ -14,6 +14,8 @@ from __future__ import absolute_import ...@@ -14,6 +14,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import atexit
from dragon import backend as _b from dragon import backend as _b
from dragon.core.util import nest from dragon.core.util import nest
from dragon.core.util import six from dragon.core.util import six
...@@ -278,8 +280,10 @@ def _maybe_initialize(): ...@@ -278,8 +280,10 @@ def _maybe_initialize():
class _MPIContext(object): class _MPIContext(object):
"""Context to finalize mpi under destruction.""" """Context to finalize mpi under destruction."""
def __del__(self): def __init__(self):
_b.mpiFinalize() # Register a callback to finalize MPI
# on program exit.
atexit.register(lambda: _b.mpiFinalize())
_GLOBAL_MPI_CONTEXT = None _GLOBAL_MPI_CONTEXT = None
......
...@@ -204,6 +204,46 @@ def cast(inputs, dtype, **kwargs): ...@@ -204,6 +204,46 @@ def cast(inputs, dtype, **kwargs):
return op_lib.blend(**args) return op_lib.blend(**args)
@OpSchema.num_inputs(2, 3)
def channel_affine(inputs, axis=1, num_axes=1, **kwargs):
r"""Apply affine transformation along the channels.
.. math:: \text{out} = \text{weight} * \text{input} + \text{bias}
The range of channels to transform is given by:
.. math:: [\text{axis}, \text{axis} + \text{num\_axes})
Set ``axis`` to specific the start axis.
Set ``num_axes`` to -1 will transform all remained axes.
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input, weight and optional bias tensor.
axis : int, optional, default=1
The start axis, can be negative.
num_axes : int, optional, default=1
The number of axes to transform.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = array_ops_lib.ChannelAffine
if context.executing_eagerly():
return op_lib \
.instantiate(axis=axis, num_axes=num_axes) \
.apply(inputs, inplace=inplace)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
@ArgHelper.repeated_desc('perm') @ArgHelper.repeated_desc('perm')
def channel_normalize( def channel_normalize(
......
...@@ -57,6 +57,26 @@ class Cast(Operator): ...@@ -57,6 +57,26 @@ class Cast(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class ChannelAffine(Operator):
def __init__(self, key, dev, **kwargs):
super(ChannelAffine, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
self.num_axes = kwargs.get('num_axes', 1)
def attributes(self):
return {
'op_type': 'ChannelAffine',
'arguments': {
'axis': self.axis,
'num_axes': self.num_axes,
}
}
def forward(self, inputs, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs)
class ChannelNormalize(Operator): class ChannelNormalize(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ChannelNormalize, self).__init__(key, dev, **kwargs) super(ChannelNormalize, self).__init__(key, dev, **kwargs)
......
...@@ -88,45 +88,6 @@ def add(inputs, **kwargs): ...@@ -88,45 +88,6 @@ def add(inputs, **kwargs):
return op_lib.blend('Add', **args) return op_lib.blend('Add', **args)
@OpSchema.num_inputs(2, 3)
def affine(inputs, axis=1, num_axes=1, **kwargs):
r"""Compute the affine transformation along the given axes.
.. math:: y = Wx + b
The range of axes is defined as:
.. math:: [\text{Axis}, \text{Axis} + \text{NumAxes})
Set ``axis`` to specific the start axis.
Set ``num_axes`` to -1 will scale all remained axes.
Parameters
----------
inputs : Sequence[dragon.Tensor]
The tensor **x**, **W** and **b**.
axis : int, optional, default=1
The start axis, can be negative.
num_axes : int, optional, default=1
The number of axes to compute.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
op_lib = math_ops_lib.Affine
if context.executing_eagerly():
return op_lib \
.instantiate(axis=axis, num_axes=num_axes) \
.apply(inputs)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs): def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
r"""Compute the element-wise addition from input to output. r"""Compute the element-wise addition from input to output.
......
...@@ -17,25 +17,6 @@ from __future__ import print_function ...@@ -17,25 +17,6 @@ from __future__ import print_function
from dragon.core.framework.ops import Operator from dragon.core.framework.ops import Operator
class Affine(Operator):
def __init__(self, key, dev, **kwargs):
super(Affine, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
self.num_axes = kwargs.get('num_axes', 1)
def attributes(self):
return {
'op_type': 'Affine',
'arguments': {
'axis': self.axis,
'num_axes': self.num_axes,
}
}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()])
class Axpby(Operator): class Axpby(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Axpby, self).__init__(key, dev, **kwargs) super(Axpby, self).__init__(key, dev, **kwargs)
......
...@@ -51,6 +51,21 @@ def cast_exporter(op_def, shape_dict, ws): ...@@ -51,6 +51,21 @@ def cast_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('ChannelAffine')
def channel_affine_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelAffine')
for arg in op_def.arg:
if arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
elif arg.name == 'num_axes':
helper.add_attribute(node, 'num_axes', arg.i)
# Weights and biases
const_tensors = [helper.from_tensor(e, ws) for e in op_def.input[1:]]
return node, const_tensors
@exporter.register('ChannelNormalize') @exporter.register('ChannelNormalize')
def channel_normalize_exporter(op_def, shape_dict, ws): def channel_normalize_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = exporter.translate(**locals())
......
...@@ -31,21 +31,6 @@ def add_exporter(op_def, shape_dict, ws): ...@@ -31,21 +31,6 @@ def add_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Affine')
def affine_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'Affine')
for arg in op_def.arg:
if arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
elif arg.name == 'num_axes':
helper.add_attribute(node, 'num_axes', arg.i)
# Weights and biases
const_tensors = [helper.from_tensor(e, ws) for e in op_def.input[1:]]
return node, const_tensors
@exporter.register('Div') @exporter.register('Div')
def div_exporter(op_def, shape_dict, ws): def div_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = exporter.translate(**locals())
......
...@@ -4,11 +4,46 @@ ...@@ -4,11 +4,46 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_select.cuh> #include <cub/device/device_select.cuh>
#include <cub/iterator/counting_input_iterator.cuh> #include <cub/iterator/counting_input_iterator.cuh>
#include "dragon/utils/device/common_cuda.h" #include "dragon/utils/device/common_cuda.h"
namespace cub {
struct SumHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hadd(a, b);
#else
return __float2half(__half2float(a) + __half2float(b));
#endif
}
};
struct MinHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hlt(a, b) ? a : b;
#else
return __half2float(a) < __half2float(b) ? a : b;
#endif
}
};
struct MaxHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hgt(a, b) ? a : b;
#else
return __half2float(a) > __half2float(b) ? a : b;
#endif
}
};
} // namespace cub
namespace dragon { namespace dragon {
template <typename T> template <typename T>
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <thrust/sort.h> #include <thrust/sort.h>
......
...@@ -39,11 +39,18 @@ __global__ void _Axpby<half>( ...@@ -39,11 +39,18 @@ __global__ void _Axpby<half>(
const half* x, const half* x,
const half beta, const half beta,
half* y) { half* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[i] = __hadd(__hmul(alpha, x[i]), __hmul(beta, y[i])); CUDA_1D_KERNEL_LOOP(i, n) {
#endif y[i] = __hfma(alpha, x[i], __hmul(beta, y[i]));
}
#else
const float alpha_val = __half2float(alpha);
const float beta_val = __half2float(beta);
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = __float2half(
fmaf(alpha_val, __half2float(x[i]), beta_val * __half2float(y[i])));
} }
#endif
} }
template <> template <>
...@@ -53,10 +60,44 @@ __global__ void _Axpby<half2>( ...@@ -53,10 +60,44 @@ __global__ void _Axpby<half2>(
const half2* x, const half2* x,
const half2 beta, const half2 beta,
half2* y) { half2* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[i] = __hadd2(__hmul2(alpha, x[i]), __hmul2(beta, y[i])); CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = __hfma2(alpha, x[i], __hmul2(beta, y[i]));
}
#else
const float2 alpha_val = __half22float2(alpha);
const float2 beta_val = __half22float2(beta);
CUDA_1D_KERNEL_LOOP(i, n) {
const float2 v1 = __half22float2(x[i]);
const float2 v2 = __half22float2(y[i]);
y[i] = __floats2half2_rn(
fmaf(alpha_val.x, v1.x, beta_val.x * v2.x),
fmaf(alpha_val.y, v1.y, beta_val.y * v2.y));
}
#endif #endif
}
template <>
__global__ void _Axpby<float>(
const int n,
const float alpha,
const float* x,
const float beta,
float* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = fmaf(alpha, x[i], beta * y[i]);
}
}
template <>
__global__ void _Axpby<double>(
const int n,
const double alpha,
const double* x,
const double beta,
double* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = fma(alpha, x[i], beta * y[i]);
} }
} }
...@@ -68,8 +109,7 @@ __global__ void _Axpby<half2>( ...@@ -68,8 +109,7 @@ __global__ void _Axpby<half2>(
template <> \ template <> \
DRAGON_API void Scale<T, CUDAContext>( \ DRAGON_API void Scale<T, CUDAContext>( \
const int n, const float alpha, const T* x, T* y, CUDAContext* ctx) { \ const int n, const float alpha, const T* x, T* y, CUDAContext* ctx) { \
T _alpha_ = (T)alpha; \ if (alpha != 1.f) { \
if (_alpha_ == T(1)) { \
if (x != y) { \ if (x != y) { \
cudaMemcpyAsync( \ cudaMemcpyAsync( \
y, \ y, \
...@@ -81,7 +121,7 @@ __global__ void _Axpby<half2>( ...@@ -81,7 +121,7 @@ __global__ void _Axpby<half2>(
return; \ return; \
} \ } \
_Scale<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Scale<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, _alpha_, x, y); \ n, static_cast<T>(alpha), x, y); \
} }
DEFINE_SCALE_FUNC(int8_t); DEFINE_SCALE_FUNC(int8_t);
...@@ -99,10 +139,10 @@ DEFINE_SCALE_FUNC(int64_t); ...@@ -99,10 +139,10 @@ DEFINE_SCALE_FUNC(int64_t);
y, x, sizeof(T) * n, cudaMemcpyDeviceToDevice, ctx->cuda_stream())); \ y, x, sizeof(T) * n, cudaMemcpyDeviceToDevice, ctx->cuda_stream())); \
} \ } \
if (alpha != 1.f) { \ if (alpha != 1.f) { \
T scale = (T)alpha; \ T alpha_val = static_cast<T>(alpha); \
CUBLAS_CHECK(cublasSetPointerMode( \ CUBLAS_CHECK(cublasSetPointerMode( \
ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func(ctx->cublas_handle(), n, &scale, y, 1)); \ CUBLAS_CHECK(cublas_func(ctx->cublas_handle(), n, &alpha_val, y, 1)); \
} \ } \
} }
...@@ -165,7 +205,7 @@ DEFINE_COPY_FUNC(double); ...@@ -165,7 +205,7 @@ DEFINE_COPY_FUNC(double);
DRAGON_API void Axpy<T, CUDAContext>( \ DRAGON_API void Axpy<T, CUDAContext>( \
const int n, const float alpha, const T* x, T* y, CUDAContext* ctx) { \ const int n, const float alpha, const T* x, T* y, CUDAContext* ctx) { \
_Axpy<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Axpy<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, (T)alpha, x, y); \ n, static_cast<T>(alpha), x, y); \
} }
DEFINE_AXPY_FUNC(int8_t); DEFINE_AXPY_FUNC(int8_t);
...@@ -178,10 +218,11 @@ DEFINE_AXPY_FUNC(int64_t); ...@@ -178,10 +218,11 @@ DEFINE_AXPY_FUNC(int64_t);
template <> \ template <> \
DRAGON_API void Axpy<T, CUDAContext>( \ DRAGON_API void Axpy<T, CUDAContext>( \
const int n, const float alpha, const T* x, T* y, CUDAContext* ctx) { \ const int n, const float alpha, const T* x, T* y, CUDAContext* ctx) { \
T scale = (T)alpha; \ T alpha_val = static_cast<T>(alpha); \
CUBLAS_CHECK( \ CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func(ctx->cublas_handle(), n, &scale, x, 1, y, 1)); \ CUBLAS_CHECK( \
cublas_func(ctx->cublas_handle(), n, &alpha_val, x, 1, y, 1)); \
} }
template <> template <>
...@@ -221,7 +262,7 @@ DEFINE_AXPY_FUNC(double, cublasDaxpy); ...@@ -221,7 +262,7 @@ DEFINE_AXPY_FUNC(double, cublasDaxpy);
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Axpby<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Axpby<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, (T)alpha, x, (T)beta, y); \ n, static_cast<T>(alpha), x, static_cast<T>(beta), y); \
} }
template <> template <>
...@@ -268,11 +309,11 @@ DEFINE_AXPBY_FUNC(double); ...@@ -268,11 +309,11 @@ DEFINE_AXPBY_FUNC(double);
template <> \ template <> \
DRAGON_API T Dot<T, CUDAContext>( \ DRAGON_API T Dot<T, CUDAContext>( \
const int n, const T* a, const T* b, CUDAContext* ctx) { \ const int n, const T* a, const T* b, CUDAContext* ctx) { \
T y_host; \ T ret; \
CUBLAS_CHECK( \ CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func(ctx->cublas_handle(), n, a, 1, b, 1, &y_host)); \ CUBLAS_CHECK(cublas_func(ctx->cublas_handle(), n, a, 1, b, 1, &ret)); \
return y_host; \ return ret; \
} }
template <> template <>
...@@ -313,11 +354,11 @@ DEFINE_DOT_FUNC(double, cublasDdot); ...@@ -313,11 +354,11 @@ DEFINE_DOT_FUNC(double, cublasDdot);
template <> \ template <> \
DRAGON_API T ASum<T, CUDAContext>( \ DRAGON_API T ASum<T, CUDAContext>( \
const int n, const T* x, CUDAContext* ctx) { \ const int n, const T* x, CUDAContext* ctx) { \
T y_host; \ T ret; \
CUBLAS_CHECK( \ CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
cublas_func(ctx->cublas_handle(), n, x, 1, &y_host); \ cublas_func(ctx->cublas_handle(), n, x, 1, &ret); \
return y_host; \ return ret; \
} }
DEFINE_ASUM_FUNC(float, cublasSasum); DEFINE_ASUM_FUNC(float, cublasSasum);
...@@ -409,8 +450,8 @@ DRAGON_API void Gemv<float16, CUDAContext>( ...@@ -409,8 +450,8 @@ DRAGON_API void Gemv<float16, CUDAContext>(
LDC)); LDC));
#endif #endif
} else if (math_type == "float16") { } else if (math_type == "float16") {
const half alpha_half = cast::to<half>(alpha); const half alpha_val = cast::to<half>(alpha);
const half beta_half = cast::to<half>(beta); const half beta_val = cast::to<half>(beta);
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) { if (TENSOR_CORE_AVAILABLE()) {
// GEMV + MATH16 + TENSOR-CORE // GEMV + MATH16 + TENSOR-CORE
...@@ -421,14 +462,14 @@ DRAGON_API void Gemv<float16, CUDAContext>( ...@@ -421,14 +462,14 @@ DRAGON_API void Gemv<float16, CUDAContext>(
m, m,
1, 1,
k, k,
&alpha_half, &alpha_val,
A, A,
CUDA_R_16F, CUDA_R_16F,
LDA, LDA,
x, x,
CUDA_R_16F, CUDA_R_16F,
k, k,
&beta_half, &beta_val,
y, y,
CUDA_R_16F, CUDA_R_16F,
LDC, LDC,
...@@ -443,12 +484,12 @@ DRAGON_API void Gemv<float16, CUDAContext>( ...@@ -443,12 +484,12 @@ DRAGON_API void Gemv<float16, CUDAContext>(
m, m,
1, 1,
k, k,
&alpha_half, &alpha_val,
reinterpret_cast<const half*>(A), reinterpret_cast<const half*>(A),
LDA, LDA,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
k, k,
&beta_half, &beta_val,
reinterpret_cast<half*>(y), reinterpret_cast<half*>(y),
LDC)); LDC));
} }
...@@ -460,12 +501,12 @@ DRAGON_API void Gemv<float16, CUDAContext>( ...@@ -460,12 +501,12 @@ DRAGON_API void Gemv<float16, CUDAContext>(
m, m,
1, 1,
k, k,
&alpha_half, &alpha_val,
reinterpret_cast<const half*>(A), reinterpret_cast<const half*>(A),
LDA, LDA,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
k, k,
&beta_half, &beta_val,
reinterpret_cast<half*>(y), reinterpret_cast<half*>(y),
LDC)); LDC));
#endif #endif
...@@ -506,8 +547,8 @@ DRAGON_API void Gemv<double, CUDAContext>( ...@@ -506,8 +547,8 @@ DRAGON_API void Gemv<double, CUDAContext>(
CUDAContext* ctx, CUDAContext* ctx,
const string math_type) { const string math_type) {
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N; auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto alpha64 = static_cast<double>(alpha); const auto alpha_val = static_cast<double>(alpha);
const auto beta64 = static_cast<double>(beta); const auto beta_val = static_cast<double>(beta);
CUBLAS_CHECK( CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasDgemv( CUBLAS_CHECK(cublasDgemv(
...@@ -515,12 +556,12 @@ DRAGON_API void Gemv<double, CUDAContext>( ...@@ -515,12 +556,12 @@ DRAGON_API void Gemv<double, CUDAContext>(
cuTransA, cuTransA,
N, N,
M, M,
&alpha64, &alpha_val,
A, A,
N, N,
x, x,
1, 1,
&beta64, &beta_val,
y, y,
1)); 1));
} }
...@@ -611,8 +652,8 @@ DRAGON_API void Gemm<float16, CUDAContext>( ...@@ -611,8 +652,8 @@ DRAGON_API void Gemm<float16, CUDAContext>(
N)); N));
#endif #endif
} else if (math_type == "float16") { } else if (math_type == "float16") {
const half alpha_half = cast::to<half>(alpha); const half alpha_val = cast::to<half>(alpha);
const half beta_half = cast::to<half>(beta); const half beta_val = cast::to<half>(beta);
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) { if (TENSOR_CORE_AVAILABLE()) {
// GEMM + MATH16 + TENSOR-CORE // GEMM + MATH16 + TENSOR-CORE
...@@ -623,14 +664,14 @@ DRAGON_API void Gemm<float16, CUDAContext>( ...@@ -623,14 +664,14 @@ DRAGON_API void Gemm<float16, CUDAContext>(
N, N,
M, M,
K, K,
&alpha_half, &alpha_val,
B, B,
CUDA_R_16F, CUDA_R_16F,
ldb, ldb,
A, A,
CUDA_R_16F, CUDA_R_16F,
lda, lda,
&beta_half, &beta_val,
C, C,
CUDA_R_16F, CUDA_R_16F,
N, N,
...@@ -645,12 +686,12 @@ DRAGON_API void Gemm<float16, CUDAContext>( ...@@ -645,12 +686,12 @@ DRAGON_API void Gemm<float16, CUDAContext>(
N, N,
M, M,
K, K,
&alpha_half, &alpha_val,
reinterpret_cast<const half*>(B), reinterpret_cast<const half*>(B),
ldb, ldb,
reinterpret_cast<const half*>(A), reinterpret_cast<const half*>(A),
lda, lda,
&beta_half, &beta_val,
reinterpret_cast<half*>(C), reinterpret_cast<half*>(C),
N)); N));
} }
...@@ -662,12 +703,12 @@ DRAGON_API void Gemm<float16, CUDAContext>( ...@@ -662,12 +703,12 @@ DRAGON_API void Gemm<float16, CUDAContext>(
N, N,
M, M,
K, K,
&alpha_half, &alpha_val,
reinterpret_cast<const half*>(B), reinterpret_cast<const half*>(B),
ldb, ldb,
reinterpret_cast<const half*>(A), reinterpret_cast<const half*>(A),
lda, lda,
&beta_half, &beta_val,
reinterpret_cast<half*>(C), reinterpret_cast<half*>(C),
N)); N));
#endif #endif
...@@ -731,8 +772,8 @@ DRAGON_API void Gemm<double, CUDAContext>( ...@@ -731,8 +772,8 @@ DRAGON_API void Gemm<double, CUDAContext>(
int ldb = (TransB == CblasNoTrans) ? N : K; int ldb = (TransB == CblasNoTrans) ? N : K;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
const auto alpha64 = static_cast<double>(alpha); const auto alpha_val = static_cast<double>(alpha);
const auto beta64 = static_cast<double>(beta); const auto beta_val = static_cast<double>(beta);
CUBLAS_CHECK( CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasDgemm( CUBLAS_CHECK(cublasDgemm(
...@@ -742,12 +783,12 @@ DRAGON_API void Gemm<double, CUDAContext>( ...@@ -742,12 +783,12 @@ DRAGON_API void Gemm<double, CUDAContext>(
N, N,
M, M,
K, K,
&alpha64, &alpha_val,
B, B,
ldb, ldb,
A, A,
lda, lda,
&beta64, &beta_val,
C, C,
N)); N));
} }
......
#include "dragon/utils/cast.h" #include "dragon/utils/cast.h"
#include "dragon/utils/math/elementwise.h"
#include "dragon/utils/omp_utils.h" #include "dragon/utils/omp_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
namespace kernel { namespace math {
namespace { namespace {
template <typename Tx, typename Ty> template <typename Tx, typename Ty>
void _Cast(const int count, const Tx* x, Ty* y) { void _Cast(const int n, const Tx* x, Ty* y) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count)) #pragma omp parallel for num_threads(OMP_THREADS(n))
#endif #endif
for (int i = 0; i < count; ++i) { for (int i = 0; i < n; ++i) {
y[i] = cast::to<Ty>(x[i]); y[i] = cast::to<Ty>(x[i]);
} }
} }
...@@ -22,23 +22,23 @@ void _Cast(const int count, const Tx* x, Ty* y) { ...@@ -22,23 +22,23 @@ void _Cast(const int count, const Tx* x, Ty* y) {
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_GENERIC_KERNEL_LAUNCHER(Tx, Ty) \ #define DEFINE_GENERIC_KERNEL_LAUNCHER(Tx, Ty) \
template <> \ template <> \
void Cast<Tx, Ty, CPUContext>( \ void Cast<Tx, Ty, CPUContext>( \
const int count, const Tx* x, Ty* y, CPUContext* ctx) { \ const int n, const Tx* x, Ty* y, CPUContext* ctx) { \
_Cast(count, x, y); \ _Cast(n, x, y); \
} }
#define DEFINE_FP16_KERNEL_LAUNCHER(T) \ #define DEFINE_FP16_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void Cast<float16, T, CPUContext>( \ void Cast<float16, T, CPUContext>( \
const int count, const float16* x, T* y, CPUContext* ctx) { \ const int n, const float16* x, T* y, CPUContext* ctx) { \
LOG(FATAL) << "Not Implemented: float16 -> " \ LOG(FATAL) << "Not Implemented: float16 -> " \
<< types::to_string(TypeMeta::Make<T>()); \ << types::to_string(TypeMeta::Make<T>()); \
} \ } \
template <> \ template <> \
void Cast<T, float16, CPUContext>( \ void Cast<T, float16, CPUContext>( \
const int count, const T* x, float16* y, CPUContext* ctx) { \ const int n, const T* x, float16* y, CPUContext* ctx) { \
LOG(FATAL) << "Not Implemented: " << types::to_string(TypeMeta::Make<T>()) \ LOG(FATAL) << "Not Implemented: " << types::to_string(TypeMeta::Make<T>()) \
<< " -> float16"; \ << " -> float16"; \
} }
...@@ -75,6 +75,6 @@ DEFINE_FP16_KERNEL_LAUNCHER(double); ...@@ -75,6 +75,6 @@ DEFINE_FP16_KERNEL_LAUNCHER(double);
#undef DEFINE_GENERIC_KERNEL_LAUNCHER #undef DEFINE_GENERIC_KERNEL_LAUNCHER
#undef DEFINE_FP16_KERNEL_LAUNCHER #undef DEFINE_FP16_KERNEL_LAUNCHER
} // namespace kernel } // namespace math
} // namespace dragon } // namespace dragon
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/math/elementwise.h"
namespace dragon { namespace dragon {
namespace kernel { namespace math {
namespace { namespace {
...@@ -45,40 +45,39 @@ __global__ void _Cast<half, half>(const int nthreads, const half* x, half* y) { ...@@ -45,40 +45,39 @@ __global__ void _Cast<half, half>(const int nthreads, const half* x, half* y) {
template <> template <>
void Cast<float16, float, CUDAContext>( void Cast<float16, float, CUDAContext>(
const int count, const int n,
const float16* x, const float16* x,
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Cast<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, reinterpret_cast<const half*>(x), y); n, reinterpret_cast<const half*>(x), y);
} }
template <> template <>
void Cast<float, float16, CUDAContext>( void Cast<float, float16, CUDAContext>(
const int count, const int n,
const float* x, const float* x,
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Cast<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, x, reinterpret_cast<half*>(y)); n, x, reinterpret_cast<half*>(y));
} }
template <> template <>
void Cast<float16, float16, CUDAContext>( void Cast<float16, float16, CUDAContext>(
const int count, const int n,
const float16* x, const float16* x,
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Cast<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, reinterpret_cast<const half*>(x), reinterpret_cast<half*>(y)); n, reinterpret_cast<const half*>(x), reinterpret_cast<half*>(y));
} }
#define DEFINE_GENERIC_KERNEL_LAUNCHER(Tx, Ty) \ #define DEFINE_GENERIC_KERNEL_LAUNCHER(Tx, Ty) \
template <> \ template <> \
void Cast<Tx, Ty, CUDAContext>( \ void Cast<Tx, Ty, CUDAContext>( \
const int count, const Tx* x, Ty* y, CUDAContext* ctx) { \ const int n, const Tx* x, Ty* y, CUDAContext* ctx) { \
_Cast<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(n, x, y); \
count, x, y); \
} }
#define DEFINE_KERNEL_LAUNCHER(Tx) \ #define DEFINE_KERNEL_LAUNCHER(Tx) \
...@@ -93,13 +92,13 @@ void Cast<float16, float16, CUDAContext>( ...@@ -93,13 +92,13 @@ void Cast<float16, float16, CUDAContext>(
#define DEFINE_FP16_KERNEL_LAUNCHER(T) \ #define DEFINE_FP16_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void Cast<float16, T, CUDAContext>( \ void Cast<float16, T, CUDAContext>( \
const int count, const float16* x, T* y, CUDAContext* ctx) { \ const int n, const float16* x, T* y, CUDAContext* ctx) { \
LOG(FATAL) << "Not Implemented: float16 -> " \ LOG(FATAL) << "Not Implemented: float16 -> " \
<< types::to_string(TypeMeta::Make<T>()); \ << types::to_string(TypeMeta::Make<T>()); \
} \ } \
template <> \ template <> \
void Cast<T, float16, CUDAContext>( \ void Cast<T, float16, CUDAContext>( \
const int count, const T* x, float16* y, CUDAContext* ctx) { \ const int n, const T* x, float16* y, CUDAContext* ctx) { \
LOG(FATAL) << "Not Implemented: " << types::to_string(TypeMeta::Make<T>()) \ LOG(FATAL) << "Not Implemented: " << types::to_string(TypeMeta::Make<T>()) \
<< " -> float16"; \ << " -> float16"; \
} }
...@@ -123,7 +122,7 @@ DEFINE_FP16_KERNEL_LAUNCHER(double); ...@@ -123,7 +122,7 @@ DEFINE_FP16_KERNEL_LAUNCHER(double);
#undef DEFINE_GENERIC_KERNEL_LAUNCHER #undef DEFINE_GENERIC_KERNEL_LAUNCHER
#undef DEFINE_FP16_KERNEL_LAUNCHER #undef DEFINE_FP16_KERNEL_LAUNCHER
} // namespace kernel } // namespace math
} // namespace dragon } // namespace dragon
......
...@@ -60,6 +60,9 @@ DRAGON_API void Rsqrt(const int n, const T* x, T* y, Context* ctx); ...@@ -60,6 +60,9 @@ DRAGON_API void Rsqrt(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void Set(const int n, const T value, T* y, Context* ctx); DRAGON_API void Set(const int n, const T value, T* y, Context* ctx);
template <typename Tx, typename Ty, class Context>
DRAGON_API void Cast(const int n, const Tx* x, Ty* y, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void Sign(const int n, const T* x, T* y, Context* ctx); DRAGON_API void Sign(const int n, const T* x, T* y, Context* ctx);
......
...@@ -9,17 +9,25 @@ namespace math { ...@@ -9,17 +9,25 @@ namespace math {
namespace { namespace {
#define DEFINE_GLOBAL_REDUCE_FUNC(name, expr) \
template <typename T> \
void _GlobalReduce##name(const int n, const float scale, const T* x, T* y) { \
*y = ConstEigenVectorMap<T>(x, n).expr(); \
if (scale != 1.f) y[0] *= T(scale); \
}
DEFINE_GLOBAL_REDUCE_FUNC(Max, maxCoeff);
DEFINE_GLOBAL_REDUCE_FUNC(Min, minCoeff);
DEFINE_GLOBAL_REDUCE_FUNC(Sum, sum);
#undef DEFINE_GLOBAL_REDUCE_FUNC
#define DEFINE_ROWWISE_REDUCE_FUNC(name, expr) \ #define DEFINE_ROWWISE_REDUCE_FUNC(name, expr) \
template <typename T> \ template <typename T> \
void _RowwiseReduce##name( \ void _RowwiseReduce##name( \
const int rows, const int cols, const T* scale, const T* x, T* y) { \ const int rows, const int cols, const float scale, const T* x, T* y) { \
if (scale != nullptr) { \ EigenVectorMap<T>(y, cols) = \
EigenVectorMap<T>(y, cols) = \ ConstEigenMatrixMap<T>(x, cols, rows).rowwise().expr(); \
ConstEigenMatrixMap<T>(x, cols, rows).rowwise().expr() * (*scale); \ if (scale != 1.f) EigenVectorMap<T>(y, cols) *= T(scale); \
} else { \
EigenVectorMap<T>(y, cols) = \
ConstEigenMatrixMap<T>(x, cols, rows).rowwise().expr(); \
} \
} }
DEFINE_ROWWISE_REDUCE_FUNC(Max, maxCoeff); DEFINE_ROWWISE_REDUCE_FUNC(Max, maxCoeff);
...@@ -30,14 +38,10 @@ DEFINE_ROWWISE_REDUCE_FUNC(Sum, sum); ...@@ -30,14 +38,10 @@ DEFINE_ROWWISE_REDUCE_FUNC(Sum, sum);
#define DEFINE_COLWISE_REDUCE_FUNC(name, expr) \ #define DEFINE_COLWISE_REDUCE_FUNC(name, expr) \
template <typename T> \ template <typename T> \
void _ColwiseReduce##name( \ void _ColwiseReduce##name( \
const int rows, const int cols, const T* scale, const T* x, T* y) { \ const int rows, const int cols, const float scale, const T* x, T* y) { \
if (scale != nullptr) { \ EigenVectorMap<T>(y, rows) = \
EigenVectorMap<T>(y, rows) = \ ConstEigenMatrixMap<T>(x, cols, rows).colwise().expr(); \
ConstEigenMatrixMap<T>(x, cols, rows).colwise().expr() * (*scale); \ if (scale != 1.f) EigenVectorMap<T>(y, rows) *= T(scale); \
} else { \
EigenVectorMap<T>(y, rows) = \
ConstEigenMatrixMap<T>(x, cols, rows).colwise().expr(); \
} \
} }
DEFINE_COLWISE_REDUCE_FUNC(Max, maxCoeff); DEFINE_COLWISE_REDUCE_FUNC(Max, maxCoeff);
...@@ -52,7 +56,7 @@ void _GenericReduceMax( ...@@ -52,7 +56,7 @@ void _GenericReduceMax(
const int num_dims, const int num_dims,
const int* x_dims, const int* x_dims,
const int* x_strides, const int* x_strides,
const T* scale, const float scale,
const T* x, const T* x,
T* y) { T* y) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
...@@ -70,7 +74,11 @@ void _GenericReduceMax( ...@@ -70,7 +74,11 @@ void _GenericReduceMax(
} }
val = std::max(x[xi], val); val = std::max(x[xi], val);
} }
y[i] = val; if (scale != 1.f) {
y[i] = static_cast<T>(static_cast<float>(val) * scale);
} else {
y[i] = val;
}
} }
} }
...@@ -81,7 +89,7 @@ void _GenericReduceMin( ...@@ -81,7 +89,7 @@ void _GenericReduceMin(
const int num_dims, const int num_dims,
const int* x_dims, const int* x_dims,
const int* x_strides, const int* x_strides,
const T* scale, const float scale,
const T* x, const T* x,
T* y) { T* y) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
...@@ -99,7 +107,11 @@ void _GenericReduceMin( ...@@ -99,7 +107,11 @@ void _GenericReduceMin(
} }
val = std::min(x[xi], val); val = std::min(x[xi], val);
} }
y[i] = val; if (scale != 1.f) {
y[i] = static_cast<T>(static_cast<float>(val) * scale);
} else {
y[i] = val;
}
} }
} }
...@@ -110,7 +122,7 @@ void _GenericReduceSum( ...@@ -110,7 +122,7 @@ void _GenericReduceSum(
const int num_dims, const int num_dims,
const int* x_dims, const int* x_dims,
const int* x_strides, const int* x_strides,
const T* scale, const float scale,
const T* x, const T* x,
T* y) { T* y) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
...@@ -128,56 +140,62 @@ void _GenericReduceSum( ...@@ -128,56 +140,62 @@ void _GenericReduceSum(
} }
val += x[xi]; val += x[xi];
} }
if (scale != nullptr) { if (scale != 1.f) {
y[i] = val * (*scale); y[i] = static_cast<T>(static_cast<float>(val) * scale);
} else { } else {
y[i] = val; y[i] = val;
} }
} }
} }
#define DEFINE_REDUCE_FUNC(name) \ #define DEFINE_REDUCE_FUNC(name) \
template <typename T> \ template <typename T> \
void _Reduce##name( \ void _Reduce##name( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
const int* axes, \ const int* axes, \
const T* scale, \ const float scale, \
const T* x, \ const T* x, \
T* y) { \ T* y) { \
int rows, cols; \ if (num_dims == num_axes) { \
vec32_t y_dims(dims, dims + num_dims); \ const int count = \
for (int i = 0; i < num_axes; ++i) \ std::accumulate(dims, dims + num_dims, 1, std::multiplies<int>()); \
y_dims[axes[i]] = 1; \ _GlobalReduce##name(count, scale, x, y); \
/* Case #1: Rowwise Reduce */ \ return; \
if (utils::math::IsRowwiseReduce( \ } \
num_dims, dims, y_dims.data(), &rows, &cols)) { \ int rows, cols; \
_RowwiseReduce##name(rows, cols, scale, x, y); \ vec32_t y_dims(dims, dims + num_dims); \
return; \ for (int i = 0; i < num_axes; ++i) \
} \ y_dims[axes[i]] = 1; \
/* Case #2: Colwise Reduce */ \ /* Case #1: Rowwise Reduce */ \
if (utils::math::IsColwiseReduce( \ if (utils::math::IsRowwiseReduce( \
num_dims, dims, y_dims.data(), &rows, &cols)) { \ num_dims, dims, y_dims.data(), &rows, &cols)) { \
_ColwiseReduce##name(rows, cols, scale, x, y); \ _RowwiseReduce##name(rows, cols, scale, x, y); \
return; \ return; \
} \ } \
/* Case #3: Generic Reduce */ \ /* Case #2: Colwise Reduce */ \
vec32_t axesT(num_dims), stridesT(num_dims), dimsT(num_dims); \ if (utils::math::IsColwiseReduce( \
utils::math::TransposeAxesForReduce( \ num_dims, dims, y_dims.data(), &rows, &cols)) { \
num_dims, num_axes, axes, axesT.data()); \ _ColwiseReduce##name(rows, cols, scale, x, y); \
utils::math::ComputeTransposeStrides( \ return; \
num_dims, dims, axesT.data(), stridesT.data()); \ } \
rows = cols = 1; \ /* Case #3: Generic Reduce */ \
const int pivot = num_dims - num_axes; \ vec32_t axesT(num_dims), stridesT(num_dims), dimsT(num_dims); \
for (int i = 0; i < pivot; ++i) \ utils::math::TransposeAxesForReduce( \
rows *= dims[axesT[i]]; \ num_dims, num_axes, axes, axesT.data()); \
for (int i = pivot; i < num_dims; ++i) \ utils::math::ComputeTransposeStrides( \
cols *= dims[axesT[i]]; \ num_dims, dims, axesT.data(), stridesT.data()); \
for (int i = 0; i < num_dims; ++i) \ rows = cols = 1; \
dimsT[i] = dims[axesT[i]]; \ const int pivot = num_dims - num_axes; \
_GenericReduce##name( \ for (int i = 0; i < pivot; ++i) \
rows, cols, num_dims, dimsT.data(), stridesT.data(), scale, x, y); \ rows *= dims[axesT[i]]; \
for (int i = pivot; i < num_dims; ++i) \
cols *= dims[axesT[i]]; \
for (int i = 0; i < num_dims; ++i) \
dimsT[i] = dims[axesT[i]]; \
_GenericReduce##name( \
rows, cols, num_dims, dimsT.data(), stridesT.data(), scale, x, y); \
} }
DEFINE_REDUCE_FUNC(Max); DEFINE_REDUCE_FUNC(Max);
...@@ -189,42 +207,24 @@ DEFINE_REDUCE_FUNC(Sum); ...@@ -189,42 +207,24 @@ DEFINE_REDUCE_FUNC(Sum);
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> #define DEFINE_KERNEL_LAUNCHER(name) \
void ReduceMax<float16, CPUContext>( template <> \
const int num_dims, void Reduce##name<float16, CPUContext>( \
const int* dims, const int num_dims, \
const int num_axes, const int* dims, \
const int* axes, const int num_axes, \
const float16* x, const int* axes, \
float16* y, const float scale, \
CPUContext* ctx) { const float16* x, \
CPU_FP16_NOT_SUPPORTED; float16* y, \
} CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \
template <> }
void ReduceMin<float16, CPUContext>(
const int num_dims,
const int* dims,
const int num_axes,
const int* axes,
const float16* x,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template <> DEFINE_KERNEL_LAUNCHER(Max);
void ReduceSum<float16, CPUContext>( DEFINE_KERNEL_LAUNCHER(Min);
const int num_dims, DEFINE_KERNEL_LAUNCHER(Sum);
const int* dims, #undef DEFINE_KERNEL_LAUNCHER
const int num_axes,
const int* axes,
const float scale,
const float16* x,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template <> template <>
DRAGON_API void Sum<float16, CPUContext>( DRAGON_API void Sum<float16, CPUContext>(
...@@ -246,17 +246,18 @@ DRAGON_API float16 Sum<float16, CPUContext>( ...@@ -246,17 +246,18 @@ DRAGON_API float16 Sum<float16, CPUContext>(
return float16(); return float16();
} }
#define DEFINE_KERNEL_LAUNCHER(name, T) \ #define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void Reduce##name<T, CPUContext>( \ void Reduce##name<T, CPUContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
const int* axes, \ const int* axes, \
const T* x, \ const float scale, \
T* y, \ const T* x, \
CPUContext* ctx) { \ T* y, \
_Reduce##name(num_dims, dims, num_axes, axes, (T*)nullptr, x, y); \ CPUContext* ctx) { \
_Reduce##name(num_dims, dims, num_axes, axes, scale, x, y); \
} }
DEFINE_KERNEL_LAUNCHER(Max, int8_t); DEFINE_KERNEL_LAUNCHER(Max, int8_t);
...@@ -271,23 +272,6 @@ DEFINE_KERNEL_LAUNCHER(Min, int); ...@@ -271,23 +272,6 @@ DEFINE_KERNEL_LAUNCHER(Min, int);
DEFINE_KERNEL_LAUNCHER(Min, int64_t); DEFINE_KERNEL_LAUNCHER(Min, int64_t);
DEFINE_KERNEL_LAUNCHER(Min, float); DEFINE_KERNEL_LAUNCHER(Min, float);
DEFINE_KERNEL_LAUNCHER(Min, double); DEFINE_KERNEL_LAUNCHER(Min, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void Reduce##name<T, CPUContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const float scale, \
const T* x, \
T* y, \
CPUContext* ctx) { \
T s = static_cast<T>(scale); \
_Reduce##name(num_dims, dims, num_axes, axes, &s, x, y); \
}
DEFINE_KERNEL_LAUNCHER(Sum, int8_t); DEFINE_KERNEL_LAUNCHER(Sum, int8_t);
DEFINE_KERNEL_LAUNCHER(Sum, uint8_t); DEFINE_KERNEL_LAUNCHER(Sum, uint8_t);
DEFINE_KERNEL_LAUNCHER(Sum, int); DEFINE_KERNEL_LAUNCHER(Sum, int);
...@@ -301,13 +285,13 @@ DEFINE_KERNEL_LAUNCHER(Sum, double); ...@@ -301,13 +285,13 @@ DEFINE_KERNEL_LAUNCHER(Sum, double);
DRAGON_API void Sum<T, CPUContext>( \ DRAGON_API void Sum<T, CPUContext>( \
const int n, const float scale, const T* x, T* y, CPUContext* ctx) { \ const int n, const float scale, const T* x, T* y, CPUContext* ctx) { \
T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \ T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \
*y = val * scale; \ *y = val * T(scale); \
} \ } \
template <> \ template <> \
T Sum<T, CPUContext>( \ T Sum<T, CPUContext>( \
const int n, const float scale, const T* x, CPUContext* ctx) { \ const int n, const float scale, const T* x, CPUContext* ctx) { \
T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \ T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \
return val * scale; \ return val * T(scale); \
} }
DEFINE_SUM_FUNC(int8_t); DEFINE_SUM_FUNC(int8_t);
......
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/workspace.h"
#include "dragon/utils/device/common_cub.h" #include "dragon/utils/device/common_cub.h"
#include "dragon/utils/device/common_thrust.h"
#include "dragon/utils/math/blas.h"
#include "dragon/utils/math/reduce.h" #include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/utils.h" #include "dragon/utils/math/utils.h"
...@@ -26,7 +29,9 @@ __global__ void _RowwiseReduce( ...@@ -26,7 +29,9 @@ __global__ void _RowwiseReduce(
val = reducer(val, x[j * cols + i]); val = reducer(val, x[j * cols + i]);
} }
val = BlockReduce<T>(storage).Reduce(val, reducer); val = BlockReduce<T>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) y[i] = val * scale; if (threadIdx.x == 0) {
y[i] = val * scale;
}
} }
} }
...@@ -35,18 +40,24 @@ __global__ void _RowwiseReduce( ...@@ -35,18 +40,24 @@ __global__ void _RowwiseReduce(
const int rows, const int rows,
const int cols, const int cols,
const Reducer reducer, const Reducer reducer,
const float init, const half init,
const float scale, const half scale,
const half* x, const half* x,
half* y) { half* y) {
__shared__ typename BlockReduce<float>::TempStorage storage; __shared__ typename BlockReduce<half>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, cols) { CUDA_2D_KERNEL_LOOP1(i, cols) {
float val = init; half val = init;
CUDA_2D_KERNEL_LOOP2(j, rows) { CUDA_2D_KERNEL_LOOP2(j, rows) {
val = reducer(val, __half2float(x[j * cols + i])); val = reducer(val, x[j * cols + i]);
}
val = BlockReduce<half>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(val, scale);
#else
y[i] = __float2half(__half2float(val) * __half2float(scale));
#endif
} }
val = BlockReduce<float>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) y[i] = __float2half(val * scale);
} }
} }
...@@ -66,7 +77,9 @@ __global__ void _ColwiseReduce( ...@@ -66,7 +77,9 @@ __global__ void _ColwiseReduce(
val = reducer(val, x[i * cols + j]); val = reducer(val, x[i * cols + j]);
} }
val = BlockReduce<T>(storage).Reduce(val, reducer); val = BlockReduce<T>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) y[i] = val * scale; if (threadIdx.x == 0) {
y[i] = val * scale;
}
} }
} }
...@@ -75,18 +88,24 @@ __global__ void _ColwiseReduce( ...@@ -75,18 +88,24 @@ __global__ void _ColwiseReduce(
const int rows, const int rows,
const int cols, const int cols,
const Reducer reducer, const Reducer reducer,
const float init, const half init,
const float scale, const half scale,
const half* x, const half* x,
half* y) { half* y) {
__shared__ typename BlockReduce<float>::TempStorage storage; __shared__ typename BlockReduce<half>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, rows) { CUDA_2D_KERNEL_LOOP1(i, rows) {
float val = init; half val = init;
CUDA_2D_KERNEL_LOOP2(j, cols) { CUDA_2D_KERNEL_LOOP2(j, cols) {
val = reducer(val, __half2float(x[i * cols + j])); val = reducer(val, x[i * cols + j]);
}
val = BlockReduce<half>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(val, scale);
#else
y[i] = __float2half(__half2float(val) * __half2float(scale));
#endif
} }
val = BlockReduce<float>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) y[i] = __float2half(val * scale);
} }
} }
...@@ -115,7 +134,9 @@ __global__ void _GenericReduce( ...@@ -115,7 +134,9 @@ __global__ void _GenericReduce(
val = reducer(val, x[xi]); val = reducer(val, x[xi]);
} }
val = BlockReduce<T>(storage).Reduce(val, reducer); val = BlockReduce<T>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) y[i] = val * scale; if (threadIdx.x == 0) {
y[i] = val * scale;
}
} }
} }
...@@ -127,13 +148,13 @@ __global__ void _GenericReduce( ...@@ -127,13 +148,13 @@ __global__ void _GenericReduce(
const SimpleArray<int, D> x_dims, const SimpleArray<int, D> x_dims,
const SimpleArray<int, D> x_strides, const SimpleArray<int, D> x_strides,
const Reducer reducer, const Reducer reducer,
const float init, const half init,
const float scale, const half scale,
const half* x, const half* x,
half* y) { half* y) {
__shared__ typename BlockReduce<float>::TempStorage storage; __shared__ typename BlockReduce<half>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, rows) { CUDA_2D_KERNEL_LOOP1(i, rows) {
float val = init; half val = init;
CUDA_2D_KERNEL_LOOP2(j, cols) { CUDA_2D_KERNEL_LOOP2(j, cols) {
int xi = 0, c = i * cols + j; int xi = 0, c = i * cols + j;
for (int d = num_dims - 1; d >= 0; --d) { for (int d = num_dims - 1; d >= 0; --d) {
...@@ -141,26 +162,56 @@ __global__ void _GenericReduce( ...@@ -141,26 +162,56 @@ __global__ void _GenericReduce(
FIXED_DIVISOR_DIV_MOD(x_dims.data[d], c, &c, &r); FIXED_DIVISOR_DIV_MOD(x_dims.data[d], c, &c, &r);
xi += r * x_strides.data[d]; xi += r * x_strides.data[d];
} }
val = reducer(val, __half2float(x[xi])); val = reducer(val, x[xi]);
}
val = BlockReduce<half>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(val, scale);
#else
y[i] = __float2half(__half2float(val) * __half2float(scale));
#endif
} }
val = BlockReduce<float>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) y[i] = __float2half(val * scale);
} }
} }
#define DEFINE_REDUCE_FUNCTION(name) \ #define DEFINE_REDUCE_FUNCTION(name) \
template <typename Tx, typename Tp, class Reducer> \ template <typename T, class Reducer> \
void _Reduce##name( \ int _Reduce##name( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
const int* axes, \ const int* axes, \
const Reducer reducer, \ const Reducer reducer, \
const Tp init, \ const T init, \
const Tp scale, \ const float scale, \
const Tx* x, \ const T* x, \
Tx* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int count = \
std::accumulate(dims, dims + num_dims, 1, std::multiplies<int>()); \
if (num_dims == num_axes && count > 10000) { \
size_t ws_nbytes = 0; \
cub::DeviceReduce::Reduce( \
nullptr, \
ws_nbytes, \
x, \
y, \
count, \
reducer, \
cast::to<T>(init), \
ctx->cuda_stream()); \
cub::DeviceReduce::Reduce( \
ctx->workspace()->data<CUDAContext>({ws_nbytes})[0], \
ws_nbytes, \
x, \
y, \
count, \
reducer, \
cast::to<T>(init), \
ctx->cuda_stream()); \
return 0; \
} \
int rows, cols; \ int rows, cols; \
vec32_t y_dims(dims, dims + num_dims); \ vec32_t y_dims(dims, dims + num_dims); \
for (int i = 0; i < num_axes; ++i) \ for (int i = 0; i < num_axes; ++i) \
...@@ -172,8 +223,9 @@ __global__ void _GenericReduce( ...@@ -172,8 +223,9 @@ __global__ void _GenericReduce(
CUDA_2D_BLOCKS(cols), \ CUDA_2D_BLOCKS(cols), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(rows, cols, reducer, init, scale, x, y); \ ctx->cuda_stream()>>>( \
return; \ rows, cols, reducer, init, cast::to<T>(scale), x, y); \
return 1; \
} \ } \
/*! Case #2: Colwise Reduce */ \ /*! Case #2: Colwise Reduce */ \
if (utils::math::IsColwiseReduce( \ if (utils::math::IsColwiseReduce( \
...@@ -182,8 +234,9 @@ __global__ void _GenericReduce( ...@@ -182,8 +234,9 @@ __global__ void _GenericReduce(
CUDA_2D_BLOCKS(rows), \ CUDA_2D_BLOCKS(rows), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(rows, cols, reducer, init, scale, x, y); \ ctx->cuda_stream()>>>( \
return; \ rows, cols, reducer, init, cast::to<T>(scale), x, y); \
return 2; \
} \ } \
/*! Case #3: Generic Reduce */ \ /*! Case #3: Generic Reduce */ \
CUDA_TENSOR_DIMS_CHECK(num_dims); \ CUDA_TENSOR_DIMS_CHECK(num_dims); \
...@@ -204,7 +257,17 @@ __global__ void _GenericReduce( ...@@ -204,7 +257,17 @@ __global__ void _GenericReduce(
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>( \ ctx->cuda_stream()>>>( \
rows, cols, num_dims, dimsT, stridesT, reducer, init, scale, x, y); \ rows, \
cols, \
num_dims, \
dimsT, \
stridesT, \
reducer, \
init, \
cast::to<T>(scale), \
x, \
y); \
return 3; \
} }
DEFINE_REDUCE_FUNCTION(Max); DEFINE_REDUCE_FUNCTION(Max);
...@@ -216,85 +279,54 @@ DEFINE_REDUCE_FUNCTION(Sum); ...@@ -216,85 +279,54 @@ DEFINE_REDUCE_FUNCTION(Sum);
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> #define DEFINE_KERNEL_LAUNCHER(name, Reducer, kInit) \
void ReduceMax<float16, CUDAContext>( template <> \
const int num_dims, void Reduce##name<float16, CUDAContext>( \
const int* dims, const int num_dims, \
const int num_axes, const int* dims, \
const int* axes, const int num_axes, \
const float16* x, const int* axes, \
float16* y, const float scale, \
CUDAContext* ctx) { const float16* x, \
_ReduceMax( float16* y, \
num_dims, CUDAContext* ctx) { \
dims, auto kind = _Reduce##name( \
num_axes, num_dims, \
axes, dims, \
cub::Max(), num_axes, \
std::numeric_limits<float>::lowest(), axes, \
1.f, Reducer(), \
reinterpret_cast<const half*>(x), cast::to<half>(kInit), \
reinterpret_cast<half*>(y), scale, \
ctx); reinterpret_cast<const half*>(x), \
} reinterpret_cast<half*>(y), \
ctx); \
template <> if (kind == 0) { \
void ReduceMin<float16, CUDAContext>( math::Scale(1, scale, y, y, ctx); \
const int num_dims, } \
const int* dims, }
const int num_axes,
const int* axes,
const float16* x,
float16* y,
CUDAContext* ctx) {
_ReduceMin(
num_dims,
dims,
num_axes,
axes,
cub::Min(),
std::numeric_limits<float>::max(),
1.f,
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y),
ctx);
}
template <> DEFINE_KERNEL_LAUNCHER(Max, cub::MaxHalf, -HFLT_MAX);
void ReduceSum<float16, CUDAContext>( DEFINE_KERNEL_LAUNCHER(Min, cub::MinHalf, HFLT_MAX);
const int num_dims, DEFINE_KERNEL_LAUNCHER(Sum, cub::SumHalf, 0.f);
const int* dims, #undef DEFINE_KERNEL_LAUNCHER
const int num_axes,
const int* axes,
const float scale,
const float16* x,
float16* y,
CUDAContext* ctx) {
_ReduceMin(
num_dims,
dims,
num_axes,
axes,
cub::Sum(),
0.f,
scale,
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y),
ctx);
}
#define DEFINE_KERNEL_LAUNCHER(name, T, Reducer, kInit) \ #define DEFINE_KERNEL_LAUNCHER(name, T, Reducer, kInit) \
template <> \ template <> \
void Reduce##name<T, CUDAContext>( \ void Reduce##name<T, CUDAContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
const int* axes, \ const int* axes, \
const T* x, \ const float scale, \
T* y, \ const T* x, \
CUDAContext* ctx) { \ T* y, \
_Reduce##name( \ CUDAContext* ctx) { \
num_dims, dims, num_axes, axes, Reducer(), kInit, T(1), x, y, ctx); \ auto kind = _Reduce##name( \
num_dims, dims, num_axes, axes, Reducer(), kInit, scale, x, y, ctx); \
if (kind == 0) { \
math::Scale(1, scale, y, y, ctx); \
} \
} }
DEFINE_KERNEL_LAUNCHER( DEFINE_KERNEL_LAUNCHER(
...@@ -345,32 +377,6 @@ DEFINE_KERNEL_LAUNCHER( ...@@ -345,32 +377,6 @@ DEFINE_KERNEL_LAUNCHER(
double, double,
cub::Min, cub::Min,
std::numeric_limits<double>::max()); std::numeric_limits<double>::max());
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T, Reducer, kInit) \
template <> \
void Reduce##name<T, CUDAContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const float scale, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_Reduce##name( \
num_dims, \
dims, \
num_axes, \
axes, \
Reducer(), \
kInit, \
(T)scale, \
x, \
y, \
ctx); \
}
DEFINE_KERNEL_LAUNCHER(Sum, int8_t, cub::Sum, int8_t(0)); DEFINE_KERNEL_LAUNCHER(Sum, int8_t, cub::Sum, int8_t(0));
DEFINE_KERNEL_LAUNCHER(Sum, uint8_t, cub::Sum, uint8_t(0)); DEFINE_KERNEL_LAUNCHER(Sum, uint8_t, cub::Sum, uint8_t(0));
DEFINE_KERNEL_LAUNCHER(Sum, int, cub::Sum, int(0)); DEFINE_KERNEL_LAUNCHER(Sum, int, cub::Sum, int(0));
...@@ -384,18 +390,7 @@ DEFINE_KERNEL_LAUNCHER(Sum, double, cub::Sum, 0.); ...@@ -384,18 +390,7 @@ DEFINE_KERNEL_LAUNCHER(Sum, double, cub::Sum, 0.);
DRAGON_API void Sum<T, CUDAContext>( \ DRAGON_API void Sum<T, CUDAContext>( \
const int n, const float alpha, const T* x, T* y, CUDAContext* ctx) { \ const int n, const float alpha, const T* x, T* y, CUDAContext* ctx) { \
vec32_t dims = {n}, axes = {0}; \ vec32_t dims = {n}, axes = {0}; \
ReduceSum(1, dims.data(), 1, axes.data(), alpha, x, y, ctx); \ math::ReduceSum(1, dims.data(), 1, axes.data(), alpha, x, y, ctx); \
} \
template <> \
DRAGON_API T Sum<T, CUDAContext>( \
const int n, const float alpha, const T* x, CUDAContext* ctx) { \
T val, *y = (T*)ctx->New(sizeof(T)); \
Sum(n, alpha, x, y, ctx); \
CUDA_CHECK(cudaMemcpyAsync( \
&val, y, sizeof(T), cudaMemcpyDeviceToHost, ctx->cuda_stream())); \
ctx->FinishDeviceComputation(); \
ctx->Delete(y); \
return val; \
} }
DEFINE_SUM_FUNC(int8_t); DEFINE_SUM_FUNC(int8_t);
...@@ -407,6 +402,23 @@ DEFINE_SUM_FUNC(float); ...@@ -407,6 +402,23 @@ DEFINE_SUM_FUNC(float);
DEFINE_SUM_FUNC(double); DEFINE_SUM_FUNC(double);
#undef DEFINE_SUM_FUNC #undef DEFINE_SUM_FUNC
#define DEFINE_SUM_FUNC(T) \
template <> \
DRAGON_API T Sum<T, CUDAContext>( \
const int n, const float alpha, const T* x, CUDAContext* ctx) { \
auto policy = thrust::cuda::par.on(ctx->cuda_stream()); \
auto val = thrust::reduce(policy, x, x + n) * alpha; \
return static_cast<T>(val); \
}
DEFINE_SUM_FUNC(int8_t);
DEFINE_SUM_FUNC(uint8_t);
DEFINE_SUM_FUNC(int);
DEFINE_SUM_FUNC(int64_t);
DEFINE_SUM_FUNC(float);
DEFINE_SUM_FUNC(double);
#undef DEFINE_SUM_FUNC
} // namespace math } // namespace math
} // namespace dragon } // namespace dragon
......
...@@ -25,6 +25,7 @@ DRAGON_API void ReduceMax( ...@@ -25,6 +25,7 @@ DRAGON_API void ReduceMax(
const int* dims, const int* dims,
const int num_axes, const int num_axes,
const int* axes, const int* axes,
const float scale,
const T* x, const T* x,
T* y, T* y,
Context* ctx); Context* ctx);
...@@ -35,6 +36,7 @@ DRAGON_API void ReduceMin( ...@@ -35,6 +36,7 @@ DRAGON_API void ReduceMin(
const int* dims, const int* dims,
const int num_axes, const int num_axes,
const int* axes, const int* axes,
const float scale,
const T* x, const T* x,
T* y, T* y,
Context* ctx); Context* ctx);
......
...@@ -96,7 +96,7 @@ MATH_UTILS_DECL T Cube(const T x) { ...@@ -96,7 +96,7 @@ MATH_UTILS_DECL T Cube(const T x) {
} }
#if defined(__CUDACC__) #if defined(__CUDACC__)
MATH_UTILS_DECL bool IsInf(half x) { inline __device__ bool IsInf(half x) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
return __hisinf(x); return __hisinf(x);
#else #else
...@@ -105,7 +105,7 @@ MATH_UTILS_DECL bool IsInf(half x) { ...@@ -105,7 +105,7 @@ MATH_UTILS_DECL bool IsInf(half x) {
#endif #endif
} }
MATH_UTILS_DECL bool IsNaN(half x) { inline __device__ bool IsNaN(half x) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
return __hisnan(x); return __hisnan(x);
#else #else
...@@ -113,7 +113,7 @@ MATH_UTILS_DECL bool IsNaN(half x) { ...@@ -113,7 +113,7 @@ MATH_UTILS_DECL bool IsNaN(half x) {
#endif #endif
} }
MATH_UTILS_DECL half Square(half x) { inline __device__ half Square(half x) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
return __hmul(x, x); return __hmul(x, x);
#else #else
...@@ -121,7 +121,7 @@ MATH_UTILS_DECL half Square(half x) { ...@@ -121,7 +121,7 @@ MATH_UTILS_DECL half Square(half x) {
#endif #endif
} }
MATH_UTILS_DECL half2 Square(half2 x) { inline __device__ half2 Square(half2 x) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
return __hmul2(x, x); return __hmul2(x, x);
#else #else
...@@ -130,7 +130,7 @@ MATH_UTILS_DECL half2 Square(half2 x) { ...@@ -130,7 +130,7 @@ MATH_UTILS_DECL half2 Square(half2 x) {
#endif #endif
} }
MATH_UTILS_DECL half Cube(half x) { inline __device__ half Cube(half x) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
return __hmul(__hmul(x, x), x); return __hmul(__hmul(x, x), x);
#else #else
...@@ -138,7 +138,7 @@ MATH_UTILS_DECL half Cube(half x) { ...@@ -138,7 +138,7 @@ MATH_UTILS_DECL half Cube(half x) {
#endif #endif
} }
MATH_UTILS_DECL half2 Cube(half2 x) { inline __device__ half2 Cube(half2 x) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
return __hmul2(__hmul2(x, x), x); return __hmul2(__hmul2(x, x), x);
#else #else
......
...@@ -231,10 +231,18 @@ void ArgMin( ...@@ -231,10 +231,18 @@ void ArgMin(
int64_t* y, int64_t* y,
Context* ctx); Context* ctx);
/* array.cast */ /* array.channel_affine */
template <typename Tx, typename Ty, class Context> template <typename T, class Context>
void Cast(const int count, const Tx* x, Ty* y, Context* ctx); void ChannelAffine(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const T* x,
const T* w,
const T* b,
T* y,
Context* ctx);
/* array.channel_normalize */ /* array.channel_normalize */
...@@ -344,8 +352,6 @@ void Flagged( ...@@ -344,8 +352,6 @@ void Flagged(
const uint8_t* mask, const uint8_t* mask,
IndexType* index, IndexType* index,
int* num_selected, int* num_selected,
void* scratch,
size_t& scratch_size,
Context* ctx); Context* ctx);
template <typename IndexType, typename CoordType, class Context> template <typename IndexType, typename CoordType, class Context>
...@@ -574,7 +580,7 @@ void ReduceLoss( ...@@ -574,7 +580,7 @@ void ReduceLoss(
const int num_masks, const int num_masks,
const float normalizer, const float normalizer,
const T* x, const T* x,
const int* mask, const T* mask,
T* y, T* y,
Context* ctx); Context* ctx);
...@@ -584,7 +590,7 @@ void ReduceLossGrad( ...@@ -584,7 +590,7 @@ void ReduceLossGrad(
const int num_masks, const int num_masks,
const float normalizer, const float normalizer,
const T* dy, const T* dy,
const int* mask, const T* mask,
T* dx, T* dx,
Context* ctx); Context* ctx);
...@@ -608,7 +614,7 @@ void NLLLoss( ...@@ -608,7 +614,7 @@ void NLLLoss(
const LogitType* log_prob, const LogitType* log_prob,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask, LogitType* mask,
Context* ctx); Context* ctx);
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
...@@ -620,7 +626,7 @@ void NLLLossGrad( ...@@ -620,7 +626,7 @@ void NLLLossGrad(
const LogitType* log_prob, const LogitType* log_prob,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dx,
int* mask, LogitType* mask,
Context* ctx); Context* ctx);
/* loss.sigmoid_ce_loss */ /* loss.sigmoid_ce_loss */
...@@ -631,7 +637,7 @@ void SigmoidCrossEntropy( ...@@ -631,7 +637,7 @@ void SigmoidCrossEntropy(
const T* logit, const T* logit,
const T* target, const T* target,
T* loss, T* loss,
int* mask, T* mask,
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
...@@ -640,7 +646,7 @@ void SigmoidCrossEntropyGrad( ...@@ -640,7 +646,7 @@ void SigmoidCrossEntropyGrad(
const T* logit, const T* logit,
const T* target, const T* target,
T* dlogit, T* dlogit,
int* mask, T* mask,
Context* ctx); Context* ctx);
/* loss.sigmoid_focal_loss */ /* loss.sigmoid_focal_loss */
...@@ -657,7 +663,7 @@ void SigmoidFocalLoss( ...@@ -657,7 +663,7 @@ void SigmoidFocalLoss(
const LogitType* logit, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask, LogitType* mask,
Context* ctx); Context* ctx);
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
...@@ -672,7 +678,7 @@ void SigmoidFocalLossGrad( ...@@ -672,7 +678,7 @@ void SigmoidFocalLossGrad(
const LogitType* logit, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* dlogit, LogitType* dlogit,
int* mask, LogitType* mask,
Context* ctx); Context* ctx);
/* loss.smooth_l1_loss */ /* loss.smooth_l1_loss */
...@@ -714,7 +720,7 @@ void SparseSoftmaxCrossEntropy( ...@@ -714,7 +720,7 @@ void SparseSoftmaxCrossEntropy(
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
int* mask, LogitType* mask,
Context* ctx); Context* ctx);
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
...@@ -726,7 +732,7 @@ void SparseSoftmaxCrossEntropyGrad( ...@@ -726,7 +732,7 @@ void SparseSoftmaxCrossEntropyGrad(
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dx,
int* mask, LogitType* mask,
Context* ctx); Context* ctx);
/* math.abs */ /* math.abs */
...@@ -734,19 +740,6 @@ void SparseSoftmaxCrossEntropyGrad( ...@@ -734,19 +740,6 @@ void SparseSoftmaxCrossEntropyGrad(
template <typename T, class Context> template <typename T, class Context>
void AbsGrad(const int count, const T* x, const T* dy, T* dx, Context* ctx); void AbsGrad(const int count, const T* x, const T* dy, T* dx, Context* ctx);
/* math.affine */
template <typename T, class Context>
void Affine(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const T* x,
const T* w,
const T* b,
T* y,
Context* ctx);
/* math.clip */ /* math.clip */
template <typename T, class Context> template <typename T, class Context>
...@@ -1044,19 +1037,6 @@ void SGDUpdate( ...@@ -1044,19 +1037,6 @@ void SGDUpdate(
T* m, T* m,
Context* ctx); Context* ctx);
/* training.mixed_prec_update */
template <typename T, class Context>
void MixedPrecL2Penalty(
const int count,
const float alpha,
const T* x,
float* dx,
Context* ctx);
template <typename T, class Context>
void MixedPrecUpdate(const int count, const float* dx, T* x, Context* ctx);
/* vision.bias_add */ /* vision.bias_add */
template <typename T, class Context> template <typename T, class Context>
......
...@@ -451,6 +451,32 @@ class TestArrayOps(OpTestCase): ...@@ -451,6 +451,32 @@ class TestArrayOps(OpTestCase):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_cast() self.test_cast()
def test_channel_affine(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data1 = arange((2, 3, 4, 5))
data2, data3 = arange((3, 4)), arange((3, 4))
data4 = arange(data1.shape)
grad1 = data4 * np.expand_dims(data2, -1)
grad2 = np.sum(data4 * data1, (0, 3))
grad3 = np.sum(data4, (0, 3))
x, w, b = new_tensor(data1), new_tensor(data2), new_tensor(data3)
with dragon.GradientTape() as tape:
tape.watch([x, w, b])
y = dragon.channel_affine([x, w, b], axis=1, num_axes=2)
dy = new_tensor(data4)
dx, dw, db = tape.gradient(y, [x, w, b], output_gradients=[dy])
self.assertEqual(
[y, dx, dw, db],
[data1 * np.expand_dims(data2, -1) +
np.expand_dims(data3, -1),
grad1, grad2, grad3])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_channel_affine_cuda(self):
with dragon.device('cuda'):
self.test_channel_affine()
def test_channel_normalize(self): def test_channel_normalize(self):
entries = [((2, 3, 4), [(1., 2., 3.), (3., 2., 1.), 1], {'perm': (0, 1, 2)}), entries = [((2, 3, 4), [(1., 2., 3.), (3., 2., 1.), 1], {'perm': (0, 1, 2)}),
((2, 3, 4), [(1., 2., 3.), (3., 2., 1.), 2], {'perm': (0, 2, 1)})] ((2, 3, 4), [(1., 2., 3.), (3., 2., 1.), 2], {'perm': (0, 2, 1)})]
...@@ -1448,32 +1474,6 @@ class TestMathOps(OpTestCase): ...@@ -1448,32 +1474,6 @@ class TestMathOps(OpTestCase):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_add() self.test_add()
def test_affine(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data1 = arange((2, 3, 4, 5))
data2, data3 = arange((3, 4)), arange((3, 4))
data4 = arange(data1.shape)
grad1 = data4 * np.expand_dims(data2, -1)
grad2 = np.sum(data4 * data1, (0, 3))
grad3 = np.sum(data4, (0, 3))
x, w, b = new_tensor(data1), new_tensor(data2), new_tensor(data3)
with dragon.GradientTape() as tape:
tape.watch([x, w, b])
y = dragon.math.affine([x, w, b], axis=1, num_axes=2)
dy = new_tensor(data4)
dx, dw, db = tape.gradient(y, [x, w, b], output_gradients=[dy])
self.assertEqual(
[y, dx, dw, db],
[data1 * np.expand_dims(data2, -1) +
np.expand_dims(data3, -1),
grad1, grad2, grad3])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_affine_cuda(self):
with dragon.device('cuda'):
self.test_affine()
def test_argmax(self): def test_argmax(self):
entries = [(0, True), (0, False), (1, True), (1, False)] entries = [(0, True), (0, False), (1, True), (1, False)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......
...@@ -97,12 +97,12 @@ class TestModule(unittest.TestCase): ...@@ -97,12 +97,12 @@ class TestModule(unittest.TestCase):
m.apply(lambda m: m.train()) m.apply(lambda m: m.train())
self.assertEqual(m.training, True) self.assertEqual(m.training, True)
logging.set_verbosity('FATAL') logging.set_verbosity('FATAL')
m.load_state_dict(m.state_dict(), verbose=True) m.load_state_dict(m.state_dict())
logging.set_verbosity('INFO') logging.set_verbosity('INFO')
m.load_state_dict(m.state_dict(to_numpy=True)) m.load_state_dict(m.state_dict(to_numpy=True))
try: try:
m.load_state_dict({'!@#$%^&*()': 1}) m.load_state_dict({'!@#$%^&*()': 1})
except KeyError: except RuntimeError:
pass pass
(m.sub3.weight + 1).sum().backward() (m.sub3.weight + 1).sum().backward()
m.zero_grad() m.zero_grad()
...@@ -156,10 +156,9 @@ class TestModule(unittest.TestCase): ...@@ -156,10 +156,9 @@ class TestModule(unittest.TestCase):
class TestModules(OpTestCase): class TestModules(OpTestCase):
"""Test the nn module class.""" """Test the nn module class."""
def test_affine(self): def test_affine_channel(self):
data1 = arange((2, 3, 4, 5)) data1 = arange((2, 3, 4, 5))
data2, data3 = arange((1, 3, 1, 1)), arange((1, 3, 1, 1)) data2, data3 = arange((1, 3, 1, 1)), arange((1, 3, 1, 1))
x = new_tensor(data1)
w, b = new_tensor(data2.flatten()), new_tensor(data3.flatten()) w, b = new_tensor(data2.flatten()), new_tensor(data3.flatten())
entries = [(True, False, False), entries = [(True, False, False),
(True, True, False), (True, True, False),
...@@ -167,8 +166,9 @@ class TestModules(OpTestCase): ...@@ -167,8 +166,9 @@ class TestModules(OpTestCase):
(False, False, False), (False, False, False),
(False, True, False)] (False, True, False)]
for bias, fix_weight, fix_bias in entries: for bias, fix_weight, fix_bias in entries:
x = new_tensor(data1)
try: try:
m = torch.nn.Affine( m = torch.nn.AffineChannel(
num_features=3, num_features=3,
bias=bias, bias=bias,
fix_weight=fix_weight, fix_weight=fix_weight,
...@@ -176,7 +176,7 @@ class TestModules(OpTestCase): ...@@ -176,7 +176,7 @@ class TestModules(OpTestCase):
inplace=True, inplace=True,
) )
except ValueError: except ValueError:
m = torch.nn.Affine( m = torch.nn.AffineChannel(
num_features=3, num_features=3,
bias=bias, bias=bias,
fix_weight=fix_weight, fix_weight=fix_weight,
......
...@@ -50,6 +50,7 @@ from dragon.vm.torch.core.ops.array.functional import argmax ...@@ -50,6 +50,7 @@ from dragon.vm.torch.core.ops.array.functional import argmax
from dragon.vm.torch.core.ops.array.functional import argmin from dragon.vm.torch.core.ops.array.functional import argmin
from dragon.vm.torch.core.ops.array.functional import assign from dragon.vm.torch.core.ops.array.functional import assign
from dragon.vm.torch.core.ops.array.functional import cat from dragon.vm.torch.core.ops.array.functional import cat
from dragon.vm.torch.core.ops.array.functional import channel_affine
from dragon.vm.torch.core.ops.array.functional import channel_normalize from dragon.vm.torch.core.ops.array.functional import channel_normalize
from dragon.vm.torch.core.ops.array.functional import channel_shuffle from dragon.vm.torch.core.ops.array.functional import channel_shuffle
from dragon.vm.torch.core.ops.array.functional import chunk from dragon.vm.torch.core.ops.array.functional import chunk
......
...@@ -30,7 +30,6 @@ from dragon.vm.torch.core.nn.modules.activation import SELU ...@@ -30,7 +30,6 @@ from dragon.vm.torch.core.nn.modules.activation import SELU
from dragon.vm.torch.core.nn.modules.activation import Sigmoid from dragon.vm.torch.core.nn.modules.activation import Sigmoid
from dragon.vm.torch.core.nn.modules.activation import Softmax from dragon.vm.torch.core.nn.modules.activation import Softmax
from dragon.vm.torch.core.nn.modules.activation import Tanh from dragon.vm.torch.core.nn.modules.activation import Tanh
from dragon.vm.torch.core.nn.modules.affine import Affine
from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm1d from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm1d
from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm2d from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm2d
from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm3d from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm3d
...@@ -55,6 +54,7 @@ from dragon.vm.torch.core.nn.modules.loss import NLLLoss ...@@ -55,6 +54,7 @@ from dragon.vm.torch.core.nn.modules.loss import NLLLoss
from dragon.vm.torch.core.nn.modules.loss import SigmoidFocalLoss from dragon.vm.torch.core.nn.modules.loss import SigmoidFocalLoss
from dragon.vm.torch.core.nn.modules.loss import SmoothL1Loss from dragon.vm.torch.core.nn.modules.loss import SmoothL1Loss
from dragon.vm.torch.core.nn.modules.module import Module from dragon.vm.torch.core.nn.modules.module import Module
from dragon.vm.torch.core.nn.modules.normalization import AffineChannel
from dragon.vm.torch.core.nn.modules.normalization import GroupNorm from dragon.vm.torch.core.nn.modules.normalization import GroupNorm
from dragon.vm.torch.core.nn.modules.normalization import LocalResponseNorm from dragon.vm.torch.core.nn.modules.normalization import LocalResponseNorm
from dragon.vm.torch.core.nn.modules.padding import ConstantPad1d from dragon.vm.torch.core.nn.modules.padding import ConstantPad1d
......
...@@ -14,7 +14,6 @@ from __future__ import absolute_import as _absolute_import ...@@ -14,7 +14,6 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
from dragon.vm.torch.core.nn.functional import affine
from dragon.vm.torch.core.nn.functional import avg_pool2d from dragon.vm.torch.core.nn.functional import avg_pool2d
from dragon.vm.torch.core.nn.functional import batch_norm from dragon.vm.torch.core.nn.functional import batch_norm
from dragon.vm.torch.core.nn.functional import binary_cross_entropy_with_logits from dragon.vm.torch.core.nn.functional import binary_cross_entropy_with_logits
......
...@@ -20,33 +20,6 @@ from dragon.vm.torch.core.nn import _reduction ...@@ -20,33 +20,6 @@ from dragon.vm.torch.core.nn import _reduction
from dragon.vm.torch.core.nn.modules import utils from dragon.vm.torch.core.nn.modules import utils
def affine(input, weight, bias=None):
r"""Apply the affine transformation to input.
.. math:: y = Ax + b
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.Affine(...)`_
"""
return _functions.Affine.instantiate(input.device).apply(input, weight, bias)
def avg_pool2d( def avg_pool2d(
input, input,
kernel_size, kernel_size,
......
...@@ -98,24 +98,6 @@ class _PoolNd(function.Function): ...@@ -98,24 +98,6 @@ class _PoolNd(function.Function):
return self.dispatch([input], [self.alloc()]) return self.dispatch([input], [self.alloc()])
class Affine(function.Function):
def __init__(self, key, dev, **kwargs):
super(Affine, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'Affine',
'arguments': {
'axis': 1,
'num_axes': 1,
}
}
def forward(self, input, weight, bias=None):
inputs = [input, weight] + ([bias] if bias else [])
return self.dispatch(inputs, [self.alloc()])
class BatchNorm(function.Function): class BatchNorm(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(BatchNorm, self).__init__(key, dev, **kwargs) super(BatchNorm, self).__init__(key, dev, **kwargs)
......
...@@ -20,7 +20,7 @@ from dragon.core import distributed ...@@ -20,7 +20,7 @@ from dragon.core import distributed
from dragon.vm.torch.core.nn import functional as F from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules.module import Module from dragon.vm.torch.core.nn.modules.module import Module
from dragon.vm.torch.core.nn.parameter import Parameter from dragon.vm.torch.core.nn.parameter import Parameter
from dragon.vm.torch.core.ops.init import functional as init from dragon.vm.torch.core.ops.init import functional as init_funcs
from dragon.vm.torch.core.tensor import Tensor from dragon.vm.torch.core.tensor import Tensor
...@@ -43,10 +43,10 @@ class _BatchNorm(Module): ...@@ -43,10 +43,10 @@ class _BatchNorm(Module):
self.weight = Parameter(Tensor(num_features)) self.weight = Parameter(Tensor(num_features))
self.bias = Parameter(Tensor(num_features)) self.bias = Parameter(Tensor(num_features))
else: else:
self.register_buffer('weight', init.ones(num_features)) self.register_buffer('weight', init_funcs.ones(num_features))
self.register_buffer('bias', init.zeros(num_features)) self.register_buffer('bias', init_funcs.zeros(num_features))
self.register_buffer('running_mean', init.zeros(num_features)) self.register_buffer('running_mean', init_funcs.zeros(num_features))
self.register_buffer('running_var', init.ones(num_features)) self.register_buffer('running_var', init_funcs.ones(num_features))
self.inputs = [self.running_mean, self.running_var, self.weight, self.bias] self.inputs = [self.running_mean, self.running_var, self.weight, self.bias]
self.reset_parameters() self.reset_parameters()
......
...@@ -15,10 +15,10 @@ from __future__ import division ...@@ -15,10 +15,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import itertools
import numpy import numpy
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.util import logging
from dragon.core.util import string from dragon.core.util import string
from dragon.vm.torch.core.nn.parameter import Parameter from dragon.vm.torch.core.nn.parameter import Parameter
from dragon.vm.torch.core.tensor import Tensor from dragon.vm.torch.core.tensor import Tensor
...@@ -231,7 +231,7 @@ class Module(object): ...@@ -231,7 +231,7 @@ class Module(object):
if t.is_floating_point() else t, if t.is_floating_point() else t,
) )
def load_state_dict(self, state_dict, strict=True, verbose=False): def load_state_dict(self, state_dict, strict=True):
"""Load the state dict from other module. """Load the state dict from other module.
Typically, states can only loaded from the same module class: Typically, states can only loaded from the same module class:
...@@ -255,49 +255,36 @@ class Module(object): ...@@ -255,49 +255,36 @@ class Module(object):
The state dict. The state dict.
strict : bool, optional, default=True strict : bool, optional, default=True
**True** to verify the names strictly. **True** to verify the names strictly.
verbose : bool, optional, default=False
**True** to print the state info.
""" """
if verbose: missing_keys = []
logging.info('Load the state dict.') unexpected_keys = []
unexpected = [] error_msgs = []
own_state = self.state_dict()
for name, param in state_dict.items(): def load(module, prefix=''):
if name in own_state: module._load_from_state_dict(
state_shape = own_state[name].shape state_dict, prefix, True,
param_shape = param.shape missing_keys, unexpected_keys, error_msgs)
if state_shape != param_shape: for name, child in module._modules.items():
raise ValueError( if child is not None:
'Size of state({}) is ({}), while load from: ({}).' load(child, prefix + name + '.')
.format(name, ', '.join(
[str(d) for d in state_shape]), load(self)
', '.join([str(d) for d in param_shape])))
if isinstance(param, Tensor):
own_state[name].copy_(param)
elif isinstance(param, numpy.ndarray):
own_state[name]._impl.FromNumpy(param.copy())
else:
raise ValueError(
'Excepted the type of source state is either '
'torch.Tensor or numpy.ndarray, got {}.'.format(type(param)))
if verbose:
logging.info(
'Tensor({}) loaded, size: ({})'
.format(name, ', '.join([str(d) for d in param_shape])))
else:
unexpected.append(name)
if strict: if strict:
missing = set(own_state.keys()) - set(state_dict.keys()) if len(unexpected_keys) > 0:
error_msg = '' error_msgs.insert(
if len(unexpected) > 0: 0, 'Unexpected key(s) in state_dict: {}. '
error_msg += 'Unexpected key(s) in state_dict: {}.\n'.format( .format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
', '.join('"{}"'.format(k) for k in unexpected)) if len(missing_keys) > 0:
if len(missing) > 0: error_msgs.insert(
error_msg += 'Missing key(s) in state_dict: {}.'.format( 0, 'Missing key(s) in state_dict: {}. '
', '.join('"{}"'.format(k) for k in missing)) .format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msg) > 0:
raise KeyError(error_msg) if len(error_msgs) > 0:
raise RuntimeError(
'Error(s) in loading state_dict for {}:\n\t{}'
.format(self.__class__.__name__, "\n\t".join(error_msgs)))
def modules(self): def modules(self):
"""Return an iterator over all modules. """Return an iterator over all modules.
...@@ -577,6 +564,51 @@ class Module(object): ...@@ -577,6 +564,51 @@ class Module(object):
"""Return the class name.""" """Return the class name."""
return self.__class__.__name__ return self.__class__.__name__
def _load_from_state_dict(
self,
state_dict,
prefix,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Load buffers and parameters from the state dict for this module only."""
local_name_params = itertools.chain(
self._parameters.items(), self._buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if input_param.shape != param.shape:
error_msgs.append(
'Size of param({}) is ({}), while load from: ({}).'
.format(name, ', '.join(
[str(d) for d in param.shape]),
', '.join([str(d) for d in input_param.shape])))
if isinstance(input_param, Tensor):
param.copy_(input_param)
elif isinstance(input_param, numpy.ndarray):
param._impl.FromNumpy(input_param.copy())
else:
error_msgs.append(
'Excepted the input param is either '
'torch.Tensor or numpy.ndarray, got {}.'
.format(type(input_param)))
elif strict:
missing_keys.append(key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix):
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0]
if input_name not in self._modules \
and input_name not in local_state:
unexpected_keys.append(key)
def _named_members(self, getter, prefix='', recurse=True): def _named_members(self, getter, prefix='', recurse=True):
"""Return the named members.""" """Return the named members."""
memo = set() memo = set()
......
...@@ -19,10 +19,98 @@ import inspect ...@@ -19,10 +19,98 @@ import inspect
from dragon.vm.torch.core.nn import functional as F from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules.module import Module from dragon.vm.torch.core.nn.modules.module import Module
from dragon.vm.torch.core.nn.parameter import Parameter from dragon.vm.torch.core.nn.parameter import Parameter
from dragon.vm.torch.core.ops.init import functional as init from dragon.vm.torch.core.ops.array import functional as array_funcs
from dragon.vm.torch.core.ops.init import functional as init_funcs
from dragon.vm.torch.core.tensor import Tensor from dragon.vm.torch.core.tensor import Tensor
class AffineChannel(Module):
"""Apply affine transformation along the channels.
Affine is often taken as a post-processing of normalization.
Examples:
```python
m = torch.nn.AffineChannel(5)
# Apply a 2d transformation
x2d = torch.ones(3, 5)
y2d = m(x2d)
# Apply a 3d transformation
x3d = torch.ones(3, 5, 4)
y3d = m(x3d)
# Apply a 4d transformation
x4d = torch.ones(3, 5, 2, 2)
y4d = m(x4d)
```
See Also
--------
`torch.channel_affine(...)`_
"""
def __init__(
self,
num_features,
bias=True,
fix_weight=False,
fix_bias=False,
inplace=False,
):
"""Create an ``Affine`` module.
Parameters
----------
num_features : int
The number of channels.
bias : bool, optional, default=True
**True** to attach a bias.
fix_weight : bool, optional, default=False
**True** to frozen the ``weight``.
fix_bias : bool, optional, default=False
**True** to frozen the ``bias``.
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(AffineChannel, self).__init__()
self.num_features = num_features
self.inplace = inplace
if not fix_weight:
self.weight = Parameter(init_funcs.ones(num_features))
if inplace:
raise ValueError('In-place operation requires fixed weight.')
else:
self.register_buffer('weight', init_funcs.ones(num_features))
if bias:
if not fix_bias:
self.bias = Parameter(init_funcs.zeros(num_features))
else:
self.register_buffer('bias', init_funcs.zeros(num_features))
else:
self.bias = None
def extra_repr(self):
s = '{num_features}, ' \
'inplace={inplace}'.format(**self.__dict__)
if self.bias is None:
s += ', bias=False'
return s
def forward(self, input):
return array_funcs.channel_affine(
input,
self.weight,
self.bias,
dim=1,
out=input if self.inplace else None,
)
class GroupNorm(Module): class GroupNorm(Module):
r"""Apply the group normalization. r"""Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_. `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
...@@ -76,8 +164,8 @@ class GroupNorm(Module): ...@@ -76,8 +164,8 @@ class GroupNorm(Module):
self.weight = Parameter(Tensor(num_channels)) self.weight = Parameter(Tensor(num_channels))
self.bias = Parameter(Tensor(num_channels)) self.bias = Parameter(Tensor(num_channels))
else: else:
self.register_buffer('weight', init.ones(num_channels)) self.register_buffer('weight', init_funcs.ones(num_channels))
self.register_buffer('bias', init.zeros(num_channels)) self.register_buffer('bias', init_funcs.zeros(num_channels))
self.inputs = [self.weight, self.bias] self.inputs = [self.weight, self.bias]
self.reset_parameters() self.reset_parameters()
......
...@@ -24,7 +24,7 @@ from dragon.vm.torch.core.nn import functional as F ...@@ -24,7 +24,7 @@ from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules import _functions as nn_funcs from dragon.vm.torch.core.nn.modules import _functions as nn_funcs
from dragon.vm.torch.core.nn.modules.module import Module from dragon.vm.torch.core.nn.modules.module import Module
from dragon.vm.torch.core.nn.parameter import Parameter from dragon.vm.torch.core.nn.parameter import Parameter
from dragon.vm.torch.core.ops.init import functional as init from dragon.vm.torch.core.ops.init import functional as init_funcs
from dragon.vm.torch.core.tensor import Tensor from dragon.vm.torch.core.tensor import Tensor
...@@ -141,8 +141,8 @@ class RNNBase(Module): ...@@ -141,8 +141,8 @@ class RNNBase(Module):
num_cols = shape[-1] num_cols = shape[-1]
flat_shape = (num_cols, num_rows) if num_rows < num_cols \ flat_shape = (num_cols, num_rows) if num_rows < num_cols \
else (num_rows, num_cols) else (num_rows, num_cols)
W = numpy.random.randn(*flat_shape) w = numpy.random.randn(*flat_shape)
q, r = numpy.linalg.qr(W) q, r = numpy.linalg.qr(w)
# Make Q uniform # Make Q uniform
d = numpy.diag(r) d = numpy.diag(r)
q *= numpy.sign(d) q *= numpy.sign(d)
...@@ -423,7 +423,7 @@ class LSTMCell(RNNCellBase): ...@@ -423,7 +423,7 @@ class LSTMCell(RNNCellBase):
def forward(self, input, hx=None): def forward(self, input, hx=None):
if hx is None: if hx is None:
zeros = init.zeros( zeros = init_funcs.zeros(
input.size(0), input.size(0),
self.hidden_size, self.hidden_size,
dtype=input.dtype, dtype=input.dtype,
......
...@@ -94,6 +94,26 @@ class Cast(function.Function): ...@@ -94,6 +94,26 @@ class Cast(function.Function):
return self.dispatch([input], [self.alloc()]) return self.dispatch([input], [self.alloc()])
class ChannelAffine(function.Function):
def __init__(self, key, dev, **kwargs):
super(ChannelAffine, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
self.num_axes = kwargs.get('num_axes', 1)
def attributes(self):
return {
'op_type': 'ChannelAffine',
'arguments': {
'axis': self.axis,
'num_axes': self.num_axes,
}
}
def forward(self, input, weight, bias=None, out=None):
inputs = [input, weight] + ([bias] if bias else [])
return self.dispatch(inputs, [self.alloc(out)])
class ChannelNormalize(function.Function): class ChannelNormalize(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ChannelNormalize, self).__init__(key, dev, **kwargs) super(ChannelNormalize, self).__init__(key, dev, **kwargs)
......
...@@ -150,6 +150,36 @@ def cat(seq, dim=0, out=None): ...@@ -150,6 +150,36 @@ def cat(seq, dim=0, out=None):
.apply(seq, out) .apply(seq, out)
def channel_affine(input, weight, bias=None, dim=0, out=None):
"""Apply affine transformation along the channels.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias.
dim : int, optional, default=0
The start dimension to transform.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return _functions.ChannelAffine \
.instantiate(
input.device,
axis=dim,
num_axes=weight.ndimension(),
).apply(input, weight, bias, out)
def channel_normalize( def channel_normalize(
input, input,
mean, mean,
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!