Commit c9db9eee by Ting PAN

Fix/Refactor the GroupConvolution on cuDNN

1 parent 6f2751b1
...@@ -28,7 +28,7 @@ class CPUContext { ...@@ -28,7 +28,7 @@ class CPUContext {
public: public:
CPUContext(): random_seed_(3) { generator(); } CPUContext(): random_seed_(3) { generator(); }
CPUContext(unsigned int random_seed): random_seed_(random_seed) { generator(); } CPUContext(unsigned int random_seed): random_seed_(random_seed) { generator(); }
CPUContext(const DeviceOption& option): random_seed_(option.has_random_seed() ? CPUContext(const DeviceOption& option): random_seed_(option.has_random_seed() ?
option.random_seed() : 3) { generator(); } option.random_seed() : 3) { generator(); }
virtual ~CPUContext() {} virtual ~CPUContext() {}
...@@ -51,6 +51,9 @@ class CPUContext { ...@@ -51,6 +51,9 @@ class CPUContext {
inline static void Memcpy(size_t nbytes, void* dst, const void* src) { memcpy(dst, src, nbytes); } inline static void Memcpy(size_t nbytes, void* dst, const void* src) { memcpy(dst, src, nbytes); }
inline static void Delete(void* data) { free(data); } inline static void Delete(void* data) { free(data); }
template<class DstContext, class SrcContext>
inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) { NOT_IMPLEMENTED; }
template<typename T, class DstContext, class SrcContext> template<typename T, class DstContext, class SrcContext>
inline static void Copy(int n, T* dst, const T* src) { inline static void Copy(int n, T* dst, const T* src) {
if (dst == src) return; if (dst == src) return;
...@@ -62,7 +65,7 @@ class CPUContext { ...@@ -62,7 +65,7 @@ class CPUContext {
inline std::mt19937* generator() { inline std::mt19937* generator() {
auto& generator = cpu_object_.rand_generator; auto& generator = cpu_object_.rand_generator;
if (!generator.get()) if (!generator.get())
generator.reset(new std::mt19937(random_seed_)); generator.reset(new std::mt19937(random_seed_));
return generator.get(); return generator.get();
} }
...@@ -79,4 +82,4 @@ static inline std::mt19937* rand_generator() { ...@@ -79,4 +82,4 @@ static inline std::mt19937* rand_generator() {
} // namepsace dragon } // namepsace dragon
#endif // DRAGON_CORE_CONTEXT_H_ #endif // DRAGON_CORE_CONTEXT_H_
\ No newline at end of file
...@@ -60,7 +60,7 @@ class CUDAObject { ...@@ -60,7 +60,7 @@ class CUDAObject {
class CUDAContext { class CUDAContext {
public: public:
CUDAContext(const DeviceOption& option) CUDAContext(const DeviceOption& option)
: gpu_id_(option.gpu_id()), : gpu_id_(option.gpu_id()),
random_seed_(option.has_random_seed() ? option.random_seed() : 3) { random_seed_(option.has_random_seed() ? option.random_seed() : 3) {
CPUContext context(option); CPUContext context(option);
...@@ -72,7 +72,7 @@ class CUDAContext { ...@@ -72,7 +72,7 @@ class CUDAContext {
#endif #endif
} }
CUDAContext(const int gpu_id = 0) CUDAContext(const int gpu_id = 0)
: gpu_id_(gpu_id), random_seed_(3) { : gpu_id_(gpu_id), random_seed_(3) {
CPUContext context; CPUContext context;
cublas_handle(); cublas_handle();
...@@ -90,7 +90,7 @@ class CUDAContext { ...@@ -90,7 +90,7 @@ class CUDAContext {
void FinishDeviceCompution() { void FinishDeviceCompution() {
cudaStreamSynchronize(cudaStreamDefault); cudaStreamSynchronize(cudaStreamDefault);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
CHECK_EQ(error, cudaSuccess) CHECK_EQ(error, cudaSuccess)
<< "CUDA Error: " << cudaGetErrorString(error); << "CUDA Error: " << cudaGetErrorString(error);
} }
...@@ -108,11 +108,11 @@ class CUDAContext { ...@@ -108,11 +108,11 @@ class CUDAContext {
CUDA_CHECK(cudaMemcpy(dst, src, nbytes, cudaMemcpyDefault)); CUDA_CHECK(cudaMemcpy(dst, src, nbytes, cudaMemcpyDefault));
} }
template<class DstContext, class SrcContext>
inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) { inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) {
cudaStream_t stream; cudaStream_t stream;
CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDefault, stream)); CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDefault, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaStreamDestroy(stream)); CUDA_CHECK(cudaStreamDestroy(stream));
} }
...@@ -205,4 +205,4 @@ class CUDAContext { ...@@ -205,4 +205,4 @@ class CUDAContext {
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_CONTEXT_CUDA_H_ #endif // DRAGON_CORE_CONTEXT_CUDA_H_
\ No newline at end of file
...@@ -17,12 +17,12 @@ class MixedMemory { ...@@ -17,12 +17,12 @@ class MixedMemory {
public: public:
enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED }; enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED };
MixedMemory() MixedMemory()
: state_(UNINITIALIZED), : state_(UNINITIALIZED),
cpu_ptr_(nullptr), cuda_ptr_(nullptr), cpu_ptr_(nullptr), cuda_ptr_(nullptr),
nbytes_(0) {} nbytes_(0) {}
MixedMemory(const TypeMeta& meta, const size_t nbytes) MixedMemory(const TypeMeta& meta, const size_t nbytes)
: state_(UNINITIALIZED), meta_(meta), : state_(UNINITIALIZED), meta_(meta),
cpu_ptr_(nullptr), cuda_ptr_(nullptr), cpu_ptr_(nullptr), cuda_ptr_(nullptr),
nbytes_(nbytes) {} nbytes_(nbytes) {}
~MixedMemory(); ~MixedMemory();
...@@ -55,4 +55,4 @@ class MixedMemory { ...@@ -55,4 +55,4 @@ class MixedMemory {
} // namespace dragon } // namespace dragon
#endif #endif
\ No newline at end of file
...@@ -37,7 +37,7 @@ class Tensor { ...@@ -37,7 +37,7 @@ class Tensor {
capacity_ = 0; capacity_ = 0;
} }
} else { } else {
if (ex_memory_ && TIndex(ex_memory_->nbytes()) < if (ex_memory_ && TIndex(ex_memory_->nbytes()) <
TIndex(new_size * meta_.itemsize())) { TIndex(new_size * meta_.itemsize())) {
delete ex_memory_; delete ex_memory_;
ex_memory_ = nullptr; ex_memory_ = nullptr;
...@@ -72,7 +72,7 @@ class Tensor { ...@@ -72,7 +72,7 @@ class Tensor {
inline TIndex count() const { return size_; } inline TIndex count() const { return size_; }
inline TIndex count(const TIndex start) const { return count(start, ndim()); } inline TIndex count(const TIndex start) const { return count(start, ndim()); }
inline TIndex offset(const TIndex n, const TIndex c = 0, inline TIndex offset(const TIndex n, const TIndex c = 0,
const TIndex h = 0, const TIndex w = 0) { const TIndex h = 0, const TIndex w = 0) {
CHECK_LE(n, dim(0)); CHECK_LE(n, dim(0));
CHECK_LE(c, dim(1)); CHECK_LE(c, dim(1));
...@@ -103,13 +103,13 @@ class Tensor { ...@@ -103,13 +103,13 @@ class Tensor {
inline void Corrupt() { is_corrupted_ = true; } inline void Corrupt() { is_corrupted_ = true; }
MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; } MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; }
MixedMemory::State memory_state() const { MixedMemory::State memory_state() const {
MixedMemory* mem = memory(); MixedMemory* mem = memory();
CHECK(mem) << "\nMemory access before allowcating."; CHECK(mem) << "\nMemory access before allowcating.";
return memory()->state(); return memory()->state();
} }
void SwitchToDevice() { void SwitchToDevice() {
MixedMemory* mem = own_mem_ ? memory_.get() : ex_memory_; MixedMemory* mem = own_mem_ ? memory_.get() : ex_memory_;
if (mem) mem->SwitchToDevice(); if (mem) mem->SwitchToDevice();
} }
...@@ -166,15 +166,15 @@ class Tensor { ...@@ -166,15 +166,15 @@ class Tensor {
template <class Context> template <class Context>
void* raw_mutable_data() { void* raw_mutable_data() {
CHECK_NE(meta_.id(), 0) CHECK_NE(meta_.id(), 0)
<< "\nTensor(" << name_ << "): unknown type, " << "\nTensor(" << name_ << "): unknown type, "
<< "or does not have a type."; << "or does not have a type.";
return raw_mutable_data<Context>(meta_); return raw_mutable_data<Context>(meta_);
} }
template <class Context> template <class Context>
const void* raw_data() const { const void* raw_data() const {
return const_data_ptr<Context>(); return const_data_ptr<Context>();
} }
template <typename T, class Context> template <typename T, class Context>
...@@ -186,8 +186,8 @@ class Tensor { ...@@ -186,8 +186,8 @@ class Tensor {
} }
template <typename T, class Context> template <typename T, class Context>
const T* data() const { const T* data() const {
return static_cast<const T*>(raw_data<Context>()); return static_cast<const T*>(raw_data<Context>());
} }
inline void Share(const Tensor& other) { inline void Share(const Tensor& other) {
...@@ -198,7 +198,7 @@ class Tensor { ...@@ -198,7 +198,7 @@ class Tensor {
} }
inline void Move(MixedMemory* mem) { inline void Move(MixedMemory* mem) {
if (mem != nullptr) ex_memory_ = mem; if (mem != nullptr) ex_memory_ = mem;
else ex_memory_ = new MixedMemory(TypeMeta::Make<float>(), 4); else ex_memory_ = new MixedMemory(TypeMeta::Make<float>(), 4);
own_mem_ = false; own_mem_ = false;
} }
...@@ -215,11 +215,11 @@ class Tensor { ...@@ -215,11 +215,11 @@ class Tensor {
TIndex size_ = 0, capacity_ = 0; TIndex size_ = 0, capacity_ = 0;
TypeMeta meta_; TypeMeta meta_;
string name_; string name_;
shared_ptr<MixedMemory> memory_; shared_ptr<MixedMemory> memory_, host_memory_;
MixedMemory* ex_memory_ = nullptr; MixedMemory* ex_memory_ = nullptr;
bool is_corrupted_ = false, own_mem_ = true; bool is_corrupted_ = false, own_mem_ = true;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAONG_CORE_TENSOR_H_ #endif // DRAONG_CORE_TENSOR_H_
\ No newline at end of file
...@@ -18,7 +18,8 @@ class L2NormOp final : public Operator<Context> { ...@@ -18,7 +18,8 @@ class L2NormOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-5))) {} eps(OperatorBase::GetSingleArg<float>("eps", float(1e-5))),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -26,6 +27,7 @@ class L2NormOp final : public Operator<Context> { ...@@ -26,6 +27,7 @@ class L2NormOp final : public Operator<Context> {
protected: protected:
float eps; float eps;
TIndex axis, num_axes, end_axis; TIndex axis, num_axes, end_axis;
string mode;
bool across_inner; bool across_inner;
Tensor* norm, *buffer, *multiplier; Tensor* norm, *buffer, *multiplier;
TIndex outer_dim, dim, inner_dim, spatial_dim; TIndex outer_dim, dim, inner_dim, spatial_dim;
......
...@@ -30,7 +30,7 @@ class Conv2dOp : public ConvOpBase<Context> { ...@@ -30,7 +30,7 @@ class Conv2dOp : public ConvOpBase<Context> {
template <class Context> template <class Context>
class Conv2dGradientOp : public Conv2dOp<Context> { class Conv2dGradientOp : public Conv2dOp<Context> {
public: public:
Conv2dGradientOp(const OperatorDef& def, Workspace* ws) Conv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) {} : Conv2dOp<Context>(def, ws) {}
bool HasBias() override { return output(2)->name() != "ignore"; } bool HasBias() override { return output(2)->name() != "ignore"; }
...@@ -48,10 +48,15 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -48,10 +48,15 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
public: public:
CuDNNConv2dOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) { : Conv2dOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group]; #if CUDNN_VERSION_MIN(7, 0, 0)
stream = new cudaStream_t[this->group]; cudnn_group = 1;
#else
cudnn_group = this->group;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
ctx().SwitchToDevice(); ctx().SwitchToDevice();
for (int g = 0; g < this->group; g++) { for (int g = 0; g < cudnn_group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g])); CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
...@@ -78,17 +83,22 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -78,17 +83,22 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t workspace_fwd_data_size;
TIndex bias_offset; TIndex bias_offset, cudnn_group;
}; };
template <class Context> template <class Context>
class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
public: public:
CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dGradientOp<Context>(def, ws) { : Conv2dGradientOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group * 3]; #if CUDNN_VERSION_MIN(7, 0, 0)
stream = new cudaStream_t[this->group * 3]; cudnn_group = 1;
for (int g = 0; g < this->group * 3; g++) { #else
cudnn_group = this->group;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
for (int g = 0; g < cudnn_group * 3; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g])); CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
...@@ -116,7 +126,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -116,7 +126,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size; size_t workspace_bwd_filter_size, workspace_bwd_data_size;
int bias_offset; TIndex bias_offset, cudnn_group;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -52,8 +52,13 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -52,8 +52,13 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
public: public:
CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) { : Conv2dTransposeOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group]; #if CUDNN_VERSION_MIN(7, 0, 0)
stream = new cudaStream_t[this->group]; cudnn_group = 1;
#else
cudnn_group = this->group;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
for (int g = 0; g < this->group; g++) { for (int g = 0; g < this->group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
...@@ -80,7 +85,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -80,7 +85,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t workspace_fwd_data_size;
int bias_offset; TIndex bias_offset, cudnn_group;
}; };
template <class Context> template <class Context>
...@@ -88,9 +93,14 @@ class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context> ...@@ -88,9 +93,14 @@ class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context>
public: public:
CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeGradientOp<Context>(def, ws) { : Conv2dTransposeGradientOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group * 3]; #if CUDNN_VERSION_MIN(7, 0, 0)
stream = new cudaStream_t[this->group * 3]; cudnn_group = 1;
for (int g = 0; g < this->group * 3; g++) { #else
cudnn_group = this->group;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
for (int g = 0; g < cudnn_group * 3; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g])); CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
...@@ -117,7 +127,7 @@ public: ...@@ -117,7 +127,7 @@ public:
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size; size_t workspace_bwd_filter_size, workspace_bwd_data_size;
int bias_offset; TIndex bias_offset, cudnn_group;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -18,11 +18,12 @@ class LRNOp : public Operator<Context> { ...@@ -18,11 +18,12 @@ class LRNOp : public Operator<Context> {
public: public:
LRNOp(const OperatorDef& op_def, Workspace* ws) LRNOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
mode((LRNMode)OperatorBase::GetSingleArg<int>("mode", ACROSS_CHANNELS)),
local_size(OperatorBase::GetSingleArg<int>("local_size", 5)), local_size(OperatorBase::GetSingleArg<int>("local_size", 5)),
alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))), alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))),
beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))), beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))),
k(OperatorBase::GetSingleArg<float>("k", float(2.0))) {} k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -34,9 +35,9 @@ class LRNOp : public Operator<Context> { ...@@ -34,9 +35,9 @@ class LRNOp : public Operator<Context> {
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunWithType();
protected: protected:
LRNMode mode;
int local_size; int local_size;
float alpha, beta, k; float alpha, beta, k;
string mode, data_format;
unique_ptr<OperatorBase> sqr_op, pool_op, pow_op, prod_op; unique_ptr<OperatorBase> sqr_op, pool_op, pow_op, prod_op;
Tensor* sqr_in, *prod_in, *sqr_out, *pool_out, *pow_out; Tensor* sqr_in, *prod_in, *sqr_out, *pool_out, *pow_out;
Tensor* scale; Tensor* scale;
...@@ -47,11 +48,12 @@ class LRNGradientOp : public Operator<Context> { ...@@ -47,11 +48,12 @@ class LRNGradientOp : public Operator<Context> {
public: public:
LRNGradientOp(const OperatorDef& op_def, Workspace* ws) LRNGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
mode((LRNMode)OperatorBase::GetSingleArg<int>("mode", ACROSS_CHANNELS)),
local_size(OperatorBase::GetSingleArg<int>("local_size", 5)), local_size(OperatorBase::GetSingleArg<int>("local_size", 5)),
alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))), alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))),
beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))), beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))),
k(OperatorBase::GetSingleArg<float>("k", float(2.0))) {} k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -63,9 +65,9 @@ class LRNGradientOp : public Operator<Context> { ...@@ -63,9 +65,9 @@ class LRNGradientOp : public Operator<Context> {
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunWithType();
protected: protected:
LRNMode mode;
int local_size; int local_size;
float alpha, beta, k; float alpha, beta, k;
string mode, data_format;
unique_ptr<OperatorBase> sqr_op, pool_op, pow_op, prod_op; unique_ptr<OperatorBase> sqr_op, pool_op, pow_op, prod_op;
Tensor* sqr_in, *prod_in, *sqr_out, *pool_out, *pow_out; Tensor* sqr_in, *prod_in, *sqr_out, *pool_out, *pow_out;
Tensor* scale; Tensor* scale;
......
...@@ -76,6 +76,9 @@ template <typename T> ...@@ -76,6 +76,9 @@ template <typename T>
void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims); void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T> template <typename T>
void cudnnSetTensor4dDescWithGroup(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims, const int64_t group);
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims); void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T> template <typename T>
......
...@@ -156,29 +156,39 @@ class Tensor(object): ...@@ -156,29 +156,39 @@ class Tensor(object):
""" """
return self.Normal(mu=mean, sigma=std) return self.Normal(mu=mean, sigma=std)
def Xavier(self): def Xavier(self, scale=3.0):
""" """
Register as a variable with xavier initializer. Register as a variable with xavier initializer.
""" """
return self._no_parameter_filler('xavier') filler = pb.TensorFiller()
filler.tensor = self.name
filler.type = 'xavier'
filler.scale = scale
ws.CreateFiller(filler)
return self
def MSRA(self): def MSRA(self, scale=2.0):
""" """
Register as a variable with msra initializer. Register as a variable with msra initializer.
""" """
return self._no_parameter_filler('msra') filler = pb.TensorFiller()
filler.tensor = self.name
filler.type = 'msra'
filler.scale = scale
ws.CreateFiller(filler)
return self
def GlorotUniform(self): def GlorotUniform(self, scale=3.0):
""" """
Register as a variable with glorot uniform initializer. Register as a variable with glorot uniform initializer.
""" """
return self.Xavier() return self.Xavier(scale)
def GlorotNormal(self): def GlorotNormal(self, scale=2.0):
""" """
Register as a variable with glorot normal initializer. Register as a variable with glorot normal initializer.
""" """
return self.MSRA() return self.MSRA(scale)
############################################## ##############################################
# # # #
......
...@@ -19,10 +19,18 @@ Installation - Linux (Normal, CPU) ...@@ -19,10 +19,18 @@ Installation - Linux (Normal, CPU)
**Step 1:** Install C++ Dependencies **Step 1:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell .. code-block:: shell
sudo apt-get install libpython-dev sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev sudo apt-get install libopenblas-dev
**Step 2:** Install Python Requirements **Step 2:** Install Python Requirements
...@@ -83,10 +91,18 @@ Installation - Linux (Normal, GPU) ...@@ -83,10 +91,18 @@ Installation - Linux (Normal, GPU)
**Step 2:** Install C++ Dependencies **Step 2:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell .. code-block:: shell
sudo apt-get install libpython-dev sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev sudo apt-get install libopenblas-dev
**Step 3:** Install Python Requirements **Step 3:** Install Python Requirements
...@@ -149,10 +165,18 @@ Installation - Linux (Distributed, CPU) ...@@ -149,10 +165,18 @@ Installation - Linux (Distributed, CPU)
**Step 2:** Install C++ Dependencies **Step 2:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell .. code-block:: shell
sudo apt-get install libpython-dev sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev sudo apt-get install libopenblas-dev
**Step 3:** Install Python Requirements **Step 3:** Install Python Requirements
...@@ -229,10 +253,18 @@ Installation - Linux (Distributed, GPU) ...@@ -229,10 +253,18 @@ Installation - Linux (Distributed, GPU)
**Step 3:** Install C++ Dependencies **Step 3:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell .. code-block:: shell
sudo apt-get install libpython-dev sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev sudo apt-get install libopenblas-dev
**Step 4:** Install Python Requirements **Step 4:** Install Python Requirements
...@@ -564,6 +596,7 @@ Add ``REPO_ROOT/3rdparty/bin`` to system environment variables ...@@ -564,6 +596,7 @@ Add ``REPO_ROOT/3rdparty/bin`` to system environment variables
python setup.py install --user python setup.py install --user
.. _Anaconda: https://www.anaconda.com/download
.. _CUDA: https://developer.nvidia.com/cuda-toolkit .. _CUDA: https://developer.nvidia.com/cuda-toolkit
.. _CUDNN: https://developer.nvidia.com/cudnn .. _CUDNN: https://developer.nvidia.com/cudnn
.. _NCCL: https://developer.nvidia.com/nccl .. _NCCL: https://developer.nvidia.com/nccl
......
...@@ -673,6 +673,7 @@ def Reshape(inputs, shape, **kwargs): ...@@ -673,6 +673,7 @@ def Reshape(inputs, shape, **kwargs):
output.shape = [1] * len(shape) output.shape = [1] * len(shape)
for i, s in enumerate(shape): for i, s in enumerate(shape):
if s == -1: output.shape[i] = 1 if s == -1: output.shape[i] = 1
elif s == 0: output.shape[i] = inputs.shape[i]
else: output.shape[i] = s else: output.shape[i] = s
return output return output
......
...@@ -189,7 +189,7 @@ def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs): ...@@ -189,7 +189,7 @@ def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs):
return output return output
def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, **kwargs): def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, mode='SUM', **kwargs):
"""L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_. """L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
Parameters Parameters
...@@ -202,6 +202,8 @@ def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, **kwargs): ...@@ -202,6 +202,8 @@ def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, **kwargs):
The number of axes of stats region. Default is ``-1`` (Till End). The number of axes of stats region. Default is ``-1`` (Till End).
eps : float eps : float
The eps. The eps.
mode : str
The mode on computing normalizer. ``SUM`` or ``MEAN``.
Returns Returns
------- -------
......
...@@ -61,6 +61,12 @@ def Conv2d(inputs, num_output, kernel_size, ...@@ -61,6 +61,12 @@ def Conv2d(inputs, num_output, kernel_size,
""" """
CheckInputs(inputs, 2, 3) CheckInputs(inputs, 2, 3)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if padding not in ('VALID', 'SAME'):
raise ValueError('Unsupported padding algorithm: {}'.format(padding))
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
if not isinstance(arguments['stride'], list): if not isinstance(arguments['stride'], list):
...@@ -154,6 +160,11 @@ def Conv2dTranspose(inputs, num_output, kernel_size, ...@@ -154,6 +160,11 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
CheckInputs(inputs, 2, 3) CheckInputs(inputs, 2, 3)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if padding not in ('VALID', 'SAME'):
raise ValueError('Unsupported padding algorithm: {}'.format(padding))
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
arguments['output_shape'] = None arguments['output_shape'] = None
if output_shape is not None: if output_shape is not None:
if not isinstance(output_shape, list): if not isinstance(output_shape, list):
...@@ -170,17 +181,43 @@ def Conv2dTranspose(inputs, num_output, kernel_size, ...@@ -170,17 +181,43 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
if not isinstance(arguments['stride'], list): if not isinstance(arguments['stride'], list):
arguments['stride'] = [arguments['stride']] arguments['stride'] = [arguments['stride']]
if not isinstance(arguments['pad'], list): if not isinstance(arguments['pad'], list):
arguments['pad'] = [arguments['pad']] arguments['pad'] = [arguments['pad']]
if not isinstance(arguments['dilation'], list): if not isinstance(arguments['dilation'], list):
arguments['dilation'] = [arguments['dilation']] arguments['dilation'] = [arguments['dilation']]
return Tensor.CreateOperator(nout=1, op_type='Conv2dTranspose', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Conv2dTranspose', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
channel_axis = 1 if data_format == 'NCHW' else -1
spatial_axis = 2 if data_format == 'NCHW' else 1
output.shape[channel_axis] = num_output
for i in xrange(2):
k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \
else arguments['kernel_size'][-1]
s = arguments['stride'][i] if i < len(arguments['stride']) \
else arguments['stride'][-1]
p = arguments['pad'][i] if i < len(arguments['pad']) \
else arguments['pad'][-1]
d = arguments['dilation'][i] if i < len(arguments['dilation']) \
else arguments['dilation'][-1]
dk = d * (k - 1) + 1
dp = 2 * p
input_size = output.shape[i + spatial_axis]
if padding != 'SAME':
output.shape[i + spatial_axis] = s * (input_size - 1) + dk - dp
else:
if output_shape is None:
raise ValueError('The output shape must be specified if using SAME padding algorithm.')
if 'dynamic_dsize' in arguments:
output.shape = None
return output
output.shape[i + spatial_axis] = output_shape[i + spatial_axis]
return output
def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID', def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
...@@ -222,6 +259,14 @@ def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID', ...@@ -222,6 +259,14 @@ def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if mode not in ('MAX', 'AVG'):
raise ValueError('Unsupported lrn mode: {}'.format(mode))
if padding not in ('VALID', 'SAME'):
raise ValueError('Unsupported padding algorithm: {}'.format(padding))
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
if not isinstance(arguments['stride'], list): if not isinstance(arguments['stride'], list):
...@@ -311,7 +356,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs): ...@@ -311,7 +356,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
return Tensor.CreateOperator(nout=1, op_type='ROIAlign', **arguments) return Tensor.CreateOperator(nout=1, op_type='ROIAlign', **arguments)
def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANNELS', **kwargs): def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0,
mode='ACROSS_CHANNELS', data_format='NCHW', **kwargs):
"""Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_. """Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_.
Parameters Parameters
...@@ -328,17 +374,22 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN ...@@ -328,17 +374,22 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN
The k of LRN. The k of LRN.
mode : str mode : str
The mode, ``ACROSS_CHANNELS`` or ``WITHIN_CHANNEL``. The mode, ``ACROSS_CHANNELS`` or ``WITHIN_CHANNEL``.
data_format : str
The data format. ``NCHW`` or ``NHWC``.
Returns Returns
------- -------
Tensor Tensor
The normalized tensor. The output tensor.
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
SUPPORT_MODES = {'ACROSS_CHANNELS': 0, 'WITHIN_CHANNEL': 1}
arguments['mode'] = SUPPORT_MODES[mode] if mode not in ('ACROSS_CHANNELS', 'WITHIN_CHANNEL'):
raise ValueError('Unsupported lrn mode: {}'.format(mode))
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
output = Tensor.CreateOperator(nout=1, op_type='LRN', **arguments) output = Tensor.CreateOperator(nout=1, op_type='LRN', **arguments)
...@@ -356,9 +407,9 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -356,9 +407,9 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tenosr. The input tensor.
dsize : tuple, list, Tensor or None dsize : tuple, list, Tensor or None
The output size. The output size, formats as (h, w).
fy : float fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded). The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float fx : float
...@@ -374,6 +425,10 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -374,6 +425,10 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
if arguments['dsize'] is not None: if arguments['dsize'] is not None:
if isinstance(arguments['dsize'][0], Tensor): if isinstance(arguments['dsize'][0], Tensor):
arguments['dynamic_dsize'] = [arguments['dsize'][0].name, arguments['dynamic_dsize'] = [arguments['dsize'][0].name,
...@@ -388,6 +443,20 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -388,6 +443,20 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
output = Tensor.CreateOperator(nout=1, op_type='NNResize', **arguments) output = Tensor.CreateOperator(nout=1, op_type='NNResize', **arguments)
if inputs.shape is not None:
if len(inputs.shape) != 4:
raise ValueError('The inputs should be a 4d Tensor.')
if 'dynamic_dsize' not in arguments:
output.shape = inputs.shape[:]
spatial_axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2):
output_dim = output.shape[spatial_axis + i]
if 'static_size' in arguments:
output_dim = dsize[i]
else:
output_dim = int(float(output_dim) * ([fy, fx])[i])
output.shape[spatial_axis + i] = output_dim
return output return output
...@@ -399,9 +468,9 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -399,9 +468,9 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tenosr. The input tensor.
dsize : tuple, list, Tensor or None dsize : tuple, list, Tensor or None
The dest output size. The output size, formats as (h, w).
fy : float fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded). The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float fx : float
...@@ -417,6 +486,10 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -417,6 +486,10 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
if arguments['dsize'] is not None: if arguments['dsize'] is not None:
if isinstance(arguments['dsize'][0], Tensor): if isinstance(arguments['dsize'][0], Tensor):
arguments['dynamic_dsize'] = [arguments['dsize'][0].name, arguments['dynamic_dsize'] = [arguments['dsize'][0].name,
...@@ -431,6 +504,20 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -431,6 +504,20 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
output = Tensor.CreateOperator(nout=1, op_type='BilinearResize', **arguments) output = Tensor.CreateOperator(nout=1, op_type='BilinearResize', **arguments)
if inputs.shape is not None:
if len(inputs.shape) != 4:
raise ValueError('The inputs should be a 4d Tensor.')
if 'dynamic_dsize' not in arguments:
output.shape = inputs.shape[:]
spatial_axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2):
output_dim = output.shape[spatial_axis + i]
if 'static_size' in arguments:
output_dim = dsize[i]
else:
output_dim = int(float(output_dim) * ([fy, fx])[i])
output.shape[spatial_axis + i] = output_dim
return output return output
...@@ -453,6 +540,9 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs): ...@@ -453,6 +540,9 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
CheckInputs(inputs, 2) CheckInputs(inputs, 2)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
output = Tensor.CreateOperator(nout=1, op_type='BiasAdd', **arguments) output = Tensor.CreateOperator(nout=1, op_type='BiasAdd', **arguments)
if inputs[0].shape is not None: if inputs[0].shape is not None:
......
...@@ -229,7 +229,9 @@ class LRNLayer(Layer): ...@@ -229,7 +229,9 @@ class LRNLayer(Layer):
self._param = {'local_size': param.local_size, self._param = {'local_size': param.local_size,
'alpha': param.alpha, 'alpha': param.alpha,
'beta': param.beta, 'beta': param.beta,
'mode': {0: 'ACROSS_CHANNELS', 1: 'WITHIN_CHANNEL'}[param.norm_region]} 'mode': {0: 'ACROSS_CHANNELS', 1: 'WITHIN_CHANNEL'}[param.norm_region],
'data_format': 'NCHW'}
def Setup(self, bottom): def Setup(self, bottom):
super(LRNLayer, self).Setup(bottom) super(LRNLayer, self).Setup(bottom)
input = bottom[0] if isinstance(bottom, list) else bottom input = bottom[0] if isinstance(bottom, list) else bottom
......
...@@ -18,7 +18,7 @@ GraphBase::GraphBase(const GraphDef& meta_graph, Workspace* ws) ...@@ -18,7 +18,7 @@ GraphBase::GraphBase(const GraphDef& meta_graph, Workspace* ws)
// check inputs // check inputs
for (auto& in : op.input()) for (auto& in : op.input())
CHECK(known_tensors.count(in) || ws_->HasTensor(in)) CHECK(known_tensors.count(in) || ws_->HasTensor(in))
<< "\nInput: " << in << " for op: " << "\nInput: " << in << " for op: "
<< op.name() << " is unknown."; << op.name() << " is unknown.";
// add outputs // add outputs
for (auto& out : op.output()) known_tensors.insert(out); for (auto& out : op.output()) known_tensors.insert(out);
...@@ -55,13 +55,13 @@ void Graph::ForwardShareDyeing(string u, string ancestor) { ...@@ -55,13 +55,13 @@ void Graph::ForwardShareDyeing(string u, string ancestor) {
auto* schema = OpSchemaRegistry::Schema(op_type); auto* schema = OpSchemaRegistry::Schema(op_type);
if (schema->AllowInplace()) if (schema->AllowInplace())
ForwardShareDyeing(dag_[u].childs[0], ancestor); ForwardShareDyeing(dag_[u].childs[0], ancestor);
} }
} }
void Graph::ForwardPruneDyeing(string u, string leaf, vector<string> path) { void Graph::ForwardPruneDyeing(string u, string leaf, vector<string> path) {
if (visited_.count(u)) { if (visited_.count(u)) {
if (visited_[u]) if (visited_[u])
for (auto& node : path) for (auto& node : path)
visited_[node] = colored_[node] = true; visited_[node] = colored_[node] = true;
return; return;
} }
...@@ -71,7 +71,7 @@ void Graph::ForwardPruneDyeing(string u, string leaf, vector<string> path) { ...@@ -71,7 +71,7 @@ void Graph::ForwardPruneDyeing(string u, string leaf, vector<string> path) {
vector<string> new_path(path); vector<string> new_path(path);
new_path.push_back(v); new_path.push_back(v);
if (v == leaf) { if (v == leaf) {
for (auto& node : new_path) for (auto& node : new_path)
visited_[node] = colored_[node] = true; visited_[node] = colored_[node] = true;
return; return;
} }
...@@ -260,8 +260,8 @@ GraphDef Graph::MakeUpdate(const GraphDef& meta_graph) { ...@@ -260,8 +260,8 @@ GraphDef Graph::MakeUpdate(const GraphDef& meta_graph) {
collective_ops.push_back(op_def); collective_ops.push_back(op_def);
} else if (this->args_["parallel_mode"].s() == "MIXED") { } else if (this->args_["parallel_mode"].s() == "MIXED") {
/* /*
See: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour See: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour
Links: http://arxiv.org/abs/1706.02677 Link: http://arxiv.org/abs/1706.02677
*/ */
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -282,11 +282,11 @@ bool Graph::Create(const GraphDef& optimized_graph, Workspace* ws) { ...@@ -282,11 +282,11 @@ bool Graph::Create(const GraphDef& optimized_graph, Workspace* ws) {
bool has_share_grads = optimized_graph.has_share_grads(); bool has_share_grads = optimized_graph.has_share_grads();
for (const OperatorDef& plain_op_def : optimized_graph.op()) { for (const OperatorDef& plain_op_def : optimized_graph.op()) {
OperatorDef op_def(plain_op_def); OperatorDef op_def(plain_op_def);
LOG(DEBUG) << "Create Operator " << plain_op_def.name() LOG(DEBUG) << "Create Operator " << plain_op_def.name()
<< ": " << plain_op_def.type(); << ": " << plain_op_def.type();
// inherit device option if necessary // inherit device option if necessary
if (!op_def.has_device_option() && has_device_option) if (!op_def.has_device_option() && has_device_option)
op_def.mutable_device_option()->CopyFrom(optimized_graph.device_option()); op_def.mutable_device_option()->CopyFrom(optimized_graph.device_option());
// inherit debug mode if necessary // inherit debug mode if necessary
...@@ -316,7 +316,7 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) { ...@@ -316,7 +316,7 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) {
bool mirror_stage = ops_[i]->GetSingleArg<bool>("mirror_stage", false); bool mirror_stage = ops_[i]->GetSingleArg<bool>("mirror_stage", false);
for (auto& u : optimized_graph.op(i).input()) { for (auto& u : optimized_graph.op(i).input()) {
bool inplace_flag = false; bool inplace_flag = false;
for (auto& v : optimized_graph.op(i).output()) for (auto& v : optimized_graph.op(i).output())
if (u == v) inplace_flag = true; if (u == v) inplace_flag = true;
mirror_stage &= (!inplace_flag); mirror_stage &= (!inplace_flag);
if (!inplace_flag) multi_use_count[u]++; if (!inplace_flag) multi_use_count[u]++;
...@@ -324,7 +324,7 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) { ...@@ -324,7 +324,7 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) {
if (mirror_stage) { if (mirror_stage) {
// TODO(PhyscalX): we assume input(0)->output(0) as a in-place currently // TODO(PhyscalX): we assume input(0)->output(0) as a in-place currently
OperatorDef* op = fake_graph.mutable_op(i); OperatorDef* op = fake_graph.mutable_op(i);
if (rename_map.count(op->input(0))) if (rename_map.count(op->input(0)))
*op->mutable_input(0) = rename_map[op->input(0)]; *op->mutable_input(0) = rename_map[op->input(0)];
rename_map[op->output(0)] = op->input(0); rename_map[op->output(0)] = op->input(0);
*op->mutable_output(0) = op->input(0); *op->mutable_output(0) = op->input(0);
...@@ -339,19 +339,19 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) { ...@@ -339,19 +339,19 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) {
OperatorDef op = optimized_graph.op(i); OperatorDef op = optimized_graph.op(i);
for (int j = 0; j < op.output_size(); j++) { for (int j = 0; j < op.output_size(); j++) {
string v = op.output(j); string v = op.output(j);
string fake_v = fake_op.output(j); string fake_v = fake_op.output(j);
if (!fake_recompute_map.count(fake_v)) if (!fake_recompute_map.count(fake_v))
fake_recompute_map[fake_v] = vector<OperatorBase*>(); fake_recompute_map[fake_v] = vector<OperatorBase*>();
if (v != fake_v) { if (v != fake_v) {
if (multi_use_count[fake_v] >= 2) if (multi_use_count[fake_v] >= 2)
fake_recompute_map[fake_v] = recompute_map[fake_v]; fake_recompute_map[fake_v] = recompute_map[fake_v];
} }
fake_recompute_map[fake_v].push_back(ops_[i]); fake_recompute_map[fake_v].push_back(ops_[i]);
for (int k = 0; k < fake_recompute_map[fake_v].size(); k++) { for (int k = 0; k < fake_recompute_map[fake_v].size(); k++) {
if (!hash_map.count(v)) hash_map[v] = Set<string>(); if (!hash_map.count(v)) hash_map[v] = Set<string>();
string op_name = fake_recompute_map[fake_v][k]->name(); string op_name = fake_recompute_map[fake_v][k]->name();
if (!hash_map[v].count(op_name)) { if (!hash_map[v].count(op_name)) {
if (!recompute_map.count(v)) if (!recompute_map.count(v))
recompute_map[v] = vector<OperatorBase*>(); recompute_map[v] = vector<OperatorBase*>();
recompute_map[v].push_back(fake_recompute_map[fake_v][k]); recompute_map[v].push_back(fake_recompute_map[fake_v][k]);
hash_map[v].insert(op_name); hash_map[v].insert(op_name);
...@@ -359,7 +359,7 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) { ...@@ -359,7 +359,7 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) {
} }
} }
} }
// prepare resources // prepare resources
for (auto& ops : ops_) ops->set_recompute_map(recompute_map); for (auto& ops : ops_) ops->set_recompute_map(recompute_map);
Tensor* head = ws->CreateTensor("/opt/mirror_stage/head"); Tensor* head = ws->CreateTensor("/opt/mirror_stage/head");
...@@ -403,7 +403,7 @@ Graph::Graph(const GraphDef& meta_graph, Workspace* ws) ...@@ -403,7 +403,7 @@ Graph::Graph(const GraphDef& meta_graph, Workspace* ws)
bool Graph::Run(const string& include, const string& exclude) { bool Graph::Run(const string& include, const string& exclude) {
LOG(DEBUG) << "Run Graph: " << name(); LOG(DEBUG) << "Run Graph: " << name();
for (auto op : ops_) { for (auto op : ops_) {
if (!include.empty()) if (!include.empty())
if (op->type().find(include) == string::npos) continue; if (op->type().find(include) == string::npos) continue;
if (!exclude.empty()) if (!exclude.empty())
if (op->type().find(exclude) != string::npos) continue; if (op->type().find(exclude) != string::npos) continue;
...@@ -422,4 +422,4 @@ GraphBase* NewGraph(const GraphDef& meta_graph, Workspace* ws) { ...@@ -422,4 +422,4 @@ GraphBase* NewGraph(const GraphDef& meta_graph, Workspace* ws) {
return GraphRegistry()->Create(meta_graph.graph_type(), meta_graph, ws); return GraphRegistry()->Create(meta_graph.graph_type(), meta_graph, ws);
} }
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -112,4 +112,4 @@ void MixedMemory::SwitchToDevice() { ...@@ -112,4 +112,4 @@ void MixedMemory::SwitchToDevice() {
} }
} }
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace dragon { namespace dragon {
OperatorBase::OperatorBase(const OperatorDef& op_def, Workspace* ws) OperatorBase::OperatorBase(const OperatorDef& op_def, Workspace* ws)
: op_def_(op_def), ws_(ws) { : op_def_(op_def), ws_(ws) {
for (auto& arg : this->op_def_.arg()) { for (auto& arg : this->op_def_.arg()) {
CHECK_GT(arg.name().size(), 0); CHECK_GT(arg.name().size(), 0);
...@@ -39,7 +39,7 @@ OperatorBase* TryCreateOperator(const string& key, const OperatorDef& op_def, Wo ...@@ -39,7 +39,7 @@ OperatorBase* TryCreateOperator(const string& key, const OperatorDef& op_def, Wo
case CPU: case CPU:
return CPUOperatorRegistry()->Create(key, op_def, ws); return CPUOperatorRegistry()->Create(key, op_def, ws);
case CUDA: case CUDA:
if (op_def.device_option().has_engine() && if (op_def.device_option().has_engine() &&
op_def.device_option().engine() == "CUDNN" && op_def.device_option().engine() == "CUDNN" &&
CUDNNOperatorRegistry()->Has(key)) CUDNNOperatorRegistry()->Has(key))
return CUDNNOperatorRegistry()->Create(key, op_def, ws); return CUDNNOperatorRegistry()->Create(key, op_def, ws);
...@@ -59,15 +59,15 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws) { ...@@ -59,15 +59,15 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws) {
Gradient MakeGradientForOp(const OperatorDef& def, const vector<string>& g_outputs) { Gradient MakeGradientForOp(const OperatorDef& def, const vector<string>& g_outputs) {
unique_ptr<GradientMakerBase> maker(GradientRegistry()->Create(def.type(), def, g_outputs)); unique_ptr<GradientMakerBase> maker(GradientRegistry()->Create(def.type(), def, g_outputs));
if (maker.get() == nullptr) if (maker.get() == nullptr)
LOG(FATAL) << "Gradient maker for operator " << def.type() << "not implemented."; LOG(FATAL) << "Gradient maker for operator " << def.type() << "not implemented.";
Gradient grad = maker->Make(); Gradient grad = maker->Make();
// copy device option, engine, and arguments if needed // copy device option, engine, and arguments if needed
if (maker->CopyDeviceOption() && def.has_device_option()) if (maker->CopyDeviceOption() && def.has_device_option())
for (auto& grad_def : grad.ops) for (auto& grad_def : grad.ops)
grad_def.mutable_device_option()->CopyFrom(def.device_option()); grad_def.mutable_device_option()->CopyFrom(def.device_option());
// copy arguments if needed // copy arguments if needed
if (maker->CopyArguments() && def.arg_size()) if (maker->CopyArguments() && def.arg_size())
for (auto& grad_def : grad.ops) grad_def.mutable_arg()->MergeFrom(def.arg()); for (auto& grad_def : grad.ops) grad_def.mutable_arg()->MergeFrom(def.arg());
return grad; return grad;
} }
...@@ -95,7 +95,7 @@ void Operator<Context>::ElimateCorruption() { ...@@ -95,7 +95,7 @@ void Operator<Context>::ElimateCorruption() {
all_heads.clear(); all_heads.clear();
for (int i = 0; i < head->count(); i++) { for (int i = 0; i < head->count(); i++) {
bool safe = true; bool safe = true;
for (int j = 0; j < InputSize(); j++) for (int j = 0; j < InputSize(); j++)
if (head_data[i] == input(j).name()) safe = false; if (head_data[i] == input(j).name()) safe = false;
if (safe) safe_heads.push(i); if (safe) safe_heads.push(i);
all_heads.insert(head_data[i]); all_heads.insert(head_data[i]);
...@@ -149,7 +149,9 @@ void Operator<Context>::CleanResource() { ...@@ -149,7 +149,9 @@ void Operator<Context>::CleanResource() {
Tensor* buffer = ws()->GetTensor(used); Tensor* buffer = ws()->GetTensor(used);
if (output(i)->memory() != buffer->memory()) buffer->Move(output(i)->memory()); if (output(i)->memory() != buffer->memory()) buffer->Move(output(i)->memory());
} }
} }
// post-process for sharing grads
if (allow_share_grads_) { if (allow_share_grads_) {
// TODO(PhyscalX): we preset input(-1)->output(0) to share // TODO(PhyscalX): we preset input(-1)->output(0) to share
Tensor* dY = &input(-1); Tensor* dY = &input(-1);
...@@ -201,4 +203,4 @@ template void Operator<CUDAContext>::MakeResource(); ...@@ -201,4 +203,4 @@ template void Operator<CUDAContext>::MakeResource();
template void Operator<CPUContext>::CleanResource(); template void Operator<CPUContext>::CleanResource();
template void Operator<CUDAContext>::CleanResource(); template void Operator<CUDAContext>::CleanResource();
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -30,24 +30,25 @@ void L2NormOp<Context>::RunWithType() { ...@@ -30,24 +30,25 @@ void L2NormOp<Context>::RunWithType() {
if (across_inner) { if (across_inner) {
auto* Ndata_ = norm->template mutable_data<float, CPUContext>(); auto* Ndata_ = norm->template mutable_data<float, CPUContext>();
float sum_of_sqr = math::Dot<T, Context>(buffer->count(), Xdata, Xdata); float sum_of_sqr = math::Dot<T, Context>(buffer->count(), Xdata, Xdata);
if (mode == "MEAN") sum_of_sqr = sum_of_sqr / dim;
Ndata_[n] = pow(sum_of_sqr + eps, 0.5); Ndata_[n] = pow(sum_of_sqr + eps, 0.5);
math::Scale<T, Context>(buffer->count(), 1.0 / Ndata_[n], Xdata, Ydata); math::Scale<T, Context>(buffer->count(), 1.0 / Ndata_[n], Xdata, Ydata);
} else { } else {
math::Set<T, Context>(norm->count(), dragon_cast<T, float>(eps), Ndata); math::Set<T, Context>(norm->count(), dragon_cast<T, float>(eps), Ndata);
math::Square<T, Context>(buffer->count(), Xdata, Bdata); math::Square<T, Context>(buffer->count(), Xdata, Bdata);
// compute T1 = \sum_{i} x_{i,j}^{2} // compute T1 = \sum_{i} x_{i,j}^{2}
math::Gemv<T, Context>(CblasTrans, dim, inner_dim, math::Gemv<T, Context>(CblasTrans, dim, inner_dim,
mode == "MEAN" ? 1.0 / dim : 1.0,
Bdata, DMuldata,
1.0, 1.0,
Bdata, DMuldata,
1.0,
Ndata); Ndata);
// compute T2 = \sqrt{T1} // compute T2 = \sqrt{T1}
math::Sqrt<T, Context>(inner_dim, Ndata, Ndata); math::Sqrt<T, Context>(inner_dim, Ndata, Ndata);
// compute T3 = x / [(T2)]_{dim} // compute T3 = x / [(T2)]_{dim}
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, dim, inner_dim, 1, math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, dim, inner_dim, 1,
1.0, 1.0,
DMuldata, Ndata, DMuldata, Ndata,
0.0, 0.0,
Bdata); Bdata);
math::Div<T, Context>(buffer->count(), Xdata, Bdata, Ydata); math::Div<T, Context>(buffer->count(), Xdata, Bdata, Ydata);
Ndata += inner_dim; Ndata += inner_dim;
......
...@@ -6,30 +6,33 @@ namespace dragon { ...@@ -6,30 +6,33 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNLRNOp<Context>::RunWithType() { void CuDNNLRNOp<Context>::RunWithType() {
cudnnSetTensorDesc<T>(&input_desc, &input(0)); if (this->data_format == "NCHW") {
cudnnSetTensorDesc<T>(&output_desc, output(0)); cudnnSetTensorDesc<T>(&input_desc, &input(0));
cudnnSetTensorDesc<T>(&output_desc, output(0));
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnLRNCrossChannelForward(cudnn_handle(), CUDNN_CHECK(cudnnLRNCrossChannelForward(cudnn_handle(),
norm_desc, norm_desc,
CUDNN_LRN_CROSS_CHANNEL_DIM1, CUDNN_LRN_CROSS_CHANNEL_DIM1,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
} else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
template <class Context> template <class Context>
void CuDNNLRNOp<Context>::RunOnDevice() { void CuDNNLRNOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
if (this->mode == ACROSS_CHANNELS) { if (this->mode == "ACROSS_CHANNELS") {
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
else if (input(0).template IsType<float16>()) RunWithType<float16>(); else if (input(0).template IsType<float16>()) RunWithType<float16>();
#endif #endif
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} else { } else if (this->mode == "WITHIN_CHANNEL") {
LRNOp<Context>::RunOnDevice(); LRNOp<Context>::RunOnDevice();
} else {
LOG(FATAL) << "Unsupported lrn mode: " << this->mode;
} }
} }
...@@ -37,34 +40,38 @@ DEPLOY_CUDNN(LRN); ...@@ -37,34 +40,38 @@ DEPLOY_CUDNN(LRN);
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNLRNGradientOp<Context>::RunWithType() { void CuDNNLRNGradientOp<Context>::RunWithType() {
cudnnSetTensorDesc<T>(&input_desc, &input(-1)); if (this->data_format == "NCHW") {
cudnnSetTensorDesc<T>(&output_desc, output(0)); cudnnSetTensorDesc<T>(&input_desc, &input(-1));
cudnnSetTensorDesc<T>(&output_desc, output(0));
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = input(1).template data<T, Context>(); auto* Ydata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnLRNCrossChannelBackward(cudnn_handle(), CUDNN_CHECK(cudnnLRNCrossChannelBackward(cudnn_handle(),
norm_desc, norm_desc,
CUDNN_LRN_CROSS_CHANNEL_DIM1, CUDNN_LRN_CROSS_CHANNEL_DIM1,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, input_desc, dYdata,
output_desc, Xdata, output_desc, Xdata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
} else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
template <class Context> template <class Context>
void CuDNNLRNGradientOp<Context>::RunOnDevice() { void CuDNNLRNGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
if (this->mode == ACROSS_CHANNELS) { if (this->mode == "ACROSS_CHANNELS") {
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
else if (input(0).template IsType<float16>()) RunWithType<float16>(); else if (input(0).template IsType<float16>()) RunWithType<float16>();
#endif #endif
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} else { } else if (this->mode == "WITHIN_CHANNEL") {
LRNGradientOp<Context>::RunOnDevice(); LRNGradientOp<Context>::RunOnDevice();
} else {
LOG(FATAL) << "Unsupported lrn mode: " << this->mode;
} }
} }
......
...@@ -45,15 +45,16 @@ template <class Context> template <typename T> ...@@ -45,15 +45,16 @@ template <class Context> template <typename T>
void LRNOp<Context>::PoolRunWithType() { void LRNOp<Context>::PoolRunWithType() {
pool_out = ws()->CreateTensor("/mnt/" + anchor() + "/pool_out"); pool_out = ws()->CreateTensor("/mnt/" + anchor() + "/pool_out");
if (!pool_op) { if (!pool_op) {
Argument ks, s, p, mode; Argument ks, s, p, m, df;
ks.set_name("kernel_size"); ks.add_ints(local_size); ks.set_name("kernel_size"); ks.add_ints(local_size);
s.set_name("stride"); s.add_ints(1); s.set_name("stride"); s.add_ints(1);
p.set_name("pad"); p.add_ints((local_size - 1) / 2); p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_s("AVG"); m.set_name("mode"); m.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("Pooling", "", df.set_name("data_format"); df.set_s(data_format);
OperatorDef pool_op_def = MakeOperatorDef("Pooling2d", "",
vector<string>({ sqr_out->name() }), vector<string>({ sqr_out->name() }),
vector<string>({ pool_out->name() }), vector<string>({ pool_out->name() }),
vector<Argument>({ ks, s, p, mode })); vector<Argument>({ ks, s, p, m, df }));
if (this->op_def().has_device_option()) if (this->op_def().has_device_option())
pool_op_def.mutable_device_option()->CopyFrom(this->op_def().device_option()); pool_op_def.mutable_device_option()->CopyFrom(this->op_def().device_option());
pool_op.reset(CreateOperator(pool_op_def, ws())); pool_op.reset(CreateOperator(pool_op_def, ws()));
...@@ -99,12 +100,11 @@ void LRNOp<Context>::ProdRunWithType() { ...@@ -99,12 +100,11 @@ void LRNOp<Context>::ProdRunWithType() {
template <class Context> template <class Context>
void LRNOp<Context>::RunOnDevice() { void LRNOp<Context>::RunOnDevice() {
if (mode == ACROSS_CHANNELS) { if (mode == "ACROSS_CHANNELS") {
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
AcrossRunWithType<float>(); AcrossRunWithType<float>();
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
} } else if (mode == "WITHIN_CHANNEL") {
else {
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
SplitRunWithType<float>(); SplitRunWithType<float>();
SquareRunWithType<float>(); SquareRunWithType<float>();
...@@ -112,6 +112,8 @@ void LRNOp<Context>::RunOnDevice() { ...@@ -112,6 +112,8 @@ void LRNOp<Context>::RunOnDevice() {
PowRunWithType<float>(); PowRunWithType<float>();
ProdRunWithType<float>(); ProdRunWithType<float>();
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
} else {
LOG(FATAL) << "Unsupported lrn mode: " << mode;
} }
} }
...@@ -135,10 +137,10 @@ void LRNGradientOp<Context>::ProdRunWithType() { ...@@ -135,10 +137,10 @@ void LRNGradientOp<Context>::ProdRunWithType() {
Argument operation; Argument operation;
operation.set_name("operation"); operation.set_s("PROD"); operation.set_name("operation"); operation.set_s("PROD");
OperatorDef prod_op_def = MakeOperatorDef("EltwiseGradient", "", OperatorDef prod_op_def = MakeOperatorDef("EltwiseGradient", "",
vector<string>({ prod_in->name(), vector<string>({ prod_in->name(),
pow_out->name(), pow_out->name(),
input(-1).name() }), input(-1).name() }),
vector<string>({ prod_in->name() + "_grad", vector<string>({ prod_in->name() + "_grad",
pow_out->name() + "_grad" }), pow_out->name() + "_grad" }),
vector<Argument>({ operation })); vector<Argument>({ operation }));
if (this->op_def().has_device_option()) if (this->op_def().has_device_option())
...@@ -173,17 +175,18 @@ template <class Context> template <typename T> ...@@ -173,17 +175,18 @@ template <class Context> template <typename T>
void LRNGradientOp<Context>::PoolRunWithType() { void LRNGradientOp<Context>::PoolRunWithType() {
sqr_out = ws()->GetTensor("/mnt/" + anchor() + "/sqr_out"); sqr_out = ws()->GetTensor("/mnt/" + anchor() + "/sqr_out");
if (!pool_op) { if (!pool_op) {
Argument ks, s, p, mode; Argument ks, s, p, m, df;
ks.set_name("kernel_size"); ks.add_ints(local_size); ks.set_name("kernel_size"); ks.add_ints(local_size);
s.set_name("stride"); s.add_ints(1); s.set_name("stride"); s.add_ints(1);
p.set_name("pad"); p.add_ints((local_size - 1) / 2); p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_s("AVG"); m.set_name("mode"); m.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("PoolingGradient", "", df.set_name("data_format"); df.set_s(data_format);
OperatorDef pool_op_def = MakeOperatorDef("Pooling2dGradient", "",
vector<string>({ sqr_out->name(), vector<string>({ sqr_out->name(),
pool_out->name(), pool_out->name(),
pool_out->name() + "_grad" }), pool_out->name() + "_grad" }),
vector<string>({ sqr_out->name() + "_grad" }), vector<string>({ sqr_out->name() + "_grad" }),
vector<Argument>({ ks, s, p, mode })); vector<Argument>({ ks, s, p, m, df }));
if (this->op_def().has_device_option()) if (this->op_def().has_device_option())
pool_op_def.mutable_device_option()->CopyFrom(this->op_def().device_option()); pool_op_def.mutable_device_option()->CopyFrom(this->op_def().device_option());
pool_op.reset(CreateOperator(pool_op_def, ws())); pool_op.reset(CreateOperator(pool_op_def, ws()));
...@@ -224,12 +227,11 @@ void LRNGradientOp<Context>::SplitRunWithType() { ...@@ -224,12 +227,11 @@ void LRNGradientOp<Context>::SplitRunWithType() {
template <class Context> template <class Context>
void LRNGradientOp<Context>::RunOnDevice() { void LRNGradientOp<Context>::RunOnDevice() {
if (mode == ACROSS_CHANNELS) { if (mode == "ACROSS_CHANNELS") {
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
AcrossRunWithType<float>(); AcrossRunWithType<float>();
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
} } else if (mode == "WITHIN_CHANNEL") {
else {
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
ProdRunWithType<float>(); ProdRunWithType<float>();
PowRunWithType<float>(); PowRunWithType<float>();
...@@ -237,6 +239,8 @@ void LRNGradientOp<Context>::RunOnDevice() { ...@@ -237,6 +239,8 @@ void LRNGradientOp<Context>::RunOnDevice() {
SquareRunWithType<float>(); SquareRunWithType<float>();
SplitRunWithType<float>(); SplitRunWithType<float>();
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
} else {
LOG(FATAL) << "Unsupported lrn mode: " << mode;
} }
} }
......
...@@ -65,7 +65,35 @@ void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, ...@@ -65,7 +65,35 @@ void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc,
dims[3], dims[3],
dims[1], dims[1],
dims[2])); dims[2]));
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void cudnnSetTensor4dDescWithGroup(cudnnTensorDescriptor_t* desc,
const string& data_format,
const vector<TIndex>& dims,
const TIndex group) {
if (data_format == "NCHW") {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, CUDNNType<T>::type,
dims[0],
dims[1] / group,
dims[2],
dims[3],
dims[1] * dims[2] * dims[3],
dims[2] * dims[3],
dims[3],
1));
} else if (data_format == "NHWC") {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, CUDNNType<T>::type,
dims[0],
dims[3] / group,
dims[1],
dims[2],
dims[1] * dims[2] * dims[3],
1,
dims[2] * dims[3],
dims[3]));
} else LOG(FATAL) << "Unknown data format: " << data_format;
} }
template <typename T> template <typename T>
...@@ -87,7 +115,7 @@ void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, ...@@ -87,7 +115,7 @@ void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc,
5, 5,
fake_dims.data(), fake_dims.data(),
fake_strides.data())); fake_strides.data()));
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
} }
template <typename T> template <typename T>
...@@ -169,6 +197,7 @@ template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<T ...@@ -169,6 +197,7 @@ template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<T
template void cudnnSetTensor4dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor4dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor5dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor3dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor4dDescWithGroup<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&, const TIndex);
template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
...@@ -180,6 +209,7 @@ template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector< ...@@ -180,6 +209,7 @@ template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<
template void cudnnSetTensor4dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor4dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor5dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor3dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor4dDescWithGroup<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&, const TIndex);
template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
...@@ -192,9 +222,10 @@ template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector ...@@ -192,9 +222,10 @@ template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector
template void cudnnSetTensor4dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor4dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor5dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor3dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor4dDescWithGroup<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&, const TIndex);
template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
#endif #endif
} // namespace dragon } // namespace dragon
#endif // WITH_CUDNN #endif // WITH_CUDNN
\ No newline at end of file
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!