Commit d1f714ea by Ting PAN

Apply the dispatcher to RunImpl

1 parent bd84b7fd
Showing with 668 additions and 868 deletions
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "core/types.h" #include "core/types.h"
#include "proto/dragon.pb.h" #include "proto/dragon.pb.h"
#include "utils/string.h"
#include "utils/logging.h" #include "utils/logging.h"
namespace dragon { namespace dragon {
......
...@@ -85,6 +85,8 @@ GraphBase* NewGraph( ...@@ -85,6 +85,8 @@ GraphBase* NewGraph(
const GraphDef& def, const GraphDef& def,
Workspace* ws); Workspace* ws);
/* Macros */
DECLARE_REGISTRY( DECLARE_REGISTRY(
GraphRegistry, GraphRegistry,
GraphBase, GraphBase,
......
...@@ -43,7 +43,7 @@ class GraphGradientMaker { ...@@ -43,7 +43,7 @@ class GraphGradientMaker {
bool CheckGrad( bool CheckGrad(
const OperatorDef& forward_op, const OperatorDef& forward_op,
const Set<string>& targets, const Set<string>& targets,
vector< pair<string, int> >& gen_grads); vector<pair<string, int>>& gen_grads);
string GetOperatorName(); string GetOperatorName();
......
...@@ -100,7 +100,7 @@ class OperatorBase { ...@@ -100,7 +100,7 @@ class OperatorBase {
/*! \brief Return the specified argument */ /*! \brief Return the specified argument */
const Argument& arg(const string& name) { return *(args_[name]); } const Argument& arg(const string& name) { return *(args_[name]); }
typedef Map<string, vector<OperatorBase*> > SubGraph; typedef Map<string, vector<OperatorBase*>> SubGraph;
/*! \brief Return the recomputing subgraph of this operator */ /*! \brief Return the recomputing subgraph of this operator */
SubGraph& subgraph() { return subgraph_; } SubGraph& subgraph() { return subgraph_; }
...@@ -221,7 +221,7 @@ OperatorBase* NewOperator( ...@@ -221,7 +221,7 @@ OperatorBase* NewOperator(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws); Workspace* ws);
/*! Macros */ /* Macros */
#define OpArg OperatorBase::Arg #define OpArg OperatorBase::Arg
#define OpArgs OperatorBase::Args #define OpArgs OperatorBase::Args
...@@ -266,7 +266,7 @@ DECLARE_REGISTRY( ...@@ -266,7 +266,7 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
Workspace*); Workspace*);
/*! NVIDIA's Accelerated Library - CUDNN */ /* NVIDIA's Accelerated Library - CUDNN */
DECLARE_REGISTRY( DECLARE_REGISTRY(
CUDNNOperatorRegistry, CUDNNOperatorRegistry,
...@@ -274,7 +274,7 @@ DECLARE_REGISTRY( ...@@ -274,7 +274,7 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
Workspace*); Workspace*);
/*! CAMBRICON's Accelerated Library - CNML */ /* CAMBRICON's Accelerated Library - CNML */
DECLARE_REGISTRY( DECLARE_REGISTRY(
CNMLOperatorRegistry, CNMLOperatorRegistry,
...@@ -282,13 +282,60 @@ DECLARE_REGISTRY( ...@@ -282,13 +282,60 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
Workspace*); Workspace*);
/* Dispatcher for Runtime Typed-Implementation */
#define XIsType(x, dtype) \
x.template IsType<dtype>()
template <typename... Types>
struct TensorTypes {};
template <typename Sizes, typename... Args>
struct DispatchHelper;
#define DEFINE_TENSOR_TYPES_DISPATCHER(TensorTypes, Impl) \
template <typename T, typename... Types, typename... Args> \
struct DispatchHelper<TensorTypes<T, Types...>, Args...> { \
template <typename Op> \
static void Call(Op* op, const TypeMeta& meta, string& types) { \
if (meta.Match<T>()) return op->template Impl<T, Args...>(); \
types += " * " + TypeToString<T>() + ",\n"; \
return DispatchHelper<TensorTypes<Types...>, Args...> \
::Call(op, meta, types); \
} \
template <typename Op> \
static void Call(Op* op, const Tensor& tensor) { \
string types; return Call(op, tensor.meta(), types); \
} \
}; \
template <typename... Args> \
struct DispatchHelper<TensorTypes<>, Args...> { \
template <typename Op> \
static void Call(Op* op, const TypeMeta& meta, string& types) { \
LOG(FATAL) << "Unsupported DType: " \
<< TypeMetaToString(meta) << "\n" \
<< "<" << op->type() << "Op>" \
<< " supports the following dtypes: {\n" \
<< types << "}"; \
} \
template <typename Op> \
static void Call(Op* op, const Tensor& tensor) { \
return Call(op, tensor.meta(), ""); \
} \
};
DEFINE_TENSOR_TYPES_DISPATCHER(TensorTypes, RunImpl);
#undef DEFINE_TENSOR_TYPES_DISPATCHER
/* TensorFiller */
#define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \ #define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \
if (tensor.count() == 0) { \ if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \ CHECK(ws()->GetFiller(tensor.name())) \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \ << "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "may be specify a filler for it ?"; \ << "may be specify a filler for it ?"; \
tensor.Reshape(shape); \ tensor.Reshape(shape); \
unique_ptr< Filler<type, Context> > filler( \ unique_ptr<Filler<type, Context>> filler( \
CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \ CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \ filler->Fill(&tensor, ctx()); \
} else { \ } else { \
...@@ -308,7 +355,7 @@ DECLARE_REGISTRY( ...@@ -308,7 +355,7 @@ DECLARE_REGISTRY(
<< "\nTensor(" << tensor.name() << ") is empty. \n" \ << "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "may be specify a filler for it ?"; \ << "may be specify a filler for it ?"; \
tensor.Reshape(shape); \ tensor.Reshape(shape); \
unique_ptr< Filler<T, Context> > filler( \ unique_ptr<Filler<T, Context>> filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \ CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \ filler->Fill(&tensor, ctx()); \
} else { \ } else { \
...@@ -322,6 +369,8 @@ DECLARE_REGISTRY( ...@@ -322,6 +369,8 @@ DECLARE_REGISTRY(
tensor.Reshape(shape); \ tensor.Reshape(shape); \
} }
/* Shared Multiplier */
#define DECLARE_MULTIPLIER(name, size) \ #define DECLARE_MULTIPLIER(name, size) \
const T* name; \ const T* name; \
{ \ { \
...@@ -335,6 +384,8 @@ DECLARE_REGISTRY( ...@@ -335,6 +384,8 @@ DECLARE_REGISTRY(
name = mp->template data<T, Context>(); \ name = mp->template data<T, Context>(); \
} }
/* Dynamic Arguments */
#define DECLARE_ARG_WITH_DESC(type, arg) \ #define DECLARE_ARG_WITH_DESC(type, arg) \
type arg##_; \ type arg##_; \
string arg##_desc_; \ string arg##_desc_; \
...@@ -393,8 +444,7 @@ DECLARE_REGISTRY( ...@@ -393,8 +444,7 @@ DECLARE_REGISTRY(
#define GET_ARGS_SIZE(arg) \ #define GET_ARGS_SIZE(arg) \
(int)std::max(arg##_.size(), arg##_desc_.size()) (int)std::max(arg##_.size(), arg##_desc_.size())
#define XIsType(x, dtype) \ /* Registers */
x.template IsType<dtype>()
#define INSTANTIATE_OPERATOR(name, context) \ #define INSTANTIATE_OPERATOR(name, context) \
template class name##Op<context>; template class name##Op<context>;
......
...@@ -42,7 +42,7 @@ class OpSchema { ...@@ -42,7 +42,7 @@ class OpSchema {
return *this; return *this;
} }
OpSchema& Inplace(set<pair<int, int> > inplace); OpSchema& Inplace(set<pair<int, int>> inplace);
std::function<bool(int, int)> CheckInplace; std::function<bool(int, int)> CheckInplace;
bool AllowInplace() const { return allow_inplace_; } bool AllowInplace() const { return allow_inplace_; }
......
...@@ -73,6 +73,11 @@ inline const std::string TypeMetaToString( ...@@ -73,6 +73,11 @@ inline const std::string TypeMetaToString(
m2s_type_map[meta.id()] : "unknown"; m2s_type_map[meta.id()] : "unknown";
} }
template<typename T>
inline const std::string TypeToString() {
return TypeMetaToString(TypeMeta::Make<T>());
}
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_TYPES_H_ #endif // DRAGON_CORE_TYPES_H_
\ No newline at end of file
...@@ -13,22 +13,18 @@ ...@@ -13,22 +13,18 @@
#ifndef DRAGON_CORE_WORKSPACE_H_ #ifndef DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_ #define DRAGON_CORE_WORKSPACE_H_
#include "core/common.h"
#include "core/graph.h" #include "core/graph.h"
#include "utils/string.h"
namespace dragon { namespace dragon {
class Workspace { class Workspace {
public: public:
typedef Map<string, Map<string, int64_t> > DummyNameMap; typedef Map<string, Map<string, int64_t>> DummyNameMap;
typedef Map<string, unique_ptr<Tensor>> TensorMap;
typedef Map<string, unique_ptr<Tensor> > TensorMap;
typedef Map<string, string> TensorAliasMap; typedef Map<string, string> TensorAliasMap;
typedef Map<string, TensorFillerProto> TensorFillerMap; typedef Map<string, TensorFillerProto> TensorFillerMap;
typedef Map<string, unique_ptr<OperatorBase>> OperatorMap;
typedef Map<string, unique_ptr<OperatorBase> > OperatorMap; typedef Map<string, unique_ptr<GraphBase>> GraphMap;
typedef Map<string, unique_ptr<GraphBase> > GraphMap;
/*! \brief Constructor */ /*! \brief Constructor */
Workspace(const string& name) : name_(name) { Initialize(); } Workspace(const string& name) : name_(name) { Initialize(); }
......
...@@ -28,6 +28,7 @@ class FullyConnectedOp final : public Operator<Context> { ...@@ -28,6 +28,7 @@ class FullyConnectedOp final : public Operator<Context> {
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice(); void RunOnDevice();
template <typename T> void RunImpl();
template <typename T> void TransRunImpl(); template <typename T> void TransRunImpl();
template <typename T> void NoTransRunImpl(); template <typename T> void NoTransRunImpl();
......
...@@ -22,6 +22,7 @@ class MultinomialOp final : public Operator<Context> { ...@@ -22,6 +22,7 @@ class MultinomialOp final : public Operator<Context> {
public: public:
MultinomialOp(const OperatorDef& def, Workspace* ws) MultinomialOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
eps_(OpArg<float>("eps", 0.f)),
normalize_(OpArg<int64_t>("normalize", 0)), normalize_(OpArg<int64_t>("normalize", 0)),
num_samples_(OpArg<int64_t>("num_samples", 1)) {} num_samples_(OpArg<int64_t>("num_samples", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -32,6 +33,7 @@ class MultinomialOp final : public Operator<Context> { ...@@ -32,6 +33,7 @@ class MultinomialOp final : public Operator<Context> {
template <typename T> void RunImpl(); template <typename T> void RunImpl();
protected: protected:
float eps_;
int64_t outer_dim_, axis_; int64_t outer_dim_, axis_;
int64_t normalize_, num_samples_; int64_t normalize_, num_samples_;
unique_ptr<OperatorBase> softmax_op_; unique_ptr<OperatorBase> softmax_op_;
......
...@@ -26,22 +26,24 @@ inline void LoadCaffeModel( ...@@ -26,22 +26,24 @@ inline void LoadCaffeModel(
LOG(INFO) << "Restore From Model @: " << file << "......"; LOG(INFO) << "Restore From Model @: " << file << "......";
LOG(INFO) << "Model Format: CaffeModel"; LOG(INFO) << "Model Format: CaffeModel";
for (int i = 0; i < net_param.layer_size(); i++) { for (int i = 0; i < net_param.layer_size(); i++) {
const LayerParameter& layer = net_param.layer(i); const auto& layer = net_param.layer(i);
const string& layer_name = layer.name(); const auto& layer_name = layer.name();
string prefix = layer_name + "/param:"; auto prefix = layer_name + "/param:";
for (int j = 0; j < layer.blobs_size(); j++) { for (int j = 0; j < layer.blobs_size(); j++) {
string tensor_name = prefix + std::to_string(j); auto tensor_name = prefix + std::to_string(j);
if (!ws->HasTensor(tensor_name)) if (!ws->HasTensor(tensor_name)) {
LOG(WARNING) << "Tensor(" << tensor_name << ") " LOG(WARNING)
<< "does not exist in any Graphs, skip."; << "Tensor(" << tensor_name << ") "
else{ << "does not exist in any Graphs, skip.";
BlobProto blob = layer.blobs(j); } else {
vector<int64_t> dims; auto blob = layer.blobs(j);
for (auto dim : blob.shape().dim()) dims.push_back(dim); vec64_t tensor_shape;
Tensor* tensor = ws->GetTensor(tensor_name); for (auto dim : blob.shape().dim())
tensor_shape.push_back(dim);
auto* tensor = ws->GetTensor(tensor_name);
std::stringstream DimString; std::stringstream DimString;
if (dims.size() > 0) { if (tensor_shape.size() > 0) {
tensor->Reshape(dims); tensor->Reshape(tensor_shape);
CHECK_EQ(tensor->count(), blob.data_size()) CHECK_EQ(tensor->count(), blob.data_size())
<< "\nTensor(" << tensor_name << ") " << "\nTensor(" << tensor_name << ") "
<< "failed to load, except size: " << "failed to load, except size: "
...@@ -52,9 +54,9 @@ inline void LoadCaffeModel( ...@@ -52,9 +54,9 @@ inline void LoadCaffeModel(
tensor->Reshape({ blob.data_size() }); tensor->Reshape({ blob.data_size() });
DimString << "(missing)"; DimString << "(missing)";
} }
float* Xdata = tensor->mutable_data<float, CPUContext>(); auto* x = tensor->mutable_data<float, CPUContext>();
for (int idx = 0; idx < blob.data_size(); idx++) for (int xi = 0; xi < blob.data_size(); ++xi)
Xdata[idx] = blob.data(idx); x[xi] = blob.data(xi);
LOG(INFO) << "Tensor(" << tensor_name << ") " LOG(INFO) << "Tensor(" << tensor_name << ") "
<< "loaded, shape: " << DimString.str() << "loaded, shape: " << DimString.str()
<< ", size: " << blob.data_size(); << ", size: " << blob.data_size();
...@@ -66,32 +68,33 @@ inline void LoadCaffeModel( ...@@ -66,32 +68,33 @@ inline void LoadCaffeModel(
inline void SavaCaffeModel( inline void SavaCaffeModel(
string file, string file,
const vector<Tensor*>& tensors) { const vector<Tensor*>& tensors) {
NetParameter net_param; int j = -1;
NetParameter net;
Map<string, int> layer_hash; Map<string, int> layer_hash;
int layer_idx = -1;
for (int i = 0; i < tensors.size(); i++) { for (int i = 0; i < tensors.size(); i++) {
if (tensors[i]->count() <= 0) continue; if (tensors[i]->count() <= 0) continue;
vector<string> splits = str::split( auto splits = str::split(
tensors[i]->name(), "/param:"); tensors[i]->name(), "/param:");
if (layer_hash.count(splits[0]) == 0) { if (layer_hash.count(splits[0]) == 0) {
layer_hash[splits[0]] = ++layer_idx; layer_hash[splits[0]] = ++j;
LayerParameter* layer = net_param.add_layer(); auto* layer = net.add_layer();
layer->set_name(splits[0]); layer->set_name(splits[0]);
} }
BlobProto* blob = net_param.mutable_layer(layer_idx)->add_blobs(); auto* blob = net.mutable_layer(j)->add_blobs();
for (auto dim : tensors[i]->dims()) blob->mutable_shape()->add_dim(dim); for (auto dim : tensors[i]->dims())
blob->mutable_shape()->add_dim(dim);
if (XIsType((*tensors[i]), float)) { if (XIsType((*tensors[i]), float)) {
auto* Xdata = tensors[i]->data<float, CPUContext>(); auto* x = tensors[i]->data<float, CPUContext>();
for (int id = 0; id < tensors[i]->count(); id++) for (int xi = 0; xi < tensors[i]->count(); ++xi)
blob->mutable_data()->Add(Xdata[id]); blob->mutable_data()->Add(x[xi]);
} else if (XIsType((*tensors[i]), float16)) { } else if (XIsType((*tensors[i]), float16)) {
auto* Xdata = tensors[i]->data<float16, CPUContext>(); auto* x = tensors[i]->data<float16, CPUContext>();
for (int id = 0; id < tensors[i]->count(); id++) for (int xi = 0; xi < tensors[i]->count(); ++xi)
blob->mutable_data()->Add( blob->mutable_data()->Add(
cast::to<float>(Xdata[id])); cast::to<float>(x[xi]));
} }
} }
WriteProtoToBinaryFile(net_param, file.c_str()); WriteProtoToBinaryFile(net, file.c_str());
LOG(INFO) << "Save the model @: " << file << "......"; LOG(INFO) << "Save the model @: " << file << "......";
LOG(INFO) << "Model format: Caffe"; LOG(INFO) << "Model format: Caffe";
} }
......
...@@ -748,7 +748,7 @@ def Arange(start, stop=None, step=1, dtype='float32', **kwargs): ...@@ -748,7 +748,7 @@ def Arange(start, stop=None, step=1, dtype='float32', **kwargs):
@OpSchema.Inputs(1) @OpSchema.Inputs(1)
def Multinomial(inputs, num_samples=1, normalize=False, **kwargs): def Multinomial(inputs, num_samples=1, eps=0., normalize=False, **kwargs):
"""Return a tensor where each row contains ``num_samples``, """Return a tensor where each row contains ``num_samples``,
sampled from the multinomial distribution. sampled from the multinomial distribution.
...@@ -765,6 +765,8 @@ def Multinomial(inputs, num_samples=1, normalize=False, **kwargs): ...@@ -765,6 +765,8 @@ def Multinomial(inputs, num_samples=1, normalize=False, **kwargs):
The input tensor. The input tensor.
num_samples : int, optional, default=1 num_samples : int, optional, default=1
The number of samples. The number of samples.
eps : float, optional, default=0.
The prob to a uniform sampling.
normalize : boolean, optional, default=False normalize : boolean, optional, default=False
Whether to normalize the inputs. Whether to normalize the inputs.
......
...@@ -987,7 +987,7 @@ def one_hot(input, depth): ...@@ -987,7 +987,7 @@ def one_hot(input, depth):
return module.forward(input) return module.forward(input)
def multinomial(input, num_samples, out=None): def multinomial(input, num_samples, eps=0., out=None):
"""Return a tensor where each row contains ``num_samples``, """Return a tensor where each row contains ``num_samples``,
sampled from the multinomial distribution. sampled from the multinomial distribution.
...@@ -997,8 +997,8 @@ def multinomial(input, num_samples, out=None): ...@@ -997,8 +997,8 @@ def multinomial(input, num_samples, out=None):
The input tensor. The input tensor.
num_samples : int num_samples : int
The number of samples. The number of samples.
normalize : boolean, optional, default=False eps : float, optional, default=0.
Whether to normalize the inputs. The prob to a uniform sampling.
Returns Returns
------- -------
...@@ -1008,9 +1008,11 @@ def multinomial(input, num_samples, out=None): ...@@ -1008,9 +1008,11 @@ def multinomial(input, num_samples, out=None):
""" """
dev = MakeDevice(inputs=[input]) dev = MakeDevice(inputs=[input])
key = 'Multinomial/{}' \ key = 'Multinomial/{}' \
'/num_samples:{}'.format(dev, num_samples) '/num_samples:{}' \
'/eps:{}'.format(dev, num_samples, eps)
module = get_module( module = get_module(
Multinomial, key, dev, Multinomial, key, dev,
eps=eps,
num_samples=num_samples, num_samples=num_samples,
) )
return module.forward(input, out) return module.forward(input, out)
......
...@@ -377,6 +377,7 @@ class Cast(BaseModule): ...@@ -377,6 +377,7 @@ class Cast(BaseModule):
class Multinomial(BaseModule): class Multinomial(BaseModule):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Multinomial, self).__init__(key, dev, **kwargs) super(Multinomial, self).__init__(key, dev, **kwargs)
self.eps = kwargs.get('eps', 0)
self.num_samples = kwargs.get('num_samples', 1) self.num_samples = kwargs.get('num_samples', 1)
self.register_op() self.register_op()
...@@ -384,6 +385,7 @@ class Multinomial(BaseModule): ...@@ -384,6 +385,7 @@ class Multinomial(BaseModule):
self.op_meta = { self.op_meta = {
'op_type': 'Multinomial', 'op_type': 'Multinomial',
'arguments': { 'arguments': {
'eps': float(self.eps),
'num_samples': self.num_samples, 'num_samples': self.num_samples,
'normalize': False, 'normalize': False,
}, },
......
...@@ -980,7 +980,7 @@ class Tensor(object): ...@@ -980,7 +980,7 @@ class Tensor(object):
""" """
raise NotImplementedError('Refer torch.ops.tensor.normal_') raise NotImplementedError('Refer torch.ops.tensor.normal_')
def multinomial(self, num_samples, normalize=False): def multinomial(self, num_samples, eps=0.):
"""Return a tensor where each row contains ``num_samples``, """Return a tensor where each row contains ``num_samples``,
sampled from the multinomial distribution. sampled from the multinomial distribution.
...@@ -988,8 +988,8 @@ class Tensor(object): ...@@ -988,8 +988,8 @@ class Tensor(object):
---------- ----------
num_samples : int num_samples : int
The number of samples. The number of samples.
normalize : boolean, optional, default=False eps : float, optional, default=0.
Whether to normalize the inputs. The prob to a uniform sampling.
Returns Returns
------- -------
......
...@@ -81,8 +81,8 @@ void _ApplyNMS( ...@@ -81,8 +81,8 @@ void _ApplyNMS(
CUDA_CHECK(cudaMemcpy(boxes_dev, boxes, CUDA_CHECK(cudaMemcpy(boxes_dev, boxes,
boxes_nbytes, cudaMemcpyHostToDevice)); boxes_nbytes, cudaMemcpyHostToDevice));
nms_mask<T> nms_mask<T>
<< < blocks, NMS_BLOCK_SIZE, <<< blocks, NMS_BLOCK_SIZE,
0, ctx->cuda_stream() >> > (num_boxes, 0, ctx->cuda_stream() >>> (num_boxes,
thresh, (T*)boxes_dev, (uint64_t*)mask_dev); thresh, (T*)boxes_dev, (uint64_t*)mask_dev);
ctx->FinishDeviceCompution(); ctx->FinishDeviceCompution();
......
...@@ -347,7 +347,7 @@ inline void CollectRoIs( ...@@ -347,7 +347,7 @@ inline void CollectRoIs(
const int canonical_level, const int canonical_level,
const int canonical_scale, const int canonical_scale,
const T* rois, const T* rois,
vector< vector<int64_t> >& roi_bins) { vector<vec64_t>& roi_bins) {
const T* roi = rois; const T* roi = rois;
for (int i = 0; i < num_rois; ++i) { for (int i = 0; i < num_rois; ++i) {
int bin_idx = roi_level(min_level, max_level, int bin_idx = roi_level(min_level, max_level,
...@@ -360,7 +360,7 @@ inline void CollectRoIs( ...@@ -360,7 +360,7 @@ inline void CollectRoIs(
template <typename T> template <typename T>
inline void DistributeRoIs( inline void DistributeRoIs(
const vector< vector<int64_t> >& roi_bins, const vector<vec64_t>& roi_bins,
const T* rois, const T* rois,
vector<T*> outputs) { vector<T*> outputs) {
for (int i = 0; i < roi_bins.size(); i++) { for (int i = 0; i < roi_bins.size(); i++) {
......
...@@ -123,7 +123,7 @@ Graph::Graph(const GraphDef& def, Workspace* ws) ...@@ -123,7 +123,7 @@ Graph::Graph(const GraphDef& def, Workspace* ws)
// Recomputing-aware // Recomputing-aware
if (subgraph_indices.size() > 0) { if (subgraph_indices.size() > 0) {
Map< string, vector<OperatorBase*> > subgraph; Map<string, vector<OperatorBase*>> subgraph;
for (const auto& it : subgraph_indices) { for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>(); subgraph[it.first] = vector<OperatorBase*>();
for (const auto& idx : subgraph_indices[it.first]) for (const auto& idx : subgraph_indices[it.first])
......
...@@ -7,7 +7,7 @@ namespace dragon { ...@@ -7,7 +7,7 @@ namespace dragon {
bool GraphGradientMaker::CheckGrad( bool GraphGradientMaker::CheckGrad(
const OperatorDef& forward_op, const OperatorDef& forward_op,
const Set<string>& targets, const Set<string>& targets,
vector< pair<string, int> >& gen_grads) { vector<pair<string, int>>& gen_grads) {
if (NoGradientRegistry()->Has(forward_op.type())) { if (NoGradientRegistry()->Has(forward_op.type())) {
for (auto& input : forward_op.input()) for (auto& input : forward_op.input())
blacklist_set_.insert(input); blacklist_set_.insert(input);
...@@ -81,7 +81,7 @@ void GraphGradientMaker::Make( ...@@ -81,7 +81,7 @@ void GraphGradientMaker::Make(
for (int i = (int)forward_def.size() - 1; i >= 0; --i) { for (int i = (int)forward_def.size() - 1; i >= 0; --i) {
// Collect inputs & outputs, generate RAW grad ops // Collect inputs & outputs, generate RAW grad ops
const OperatorDef& op = *forward_def[i]; const OperatorDef& op = *forward_def[i];
vector< pair<string, int> > gen_grads; vector<pair<string, int>> gen_grads;
bool is_skip = CheckGrad(op, targets_set, gen_grads); bool is_skip = CheckGrad(op, targets_set, gen_grads);
vector<string> g_outputs; vector<string> g_outputs;
for (auto& output : op.output()) { for (auto& output : op.output()) {
...@@ -214,7 +214,7 @@ void GraphGradientMaker::Make( ...@@ -214,7 +214,7 @@ void GraphGradientMaker::Make(
GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
Set<int> invalid_ops; Set<int> invalid_ops;
Map<string, int> ref_count; Map<string, int> ref_count;
Map< string, pair<int, string> > ssa_map; Map<string, pair<int, string>> ssa_map;
// Count the refs for detecting leaf nodes // Count the refs for detecting leaf nodes
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < input_def.op_size(); ++i) {
const OperatorDef& op = input_def.op(i); const OperatorDef& op = input_def.op(i);
......
...@@ -174,7 +174,7 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -174,7 +174,7 @@ GraphDef GraphOptimizer::MirrorStage(
const GraphDef& input_def, const GraphDef& input_def,
Map<string, vec32_t >& op_indices) { Map<string, vec32_t >& op_indices) {
GraphDef output_def(input_def); GraphDef output_def(input_def);
Map<string, set<int> > fake_op_indices; Map<string, set<int>> fake_op_indices;
Map<string, string> rename_map; Map<string, string> rename_map;
Map<string, int> versions; Map<string, int> versions;
......
...@@ -54,7 +54,7 @@ OpSchema& OpSchema::NumOutputs(int n) { ...@@ -54,7 +54,7 @@ OpSchema& OpSchema::NumOutputs(int n) {
return NumOutputs(n, n); return NumOutputs(n, n);
} }
OpSchema& OpSchema::Inplace(set< pair<int, int> > inplace) { OpSchema& OpSchema::Inplace(set<pair<int, int>> inplace) {
CheckInplace = [inplace](int in, int out)->bool { CheckInplace = [inplace](int in, int out)->bool {
return (inplace.count(std::make_pair(in, out)) > 0); return (inplace.count(std::make_pair(in, out)) > 0);
}; };
......
...@@ -37,14 +37,10 @@ template<> void Dropout<float, CUDAContext>( ...@@ -37,14 +37,10 @@ template<> void Dropout<float, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
auto thresh = (uint32_t)(UINT_MAX * prob); auto thresh = (uint32_t)(UINT_MAX * prob);
math::RandomUniform( math::RandomUniform(count, 0.f, 1.f, mask32, ctx);
count,
0.f, (float)UINT_MAX,
mask32, ctx
);
_Dropout _Dropout
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
thresh, thresh,
scale, scale,
...@@ -85,14 +81,10 @@ template<> void Dropout<float16, CUDAContext>( ...@@ -85,14 +81,10 @@ template<> void Dropout<float16, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
auto thresh = (uint32_t)(UINT_MAX * prob); auto thresh = (uint32_t)(UINT_MAX * prob);
math::RandomUniform( math::RandomUniform(count, 0.f, 1.f, mask32, ctx);
count,
0.f, (float)UINT_MAX,
mask32, ctx
);
_Dropout _Dropout
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
thresh, thresh,
cast::to<half>(scale), cast::to<half>(scale),
...@@ -124,8 +116,8 @@ template <> void ApplyMask<float, uint8_t, CUDAContext>( ...@@ -124,8 +116,8 @@ template <> void ApplyMask<float, uint8_t, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_ApplyMask _ApplyMask
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, scale, x, mask, y count, scale, x, mask, y
); );
} }
...@@ -157,8 +149,8 @@ template <> void ApplyMask<float16, uint8_t, CUDAContext>( ...@@ -157,8 +149,8 @@ template <> void ApplyMask<float16, uint8_t, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_ApplyMaskHalf _ApplyMaskHalf
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
cast::to<half>(scale), cast::to<half>(scale),
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
......
...@@ -44,8 +44,8 @@ template<> void DropPath<float, CUDAContext>( ...@@ -44,8 +44,8 @@ template<> void DropPath<float, CUDAContext>(
auto nthreads = rows * cols; auto nthreads = rows * cols;
auto thresh = 1.f - (1.f / scale); auto thresh = 1.f - (1.f / scale);
_DropPath _DropPath
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, cols, thresh, scale, x, mask, y nthreads, cols, thresh, scale, x, mask, y
); );
} }
...@@ -85,8 +85,8 @@ template<> void DropPath<float16, CUDAContext>( ...@@ -85,8 +85,8 @@ template<> void DropPath<float16, CUDAContext>(
auto nthreads = rows * cols; auto nthreads = rows * cols;
auto thresh = 1.f - (1.f / scale); auto thresh = 1.f - (1.f / scale);
_DropPath _DropPath
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, cols, nthreads, cols,
thresh, thresh,
cast::to<half>(scale), cast::to<half>(scale),
......
...@@ -28,8 +28,8 @@ template<> void Elu<float, CUDAContext>( ...@@ -28,8 +28,8 @@ template<> void Elu<float, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Elu _Elu
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, x, alpha, y count, x, alpha, y
); );
} }
...@@ -58,8 +58,8 @@ template<> void EluGrad<float, CUDAContext>( ...@@ -58,8 +58,8 @@ template<> void EluGrad<float, CUDAContext>(
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_EluGrad _EluGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, alpha, dy, y, dx count, alpha, dy, y, dx
); );
} }
......
...@@ -66,21 +66,21 @@ template<> void PRelu<float, CUDAContext>( ...@@ -66,21 +66,21 @@ template<> void PRelu<float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
if (channel_shared) { if (channel_shared) {
_PRelu _PRelu
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, channels, dim, x, w, y count, channels, dim, x, w, y
); );
} else { } else {
if (data_format == "NCHW") { if (data_format == "NCHW") {
_PReluNCHW _PReluNCHW
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, channels, dim, x, w, y count, channels, dim, x, w, y
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_PReluNHWC _PReluNHWC
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, channels, dim, x, w, y count, channels, dim, x, w, y
); );
} else { } else {
...@@ -152,21 +152,21 @@ template<> void PReluGrad<float, CUDAContext>( ...@@ -152,21 +152,21 @@ template<> void PReluGrad<float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
if (channel_shared) { if (channel_shared) {
_PReluGrad _PReluGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, channels, dim, dy, x, w, dx count, channels, dim, dy, x, w, dx
); );
} else { } else {
if (data_format == "NCHW") { if (data_format == "NCHW") {
_PReluGradNCHW _PReluGradNCHW
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, channels, dim, dy, x, w, dx count, channels, dim, dy, x, w, dx
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_PReluGradNHWC _PReluGradNHWC
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, channels, dim, dy, x, w, dx count, channels, dim, dy, x, w, dx
); );
} else { } else {
...@@ -210,8 +210,8 @@ template<> void PReluWGrad<float, CUDAContext>( ...@@ -210,8 +210,8 @@ template<> void PReluWGrad<float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto cdim = channels * dim; auto cdim = channels * dim;
_PReluWGradBcast _PReluWGradBcast
<< < CUDA_BLOCKS(cdim), CUDA_THREADS, <<< CUDA_BLOCKS(cdim), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
cdim, rows, row_offset, dy, x, bcast_dw cdim, rows, row_offset, dy, x, bcast_dw
); );
if (channel_shared) { if (channel_shared) {
......
...@@ -35,8 +35,8 @@ template<> void Relu<float, CUDAContext>( ...@@ -35,8 +35,8 @@ template<> void Relu<float, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Relu _Relu
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, slope, x, y count, slope, x, y
); );
} }
...@@ -83,8 +83,8 @@ template<> void Relu<float16, CUDAContext>( ...@@ -83,8 +83,8 @@ template<> void Relu<float16, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
if ((count & 1) == 0) { if ((count & 1) == 0) {
_Relu _Relu
<< < CUDA_BLOCKS(count >> 1), CUDA_THREADS, <<< CUDA_BLOCKS(count >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count >> 1, count >> 1,
cast::to<half2>(slope), cast::to<half2>(slope),
reinterpret_cast<const half2*>(x), reinterpret_cast<const half2*>(x),
...@@ -92,8 +92,8 @@ template<> void Relu<float16, CUDAContext>( ...@@ -92,8 +92,8 @@ template<> void Relu<float16, CUDAContext>(
); );
} else { } else {
_Relu _Relu
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
cast::to<half>(slope), cast::to<half>(slope),
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
...@@ -134,8 +134,8 @@ template<> void ReluGrad<float, CUDAContext>( ...@@ -134,8 +134,8 @@ template<> void ReluGrad<float, CUDAContext>(
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_ReluGrad _ReluGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, slope, dy, y, dx count, slope, dy, y, dx
); );
} }
...@@ -170,8 +170,8 @@ template<> void ReluGrad<float16, CUDAContext>( ...@@ -170,8 +170,8 @@ template<> void ReluGrad<float16, CUDAContext>(
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_ReluGrad _ReluGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, slope, count, slope,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y), reinterpret_cast<const half*>(y),
......
...@@ -34,8 +34,8 @@ template<> void SElu<float, CUDAContext>( ...@@ -34,8 +34,8 @@ template<> void SElu<float, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_SElu _SElu
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, x, y count, x, y
); );
} }
...@@ -63,8 +63,8 @@ template<> void SElu<float16, CUDAContext>( ...@@ -63,8 +63,8 @@ template<> void SElu<float16, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_SElu _SElu
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y) reinterpret_cast<half*>(y)
...@@ -99,8 +99,8 @@ template<> void SEluGrad<float, CUDAContext>( ...@@ -99,8 +99,8 @@ template<> void SEluGrad<float, CUDAContext>(
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_SEluGrad _SEluGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, dy, y, dx count, dy, y, dx
); );
} }
...@@ -131,8 +131,8 @@ template<> void SEluGrad<float16, CUDAContext>( ...@@ -131,8 +131,8 @@ template<> void SEluGrad<float16, CUDAContext>(
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_SEluGrad _SEluGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y), reinterpret_cast<const half*>(y),
......
...@@ -25,8 +25,8 @@ template<> void Sigmoid<float, CUDAContext>( ...@@ -25,8 +25,8 @@ template<> void Sigmoid<float, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Sigmoid _Sigmoid
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, x, y count, x, y
); );
} }
...@@ -51,8 +51,8 @@ template<> void SigmoidGrad<float, CUDAContext>( ...@@ -51,8 +51,8 @@ template<> void SigmoidGrad<float, CUDAContext>(
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_SigmoidGrad _SigmoidGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, dy, y, dx count, dy, y, dx
); );
} }
......
...@@ -96,26 +96,26 @@ template<> void Softmax<float, CUDAContext>( ...@@ -96,26 +96,26 @@ template<> void Softmax<float, CUDAContext>(
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
auto nelements = num_preds * axis_dim; auto nelements = num_preds * axis_dim;
_SoftmaxReduceMax _SoftmaxReduceMax
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS, <<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
num_preds, axis_dim, inner_dim, x, scale num_preds, axis_dim, inner_dim, x, scale
); );
_SoftmaxSub _SoftmaxSub
<< < CUDA_BLOCKS(nelements), CUDA_THREADS, <<< CUDA_BLOCKS(nelements), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nelements, axis_dim, inner_dim, scale, y nelements, axis_dim, inner_dim, scale, y
); );
math::Exp(nelements, y, y, ctx); math::Exp(nelements, y, y, ctx);
_SoftmaxReduceSum _SoftmaxReduceSum
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS, <<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
num_preds, axis_dim, inner_dim, y, scale num_preds, axis_dim, inner_dim, y, scale
); );
_SoftmaxDiv _SoftmaxDiv
<< < CUDA_BLOCKS(nelements), CUDA_THREADS, <<< CUDA_BLOCKS(nelements), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nelements, axis_dim, inner_dim, scale, y nelements, axis_dim, inner_dim, scale, y
); );
} }
...@@ -159,13 +159,13 @@ template<> void SoftmaxGrad<float, CUDAContext>( ...@@ -159,13 +159,13 @@ template<> void SoftmaxGrad<float, CUDAContext>(
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
auto nelements = num_preds * axis_dim; auto nelements = num_preds * axis_dim;
_SoftmaxDot _SoftmaxDot
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS, <<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
num_preds, axis_dim, inner_dim, dy, y, scale num_preds, axis_dim, inner_dim, dy, y, scale
); );
_SoftmaxSub _SoftmaxSub
<< < CUDA_BLOCKS(nelements), CUDA_THREADS, <<< CUDA_BLOCKS(nelements), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nelements, axis_dim, inner_dim, scale, dx nelements, axis_dim, inner_dim, scale, dx
); );
math::Mul(nelements, dx, y, dx, ctx); math::Mul(nelements, dx, y, dx, ctx);
......
...@@ -25,8 +25,8 @@ template<> void Tanh<float, CUDAContext>( ...@@ -25,8 +25,8 @@ template<> void Tanh<float, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Tanh _Tanh
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, x, y count, x, y
); );
} }
...@@ -51,8 +51,8 @@ template<> void TanhGrad<float, CUDAContext>( ...@@ -51,8 +51,8 @@ template<> void TanhGrad<float, CUDAContext>(
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_TanhGrad _TanhGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, dy, y, dx count, dy, y, dx
); );
} }
......
...@@ -60,15 +60,15 @@ template<> void Affine<float, CUDAContext>( ...@@ -60,15 +60,15 @@ template<> void Affine<float, CUDAContext>(
auto nthreads = outer_dim * axis_dim * inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
if (beta != nullptr) { if (beta != nullptr) {
_Affine _Affine
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nthreads, axis_dim, inner_dim,
x, alpha, beta, y x, alpha, beta, y
); );
} else { } else {
_AffineNoBias _AffineNoBias
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, x, alpha, y nthreads, axis_dim, inner_dim, x, alpha, y
); );
} }
...@@ -124,8 +124,8 @@ template<> void Affine<float16, CUDAContext>( ...@@ -124,8 +124,8 @@ template<> void Affine<float16, CUDAContext>(
auto nthreads = outer_dim * axis_dim * inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
if (beta != nullptr) { if (beta != nullptr) {
_Affine _Affine
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nthreads, axis_dim, inner_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(alpha), reinterpret_cast<const half*>(alpha),
...@@ -134,8 +134,8 @@ template<> void Affine<float16, CUDAContext>( ...@@ -134,8 +134,8 @@ template<> void Affine<float16, CUDAContext>(
); );
} else { } else {
_AffineNoBias _AffineNoBias
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nthreads, axis_dim, inner_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(alpha), reinterpret_cast<const half*>(alpha),
...@@ -156,8 +156,8 @@ template <> void AffineGrad<float, CUDAContext>( ...@@ -156,8 +156,8 @@ template <> void AffineGrad<float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
_AffineNoBias _AffineNoBias
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, dy, alpha, dx nthreads, axis_dim, inner_dim, dy, alpha, dx
); );
} }
...@@ -174,8 +174,8 @@ template <> void AffineGrad<float16, CUDAContext>( ...@@ -174,8 +174,8 @@ template <> void AffineGrad<float16, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
_AffineNoBias _AffineNoBias
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nthreads, axis_dim, inner_dim,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(alpha), reinterpret_cast<const half*>(alpha),
......
...@@ -83,8 +83,8 @@ template<> __global__ void _ClipGrad<half>( ...@@ -83,8 +83,8 @@ template<> __global__ void _ClipGrad<half>(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Clip<T> \ _Clip<T> \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, \ count, \
cast::to<T>(low), \ cast::to<T>(low), \
cast::to<T>(high), \ cast::to<T>(high), \
...@@ -102,8 +102,8 @@ template<> __global__ void _ClipGrad<half>( ...@@ -102,8 +102,8 @@ template<> __global__ void _ClipGrad<half>(
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_ClipGrad<T> \ _ClipGrad<T> \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, \ count, \
cast::to<T>(low), \ cast::to<T>(low), \
cast::to<T>(high), \ cast::to<T>(high), \
...@@ -133,8 +133,8 @@ template <> void Clip<float16, CUDAContext>( ...@@ -133,8 +133,8 @@ template <> void Clip<float16, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Clip _Clip
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
cast::to<half>(low), cast::to<half>(low),
cast::to<half>(high), cast::to<half>(high),
...@@ -152,8 +152,8 @@ template <> void ClipGrad<float16, CUDAContext>( ...@@ -152,8 +152,8 @@ template <> void ClipGrad<float16, CUDAContext>(
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_ClipGrad _ClipGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
cast::to<half>(low), cast::to<half>(low),
cast::to<half>(high), cast::to<half>(high),
......
...@@ -139,8 +139,8 @@ template<> __global__ void _BroadcastMaximumGrad<half>( ...@@ -139,8 +139,8 @@ template<> __global__ void _BroadcastMaximumGrad<half>(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##name \ _##name \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, x1, x2, y \ count, x1, x2, y \
); \ ); \
} }
...@@ -155,8 +155,8 @@ template<> __global__ void _BroadcastMaximumGrad<half>( ...@@ -155,8 +155,8 @@ template<> __global__ void _BroadcastMaximumGrad<half>(
T* dx2, \ T* dx2, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##name \ _##name \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, x1, x2, dy, dx1, dx2 \ count, x1, x2, dy, dx1, dx2 \
); \ ); \
} }
...@@ -196,8 +196,8 @@ template <> void Maximum<float16, CUDAContext>( ...@@ -196,8 +196,8 @@ template <> void Maximum<float16, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Maximum \ _Maximum \
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x1), reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x2), reinterpret_cast<const half*>(x2),
...@@ -212,8 +212,8 @@ template <> void BroadcastMaximum<float16, CUDAContext>( ...@@ -212,8 +212,8 @@ template <> void BroadcastMaximum<float16, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_BroadcastMaximum \ _BroadcastMaximum \
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x1), reinterpret_cast<const half*>(x1),
cast::to<half>(x2), cast::to<half>(x2),
...@@ -230,8 +230,8 @@ template <> void MaximumGrad<float16, CUDAContext>( ...@@ -230,8 +230,8 @@ template <> void MaximumGrad<float16, CUDAContext>(
float16* dx2, float16* dx2,
CUDAContext* ctx) { CUDAContext* ctx) {
_MaximumGrad \ _MaximumGrad \
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x1), reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x2), reinterpret_cast<const half*>(x2),
...@@ -250,8 +250,8 @@ template <> void BroadcastMaximumGrad<float16, CUDAContext>( ...@@ -250,8 +250,8 @@ template <> void BroadcastMaximumGrad<float16, CUDAContext>(
float16* dx2, float16* dx2,
CUDAContext* ctx) { CUDAContext* ctx) {
_BroadcastMaximumGrad \ _BroadcastMaximumGrad \
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x1), reinterpret_cast<const half*>(x1),
cast::to<half>(x2), cast::to<half>(x2),
......
...@@ -139,8 +139,8 @@ template<> __global__ void _BroadcastMinimumGrad<half>( ...@@ -139,8 +139,8 @@ template<> __global__ void _BroadcastMinimumGrad<half>(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##name \ _##name \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, x1, x2, y \ count, x1, x2, y \
); \ ); \
} }
...@@ -155,8 +155,8 @@ template<> __global__ void _BroadcastMinimumGrad<half>( ...@@ -155,8 +155,8 @@ template<> __global__ void _BroadcastMinimumGrad<half>(
T* dx2, \ T* dx2, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##name \ _##name \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, x1, x2, dy, dx1, dx2 \ count, x1, x2, dy, dx1, dx2 \
); \ ); \
} }
...@@ -196,8 +196,8 @@ template <> void Minimum<float16, CUDAContext>( ...@@ -196,8 +196,8 @@ template <> void Minimum<float16, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Minimum \ _Minimum \
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x1), reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x2), reinterpret_cast<const half*>(x2),
...@@ -212,8 +212,8 @@ template <> void BroadcastMinimum<float16, CUDAContext>( ...@@ -212,8 +212,8 @@ template <> void BroadcastMinimum<float16, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_BroadcastMinimum \ _BroadcastMinimum \
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x1), reinterpret_cast<const half*>(x1),
cast::to<half>(x2), cast::to<half>(x2),
...@@ -230,8 +230,8 @@ template <> void MinimumGrad<float16, CUDAContext>( ...@@ -230,8 +230,8 @@ template <> void MinimumGrad<float16, CUDAContext>(
float16* dx2, float16* dx2,
CUDAContext* ctx) { CUDAContext* ctx) {
_MinimumGrad \ _MinimumGrad \
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x1), reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x2), reinterpret_cast<const half*>(x2),
...@@ -250,8 +250,8 @@ template <> void BroadcastMinimumGrad<float16, CUDAContext>( ...@@ -250,8 +250,8 @@ template <> void BroadcastMinimumGrad<float16, CUDAContext>(
float16* dx2, float16* dx2,
CUDAContext* ctx) { CUDAContext* ctx) {
_BroadcastMinimumGrad \ _BroadcastMinimumGrad \
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(x1), reinterpret_cast<const half*>(x1),
cast::to<half>(x2), cast::to<half>(x2),
......
...@@ -251,8 +251,8 @@ void _Moments( ...@@ -251,8 +251,8 @@ void _Moments(
ndims, x_dims, y_dims, ndims, x_dims, y_dims,
&rows, &cols)) { &rows, &cols)) {
_ColwiseMoments _ColwiseMoments
<< < CUDA_2D_BLOCKS(rows), CUDA_THREADS, <<< CUDA_2D_BLOCKS(rows), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
rows, cols, x, mean, var rows, cols, x, mean, var
); return; ); return;
} }
...@@ -262,8 +262,8 @@ void _Moments( ...@@ -262,8 +262,8 @@ void _Moments(
ndims, x_dims, y_dims, ndims, x_dims, y_dims,
&rows, &cols)) { &rows, &cols)) {
_RowwiseMoments _RowwiseMoments
<< < CUDA_2D_BLOCKS(cols), CUDA_THREADS, <<< CUDA_2D_BLOCKS(cols), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
rows, cols, x, mean, var rows, cols, x, mean, var
); return; ); return;
} }
...@@ -294,8 +294,8 @@ void _Moments( ...@@ -294,8 +294,8 @@ void _Moments(
ctx->Memcpy<CUDAContext, CPUContext>(dbytes, YDS, dimsT.data()); ctx->Memcpy<CUDAContext, CPUContext>(dbytes, YDS, dimsT.data());
_GenericMoments _GenericMoments
<< < CUDA_2D_BLOCKS(outer_dim), CUDA_THREADS, <<< CUDA_2D_BLOCKS(outer_dim), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
ndims, outer_dim, inner_dim, ndims, outer_dim, inner_dim,
XSS, YDS, x, mean, var XSS, YDS, x, mean, var
); );
......
...@@ -30,8 +30,8 @@ __global__ void _Arange( ...@@ -30,8 +30,8 @@ __global__ void _Arange(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Arange \ _Arange \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, start, step, y \ count, start, step, y \
); \ ); \
} }
...@@ -64,8 +64,8 @@ template <> void Arange<float16, CUDAContext>( ...@@ -64,8 +64,8 @@ template <> void Arange<float16, CUDAContext>(
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Arange _Arange
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, start, step, count, start, step,
reinterpret_cast<half*>(y) reinterpret_cast<half*>(y)
); );
......
...@@ -20,12 +20,12 @@ void _ArgMax( ...@@ -20,12 +20,12 @@ void _ArgMax(
for (int iix = 0; iix < inner_dim; ++iix) { for (int iix = 0; iix < inner_dim; ++iix) {
const T* X = x + (oix * axis_dim * inner_dim + iix); const T* X = x + (oix * axis_dim * inner_dim + iix);
const int y_offset = oix * top_k * inner_dim + iix; const int y_offset = oix * top_k * inner_dim + iix;
vector< pair<T, int64_t> > vec(axis_dim); vector<pair<T, int64_t>> vec(axis_dim);
for (int j = 0; j < axis_dim; ++j) for (int j = 0; j < axis_dim; ++j)
vec[j] = std::make_pair(X[j * inner_dim], j); vec[j] = std::make_pair(X[j * inner_dim], j);
std::partial_sort( std::partial_sort(
vec.begin(), vec.begin() + top_k, vec.end(), vec.begin(), vec.begin() + top_k, vec.end(),
std::greater< pair<T, int64_t> >()); std::greater<pair<T, int64_t>>());
for (int j = 0; j < top_k; ++j) { for (int j = 0; j < top_k; ++j) {
indices[y_offset + j * inner_dim] = vec[j].second; indices[y_offset + j * inner_dim] = vec[j].second;
if (values) values[y_offset + j * inner_dim] = vec[j].first; if (values) values[y_offset + j * inner_dim] = vec[j].first;
...@@ -49,7 +49,7 @@ void _ArgMin( ...@@ -49,7 +49,7 @@ void _ArgMin(
for (int iix = 0; iix < inner_dim; ++iix) { for (int iix = 0; iix < inner_dim; ++iix) {
const T* X = x + (oix * axis_dim * inner_dim + iix); const T* X = x + (oix * axis_dim * inner_dim + iix);
const int y_offset = oix * top_k * inner_dim + iix; const int y_offset = oix * top_k * inner_dim + iix;
vector< pair<T, int64_t> > vec(axis_dim); vector<pair<T, int64_t>> vec(axis_dim);
for (int j = 0; j < axis_dim; ++j) for (int j = 0; j < axis_dim; ++j)
vec[j] = std::make_pair(X[j * inner_dim], j); vec[j] = std::make_pair(X[j * inner_dim], j);
std::partial_sort(vec.begin(), vec.begin() + top_k, vec.end()); std::partial_sort(vec.begin(), vec.begin() + top_k, vec.end());
......
...@@ -133,8 +133,8 @@ template<> __global__ void _ArgMin<half>( ...@@ -133,8 +133,8 @@ template<> __global__ void _ArgMin<half>(
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1."; \ CHECK_EQ(top_k, 1) << "\nRequired top_k == 1."; \
auto nthreads = outer_dim * inner_dim; \ auto nthreads = outer_dim * inner_dim; \
_##name \ _##name \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, inner_dim, axis_dim, \ nthreads, inner_dim, axis_dim, \
x, indices, values \ x, indices, values \
); \ ); \
...@@ -168,8 +168,8 @@ template<> void ArgMax<float16, CUDAContext>( ...@@ -168,8 +168,8 @@ template<> void ArgMax<float16, CUDAContext>(
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1."; CHECK_EQ(top_k, 1) << "\nRequired top_k == 1.";
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_ArgMax _ArgMax
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, inner_dim, axis_dim, nthreads, inner_dim, axis_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
indices, indices,
...@@ -189,8 +189,8 @@ template<> void ArgMin<float16, CUDAContext>( ...@@ -189,8 +189,8 @@ template<> void ArgMin<float16, CUDAContext>(
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1."; CHECK_EQ(top_k, 1) << "\nRequired top_k == 1.";
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_ArgMin _ArgMin
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, inner_dim, axis_dim, nthreads, inner_dim, axis_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
indices, indices,
......
...@@ -43,8 +43,8 @@ __global__ void _Concat( ...@@ -43,8 +43,8 @@ __global__ void _Concat(
auto cols = axis_dim * inner_dim; \ auto cols = axis_dim * inner_dim; \
auto nthreads = outer_dim * axis_dim * inner_dim; \ auto nthreads = outer_dim * axis_dim * inner_dim; \
_##name \ _##name \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, \ nthreads, \
inner_dim, \ inner_dim, \
cols, \ cols, \
......
...@@ -83,8 +83,8 @@ __global__ void _CropGrad( ...@@ -83,8 +83,8 @@ __global__ void _CropGrad(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##name \ _##name \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, ndims, \ count, ndims, \
x_strides, y_dims, \ x_strides, y_dims, \
starts, x, y \ starts, x, y \
......
...@@ -115,8 +115,8 @@ template <> __global__ void _IndexSelectGrad<half>( ...@@ -115,8 +115,8 @@ template <> __global__ void _IndexSelectGrad<half>(
CUDAContext* ctx) { \ CUDAContext* ctx) { \
auto nthreads = outer_dim * num_indices * inner_dim; \ auto nthreads = outer_dim * num_indices * inner_dim; \
_IndexSelect \ _IndexSelect \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, inner_dim, \ nthreads, inner_dim, \
axis_dim, num_indices, \ axis_dim, num_indices, \
indices, x, y \ indices, x, y \
...@@ -135,8 +135,8 @@ template <> __global__ void _IndexSelectGrad<half>( ...@@ -135,8 +135,8 @@ template <> __global__ void _IndexSelectGrad<half>(
CUDAContext* ctx) { \ CUDAContext* ctx) { \
auto nthreads = outer_dim * inner_dim; \ auto nthreads = outer_dim * inner_dim; \
_IndexSelectGrad \ _IndexSelectGrad \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, inner_dim, \ nthreads, inner_dim, \
axis_dim, num_indices, \ axis_dim, num_indices, \
indices, dy, dx \ indices, dy, dx \
...@@ -170,8 +170,8 @@ template <> void IndexSelectGrad<float16, CUDAContext>( ...@@ -170,8 +170,8 @@ template <> void IndexSelectGrad<float16, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_IndexSelectGrad _IndexSelectGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, inner_dim, nthreads, inner_dim,
axis_dim, num_indices, axis_dim, num_indices,
indices, indices,
......
...@@ -32,8 +32,8 @@ template <> void OneHot<float, CUDAContext>( ...@@ -32,8 +32,8 @@ template <> void OneHot<float, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_OneHot _OneHot
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, depth, on_value, x, y count, depth, on_value, x, y
); );
} }
...@@ -48,8 +48,8 @@ template <> void OneHot<int, CUDAContext>( ...@@ -48,8 +48,8 @@ template <> void OneHot<int, CUDAContext>(
int* y, int* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_OneHot _OneHot
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, depth, on_value, x, y count, depth, on_value, x, y
); );
} }
...@@ -64,8 +64,8 @@ template <> void OneHot<int64_t, CUDAContext>( ...@@ -64,8 +64,8 @@ template <> void OneHot<int64_t, CUDAContext>(
int64_t* y, int64_t* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_OneHot _OneHot
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, depth, on_value, x, y count, depth, on_value, x, y
); );
} }
......
...@@ -130,8 +130,8 @@ __global__ void _EdgePad( ...@@ -130,8 +130,8 @@ __global__ void _EdgePad(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_ConstPad \ _ConstPad \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, ndims, \ count, ndims, \
x_dims, x_strides, \ x_dims, x_strides, \
y_dims, l_pads, \ y_dims, l_pads, \
...@@ -152,8 +152,8 @@ __global__ void _EdgePad( ...@@ -152,8 +152,8 @@ __global__ void _EdgePad(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##name \ _##name \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, ndims, \ count, ndims, \
x_dims, x_strides, \ x_dims, x_strides, \
y_dims, l_pads, \ y_dims, l_pads, \
......
...@@ -202,8 +202,8 @@ void _ReduceSum( ...@@ -202,8 +202,8 @@ void _ReduceSum(
ndims, x_dims, y_dims, ndims, x_dims, y_dims,
&rows, &cols)) { &rows, &cols)) {
_ColwiseReduceSum _ColwiseReduceSum
<< < CUDA_2D_BLOCKS(rows), CUDA_THREADS, <<< CUDA_2D_BLOCKS(rows), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
rows, cols, scale, x, y rows, cols, scale, x, y
); return; ); return;
} }
...@@ -213,8 +213,8 @@ void _ReduceSum( ...@@ -213,8 +213,8 @@ void _ReduceSum(
ndims, x_dims, y_dims, ndims, x_dims, y_dims,
&rows, &cols)) { &rows, &cols)) {
_RowwiseReduceSum _RowwiseReduceSum
<< < CUDA_2D_BLOCKS(cols), CUDA_THREADS, <<< CUDA_2D_BLOCKS(cols), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
rows, cols, scale, x, y rows, cols, scale, x, y
); return; ); return;
} }
...@@ -245,8 +245,8 @@ void _ReduceSum( ...@@ -245,8 +245,8 @@ void _ReduceSum(
ctx->Memcpy<CUDAContext, CPUContext>(dbytes, YDS, dimsT.data()); ctx->Memcpy<CUDAContext, CPUContext>(dbytes, YDS, dimsT.data());
_GenericReduceSum _GenericReduceSum
<< < CUDA_2D_BLOCKS(outer_dim), CUDA_THREADS, <<< CUDA_2D_BLOCKS(outer_dim), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
ndims, outer_dim, inner_dim, ndims, outer_dim, inner_dim,
XSS, YDS, scale, x, y XSS, YDS, scale, x, y
); );
...@@ -372,8 +372,8 @@ template <> __global__ void _ReduceSumGrad<half>( ...@@ -372,8 +372,8 @@ template <> __global__ void _ReduceSumGrad<half>(
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_ReduceSumGrad \ _ReduceSumGrad \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, ndim, x_dims, \ count, ndim, x_dims, \
y_dims, y_strides, \ y_dims, y_strides, \
scale, dy, dx \ scale, dy, dx \
...@@ -398,8 +398,8 @@ template<> void ReduceSumGrad<float16, CUDAContext>( ...@@ -398,8 +398,8 @@ template<> void ReduceSumGrad<float16, CUDAContext>(
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_ReduceSumGrad _ReduceSumGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, ndim, x_dims, count, ndim, x_dims,
y_dims, y_strides, y_dims, y_strides,
scale, scale,
......
...@@ -93,8 +93,8 @@ template<> __global__ void _RepeatGrad<half>( ...@@ -93,8 +93,8 @@ template<> __global__ void _RepeatGrad<half>(
auto y_inner_dim = inner_dim * repeats; \ auto y_inner_dim = inner_dim * repeats; \
auto nthreads = outer_dim * axis_dim * y_inner_dim; \ auto nthreads = outer_dim * axis_dim * y_inner_dim; \
_Repeat \ _Repeat \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, axis_dim, \ nthreads, axis_dim, \
inner_dim, y_inner_dim, \ inner_dim, y_inner_dim, \
x, y \ x, y \
...@@ -113,8 +113,8 @@ template<> __global__ void _RepeatGrad<half>( ...@@ -113,8 +113,8 @@ template<> __global__ void _RepeatGrad<half>(
auto y_inner_dim = inner_dim * repeats; \ auto y_inner_dim = inner_dim * repeats; \
auto nthreads = outer_dim * axis_dim * inner_dim; \ auto nthreads = outer_dim * axis_dim * inner_dim; \
_RepeatGrad \ _RepeatGrad \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, \ nthreads, \
axis_dim, \ axis_dim, \
inner_dim, \ inner_dim, \
...@@ -151,8 +151,8 @@ template<> void RepeatGrad<float16, CUDAContext>( ...@@ -151,8 +151,8 @@ template<> void RepeatGrad<float16, CUDAContext>(
auto y_inner_dim = inner_dim * repeats; auto y_inner_dim = inner_dim * repeats;
auto nthreads = outer_dim * axis_dim * inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
_RepeatGrad _RepeatGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
axis_dim, axis_dim,
inner_dim, inner_dim,
......
...@@ -64,8 +64,8 @@ __global__ void _SliceGrad( ...@@ -64,8 +64,8 @@ __global__ void _SliceGrad(
auto cols = slice_dim * inner_dim; \ auto cols = slice_dim * inner_dim; \
auto nthreads = outer_dim * cols; \ auto nthreads = outer_dim * cols; \
_##name \ _##name \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, \ nthreads, \
inner_dim, \ inner_dim, \
axis_dim, \ axis_dim, \
...@@ -126,8 +126,8 @@ template <> void SliceGrad<float16, CUDAContext>( ...@@ -126,8 +126,8 @@ template <> void SliceGrad<float16, CUDAContext>(
auto cols = slice_dim * inner_dim; auto cols = slice_dim * inner_dim;
auto nthreads = outer_dim * cols; auto nthreads = outer_dim * cols;
_SliceGrad _SliceGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
inner_dim, inner_dim,
axis_dim, axis_dim,
......
...@@ -98,8 +98,8 @@ template<> __global__ void _TileGrad<half>( ...@@ -98,8 +98,8 @@ template<> __global__ void _TileGrad<half>(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Tile \ _Tile \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, \ count, \
ndims, \ ndims, \
x_dims, \ x_dims, \
...@@ -120,8 +120,8 @@ template<> __global__ void _TileGrad<half>( ...@@ -120,8 +120,8 @@ template<> __global__ void _TileGrad<half>(
auto nthreads = rows * cols; \ auto nthreads = rows * cols; \
auto tiled_cols = multiple * cols; \ auto tiled_cols = multiple * cols; \
_TileGrad \ _TileGrad \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, \ nthreads, \
cols, \ cols, \
tiled_cols, \ tiled_cols, \
...@@ -156,8 +156,8 @@ template<> void TileGrad<float16, CUDAContext>( ...@@ -156,8 +156,8 @@ template<> void TileGrad<float16, CUDAContext>(
auto nthreads = rows * cols; auto nthreads = rows * cols;
auto tiled_cols = multiple * cols; auto tiled_cols = multiple * cols;
_TileGrad _TileGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
cols, cols,
tiled_cols, tiled_cols,
......
...@@ -80,8 +80,8 @@ __global__ void _TransposeGrad( ...@@ -80,8 +80,8 @@ __global__ void _TransposeGrad(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##name \ _##name \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, ndims, x_strides, y_dims, x, y \ count, ndims, x_strides, y_dims, x, y \
); \ ); \
} }
......
...@@ -55,8 +55,8 @@ __global__ void _Assign( ...@@ -55,8 +55,8 @@ __global__ void _Assign(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Assign \ _Assign \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, \ count, \
ndims, \ ndims, \
x_dims, \ x_dims, \
......
...@@ -153,8 +153,8 @@ __global__ void _GreaterEqualHalf( ...@@ -153,8 +153,8 @@ __global__ void _GreaterEqualHalf(
bool* y, \ bool* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
IMPL \ IMPL \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, a, b, y \ count, a, b, y \
); \ ); \
} }
...@@ -167,8 +167,8 @@ __global__ void _GreaterEqualHalf( ...@@ -167,8 +167,8 @@ __global__ void _GreaterEqualHalf(
bool* y, \ bool* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_##OP##Half \ _##OP##Half \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, \ count, \
reinterpret_cast<const half*>(a), \ reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \ reinterpret_cast<const half*>(b), \
......
...@@ -30,8 +30,8 @@ __global__ void _MaskedAssign( ...@@ -30,8 +30,8 @@ __global__ void _MaskedAssign(
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_MaskedAssign \ _MaskedAssign \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, mask, x, y \ count, mask, x, y \
); \ ); \
} }
......
...@@ -27,8 +27,8 @@ template<> void AbsGrad<float, CUDAContext>( ...@@ -27,8 +27,8 @@ template<> void AbsGrad<float, CUDAContext>(
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_AbsGrad _AbsGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, dy, dx count, dy, dx
); );
} }
......
...@@ -55,8 +55,8 @@ template <> void NLLLoss<float, float, CUDAContext>( ...@@ -55,8 +55,8 @@ template <> void NLLLoss<float, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_NLLLoss _NLLLoss
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nignores, nthreads, axis_dim, inner_dim, nignores,
ignore, log_prob, target, loss, flag ignore, log_prob, target, loss, flag
); );
...@@ -77,8 +77,8 @@ template <> void NLLLoss<float, int64_t, CUDAContext>( ...@@ -77,8 +77,8 @@ template <> void NLLLoss<float, int64_t, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_NLLLoss _NLLLoss
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nignores, nthreads, axis_dim, inner_dim, nignores,
ignore, log_prob, target, loss, flag ignore, log_prob, target, loss, flag
); );
...@@ -129,8 +129,8 @@ template<> void NLLLossGrad<float, float, CUDAContext>( ...@@ -129,8 +129,8 @@ template<> void NLLLossGrad<float, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_NLLLossGrad _NLLLossGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nignores, nthreads, axis_dim, inner_dim, nignores,
ignore, log_prob, target, dx, flag ignore, log_prob, target, dx, flag
); );
...@@ -151,8 +151,8 @@ template<> void NLLLossGrad<float, int64_t, CUDAContext>( ...@@ -151,8 +151,8 @@ template<> void NLLLossGrad<float, int64_t, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_NLLLossGrad _NLLLossGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nignores, nthreads, axis_dim, inner_dim, nignores,
ignore, log_prob, target, dx, flag ignore, log_prob, target, dx, flag
); );
......
...@@ -42,8 +42,8 @@ template <> void SigmoidCrossEntropy<float, CUDAContext>( ...@@ -42,8 +42,8 @@ template <> void SigmoidCrossEntropy<float, CUDAContext>(
int* flag, int* flag,
CUDAContext* ctx) { CUDAContext* ctx) {
_SigmoidCrossEntropy _SigmoidCrossEntropy
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, logit, target, loss, flag count, logit, target, loss, flag
); );
} }
...@@ -77,8 +77,8 @@ template <> void SigmoidCrossEntropyGrad<float, CUDAContext>( ...@@ -77,8 +77,8 @@ template <> void SigmoidCrossEntropyGrad<float, CUDAContext>(
int* flag, int* flag,
CUDAContext* ctx) { CUDAContext* ctx) {
_SigmoidCrossEntropyGrad _SigmoidCrossEntropyGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, logit, target, dlogit, flag count, logit, target, dlogit, flag
); );
} }
......
...@@ -71,8 +71,8 @@ template <> void SigmoidFocalLoss<float, float, CUDAContext>( ...@@ -71,8 +71,8 @@ template <> void SigmoidFocalLoss<float, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
_SigmoidFocalLoss _SigmoidFocalLoss
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nthreads, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
logits, targets, losses, flags logits, targets, losses, flags
...@@ -96,8 +96,8 @@ template <> void SigmoidFocalLoss<float, int64_t, CUDAContext>( ...@@ -96,8 +96,8 @@ template <> void SigmoidFocalLoss<float, int64_t, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
_SigmoidFocalLoss _SigmoidFocalLoss
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nthreads, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
logits, targets, losses, flags logits, targets, losses, flags
...@@ -171,8 +171,8 @@ template <> void SigmoidFocalLossGrad<float, float, CUDAContext>( ...@@ -171,8 +171,8 @@ template <> void SigmoidFocalLossGrad<float, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto count = outer_dim * axis_dim * inner_dim; auto count = outer_dim * axis_dim * inner_dim;
_SigmoidFocalLossGrad _SigmoidFocalLossGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, axis_dim, inner_dim, count, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
logits, targets, dlogits, flags logits, targets, dlogits, flags
...@@ -196,8 +196,8 @@ template <> void SigmoidFocalLossGrad<float, int64_t, CUDAContext>( ...@@ -196,8 +196,8 @@ template <> void SigmoidFocalLossGrad<float, int64_t, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto count = outer_dim * axis_dim * inner_dim; auto count = outer_dim * axis_dim * inner_dim;
_SigmoidFocalLossGrad _SigmoidFocalLossGrad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, axis_dim, inner_dim, count, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
logits, targets, dlogits, flags logits, targets, dlogits, flags
......
...@@ -33,8 +33,8 @@ template<> void SmoothL1<float, CUDAContext>( ...@@ -33,8 +33,8 @@ template<> void SmoothL1<float, CUDAContext>(
float* y, float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_SmoothL1 _SmoothL1
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, beta, x, y count, beta, x, y
); );
} }
...@@ -63,8 +63,8 @@ template<> void SmoothL1Grad<float, CUDAContext>( ...@@ -63,8 +63,8 @@ template<> void SmoothL1Grad<float, CUDAContext>(
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_SmoothL1Grad _SmoothL1Grad
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, beta, dy, dx count, beta, dy, dx
); );
} }
......
...@@ -29,8 +29,8 @@ template <> void SoftmaxCrossEntropy<float, CUDAContext>( ...@@ -29,8 +29,8 @@ template <> void SoftmaxCrossEntropy<float, CUDAContext>(
float* losses, float* losses,
CUDAContext* ctx) { CUDAContext* ctx) {
_SoftmaxCrossEntropy _SoftmaxCrossEntropy
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, prob, targets, losses count, prob, targets, losses
); );
} }
......
...@@ -67,8 +67,8 @@ template <> void SoftmaxFocalLoss<float, float, CUDAContext>( ...@@ -67,8 +67,8 @@ template <> void SoftmaxFocalLoss<float, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
_SoftmaxFocalLoss _SoftmaxFocalLoss
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS, <<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
num_preds, axis_dim, inner_dim, num_preds, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
nignores, ignores, nignores, ignores,
...@@ -95,8 +95,8 @@ template <> void SoftmaxFocalLoss<float, int64_t, CUDAContext>( ...@@ -95,8 +95,8 @@ template <> void SoftmaxFocalLoss<float, int64_t, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
_SoftmaxFocalLoss _SoftmaxFocalLoss
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS, <<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
num_preds, axis_dim, inner_dim, num_preds, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
nignores, ignores, nignores, ignores,
...@@ -179,8 +179,8 @@ template<> void SoftmaxFocalLossGrad<float, float, CUDAContext>( ...@@ -179,8 +179,8 @@ template<> void SoftmaxFocalLossGrad<float, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
_SoftmaxFocalLossGrad _SoftmaxFocalLossGrad
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS, <<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
num_preds, axis_dim, inner_dim, num_preds, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
nignores, ignores, nignores, ignores,
...@@ -207,8 +207,8 @@ template<> void SoftmaxFocalLossGrad<float, int64_t, CUDAContext>( ...@@ -207,8 +207,8 @@ template<> void SoftmaxFocalLossGrad<float, int64_t, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
_SoftmaxFocalLossGrad _SoftmaxFocalLossGrad
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS, <<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
num_preds, axis_dim, inner_dim, num_preds, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
nignores, ignores, nignores, ignores,
......
...@@ -59,8 +59,8 @@ template <> void SparseSoftmaxCrossEntropy<float, float, CUDAContext>( ...@@ -59,8 +59,8 @@ template <> void SparseSoftmaxCrossEntropy<float, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropy _SparseSoftmaxCrossEntropy
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nignores, nthreads, axis_dim, inner_dim, nignores,
ignore, prob, target, loss, flag ignore, prob, target, loss, flag
); );
...@@ -81,8 +81,8 @@ template <> void SparseSoftmaxCrossEntropy<float, int64_t, CUDAContext>( ...@@ -81,8 +81,8 @@ template <> void SparseSoftmaxCrossEntropy<float, int64_t, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropy _SparseSoftmaxCrossEntropy
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nignores, nthreads, axis_dim, inner_dim, nignores,
ignore, prob, target, loss, flag ignore, prob, target, loss, flag
); );
...@@ -136,8 +136,8 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, float, CUDAContext>( ...@@ -136,8 +136,8 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropyGrad _SparseSoftmaxCrossEntropyGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nignores, nthreads, axis_dim, inner_dim, nignores,
ignore, prob, target, dx, flag ignore, prob, target, dx, flag
); );
...@@ -158,8 +158,8 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, int64_t, CUDAContext>( ...@@ -158,8 +158,8 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, int64_t, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim; auto nthreads = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropyGrad _SparseSoftmaxCrossEntropyGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, nignores, nthreads, axis_dim, inner_dim, nignores,
ignore, prob, target, dx, flag ignore, prob, target, dx, flag
); );
......
...@@ -26,8 +26,8 @@ __global__ void _TypeA2B( ...@@ -26,8 +26,8 @@ __global__ void _TypeA2B(
Tb* b, \ Tb* b, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_TypeA2B \ _TypeA2B \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, a, b \ count, a, b \
); \ ); \
} }
...@@ -66,8 +66,8 @@ template <> void TypeA2B<float16, float, CUDAContext>( ...@@ -66,8 +66,8 @@ template <> void TypeA2B<float16, float, CUDAContext>(
float* b, float* b,
CUDAContext* ctx) { CUDAContext* ctx) {
_TypeA2B _TypeA2B
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, reinterpret_cast<const half*>(a), b count, reinterpret_cast<const half*>(a), b
); );
} }
...@@ -89,8 +89,8 @@ template <> void TypeA2B<float, float16, CUDAContext>( ...@@ -89,8 +89,8 @@ template <> void TypeA2B<float, float16, CUDAContext>(
float16* b, float16* b,
CUDAContext* ctx) { CUDAContext* ctx) {
_TypeA2B _TypeA2B
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, a, reinterpret_cast<half*>(b) count, a, reinterpret_cast<half*>(b)
); );
} }
...@@ -112,8 +112,8 @@ template <> void TypeA2B<float16, float16, CUDAContext>( ...@@ -112,8 +112,8 @@ template <> void TypeA2B<float16, float16, CUDAContext>(
float16* b, float16* b,
CUDAContext* ctx) { CUDAContext* ctx) {
_TypeA2B _TypeA2B
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(a), reinterpret_cast<const half*>(a),
reinterpret_cast<half*>(b) reinterpret_cast<half*>(b)
......
...@@ -62,8 +62,8 @@ template <> __global__ void _GradientTwoSum<half2>( ...@@ -62,8 +62,8 @@ template <> __global__ void _GradientTwoSum<half2>(
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_GradientTwoSum \ _GradientTwoSum \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \ <<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
count, dy1, dy2, dx \ count, dy1, dy2, dx \
); \ ); \
} }
...@@ -83,8 +83,8 @@ template <> void GradientTwoSum<float16, CUDAContext>( ...@@ -83,8 +83,8 @@ template <> void GradientTwoSum<float16, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
if ((count & 1) == 0) { if ((count & 1) == 0) {
_GradientTwoSum _GradientTwoSum
<< < CUDA_BLOCKS(count >> 2), CUDA_THREADS, <<< CUDA_BLOCKS(count >> 2), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count >> 2, count >> 2,
reinterpret_cast<const half2*>(dy1), reinterpret_cast<const half2*>(dy1),
reinterpret_cast<const half2*>(dy2), reinterpret_cast<const half2*>(dy2),
...@@ -92,8 +92,8 @@ template <> void GradientTwoSum<float16, CUDAContext>( ...@@ -92,8 +92,8 @@ template <> void GradientTwoSum<float16, CUDAContext>(
); );
} else { } else {
_GradientTwoSum _GradientTwoSum
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
reinterpret_cast<const half*>(dy1), reinterpret_cast<const half*>(dy1),
reinterpret_cast<const half*>(dy2), reinterpret_cast<const half*>(dy2),
......
...@@ -76,14 +76,14 @@ template <> void ImageData<float, float, CUDAContext>( ...@@ -76,14 +76,14 @@ template <> void ImageData<float, float, CUDAContext>(
auto nthreads = N * C * H * W; auto nthreads = N * C * H * W;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_ImageDataNCHW _ImageDataNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, mean, std, x, y nthreads, C, H, W, mean, std, x, y
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_ImageDataNHWC _ImageDataNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, mean, std, x, y nthreads, C, H, W, mean, std, x, y
); );
} else { } else {
...@@ -107,14 +107,14 @@ template <> void ImageData<uint8_t, float, CUDAContext>( ...@@ -107,14 +107,14 @@ template <> void ImageData<uint8_t, float, CUDAContext>(
auto nthreads = N * C * H * W; auto nthreads = N * C * H * W;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_ImageDataNCHW _ImageDataNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, mean, std, x, y nthreads, C, H, W, mean, std, x, y
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_ImageDataNHWC _ImageDataNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, mean, std, x, y nthreads, C, H, W, mean, std, x, y
); );
} else { } else {
...@@ -191,15 +191,15 @@ template <> void ImageData<float, float16, CUDAContext>( ...@@ -191,15 +191,15 @@ template <> void ImageData<float, float16, CUDAContext>(
auto nthreads = N * C * H * W; auto nthreads = N * C * H * W;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_ImageDataHalfNCHW _ImageDataHalfNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, mean, std, nthreads, C, H, W, mean, std,
x, reinterpret_cast<half*>(y) x, reinterpret_cast<half*>(y)
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_ImageDataHalfNHWC _ImageDataHalfNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, mean, std, nthreads, C, H, W, mean, std,
x, reinterpret_cast<half*>(y) x, reinterpret_cast<half*>(y)
); );
...@@ -222,15 +222,15 @@ template <> void ImageData<uint8_t, float16, CUDAContext>( ...@@ -222,15 +222,15 @@ template <> void ImageData<uint8_t, float16, CUDAContext>(
auto nthreads = N * C * H * W; auto nthreads = N * C * H * W;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_ImageDataHalfNCHW _ImageDataHalfNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, mean, std, nthreads, C, H, W, mean, std,
x, reinterpret_cast<half*>(y) x, reinterpret_cast<half*>(y)
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_ImageDataHalfNHWC _ImageDataHalfNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, mean, std, nthreads, C, H, W, mean, std,
x, reinterpret_cast<half*>(y) x, reinterpret_cast<half*>(y)
); );
......
...@@ -190,27 +190,27 @@ __global__ void _BatchNormInferenceGrad( ...@@ -190,27 +190,27 @@ __global__ void _BatchNormInferenceGrad(
auto nthreads = N * C * S; \ auto nthreads = N * C * S; \
if (data_format == "NCHW") { \ if (data_format == "NCHW") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \ _BatchNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
<< < CUDA_2D_BLOCKS(C), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, C, S, x, mu, rsig, gamma, \ N, C, S, x, mu, rsig, gamma, \
dy, ds, db, dgamma, dbeta \ dy, ds, db, dgamma, dbeta \
); \ ); \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NCHW> \ _BatchNormTrainingGrad<Tx, Tp, StorageOrder::NCHW> \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, N, C, S, x, mu, \ nthreads, N, C, S, x, mu, \
rsig, gamma, ds, db, dy, dx \ rsig, gamma, ds, db, dy, dx \
); \ ); \
} else if (data_format == "NHWC") { \ } else if (data_format == "NHWC") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \ _BatchNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
<< < CUDA_2D_BLOCKS(C), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, C, S, x, mu, rsig, gamma, \ N, C, S, x, mu, rsig, gamma, \
dy, ds, db, dgamma, dbeta \ dy, ds, db, dgamma, dbeta \
); \ ); \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NHWC> \ _BatchNormTrainingGrad<Tx, Tp, StorageOrder::NHWC> \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, N, C, S, x, mu, \ nthreads, N, C, S, x, mu, \
rsig, gamma, ds, db, dy, dx \ rsig, gamma, ds, db, dy, dx \
); \ ); \
...@@ -234,24 +234,24 @@ __global__ void _BatchNormInferenceGrad( ...@@ -234,24 +234,24 @@ __global__ void _BatchNormInferenceGrad(
if (data_format == "NCHW") { \ if (data_format == "NCHW") { \
if (dgamma != nullptr) { \ if (dgamma != nullptr) { \
_BatchNormWGrad<Tx, Tp, StorageOrder::NCHW> \ _BatchNormWGrad<Tx, Tp, StorageOrder::NCHW> \
<< < CUDA_2D_BLOCKS(C), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \ 0, ctx->cuda_stream() >>> \
(N, C, S, x, mu, rsig, dy, dgamma, dbeta); \ (N, C, S, x, mu, rsig, dy, dgamma, dbeta); \
} \ } \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NCHW> \ _BatchNormInferenceGrad<Tx, Tp, StorageOrder::NCHW> \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \ 0, ctx->cuda_stream() >>> \
(nthreads, C, S, rsig, gamma, dy, dx); \ (nthreads, C, S, rsig, gamma, dy, dx); \
} else if (data_format == "NHWC") { \ } else if (data_format == "NHWC") { \
if (dgamma != nullptr) { \ if (dgamma != nullptr) { \
_BatchNormWGrad<Tx, Tp, StorageOrder::NHWC> \ _BatchNormWGrad<Tx, Tp, StorageOrder::NHWC> \
<< < CUDA_2D_BLOCKS(C), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \ 0, ctx->cuda_stream() >>> \
(N, C, S, x, mu, rsig, dy, dgamma, dbeta); \ (N, C, S, x, mu, rsig, dy, dgamma, dbeta); \
} \ } \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NHWC> \ _BatchNormInferenceGrad<Tx, Tp, StorageOrder::NHWC> \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \ 0, ctx->cuda_stream() >>> \
(nthreads, C, S, rsig, gamma, dy, dx); \ (nthreads, C, S, rsig, gamma, dy, dx); \
} \ } \
} }
......
...@@ -408,20 +408,20 @@ __global__ void _GroupNormGradHalf( ...@@ -408,20 +408,20 @@ __global__ void _GroupNormGradHalf(
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int C = G * D; \ const int C = G * D; \
_GroupNormFusedParams<Tp> \ _GroupNormFusedParams<Tp> \
<< < CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, G, D, mu, rsig, gamma, beta, scale, bias \ N, G, D, mu, rsig, gamma, beta, scale, bias \
); \ ); \
if (data_format == "NCHW") { \ if (data_format == "NCHW") { \
_GroupNormForwardNCHW<Tx, Tp> \ _GroupNormForwardNCHW<Tx, Tp> \
<< < CUDA_2D_BLOCKS(N * C), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(N * C), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, C, S, x, scale, bias, y \ N, C, S, x, scale, bias, y \
); \ ); \
} else if (data_format == "NHWC") { \ } else if (data_format == "NHWC") { \
_GroupNormForwardNHWC<Tx, Tp> \ _GroupNormForwardNHWC<Tx, Tp> \
<< < CUDA_2D_BLOCKS(N * C), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(N * C), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, C, S, x, scale, bias, y \ N, C, S, x, scale, bias, y \
); \ ); \
} \ } \
...@@ -448,35 +448,35 @@ __global__ void _GroupNormGradHalf( ...@@ -448,35 +448,35 @@ __global__ void _GroupNormGradHalf(
auto nthreads = N * G * D * S; \ auto nthreads = N * G * D * S; \
if (data_format == "NCHW") { \ if (data_format == "NCHW") { \
_GroupNormWGrad<Tx, Tp, StorageOrder::NCHW> \ _GroupNormWGrad<Tx, Tp, StorageOrder::NCHW> \
<< < CUDA_2D_BLOCKS(G * D), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(G * D), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, G, D, S, x, mu, rsig, dy, dgamma, dbeta \ N, G, D, S, x, mu, rsig, dy, dgamma, dbeta \
); \ ); \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \ _GroupNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
<< < CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, G, D, S, x, gamma, dy, ds, db \ N, G, D, S, x, gamma, dy, ds, db \
); \ ); \
_GroupNormGrad<Tx, Tp, StorageOrder::NCHW> \ _GroupNormGrad<Tx, Tp, StorageOrder::NCHW> \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
nthreads, G, D, S, x, mu, rsig, \ nthreads, G, D, S, x, mu, rsig, \
gamma, ds, db, dy, dx \ gamma, ds, db, dy, dx \
); \ ); \
} else if (data_format == "NHWC") { \ } else if (data_format == "NHWC") { \
_GroupNormWGrad<Tx, Tp, StorageOrder::NHWC> \ _GroupNormWGrad<Tx, Tp, StorageOrder::NHWC> \
<< < CUDA_2D_BLOCKS(G * D), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(G * D), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, G, D, S, x, mu, rsig, dy, dgamma, dbeta \ N, G, D, S, x, mu, rsig, dy, dgamma, dbeta \
); \ ); \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \ _GroupNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
<< < CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \ <<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \ 0, ctx->cuda_stream() >>>( \
N, G, D, S, x, gamma, dy, ds, db \ N, G, D, S, x, gamma, dy, ds, db \
); \ ); \
_GroupNormGrad<Tx, Tp, StorageOrder::NHWC> \ _GroupNormGrad<Tx, Tp, StorageOrder::NHWC> \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \ <<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> > ( \ 0, ctx->cuda_stream() >>> ( \
nthreads, G, D, S, x, mu, rsig, \ nthreads, G, D, S, x, mu, rsig, \
gamma, ds, db, dy, dx \ gamma, ds, db, dy, dx \
); \ ); \
...@@ -503,14 +503,14 @@ template <> void GroupNormForward<float16, float, CUDAContext>( ...@@ -503,14 +503,14 @@ template <> void GroupNormForward<float16, float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
const int C = G * D; const int C = G * D;
_GroupNormFusedParams<float> _GroupNormFusedParams<float>
<< < CUDA_2D_BLOCKS(N * G), CUDA_THREADS, <<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, G, D, mu, rsig, gamma, beta, scale, bias N, G, D, mu, rsig, gamma, beta, scale, bias
); );
if (data_format == "NCHW") { if (data_format == "NCHW") {
_GroupNormForwardNCHW<half, float> _GroupNormForwardNCHW<half, float>
<< < CUDA_2D_BLOCKS(N * C), CUDA_THREADS, <<< CUDA_2D_BLOCKS(N * C), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, C, S, N, C, S,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
scale, bias, scale, bias,
...@@ -518,8 +518,8 @@ template <> void GroupNormForward<float16, float, CUDAContext>( ...@@ -518,8 +518,8 @@ template <> void GroupNormForward<float16, float, CUDAContext>(
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_GroupNormForwardNHWC<half, float> _GroupNormForwardNHWC<half, float>
<< < CUDA_2D_BLOCKS(N * C), CUDA_THREADS, <<< CUDA_2D_BLOCKS(N * C), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, C, S, N, C, S,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
scale, bias, scale, bias,
...@@ -548,8 +548,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>( ...@@ -548,8 +548,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
auto nthreads = N * G * D * S; auto nthreads = N * G * D * S;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_GroupNormWGradHalf<StorageOrder::NCHW> _GroupNormWGradHalf<StorageOrder::NCHW>
<< < CUDA_2D_BLOCKS(G * D), CUDA_THREADS, <<< CUDA_2D_BLOCKS(G * D), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, G, D, S, N, G, D, S,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
mu, rsig, mu, rsig,
...@@ -557,8 +557,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>( ...@@ -557,8 +557,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
dgamma, dbeta dgamma, dbeta
); );
_GroupNormInternalGradHalf<StorageOrder::NCHW> _GroupNormInternalGradHalf<StorageOrder::NCHW>
<< < CUDA_2D_BLOCKS(N * G), CUDA_THREADS, <<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, G, D, S, N, G, D, S,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
gamma, gamma,
...@@ -566,8 +566,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>( ...@@ -566,8 +566,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
ds, db ds, db
); );
_GroupNormGradHalf<StorageOrder::NCHW> _GroupNormGradHalf<StorageOrder::NCHW>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, G, D, S, nthreads, G, D, S,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
mu, rsig, gamma, ds, db, mu, rsig, gamma, ds, db,
...@@ -576,8 +576,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>( ...@@ -576,8 +576,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
); );
} else if (data_format == "NHWC") { \ } else if (data_format == "NHWC") { \
_GroupNormWGradHalf<StorageOrder::NHWC> _GroupNormWGradHalf<StorageOrder::NHWC>
<< < CUDA_2D_BLOCKS(G * D), CUDA_THREADS, <<< CUDA_2D_BLOCKS(G * D), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, G, D, S, N, G, D, S,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
mu, rsig, mu, rsig,
...@@ -585,8 +585,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>( ...@@ -585,8 +585,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
dgamma, dbeta dgamma, dbeta
); );
_GroupNormInternalGradHalf<StorageOrder::NHWC> _GroupNormInternalGradHalf<StorageOrder::NHWC>
<< < CUDA_2D_BLOCKS(N * G), CUDA_THREADS, <<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, G, D, S, N, G, D, S,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
gamma, gamma,
...@@ -594,8 +594,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>( ...@@ -594,8 +594,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
ds, db ds, db
); );
_GroupNormGradHalf<StorageOrder::NHWC> _GroupNormGradHalf<StorageOrder::NHWC>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, G, D, S, nthreads, G, D, S,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
mu, rsig, gamma, ds, db, mu, rsig, gamma, ds, db,
......
...@@ -58,13 +58,13 @@ template <> void LSTMCell<float, CUDAContext>( ...@@ -58,13 +58,13 @@ template <> void LSTMCell<float, CUDAContext>(
auto o_offset = 2 * C, c_offset = 3 * C, auto o_offset = 2 * C, c_offset = 3 * C,
x_offset = 4 * C, NC = N * C; x_offset = 4 * C, NC = N * C;
_LSTMCellAct _LSTMCellAct
<< < CUDA_BLOCKS(NC * 4), CUDA_THREADS, <<< CUDA_BLOCKS(NC * 4), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
NC * 4, c_offset, x_offset, actx NC * 4, c_offset, x_offset, actx
); );
_LSTMCellGate _LSTMCellGate
<< < CUDA_BLOCKS(NC), CUDA_THREADS, <<< CUDA_BLOCKS(NC), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
NC, C, o_offset, c_offset, NC, C, o_offset, c_offset,
x_offset, cx, actx, c, h x_offset, cx, actx, c, h
); );
...@@ -138,14 +138,14 @@ template <> void LSTMCellGrad<float, CUDAContext>( ...@@ -138,14 +138,14 @@ template <> void LSTMCellGrad<float, CUDAContext>(
auto o_offset = 2 * C, c_offset = 3 * C, auto o_offset = 2 * C, c_offset = 3 * C,
x_offset = 4 * C, NC = N * C; x_offset = 4 * C, NC = N * C;
_LSTMCellGateGrad _LSTMCellGateGrad
<< < CUDA_BLOCKS(NC), CUDA_THREADS, <<< CUDA_BLOCKS(NC), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
NC, C, o_offset, c_offset, x_offset, NC, C, o_offset, c_offset, x_offset,
cx, actx, c, dc, dh, dcx, dx cx, actx, c, dc, dh, dcx, dx
); );
_LSTMCellActGrad _LSTMCellActGrad
<< < CUDA_BLOCKS(NC * 4), CUDA_THREADS, <<< CUDA_BLOCKS(NC * 4), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
NC * 4, c_offset, x_offset, actx, dx NC * 4, c_offset, x_offset, actx, dx
); );
} }
......
...@@ -39,8 +39,8 @@ template <> void AdamUpdate<float, CUDAContext>( ...@@ -39,8 +39,8 @@ template <> void AdamUpdate<float, CUDAContext>(
float* v, float* v,
CUDAContext* ctx) { CUDAContext* ctx) {
_AdamUpdate _AdamUpdate
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, lr, beta1, beta2, eps, g, m, v count, lr, beta1, beta2, eps, g, m, v
); );
} }
......
...@@ -29,8 +29,8 @@ template <> void MixedPrecL2Decay<float16, CUDAContext>( ...@@ -29,8 +29,8 @@ template <> void MixedPrecL2Decay<float16, CUDAContext>(
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_MixedPrecL2DecayHalf _MixedPrecL2DecayHalf
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
alpha, alpha,
reinterpret_cast<const half*>(w), reinterpret_cast<const half*>(w),
...@@ -58,8 +58,8 @@ template <> void MixedPrecUpdate<float16, CUDAContext>( ...@@ -58,8 +58,8 @@ template <> void MixedPrecUpdate<float16, CUDAContext>(
float16* w, float16* w,
CUDAContext* ctx) { CUDAContext* ctx) {
_MixedPrecUpdateHalf _MixedPrecUpdateHalf
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, count,
updates, updates,
reinterpret_cast<half*>(w) reinterpret_cast<half*>(w)
......
...@@ -32,8 +32,8 @@ template <> void NesterovUpdate<float, CUDAContext>( ...@@ -32,8 +32,8 @@ template <> void NesterovUpdate<float, CUDAContext>(
float* h, float* h,
CUDAContext* ctx) { CUDAContext* ctx) {
_NesterovUpdate _NesterovUpdate
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, lr, momentum, g, h count, lr, momentum, g, h
); );
} }
......
...@@ -34,8 +34,8 @@ template <> void RMSPropUpdate<float, CUDAContext>( ...@@ -34,8 +34,8 @@ template <> void RMSPropUpdate<float, CUDAContext>(
float* h, float* h,
CUDAContext* ctx) { CUDAContext* ctx) {
_RMSPropUpdate _RMSPropUpdate
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, lr, decay, eps, g, h count, lr, decay, eps, g, h
); );
} }
......
...@@ -31,8 +31,8 @@ template <> void SGDUpdate<float, CUDAContext>( ...@@ -31,8 +31,8 @@ template <> void SGDUpdate<float, CUDAContext>(
float* h, float* h,
CUDAContext* ctx) { CUDAContext* ctx) {
_SGDUpdate _SGDUpdate
<< < CUDA_BLOCKS(count), CUDA_THREADS, <<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
count, lr, momentum, g, h count, lr, momentum, g, h
); );
} }
......
...@@ -52,14 +52,14 @@ template<> void BiasAdd<float, CUDAContext>( ...@@ -52,14 +52,14 @@ template<> void BiasAdd<float, CUDAContext>(
auto nthreads = outer_dim * axis_dim * inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_BiasAddNCHW _BiasAddNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, inner_dim, bias, y nthreads, axis_dim, inner_dim, bias, y
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_BiasAddNHWC _BiasAddNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, axis_dim, bias, y nthreads, axis_dim, bias, y
); );
} else { } else {
......
...@@ -109,15 +109,15 @@ template <> void BilinearResize<float, CUDAContext>( ...@@ -109,15 +109,15 @@ template <> void BilinearResize<float, CUDAContext>(
auto scale_w = (float)W / (float)out_w; auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_BilinearResizeNCHW _BilinearResizeNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, out_h, out_w, nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, x, y scale_h, scale_w, x, y
); );
} else if(data_format == "NHWC") { } else if(data_format == "NHWC") {
_BilinearResizeNHWC _BilinearResizeNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, out_h, out_w, nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, x, y scale_h, scale_w, x, y
); );
...@@ -224,15 +224,15 @@ template <> void BilinearResizeGrad<float, CUDAContext>( ...@@ -224,15 +224,15 @@ template <> void BilinearResizeGrad<float, CUDAContext>(
auto scale_w = (float)W / (float)out_w; auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_BilinearResizeGradNCHW _BilinearResizeGradNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, out_h, out_w, nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, dy, dx scale_h, scale_w, dy, dx
); );
} else if(data_format == "NHWC") { } else if(data_format == "NHWC") {
_BilinearResizeGradNHWC _BilinearResizeGradNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, out_h, out_w, nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, dy, dx scale_h, scale_w, dy, dx
); );
......
...@@ -123,8 +123,8 @@ template <> void Im2Col2d<float, CUDAContext>( ...@@ -123,8 +123,8 @@ template <> void Im2Col2d<float, CUDAContext>(
auto nthreads = C * out_h * out_w; auto nthreads = C * out_h * out_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_Im2Col2dNCHW _Im2Col2dNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
H, W, H, W,
out_h, out_w, out_h, out_w,
...@@ -136,8 +136,8 @@ template <> void Im2Col2d<float, CUDAContext>( ...@@ -136,8 +136,8 @@ template <> void Im2Col2d<float, CUDAContext>(
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_Im2Col2dNHWC _Im2Col2dNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -286,8 +286,8 @@ template <> void Col2Im2d<float, CUDAContext>( ...@@ -286,8 +286,8 @@ template <> void Col2Im2d<float, CUDAContext>(
const int nthreads = C * H * W; const int nthreads = C * H * W;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_Col2Im2dNCHW _Col2Im2dNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
H, W, H, W,
out_h, out_w, out_h, out_w,
...@@ -299,8 +299,8 @@ template <> void Col2Im2d<float, CUDAContext>( ...@@ -299,8 +299,8 @@ template <> void Col2Im2d<float, CUDAContext>(
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_Col2Im2dNHWC _Col2Im2dNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
......
...@@ -144,8 +144,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>( ...@@ -144,8 +144,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
if (data_format == "NCHW") { if (data_format == "NCHW") {
if (kernel_h == 3 && kernel_w == 3) { if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dNCHW<float, 3, 3> _DepthwiseConv2dNCHW<float, 3, 3>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -157,8 +157,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>( ...@@ -157,8 +157,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
); );
} else if (kernel_h == 5 && kernel_w == 5) { } else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dNCHW<float, 5, 5> _DepthwiseConv2dNCHW<float, 5, 5>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -170,8 +170,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>( ...@@ -170,8 +170,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
); );
} else if (kernel_h == 7 && kernel_w == 7) { } else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dNCHW<float, 7, 7> _DepthwiseConv2dNCHW<float, 7, 7>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -183,8 +183,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>( ...@@ -183,8 +183,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
); );
} else { } else {
_DepthwiseConv2dNCHW<float, -1, -1> _DepthwiseConv2dNCHW<float, -1, -1>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -198,8 +198,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>( ...@@ -198,8 +198,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
if (kernel_h == 3 && kernel_w == 3) { if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dNHWC<float, 3, 3> _DepthwiseConv2dNHWC<float, 3, 3>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -211,8 +211,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>( ...@@ -211,8 +211,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
); );
} else if (kernel_h == 5 && kernel_w == 5) { } else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dNHWC<float, 5, 5> _DepthwiseConv2dNHWC<float, 5, 5>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -224,8 +224,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>( ...@@ -224,8 +224,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
); );
} else if (kernel_h == 7 && kernel_w == 7) { } else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dNHWC<float, 7, 7> _DepthwiseConv2dNHWC<float, 7, 7>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -237,8 +237,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>( ...@@ -237,8 +237,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
); );
} else { } else {
_DepthwiseConv2dNHWC<float, -1, -1> _DepthwiseConv2dNHWC<float, -1, -1>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -394,8 +394,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>( ...@@ -394,8 +394,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
if (data_format == "NCHW") { if (data_format == "NCHW") {
if (kernel_h == 3 && kernel_w == 3) { if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dGradNCHW<float, 3, 3> _DepthwiseConv2dGradNCHW<float, 3, 3>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -407,8 +407,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>( ...@@ -407,8 +407,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
); );
} else if (kernel_h == 5 && kernel_w == 5) { } else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dGradNCHW<float, 5, 5> _DepthwiseConv2dGradNCHW<float, 5, 5>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -420,8 +420,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>( ...@@ -420,8 +420,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
); );
} else if (kernel_h == 7 && kernel_w == 7) { } else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dGradNCHW<float, 7, 7> _DepthwiseConv2dGradNCHW<float, 7, 7>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -433,8 +433,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>( ...@@ -433,8 +433,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
); );
} else { } else {
_DepthwiseConv2dGradNCHW<float, -1, -1> _DepthwiseConv2dGradNCHW<float, -1, -1>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -448,8 +448,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>( ...@@ -448,8 +448,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
if (kernel_h == 3 && kernel_w == 3) { if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dGradNHWC<float, 3, 3> _DepthwiseConv2dGradNHWC<float, 3, 3>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -461,8 +461,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>( ...@@ -461,8 +461,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
); );
} else if (kernel_h == 5 && kernel_w == 5) { } else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dGradNHWC<float, 5, 5> _DepthwiseConv2dGradNHWC<float, 5, 5>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -474,8 +474,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>( ...@@ -474,8 +474,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
); );
} else if (kernel_h == 7 && kernel_w == 7) { } else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dGradNHWC<float, 7, 7> _DepthwiseConv2dGradNHWC<float, 7, 7>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -487,8 +487,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>( ...@@ -487,8 +487,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
); );
} else { } else {
_DepthwiseConv2dGradNHWC<float, -1, -1> _DepthwiseConv2dGradNHWC<float, -1, -1>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
out_h, out_w, out_h, out_w,
...@@ -634,8 +634,8 @@ template <> void DepthwiseConv2dWGrad<float, CUDAContext>( ...@@ -634,8 +634,8 @@ template <> void DepthwiseConv2dWGrad<float, CUDAContext>(
auto nblocks = C * kernel_h * kernel_w; auto nblocks = C * kernel_h * kernel_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_DepthwiseConv2dWGradNCHW _DepthwiseConv2dWGradNCHW
<< < nblocks, nthreads, <<< nblocks, nthreads,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, C, H, W, N, C, H, W,
out_h, out_w, out_h, out_w,
kernel_h, kernel_w, kernel_h, kernel_w,
...@@ -646,8 +646,8 @@ template <> void DepthwiseConv2dWGrad<float, CUDAContext>( ...@@ -646,8 +646,8 @@ template <> void DepthwiseConv2dWGrad<float, CUDAContext>(
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_DepthwiseConv2dWGradNHWC _DepthwiseConv2dWGradNHWC
<< < nblocks, nthreads, <<< nblocks, nthreads,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
N, C, H, W, N, C, H, W,
out_h, out_w, out_h, out_w,
kernel_h, kernel_w, kernel_h, kernel_w,
......
...@@ -77,16 +77,12 @@ template <> void DropBlock2d<CUDAContext>( ...@@ -77,16 +77,12 @@ template <> void DropBlock2d<CUDAContext>(
int* mask, int* mask,
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = N * C * seed_h * seed_w; auto nthreads = N * C * seed_h * seed_w;
math::RandomUniform( math::RandomUniform(nthreads, 0.f, 1.f, seed, ctx);
nthreads,
0.f, float(UINT_MAX),
seed, ctx
);
auto mask_thresh = (uint32_t)(UINT_MAX * gamma); auto mask_thresh = (uint32_t)(UINT_MAX * gamma);
if (data_format == "NCHW") { if (data_format == "NCHW") {
_DropBlock2dNCHW _DropBlock2dNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
seed_h, seed_w, seed_h, seed_w,
...@@ -96,8 +92,8 @@ template <> void DropBlock2d<CUDAContext>( ...@@ -96,8 +92,8 @@ template <> void DropBlock2d<CUDAContext>(
); );
} else if(data_format == "NHWC") { } else if(data_format == "NHWC") {
_DropBlock2dNHWC _DropBlock2dNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
seed_h, seed_w, seed_h, seed_w,
......
...@@ -81,15 +81,15 @@ template <> void NNResize<float, CUDAContext>( ...@@ -81,15 +81,15 @@ template <> void NNResize<float, CUDAContext>(
auto scale_w = (float)W / (float)out_w; auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_NNResizeNCHW _NNResizeNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, out_h, out_w, nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, x, y scale_h, scale_w, x, y
); );
} else if(data_format == "NHWC") { } else if(data_format == "NHWC") {
_NNResizeNHWC _NNResizeNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, out_h, out_w, nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, x, y scale_h, scale_w, x, y
); );
...@@ -116,8 +116,8 @@ template <> void NNResize<float16, CUDAContext>( ...@@ -116,8 +116,8 @@ template <> void NNResize<float16, CUDAContext>(
auto scale_w = (float)W / (float)out_w; auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_NNResizeNCHW _NNResizeNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, nthreads, C, H, W,
out_h, out_w, scale_h, scale_w, out_h, out_w, scale_h, scale_w,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
...@@ -125,8 +125,8 @@ template <> void NNResize<float16, CUDAContext>( ...@@ -125,8 +125,8 @@ template <> void NNResize<float16, CUDAContext>(
); );
} else if(data_format == "NHWC") { } else if(data_format == "NHWC") {
_NNResizeNHWC _NNResizeNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, nthreads, C, H, W,
out_h, out_w, scale_h, scale_w, out_h, out_w, scale_h, scale_w,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
...@@ -209,15 +209,15 @@ template <> void NNResizeGrad<float, CUDAContext>( ...@@ -209,15 +209,15 @@ template <> void NNResizeGrad<float, CUDAContext>(
auto scale_w = (float)W / (float)out_w; auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_NNResizeGradNCHW _NNResizeGradNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, out_h, out_w, nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, dy, dx scale_h, scale_w, dy, dx
); );
} else if(data_format == "NHWC") { } else if(data_format == "NHWC") {
_NNResizeGradNHWC _NNResizeGradNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, C, H, W, out_h, out_w, nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, dy, dx scale_h, scale_w, dy, dx
); );
......
...@@ -120,8 +120,8 @@ template<> void MaxPool2d<float, CUDAContext>( ...@@ -120,8 +120,8 @@ template<> void MaxPool2d<float, CUDAContext>(
auto nthreads = N * C * pool_h * pool_w; auto nthreads = N * C * pool_h * pool_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_MaxPool2dNCHW _MaxPool2dNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -132,8 +132,8 @@ template<> void MaxPool2d<float, CUDAContext>( ...@@ -132,8 +132,8 @@ template<> void MaxPool2d<float, CUDAContext>(
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_MaxPool2dNHWC _MaxPool2dNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -256,8 +256,8 @@ template<> void AvgPool2d<float, CUDAContext>( ...@@ -256,8 +256,8 @@ template<> void AvgPool2d<float, CUDAContext>(
auto nthreads = N * C * pool_h * pool_w; auto nthreads = N * C * pool_h * pool_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_AvgPool2dNCHW _AvgPool2dNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -268,8 +268,8 @@ template<> void AvgPool2d<float, CUDAContext>( ...@@ -268,8 +268,8 @@ template<> void AvgPool2d<float, CUDAContext>(
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_AvgPool2dNHWC _AvgPool2dNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -392,8 +392,8 @@ template<> void MaxPool2dGrad<float, CUDAContext>( ...@@ -392,8 +392,8 @@ template<> void MaxPool2dGrad<float, CUDAContext>(
auto nthreads = N * C * H * W; auto nthreads = N * C * H * W;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_MaxPool2dGrad_NCHW _MaxPool2dGrad_NCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -404,8 +404,8 @@ template<> void MaxPool2dGrad<float, CUDAContext>( ...@@ -404,8 +404,8 @@ template<> void MaxPool2dGrad<float, CUDAContext>(
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_MaxPool2dGradNHWC _MaxPool2dGradNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -531,8 +531,8 @@ template<> void AvgPool2dGrad<float, CUDAContext>( ...@@ -531,8 +531,8 @@ template<> void AvgPool2dGrad<float, CUDAContext>(
auto nthreads = N * C * H * W; auto nthreads = N * C * H * W;
if (data_format == "NCHW") { if (data_format == "NCHW") {
_AvgPool2dGradNCHW _AvgPool2dGradNCHW
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -543,8 +543,8 @@ template<> void AvgPool2dGrad<float, CUDAContext>( ...@@ -543,8 +543,8 @@ template<> void AvgPool2dGrad<float, CUDAContext>(
); );
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_AvgPool2dGradNHWC _AvgPool2dGradNHWC
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
......
...@@ -132,8 +132,8 @@ template<> void ROIAlign<float, CUDAContext>( ...@@ -132,8 +132,8 @@ template<> void ROIAlign<float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w; auto nthreads = num_rois * C * pool_h * pool_w;
_ROIAlign _ROIAlign
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -283,8 +283,8 @@ template<> void ROIAlignGrad<float, CUDAContext>( ...@@ -283,8 +283,8 @@ template<> void ROIAlignGrad<float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w; auto nthreads = num_rois * C * pool_h * pool_w;
_ROIAlignGrad _ROIAlignGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
......
...@@ -134,8 +134,8 @@ template<> void ROIAlign<float16, CUDAContext>( ...@@ -134,8 +134,8 @@ template<> void ROIAlign<float16, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w; auto nthreads = num_rois * C * pool_h * pool_w;
_ROIAlignHalf _ROIAlignHalf
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> > 0, ctx->cuda_stream() >>>
(nthreads, C, H, W, pool_h, pool_w, (nthreads, C, H, W, pool_h, pool_w,
sampling_ratio, spatial_scale, sampling_ratio, spatial_scale,
reinterpret_cast<const half*>(x), rois, reinterpret_cast<const half*>(x), rois,
......
...@@ -92,8 +92,8 @@ template<> void ROIPool<float, CUDAContext>( ...@@ -92,8 +92,8 @@ template<> void ROIPool<float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w; auto nthreads = num_rois * C * pool_h * pool_w;
_ROIPool _ROIPool
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -185,8 +185,8 @@ template<> void ROIPool<float16, CUDAContext>( ...@@ -185,8 +185,8 @@ template<> void ROIPool<float16, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w; auto nthreads = num_rois * C * pool_h * pool_w;
_ROIPoolHalf _ROIPoolHalf
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
C, H, W, C, H, W,
pool_h, pool_w, pool_h, pool_w,
...@@ -286,8 +286,8 @@ template<> void ROIPoolGrad<float, CUDAContext>( ...@@ -286,8 +286,8 @@ template<> void ROIPoolGrad<float, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = N * C * H * W; auto nthreads = N * C * H * W;
_ROIPoolGrad _ROIPoolGrad
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, <<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> >( 0, ctx->cuda_stream() >>>(
nthreads, nthreads,
num_rois, num_rois,
C, H, W, C, H, W,
......
...@@ -180,9 +180,9 @@ ONNXBackend::get_special_nodes() const { ...@@ -180,9 +180,9 @@ ONNXBackend::get_special_nodes() const {
}; return kSpecialNodes; }; return kSpecialNodes;
} }
const Map< string, Map<string, string> >& const Map<string, Map<string, string>>&
ONNXBackend::get_node_renamed_attrs() const { ONNXBackend::get_node_renamed_attrs() const {
const static Map< string, Map<string, string> > const static Map<string, Map<string, string>>
kPerNodeRenamedAttrs = { kPerNodeRenamedAttrs = {
{ "Gemm", { { "transB", "transW" } } }, { "Gemm", { { "transB", "transW" } } },
{ "BatchNormalization", { { "epsilon", "eps" } } }, { "BatchNormalization", { { "epsilon", "eps" } } },
......
...@@ -221,7 +221,7 @@ class ONNXBackend { ...@@ -221,7 +221,7 @@ class ONNXBackend {
const Map<string, SpecialNodeConverter>& get_special_nodes() const; const Map<string, SpecialNodeConverter>& get_special_nodes() const;
const Map<string, string>& get_renamed_attrs() const; const Map<string, string>& get_renamed_attrs() const;
const Map< string, Map<string, string> >& get_node_renamed_attrs() const; const Map<string, Map<string, string>>& get_node_renamed_attrs() const;
}; };
} // namespace onnx } // namespace onnx
......
...@@ -77,15 +77,8 @@ template <class Context> ...@@ -77,15 +77,8 @@ template <class Context>
void CuDNNDropoutOp<Context>::RunOnDevice() { void CuDNNDropoutOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -147,15 +140,8 @@ template <class Context> ...@@ -147,15 +140,8 @@ template <class Context>
void CuDNNDropoutGradientOp<Context>::RunOnDevice() { void CuDNNDropoutGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CUDNN(Dropout); DEPLOY_CUDNN(Dropout);
......
...@@ -26,15 +26,8 @@ template <class Context> ...@@ -26,15 +26,8 @@ template <class Context>
void CuDNNEluOp<Context>::RunOnDevice() { void CuDNNEluOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -60,15 +53,8 @@ template <class Context> ...@@ -60,15 +53,8 @@ template <class Context>
void CuDNNEluGradientOp<Context>::RunOnDevice() { void CuDNNEluGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CUDNN(Elu); DEPLOY_CUDNN(Elu);
......
...@@ -40,15 +40,8 @@ void CuDNNReluOp<Context>::RunOnDevice() { ...@@ -40,15 +40,8 @@ void CuDNNReluOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -92,15 +85,8 @@ void CuDNNReluGradientOp<Context>::RunOnDevice() { ...@@ -92,15 +85,8 @@ void CuDNNReluGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CUDNN(Relu); DEPLOY_CUDNN(Relu);
......
...@@ -35,15 +35,8 @@ template <class Context> ...@@ -35,15 +35,8 @@ template <class Context>
void CuDNNSigmoidOp<Context>::RunOnDevice() { void CuDNNSigmoidOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -82,15 +75,8 @@ template <class Context> ...@@ -82,15 +75,8 @@ template <class Context>
void CuDNNSigmoidGradientOp<Context>::RunOnDevice() { void CuDNNSigmoidGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CUDNN(Sigmoid); DEPLOY_CUDNN(Sigmoid);
......
...@@ -45,15 +45,8 @@ void CuDNNSoftmaxOp<Context>::RunOnDevice() { ...@@ -45,15 +45,8 @@ void CuDNNSoftmaxOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -91,15 +84,8 @@ void CuDNNSoftmaxGradientOp<Context>::RunOnDevice() { ...@@ -91,15 +84,8 @@ void CuDNNSoftmaxGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CUDNN(Softmax); DEPLOY_CUDNN(Softmax);
......
...@@ -35,15 +35,8 @@ template <class Context> ...@@ -35,15 +35,8 @@ template <class Context>
void CuDNNTanhOp<Context>::RunOnDevice() { void CuDNNTanhOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -82,15 +75,8 @@ template <class Context> ...@@ -82,15 +75,8 @@ template <class Context>
void CuDNNTanhGradientOp<Context>::RunOnDevice() { void CuDNNTanhGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CUDNN(Tanh); DEPLOY_CUDNN(Tanh);
......
...@@ -44,15 +44,8 @@ template <class Context> ...@@ -44,15 +44,8 @@ template <class Context>
void DropoutOp<Context>::RunOnDevice() { void DropoutOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -83,15 +76,8 @@ template <class Context> ...@@ -83,15 +76,8 @@ template <class Context>
void DropoutGradientOp<Context>::RunOnDevice() { void DropoutGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CPU(Dropout); DEPLOY_CPU(Dropout);
......
...@@ -52,15 +52,8 @@ void DropPathOp<Context>::RunOnDevice() { ...@@ -52,15 +52,8 @@ void DropPathOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -97,15 +90,8 @@ void DropPathGradientOp<Context>::RunOnDevice() { ...@@ -97,15 +90,8 @@ void DropPathGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CPU(DropPath); DEPLOY_CPU(DropPath);
......
...@@ -20,13 +20,8 @@ template <class Context> ...@@ -20,13 +20,8 @@ template <class Context>
void EluOp<Context>::RunOnDevice() { void EluOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(
X(0), { "float32" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -46,13 +41,8 @@ template <class Context> ...@@ -46,13 +41,8 @@ template <class Context>
void EluGradientOp<Context>::RunOnDevice() { void EluGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(
X(0), { "float32" }
);
}
} }
DEPLOY_CPU(Elu); DEPLOY_CPU(Elu);
......
...@@ -40,13 +40,8 @@ void PReluOp<Context>::RunOnDevice() { ...@@ -40,13 +40,8 @@ void PReluOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -98,13 +93,8 @@ void PReluGradientOp<Context>::RunOnDevice() { ...@@ -98,13 +93,8 @@ void PReluGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
Y(1)->ReshapeLike(X(1)); Y(1)->ReshapeLike(X(1));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(
X(0), { "float32" }
);
}
} }
DEPLOY_CPU(PRelu); DEPLOY_CPU(PRelu);
......
...@@ -20,15 +20,8 @@ template <class Context> ...@@ -20,15 +20,8 @@ template <class Context>
void ReluOp<Context>::RunOnDevice() { void ReluOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -48,15 +41,8 @@ template <class Context> ...@@ -48,15 +41,8 @@ template <class Context>
void ReluGradientOp<Context>::RunOnDevice() { void ReluGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CPU(Relu); DEPLOY_CPU(Relu);
......
...@@ -19,15 +19,8 @@ template <class Context> ...@@ -19,15 +19,8 @@ template <class Context>
void SEluOp<Context>::RunOnDevice() { void SEluOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -47,15 +40,8 @@ template <class Context> ...@@ -47,15 +40,8 @@ template <class Context>
void SEluGradientOp<Context>::RunOnDevice() { void SEluGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
DEPLOY_CPU(SElu); DEPLOY_CPU(SElu);
......
...@@ -15,13 +15,8 @@ template <class Context> ...@@ -15,13 +15,8 @@ template <class Context>
void SigmoidOp<Context>::RunOnDevice() { void SigmoidOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(
X(0), { "float32" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -41,13 +36,8 @@ template <class Context> ...@@ -41,13 +36,8 @@ template <class Context>
void SigmoidGradientOp<Context>::RunOnDevice() { void SigmoidGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32" }
);
}
} }
DEPLOY_CPU(Sigmoid); DEPLOY_CPU(Sigmoid);
......
...@@ -43,13 +43,8 @@ void SoftmaxOp<Context>::RunOnDevice() { ...@@ -43,13 +43,8 @@ void SoftmaxOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(
X(0), { "float32" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -86,13 +81,8 @@ void SoftmaxGradientOp<Context>::RunOnDevice() { ...@@ -86,13 +81,8 @@ void SoftmaxGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(
X(0), { "float32" }
);
}
} }
DEPLOY_CPU(Softmax); DEPLOY_CPU(Softmax);
......
...@@ -15,13 +15,8 @@ template <class Context> ...@@ -15,13 +15,8 @@ template <class Context>
void TanhOp<Context>::RunOnDevice() { void TanhOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(
X(0), { "float32" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -41,13 +36,8 @@ template <class Context> ...@@ -41,13 +36,8 @@ template <class Context>
void TanhGradientOp<Context>::RunOnDevice() { void TanhGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float>>::Call(this, X(0));
} else {
LOG(FATAL) << DTypeString(
X(0), { "float32" }
);
}
} }
DEPLOY_CPU(Tanh); DEPLOY_CPU(Tanh);
......
...@@ -46,15 +46,8 @@ void AffineOp<Context>::RunOnDevice() { ...@@ -46,15 +46,8 @@ void AffineOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(0));
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float32", "float16" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -111,9 +104,7 @@ void AffineGradientOp<Context>::RunImpl() { ...@@ -111,9 +104,7 @@ void AffineGradientOp<Context>::RunImpl() {
} }
template <class Context> template <typename T> template <class Context> template <typename T>
void AffineGradientOp<Context>::Reduce( void AffineGradientOp<Context>::Reduce(T* x, T* y) {
T* x,
T* y) {
vec32_t dims = { vec32_t dims = {
(int)outer_dim_, (int)outer_dim_,
(int)scale_dim_, (int)scale_dim_,
...@@ -138,15 +129,8 @@ void AffineGradientOp<Context>::RunOnDevice() { ...@@ -138,15 +129,8 @@ void AffineGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(-1)); Y(0)->ReshapeLike(X(-1));
if (XIsType(X(-1), float)) { DispatchHelper<TensorTypes
RunImpl<float>(); <float, float16>>::Call(this, X(-1));
} else if (XIsType(X(-1), float16)) {
RunImpl<float16>();
} else {
LOG(FATAL) << DTypeString(X(-1),
{ "float32", "float16" }
);
}
} }
DEPLOY_CPU(Affine); DEPLOY_CPU(Affine);
......
...@@ -108,13 +108,6 @@ void CuDNNAffineOp<Context>::RunOnDevice() { ...@@ -108,13 +108,6 @@ void CuDNNAffineOp<Context>::RunOnDevice() {
template <class Context> template <typename DT, typename CT> template <class Context> template <typename DT, typename CT>
void CuDNNAffineGradientOp<Context>::RunImpl() { void CuDNNAffineGradientOp<Context>::RunImpl() {
this->template ResetDesc<DT>(X(-1)); this->template ResetDesc<DT>(X(-1));
scale_dim_ = X(1).count();
outer_dim_ = X(-1).count(0, axis_);
inner_dim_ = X(-1).count(axis_ + num_axes_);
dim_ = scale_dim_ * inner_dim_;
reduce_dim_ = std::max(outer_dim_, inner_dim_);
Y(0)->ReshapeLike(X(-1));
auto* alpha = X(1).template data<DT, Context>(); auto* alpha = X(1).template data<DT, Context>();
auto* dy = X(-1).template mutable_data<DT, Context>(); auto* dy = X(-1).template mutable_data<DT, Context>();
...@@ -230,9 +223,7 @@ void CuDNNAffineGradientOp<Context>::CuDNNReduce( ...@@ -230,9 +223,7 @@ void CuDNNAffineGradientOp<Context>::CuDNNReduce(
} }
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNAffineGradientOp<Context>::Reduce( void CuDNNAffineGradientOp<Context>::Reduce(T* x, T* y) {
T* x,
T* y) {
vec32_t dims = { vec32_t dims = {
(int)outer_dim_, (int)outer_dim_,
(int)scale_dim_, (int)scale_dim_,
...@@ -248,6 +239,14 @@ void CuDNNAffineGradientOp<Context>::Reduce( ...@@ -248,6 +239,14 @@ void CuDNNAffineGradientOp<Context>::Reduce(
template <class Context> template <class Context>
void CuDNNAffineGradientOp<Context>::RunOnDevice() { void CuDNNAffineGradientOp<Context>::RunOnDevice() {
scale_dim_ = X(1).count();
outer_dim_ = X(-1).count(0, axis_);
inner_dim_ = X(-1).count(axis_ + num_axes_);
dim_ = scale_dim_ * inner_dim_;
reduce_dim_ = std::max(outer_dim_, inner_dim_);
Y(0)->ReshapeLike(X(-1));
if (XIsType(X(-1), float)) { if (XIsType(X(-1), float)) {
RunImpl<float, float>(); RunImpl<float, float>();
} else if (XIsType(X(-1), float16)) { } else if (XIsType(X(-1), float16)) {
......
...@@ -36,6 +36,13 @@ void EltwiseOp<Context>::ProdRunImpl() { ...@@ -36,6 +36,13 @@ void EltwiseOp<Context>::ProdRunImpl() {
template <class Context> template <typename T> template <class Context> template <typename T>
void EltwiseOp<Context>::RunImpl() { void EltwiseOp<Context>::RunImpl() {
if (operation_ == "SUM") SumRunImpl<T>();
else if (operation_ == "PROD") ProdRunImpl<T>();
else LOG(FATAL) << "Unknwon Operation: " << operation_;
}
template <class Context>
void EltwiseOp<Context>::RunOnDevice() {
for (int i = 1; i < XSize(); i++) { for (int i = 1; i < XSize(); i++) {
CHECK(X(i).dims() == X(0).dims()) CHECK(X(i).dims() == X(0).dims())
<< "\nExcepted Input(" << i << ")'s dims as " << "\nExcepted Input(" << i << ")'s dims as "
...@@ -45,33 +52,10 @@ void EltwiseOp<Context>::RunImpl() { ...@@ -45,33 +52,10 @@ void EltwiseOp<Context>::RunImpl() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (operation_ == "SUM") SumRunImpl<T>(); DispatchHelper<TensorTypes
else if (operation_ == "PROD") ProdRunImpl<T>(); <int8_t, uint8_t, int, int64_t,
else LOG(FATAL) << "Unknwon Operation: " << operation_; float16, float, double>
} >::Call(this, X(0));
template <class Context>
void EltwiseOp<Context>::RunOnDevice() {
if (XIsType(X(0), int8_t)) {
RunImpl<int8_t>();
} else if (XIsType(X(0), uint8_t)) {
RunImpl<uint8_t>();
} else if (XIsType(X(0), int)) {
RunImpl<int>();
} else if (XIsType(X(0), int64_t)) {
RunImpl<int64_t>();
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else if (XIsType(X(0), float)) {
RunImpl<float>();
} else if (XIsType(X(0), double)) {
RunImpl<double>();
} else {
LOG(FATAL) << DTypeString(X(0), {
"int8", "uint8", "int32", "int64",
"float16", "float32", "float64",
});
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -133,26 +117,10 @@ void EltwiseGradientOp<Context>::RunImpl() { ...@@ -133,26 +117,10 @@ void EltwiseGradientOp<Context>::RunImpl() {
template <class Context> template <class Context>
void EltwiseGradientOp<Context>::RunOnDevice() { void EltwiseGradientOp<Context>::RunOnDevice() {
if (XIsType(X(0), int8_t)) { DispatchHelper<TensorTypes
RunImpl<int8_t>(); <int8_t, uint8_t, int, int64_t,
} else if (XIsType(X(0), uint8_t)) { float16, float, double>
RunImpl<uint8_t>(); >::Call(this, X(0));
} else if (XIsType(X(0), int)) {
RunImpl<int>();
} else if (XIsType(X(0), int64_t)) {
RunImpl<int64_t>();
} else if (XIsType(X(0), float16)) {
RunImpl<float16>();
} else if (XIsType(X(0), float)) {
RunImpl<float>();
} else if (XIsType(X(0), double)) {
RunImpl<double>();
} else {
LOG(FATAL) << DTypeString(X(0), {
"int8", "uint8", "int32", "int64",
"float16", "float32", "float64",
});
}
} }
DEPLOY_CPU(Eltwise); DEPLOY_CPU(Eltwise);
......
...@@ -15,17 +15,9 @@ template <class Context> ...@@ -15,17 +15,9 @@ template <class Context>
void ExpOp<Context>::RunOnDevice() { void ExpOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float16)) { DispatchHelper<TensorTypes
RunImpl<float16>(); <float, float16, double>
} else if (XIsType(X(0), float)) { >::Call(this, X(0));
RunImpl<float>();
} else if (XIsType(X(0), double)) {
RunImpl<double>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float16", "float32", "float64" }
);
}
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -40,17 +32,9 @@ template <class Context> ...@@ -40,17 +32,9 @@ template <class Context>
void ExpGradientOp<Context>::RunOnDevice() { void ExpGradientOp<Context>::RunOnDevice() {
Y(0)->ReshapeLike(X(0)); Y(0)->ReshapeLike(X(0));
if (XIsType(X(0), float16)) { DispatchHelper<TensorTypes
RunImpl<float16>(); <float, float16, double>
} else if (XIsType(X(0), float)) { >::Call(this, X(0));
RunImpl<float>();
} else if (XIsType(X(0), double)) {
RunImpl<double>();
} else {
LOG(FATAL) << DTypeString(X(0),
{ "float16", "float32", "float64" }
);
}
} }
DEPLOY_CPU(Exp); DEPLOY_CPU(Exp);
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!