Commit 1d551431 by Ting PAN

Re-implement Softmax Focal Loss

1 parent 5dea1524
...@@ -52,9 +52,9 @@ using Set = std::unordered_set<Value> ; ...@@ -52,9 +52,9 @@ using Set = std::unordered_set<Value> ;
/* /*
* Define the Kernel version. * Define the Kernel version.
* *
* | Major(2) | Minor(2) | Patch(06) | * | Major(2) | Minor(2) | Patch(07) |
*/ */
#define DRAGON_VERSION 2206 #define DRAGON_VERSION 2207
/* /*
* Define the default random seed. * Define the default random seed.
......
...@@ -90,8 +90,10 @@ class Operator : public OperatorBase { ...@@ -90,8 +90,10 @@ class Operator : public OperatorBase {
public: public:
Operator(const OperatorDef& def, Workspace* ws) Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws), ctx_(def.device_option()), : OperatorBase(def, ws), ctx_(def.device_option()),
recomputing_aware_(OperatorBase::Arg<bool>( allow_recompute_(OperatorBase::Arg<bool>(
"recomputing_aware", false)) { "recomputing_aware", false)),
do_sync_(OperatorBase::Arg<bool>(
"do_sync", true)) {
allow_run_ = true; allow_run_ = true;
allow_run_ &= _MPICheck(); allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && allow_run_ &= (!(OutputSize() == 1 &&
...@@ -100,12 +102,12 @@ class Operator : public OperatorBase { ...@@ -100,12 +102,12 @@ class Operator : public OperatorBase {
virtual void Run() final { virtual void Run() final {
if (!allow_run_) return; if (!allow_run_) return;
if (recomputing_aware_) MakeResource(); if (allow_recompute_) MakeResource();
ctx().SwitchToDevice(); ctx().SwitchToDevice();
MemorySwitch(); MemorySwitch();
RunOnDevice(); RunOnDevice();
ctx().FinishDeviceCompution(); if (do_sync_) ctx().FinishDeviceCompution();
if (recomputing_aware_) CleanResource(); if (allow_recompute_) CleanResource();
} }
virtual void ElimateCorruption(); virtual void ElimateCorruption();
...@@ -126,7 +128,7 @@ class Operator : public OperatorBase { ...@@ -126,7 +128,7 @@ class Operator : public OperatorBase {
protected: protected:
Context ctx_; Context ctx_;
bool allow_run_, recomputing_aware_; bool allow_run_, allow_recompute_, do_sync_;
private: private:
bool _MPICheck() { bool _MPICheck() {
......
...@@ -24,11 +24,11 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -24,11 +24,11 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>( normalization(OperatorBase::Arg<string>(
"normalization", "VALID")) { "normalization", "VALID")) {
vector<int> ignores = OperatorBase::Args<int>("ignore_labels"); auto xs = OperatorBase::Args<int>("ignore_labels");
if (ignores.size()) { if (xs.size()) {
ignore.Reshape({ (TIndex)ignores.size() }); ignores.Reshape({ (TIndex)xs.size() });
auto* Idata = ignore.mutable_data<int, CPUContext>(); auto* Idata = ignores.mutable_data<int, CPUContext>();
for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i]; for (int i = 0; i < xs.size(); i++) Idata[i] = xs[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -41,8 +41,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -41,8 +41,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
protected: protected:
TIndex axis, outer_dim, inner_dim; TIndex axis, outer_dim, inner_dim;
Tensor ignore, valid, losses; Tensor* prob, losses, flags, ignores;
Tensor* prob;
unique_ptr<OperatorBase> softmax_op; unique_ptr<OperatorBase> softmax_op;
string normalization; string normalization;
}; };
...@@ -55,11 +54,11 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> { ...@@ -55,11 +54,11 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>( normalization(OperatorBase::Arg<string>(
"normalization", "VALID")) { "normalization", "VALID")) {
vector<int> ignores = OperatorBase::Args<int>("ignore_labels"); auto xs = OperatorBase::Args<int>("ignore_labels");
if (ignores.size()) { if (xs.size()) {
ignore.Reshape({ (TIndex)ignores.size() }); ignores.Reshape({ (TIndex)xs.size() });
auto* Idata = ignore.mutable_data<int, CPUContext>(); auto* Idata = ignores.mutable_data<int, CPUContext>();
for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i]; for (int i = 0; i < xs.size(); i++) Idata[i] = xs[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -69,8 +68,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> { ...@@ -69,8 +68,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
protected: protected:
TIndex axis, outer_dim, inner_dim; TIndex axis, outer_dim, inner_dim;
Tensor ignore, valid; Tensor* prob, ignores, flags;
Tensor* prob;
string normalization; string normalization;
}; };
......
...@@ -17,18 +17,19 @@ ...@@ -17,18 +17,19 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Context> { class SparseSoftmaxFocalLossOp final
: public SparseSoftmaxCrossEntropyOp<Context> {
public: public:
SparseSoftmaxFocalLossOp(const OperatorDef& def, Workspace* ws) SparseSoftmaxFocalLossOp(const OperatorDef& def, Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(def, ws), : SparseSoftmaxCrossEntropyOp<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>( normalization(OperatorBase::Arg<string>(
"normalization", "VALID")), "normalization", "VALID")),
alpha(OperatorBase::Arg<float>("alpha", 0.5)), alpha(OperatorBase::Arg<float>("alpha", 0.25f)),
gamma(OperatorBase::Arg<float>("gamma", 0.0)), gamma(OperatorBase::Arg<float>("gamma", 2.f)),
neg_id(OperatorBase::Arg<int>("neg_id", -1)) { neg_id(OperatorBase::Arg<int>("neg_id", 0)) {
pos_alpha = alpha * 2.0; pos_alpha = alpha;
neg_alpha = (1 - alpha) * 2.0; neg_alpha = 1.f - alpha;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -36,35 +37,36 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex ...@@ -36,35 +37,36 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float alpha, gamma; float alpha, gamma, pos_alpha, neg_alpha;
int neg_id; TIndex axis, neg_id, outer_dim, inner_dim;
float pos_alpha, neg_alpha; Tensor losses, flags;
TIndex axis, outer_dim, inner_dim;
Tensor* scale;
string normalization; string normalization;
}; };
template <class Context> template <class Context>
class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyGradientOp<Context> { class SparseSoftmaxFocalLossGradientOp final
: public SparseSoftmaxCrossEntropyGradientOp<Context> {
public: public:
SparseSoftmaxFocalLossGradientOp(const OperatorDef& def, Workspace* ws) SparseSoftmaxFocalLossGradientOp(const OperatorDef& def, Workspace* ws)
: SparseSoftmaxCrossEntropyGradientOp<Context>(def, ws), : SparseSoftmaxCrossEntropyGradientOp<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>( normalization(OperatorBase::Arg<string>(
"normalization", "VALID")), "normalization", "VALID")),
gamma(OperatorBase::Arg<float>("gamma", 0.0)), alpha(OperatorBase::Arg<float>("alpha", 0.25f)),
eps(OperatorBase::Arg<float>("eps", float(1e-10))), gamma(OperatorBase::Arg<float>("gamma", 2.f)),
neg_id(OperatorBase::Arg<int>("neg_id", -1)) {} neg_id(OperatorBase::Arg<int>("neg_id", 0)) {
pos_alpha = alpha;
neg_alpha = 1.f - alpha;
}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float gamma, eps; float alpha, gamma, pos_alpha, neg_alpha;
int neg_id; TIndex axis, neg_id, outer_dim, inner_dim;
TIndex axis, outer_dim, inner_dim; Tensor flags;
Tensor* scale;
string normalization; string normalization;
}; };
......
...@@ -289,37 +289,36 @@ void SoftmaxCrossEntropy( ...@@ -289,37 +289,36 @@ void SoftmaxCrossEntropy(
template <typename Tx, typename Ty, class Context> template <typename Tx, typename Ty, class Context>
void SparseSoftmaxCrossEntropy( void SparseSoftmaxCrossEntropy(
const int count,
const int classes,
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const Tx* prob, const Tx* prob,
const Ty* labels, const Ty* labels,
Tx* loss, const int* ignores,
Tx* valid, const int num_ignores,
Tensor* ignore, Tx* losses,
Tx* flags,
Context* ctx); Context* ctx);
template <typename Tx, typename Ty, class Context> template <typename Tx, typename Ty, class Context>
void SparseSoftmaxCrossEntropyGrad( void SparseSoftmaxCrossEntropyGrad(
const int count,
const int classes,
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const Tx* prob, const Tx* prob,
const Ty* labels, const Ty* labels,
Tx* valid, const int* ignores,
Tensor* ignore, const int num_ignores,
Tx* dx, Tx* dx,
Tx* flags,
Context* ctx); Context* ctx);
/******************** loss.sparse_softmax_focal_loss ********************/ /******************** loss.sparse_softmax_focal_loss ********************/
template <typename T, class Context> template <typename T, class Context>
void SparseSoftmaxFocalLoss( void SparseSoftmaxFocalLoss(
const int count,
const int classes,
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const float pos_alpha, const float pos_alpha,
const float neg_alpha, const float neg_alpha,
...@@ -327,26 +326,28 @@ void SparseSoftmaxFocalLoss( ...@@ -327,26 +326,28 @@ void SparseSoftmaxFocalLoss(
const int neg_id, const int neg_id,
const T* prob, const T* prob,
const T* labels, const T* labels,
T* scale, const int* ignores,
T* loss, const int num_ignores,
T* valid, T* losses,
Tensor* ignore); T* flags ,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void SparseSoftmaxFocalLossGrad( void SparseSoftmaxFocalLossGrad(
const int count,
const int classes,
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const float pos_alpha,
const float neg_alpha,
const float gamma, const float gamma,
const int neg_id, const int neg_id,
const float eps,
const T* scale,
const T* prob, const T* prob,
const T* labels, const T* labels,
T* valid, const int* ignores,
Tensor* ignore, const int num_ignores,
T* dx); T* dx,
T* flags,
Context* ctx);
/******************** misc.astype ********************/ /******************** misc.astype ********************/
......
...@@ -227,6 +227,7 @@ PyMethodDef* GetAllMethods() { ...@@ -227,6 +227,7 @@ PyMethodDef* GetAllMethods() {
PYFUNC(HasTensorCC), PYFUNC(HasTensorCC),
PYFUNC(CreateTensorCC), PYFUNC(CreateTensorCC),
PYFUNC(CreateFillerCC), PYFUNC(CreateFillerCC),
PYFUNC(GetFillerTypeCC),
PYFUNC(RenameTensorCC), PYFUNC(RenameTensorCC),
PYFUNC(TensorFromShapeCC), PYFUNC(TensorFromShapeCC),
PYFUNC(TensorFromPyArrayCC), PYFUNC(TensorFromPyArrayCC),
......
...@@ -56,12 +56,14 @@ class NumpyFetcher : public TensorFetcherBase { ...@@ -56,12 +56,14 @@ class NumpyFetcher : public TensorFetcherBase {
for (const auto dim : tensor.dims()) npy_dims.push_back(dim); for (const auto dim : tensor.dims()) npy_dims.push_back(dim);
int npy_type = TypeMetaToNPY(tensor.meta()); int npy_type = TypeMetaToNPY(tensor.meta());
if (npy_type == -1) { if (npy_type == -1) {
string s = "The data type of Tensor(" + tensor.name() + ") is unknown. Have you solved it ?"; string s = "The data type of Tensor(" +
tensor.name() + ") is unknown. Have you solved it ?";
PyErr_SetString(PyExc_RuntimeError, s.c_str()); PyErr_SetString(PyExc_RuntimeError, s.c_str());
return nullptr; return nullptr;
} }
// create a empty array with r shape // create a empty array with r shape
PyObject* array = PyArray_SimpleNew(tensor.ndim(), npy_dims.data(), npy_type); PyObject* array = PyArray_SimpleNew(
tensor.ndim(), npy_dims.data(), npy_type);
// copy the tensor data to the numpy array // copy the tensor data to the numpy array
if (tensor.memory_state() == MixedMemory::STATE_AT_CUDA) { if (tensor.memory_state() == MixedMemory::STATE_AT_CUDA) {
CUDAContext::Memcpy<CPUContext, CUDAContext>(tensor.nbytes(), CUDAContext::Memcpy<CPUContext, CUDAContext>(tensor.nbytes(),
......
...@@ -52,6 +52,11 @@ inline PyObject* CreateFillerCC(PyObject* self, PyObject* args) { ...@@ -52,6 +52,11 @@ inline PyObject* CreateFillerCC(PyObject* self, PyObject* args) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
inline PyObject* GetFillerTypeCC(PyObject* self, PyObject* args) {
const auto* f = ws()->GetFiller(ParseName(self, args));
return String_AsPyUnicode(f->type());
}
inline PyObject* RenameTensorCC(PyObject* self, PyObject* args) { inline PyObject* RenameTensorCC(PyObject* self, PyObject* args) {
char* ori_name, *tar_name; char* ori_name, *tar_name;
if (!PyArg_ParseTuple(args, "ss", &ori_name, &tar_name)) { if (!PyArg_ParseTuple(args, "ss", &ori_name, &tar_name)) {
......
...@@ -44,6 +44,7 @@ __all__ = [ ...@@ -44,6 +44,7 @@ __all__ = [
'HasTensor', 'HasTensor',
'CreateTensor', 'CreateTensor',
'CreateFiller', 'CreateFiller',
'GetFillerType',
'GetTensorName', 'GetTensorName',
'RenameTensor', 'RenameTensor',
'FeedTensor', 'FeedTensor',
...@@ -335,7 +336,7 @@ def CreateFiller(filler_def): ...@@ -335,7 +336,7 @@ def CreateFiller(filler_def):
Parameters Parameters
---------- ----------
filler_def : dragon_pb2.TensorFiller filler_def : dragon_pb2.TensorFiller
The The filler.
Returns Returns
------- -------
...@@ -356,6 +357,31 @@ def CreateFiller(filler_def): ...@@ -356,6 +357,31 @@ def CreateFiller(filler_def):
CreateFillerCC(filler_def) CreateFillerCC(filler_def)
def GetFillerType(tensor):
"""Get the filler type of specific tensor.
It is useful if you want to tag some tensors,
e.g. tag with ``numpy``, and get to initialize them lazily.
Parameters
----------
tensor : Tensor or str
The tensor to query.
Returns
-------
str
The filler type.
References
----------
The wrapper of ``GetFillerTypeCC``.
"""
return GetFillerTypeCC(_stringify_tensor(tensor))
def GetTensorName(tensor): def GetTensorName(tensor):
"""Query the name represented in current workspace. """Query the name represented in current workspace.
......
...@@ -218,7 +218,7 @@ def L2Loss(inputs, normalization='BATCH_SIZE', **kwargs): ...@@ -218,7 +218,7 @@ def L2Loss(inputs, normalization='BATCH_SIZE', **kwargs):
def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=(), def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=(),
alpha=0.5, gamma=0.0, eps=1e-10, neg_id=-1, **kwargs): alpha=0.25, gamma=2.0, neg_id=0, **kwargs):
"""SoftmaxFocalLoss with sparse labels. `[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`_. """SoftmaxFocalLoss with sparse labels. `[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`_.
Parameters Parameters
...@@ -232,13 +232,11 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels= ...@@ -232,13 +232,11 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=
ignore_label : tuple or list ignore_label : tuple or list
The label id to ignore. Default is ``empty``. The label id to ignore. Default is ``empty``.
alpha : float alpha : float
The scale factor on the rare class. Default is ``0.5``. The scale factor on the rare class. Default is ``0.25``.
gamma : float gamma : float
The exponential decay factor on the easy examples. Default is ``0.0``. The exponential decay factor on the easy examples. Default is ``2.0``.
eps : float
The eps.
neg_id : int neg_id : int
The negative id. Default is ``-1`` (Without Class Balance) The negative id. Default is ``0``.
Returns Returns
------- -------
......
...@@ -14,7 +14,7 @@ from __future__ import division ...@@ -14,7 +14,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
version = '0.2.2' version = '0.2.2'
full_version = '0.2.2.5' full_version = '0.2.2.7'
release = False release = False
if not release: if not release:
......
...@@ -149,8 +149,6 @@ class SoftmaxWithFocalLossLayer(Layer): ...@@ -149,8 +149,6 @@ class SoftmaxWithFocalLossLayer(Layer):
The scale on the rare class. Refer `FocalLossParameter.alpha`_. The scale on the rare class. Refer `FocalLossParameter.alpha`_.
gamma : float gamma : float
The exponential decay. Refer `FocalLossParameter.gamma`_. The exponential decay. Refer `FocalLossParameter.gamma`_.
eps : float
The eps. Refer `FocalLossParameter.eps`_.
neg_id : int neg_id : int
The negative id. Refer `FocalLossParameter.neg_id`_. The negative id. Refer `FocalLossParameter.neg_id`_.
normalization : NormalizationMode normalization : NormalizationMode
...@@ -174,7 +172,6 @@ class SoftmaxWithFocalLossLayer(Layer): ...@@ -174,7 +172,6 @@ class SoftmaxWithFocalLossLayer(Layer):
'ignore_labels': [param.ignore_label] if param.HasField('ignore_label') else [], 'ignore_labels': [param.ignore_label] if param.HasField('ignore_label') else [],
'alpha': float(focal_loss_param.alpha), 'alpha': float(focal_loss_param.alpha),
'gamma': float(focal_loss_param.gamma), 'gamma': float(focal_loss_param.gamma),
'eps': float(focal_loss_param.eps),
'neg_id': focal_loss_param.neg_id} 'neg_id': focal_loss_param.neg_id}
def Setup(self, bottom): def Setup(self, bottom):
......
...@@ -1504,10 +1504,9 @@ message DenseConcatParameter { ...@@ -1504,10 +1504,9 @@ message DenseConcatParameter {
} }
message FocalLossParameter { message FocalLossParameter {
optional float alpha = 1 [default = 0.5]; optional float alpha = 1 [default = 0.25];
optional float gamma = 2 [default = 0.0]; optional float gamma = 2 [default = 2.0];
optional float eps = 3 [default = 1e-10]; optional int32 neg_id = 3 [default = 0];
optional int32 neg_id = 4 [default = -1];
} }
message GatherParameter { message GatherParameter {
......
...@@ -42,7 +42,7 @@ find_modules() ...@@ -42,7 +42,7 @@ find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.2.6', version='0.2.2.7',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/seetaresearch/Dragon', url='https://github.com/seetaresearch/Dragon',
author='Ting Pan', author='Ting Pan',
......
...@@ -229,11 +229,35 @@ void Operator<Context>::CleanResource() { ...@@ -229,11 +229,35 @@ void Operator<Context>::CleanResource() {
} }
} }
DEFINE_REGISTRY(CPUOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); DEFINE_REGISTRY(
DEFINE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); CPUOperatorRegistry,
DEFINE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); OperatorBase,
DEFINE_REGISTRY(GradientRegistry, GradientMakerBase, const OperatorDef&, const vector<string>&); const OperatorDef&,
DEFINE_REGISTRY(NoGradientRegistry, GradientMakerBase, const OperatorDef&, const vector<string>&); Workspace*);
DEFINE_REGISTRY(
CUDAOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
DEFINE_REGISTRY(
CUDNNOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
DEFINE_REGISTRY(
GradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
DEFINE_REGISTRY(
NoGradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname) \ #define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname) \
template <> T OperatorBase::Arg( \ template <> T OperatorBase::Arg( \
...@@ -252,7 +276,6 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(string, s) ...@@ -252,7 +276,6 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(string, s)
INSTANTIATE_GET_SINGLE_ARGUMENT(bool, b); INSTANTIATE_GET_SINGLE_ARGUMENT(bool, b);
INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i64); INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i64);
#define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname) \ #define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname) \
template<> vector<T> OperatorBase::Args<T>(const string& name) { \ template<> vector<T> OperatorBase::Args<T>(const string& name) { \
if(args_.count(name) == 0) return vector<T>(); \ if(args_.count(name) == 0) return vector<T>(); \
......
...@@ -42,16 +42,17 @@ void SparseSoftmaxCrossEntropyOp<Context>::SoftmaxRunFP16() { ...@@ -42,16 +42,17 @@ void SparseSoftmaxCrossEntropyOp<Context>::SoftmaxRunFP16() {
template <class Context> template <typename Tx, typename Ty> template <class Context> template <typename Tx, typename Ty>
void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() { void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() {
auto* prob_data = prob->template data<Tx, Context>(); auto* Pdata = prob->template data<Tx, Context>();
auto* label_data = Input(1).template data<Ty, Context>(); auto* Tdata = Input(1).template data<Ty, Context>();
auto* loss_data = losses.template mutable_data<Tx, Context>(); auto* Idata = !ignores.count() ? nullptr :
auto* valid_data = valid.template mutable_data<Tx, Context>(); ignores.template data<int, Context>();
auto* Ldata = losses.template mutable_data<Tx, Context>();
auto* Fdata = flags.template mutable_data<Tx, Context>();
kernel::SparseSoftmaxCrossEntropy<Tx, Ty, Context>( kernel::SparseSoftmaxCrossEntropy<Tx, Ty, Context>(
Input(0).count(), Input(0).dim(axis), outer_dim, Input(0).dim(axis), inner_dim,
outer_dim, inner_dim, Pdata, Tdata, Idata, ignores.count(),
prob_data, label_data, loss_data, Ldata, Fdata, &ctx());
valid_data, &ignore, &ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
...@@ -61,11 +62,12 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() { ...@@ -61,11 +62,12 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() {
Tx normalizer; Tx normalizer;
if (normalization == "VALID") if (normalization == "VALID")
normalizer = std::max(math::ASum<Tx, Context>(valid.count(), valid_data), (Tx)1.f); normalizer = std::max(
math::ASum<Tx, Context>(flags.count(), Fdata), (Tx)1.f);
else if (normalization == "BATCH_SIZE") normalizer = Input(0).dim(0); else if (normalization == "BATCH_SIZE") normalizer = Input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
Tx loss = math::ASum<Tx, Context>(losses.count(), loss_data); Tx loss = math::ASum<Tx, Context>(losses.count(), Ldata);
Output(0)->Reshape({ 1 }); Output(0)->Reshape({ 1 });
auto* Ydata = Output(0)->template mutable_data<Tx, Context>(); auto* Ydata = Output(0)->template mutable_data<Tx, Context>();
math::Set<Tx, Context>(1, loss / normalizer, Ydata); math::Set<Tx, Context>(1, loss / normalizer, Ydata);
...@@ -77,11 +79,12 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() { ...@@ -77,11 +79,12 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() {
inner_dim = Input(0).count(axis + 1); inner_dim = Input(0).count(axis + 1);
CHECK_EQ(outer_dim * inner_dim, Input(1).count()) CHECK_EQ(outer_dim * inner_dim, Input(1).count())
<< "\nNumber of predictions must match the number of labels."; << "\nNumber of predictions must match the number of labels.";
valid.Reshape({ outer_dim * inner_dim });
losses.Reshape({ outer_dim * inner_dim }); losses.Reshape({ outer_dim * inner_dim });
flags.Reshape({ outer_dim * inner_dim });
prob = ws()->CreateTensor("/mnt/" + anchor() + "/softmax/prob"); prob = ws()->CreateTensor("/mnt/" + anchor() + "/softmax/prob");
if (XIsType(Input(0), float) || XIsType(Input(0), float16)) { if (XIsType(Input(0), float) ||
XIsType(Input(0), float16)) {
if (XIsType(Input(0), float16)) SoftmaxRunFP16(); if (XIsType(Input(0), float16)) SoftmaxRunFP16();
else SoftmaxRun(); else SoftmaxRun();
if (XIsType(Input(1), float)) RunWithType<float, float>(); if (XIsType(Input(1), float)) RunWithType<float, float>();
...@@ -98,33 +101,35 @@ OPERATOR_SCHEMA(SparseSoftmaxCrossEntropy).NumInputs(2).NumOutputs(1); ...@@ -98,33 +101,35 @@ OPERATOR_SCHEMA(SparseSoftmaxCrossEntropy).NumInputs(2).NumOutputs(1);
template <class Context> template <typename Tx, typename Ty> template <class Context> template <typename Tx, typename Ty>
void SparseSoftmaxCrossEntropyGradientOp<Context>::RunWithType() { void SparseSoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
auto* label_data = Input(1).template data<Ty, Context>(); auto* Pdata = prob->template mutable_data<Tx, Context>();
auto* prob_data = prob->template mutable_data<Tx, Context>(); auto* Tdata = Input(1).template data<Ty, Context>();
auto* Idata = !ignores.count() ? nullptr :
ignores.template data<int, Context>();
auto* dXdata = Output(0)->template mutable_data<Tx, Context>(); auto* dXdata = Output(0)->template mutable_data<Tx, Context>();
auto* valid_data = valid.template mutable_data<Tx, Context>(); auto* Fdata = flags.template mutable_data<Tx, Context>();
ctx().template Copy<Tx, Context, Context>(prob->count(), dXdata, prob_data); ctx().template Copy<Tx, Context, Context>(
prob->count(), dXdata, Pdata);
kernel::SparseSoftmaxCrossEntropyGrad<Tx, Ty, Context>( kernel::SparseSoftmaxCrossEntropyGrad<Tx, Ty, Context>(
Output(0)->count(), Output(0)->dim(axis), outer_dim, Output(0)->dim(axis), inner_dim,
outer_dim, inner_dim, Pdata, Tdata, Idata, ignores.count(),
prob_data, label_data, valid_data, dXdata, Fdata, &ctx());
&ignore, dXdata, &ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = Input(-1).template data<Tx, Context>(); auto* dYdata = Input(-1).template data<Tx, Context>();
kernel::SumGrad<Tx, Context>( kernel::SumGrad<Tx, Context>(
Input(0).count() / Input(0).dim(axis), Input(0).count() / Input(0).dim(axis),
Input(0).dim(axis), inner_dim, Input(0).dim(axis), inner_dim,
1.0, dYdata, prob_data); 1.0, dYdata, Pdata);
math::Mul<Tx, Context>( math::Mul<Tx, Context>(
Output(0)->count(), prob_data, dXdata, dXdata); Output(0)->count(), Pdata, dXdata, dXdata);
return; return;
} }
Tx normalizer; Tx normalizer;
if (normalization == "VALID") if (normalization == "VALID")
normalizer = std::max( normalizer = std::max(
math::ASum<Tx, Context>(valid.count(), valid_data), (Tx)1.f); math::ASum<Tx, Context>(flags.count(), Fdata), (Tx)1.f);
else if (normalization == "BATCH_SIZE") normalizer = Input(0).dim(0); else if (normalization == "BATCH_SIZE") normalizer = Input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
...@@ -141,7 +146,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() { ...@@ -141,7 +146,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
outer_dim = prob->count(0, axis); outer_dim = prob->count(0, axis);
inner_dim = prob->count(axis + 1); inner_dim = prob->count(axis + 1);
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
valid.Reshape({ outer_dim * inner_dim }); flags.Reshape({ outer_dim * inner_dim });
if (XIsType(Input(0), float) || XIsType(Input(0), float16)) { if (XIsType(Input(0), float) || XIsType(Input(0), float16)) {
if (XIsType(Input(1), float)) RunWithType<float, float>(); if (XIsType(Input(1), float)) RunWithType<float, float>();
......
...@@ -9,31 +9,33 @@ namespace dragon { ...@@ -9,31 +9,33 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void SparseSoftmaxFocalLossOp<Context>::RunWithType() { void SparseSoftmaxFocalLossOp<Context>::RunWithType() {
auto* prob_data = this->prob->template data<T, Context>(); auto* Pdata = this->prob->template data<T, Context>();
auto* label_data = Input(1).template data<T, Context>(); auto* Tdata = Input(1).template data<T, Context>();
auto* loss_data = this->losses.template mutable_data<T, Context>(); auto* Idata = !this->ignores.count() ? nullptr :
auto* valid_data = this->valid.template mutable_data<T, Context>(); this->ignores.template data<int, Context>();
auto* scale_data = scale->template mutable_data<T, Context>(); auto* Ldata = losses.template mutable_data<T, Context>();
auto* Fdata = flags.template mutable_data<T, Context>();
kernel::SparseSoftmaxFocalLoss<T, Context>( kernel::SparseSoftmaxFocalLoss<T, Context>(
Input(0).count(), Input(0).dim(axis), outer_dim, inner_dim, outer_dim, Input(0).dim(axis), inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
prob_data, label_data, scale_data, Pdata, Tdata, Idata, this->ignores.count(),
loss_data, valid_data, &this->ignore); Ldata, Fdata, &ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(this->losses); Output(0)->ReshapeLike(losses);
Output(0)->template Copy<Context, Context>(this->losses); Output(0)->template Copy<Context, Context>(losses);
return; return;
} }
T normalizer; T normalizer;
if (normalization == "VALID") if (normalization == "VALID")
normalizer = std::max(math::ASum<T, Context>(this->valid.count(), valid_data), 1.f); normalizer = std::max(
math::ASum<T, Context>(flags.count(), Fdata), 1.f);
else if (normalization == "BATCH_SIZE") normalizer = Input(0).dim(0); else if (normalization == "BATCH_SIZE") normalizer = Input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(this->losses.count(), loss_data); T loss = math::ASum<T, Context>(losses.count(), Ldata);
Output(0)->Reshape({ 1 }); Output(0)->Reshape({ 1 });
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(1, loss / normalizer, Ydata); math::Set<T, Context>(1, loss / normalizer, Ydata);
...@@ -45,13 +47,11 @@ void SparseSoftmaxFocalLossOp<Context>::RunOnDevice() { ...@@ -45,13 +47,11 @@ void SparseSoftmaxFocalLossOp<Context>::RunOnDevice() {
inner_dim = Input(0).count(axis + 1); inner_dim = Input(0).count(axis + 1);
CHECK_EQ(outer_dim * inner_dim, Input(1).count()) CHECK_EQ(outer_dim * inner_dim, Input(1).count())
<< "\nNumber of predictions must match the number of labels."; << "\nNumber of predictions must match the number of labels.";
this->valid.Reshape({ outer_dim * inner_dim }); flags.Reshape({ outer_dim * inner_dim });
this->losses.Reshape({ outer_dim * inner_dim }); losses.Reshape({ outer_dim * inner_dim });
ws()->CreateTensor("/mnt/" + anchor() + "/softmax/prob"); ws()->CreateTensor("/mnt/" + anchor() + "/softmax/prob");
this->SoftmaxRun(); this->SoftmaxRun();
this->prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax/prob"); this->prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax/prob");
scale = ws()->CreateTensor("/mnt/" + anchor() + "/focal/scale");
scale->ReshapeLike(*this->prob);
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
...@@ -65,31 +65,33 @@ OPERATOR_SCHEMA(SparseSoftmaxFocalLoss).NumInputs(2).NumOutputs(1); ...@@ -65,31 +65,33 @@ OPERATOR_SCHEMA(SparseSoftmaxFocalLoss).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() { void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() {
auto* label_data = Input(1).template data<T, Context>(); auto* Pdata = this->prob->template mutable_data<T, Context>();
auto* prob_data = this->prob->template mutable_data<T, Context>(); auto* Tdata = Input(1).template data<T, Context>();
auto* Idata = !this->ignores.count() ? nullptr :
this->ignores.template data<int, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
auto* valid_data = this->valid.template mutable_data<T, Context>(); auto* Fdata = flags.template mutable_data<T, Context>();
auto* scale_data = scale->template mutable_data<T, Context>();
kernel::SparseSoftmaxFocalLossGrad<T, Context>( kernel::SparseSoftmaxFocalLossGrad<T, Context>(
Output(0)->count(), Output(0)->dim(axis), outer_dim, inner_dim, outer_dim, Output(0)->dim(axis), inner_dim,
gamma, neg_id, eps, scale_data, prob_data, label_data, pos_alpha, neg_alpha, gamma, neg_id,
valid_data, &this->ignore, dXdata); Pdata, Tdata, Idata, this->ignores.count(),
dXdata, Fdata, &ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>( kernel::SumGrad<T, Context>(
Input(0).count() / Input(0).dim(axis), Input(0).count() / Input(0).dim(axis),
Input(0).dim(axis), inner_dim, Input(0).dim(axis), inner_dim,
1.0, dYdata, prob_data); 1.0, dYdata, Pdata);
math::Mul<T, Context>(Output(0)->count(), math::Mul<T, Context>(Output(0)->count(),
prob_data, dXdata, dXdata); return; Pdata, dXdata, dXdata); return;
} }
T normalizer; T normalizer;
if (normalization == "VALID") if (normalization == "VALID")
normalizer = std::max( normalizer = std::max(
math::ASum<T, Context>(this->valid.count(), valid_data), 1.f); math::ASum<T, Context>(flags.count(), Fdata), 1.f);
else if (normalization == "BATCH_SIZE") normalizer = Input(0).dim(0); else if (normalization == "BATCH_SIZE") normalizer = Input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
...@@ -103,11 +105,10 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() { ...@@ -103,11 +105,10 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() {
template <class Context> template <class Context>
void SparseSoftmaxFocalLossGradientOp<Context>::RunOnDevice() { void SparseSoftmaxFocalLossGradientOp<Context>::RunOnDevice() {
this->prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax/prob"); this->prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax/prob");
scale = ws()->GetTensor("/mnt/" + anchor() + "/focal/scale");
outer_dim = this->prob->count(0, axis); outer_dim = this->prob->count(0, axis);
inner_dim = this->prob->count(axis + 1); inner_dim = this->prob->count(axis + 1);
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
this->valid.Reshape({ outer_dim * inner_dim }); flags.Reshape({ outer_dim * inner_dim });
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
......
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
#ifdef WITH_PYTHON #ifdef WITH_PYTHON
#ifdef WITH_PYTHON3 #ifdef WITH_PYTHON3
#define PyBytes_FromStringAndSize PyUnicode_FromStringAndSize #define PyBytes_FromStringAndSize \
PyUnicode_FromStringAndSize
#endif #endif
#define String(str) \ #define Bytes(str) \
PyBytes_FromStringAndSize(str, string(str).size()) PyBytes_FromStringAndSize(str, string(str).size())
#define CS2Bytes(cstr) Bytes(cstr.c_str())
namespace dragon { namespace dragon {
template <class Context> template <class Context>
...@@ -17,6 +20,9 @@ RunOp<Context>::RunOp(const OperatorDef& def, Workspace* ws) ...@@ -17,6 +20,9 @@ RunOp<Context>::RunOp(const OperatorDef& def, Workspace* ws)
module(OperatorBase::Arg<string>("module", "")), module(OperatorBase::Arg<string>("module", "")),
op(OperatorBase::Arg<string>("op", "")), op(OperatorBase::Arg<string>("op", "")),
param_str((OperatorBase::Arg<string>("param_str", ""))) { param_str((OperatorBase::Arg<string>("param_str", ""))) {
// optimization for all python ops
if (!AllowRun()) return; this->do_sync_ = false;
// init interpreter & load module // init interpreter & load module
Py_Initialize(); Py_Initialize();
PyObject* py_module = PyImport_ImportModule(module.c_str()); PyObject* py_module = PyImport_ImportModule(module.c_str());
...@@ -27,37 +33,38 @@ RunOp<Context>::RunOp(const OperatorDef& def, Workspace* ws) ...@@ -27,37 +33,38 @@ RunOp<Context>::RunOp(const OperatorDef& def, Workspace* ws)
<< " from module: " << module; << " from module: " << module;
self = PyObject_CallObject(py_op, NULL); self = PyObject_CallObject(py_op, NULL);
// pass param string // wrap inputs and outputs
PyObject_SetAttr(self, String("param_str"), String(param_str.c_str()));
PyObject_SetAttr(self, String("param_str_"), String(param_str.c_str()));
// build inputs and outputs for Python
inputs = PyList_New(InputSize()); inputs = PyList_New(InputSize());
for (int i = 0; i < InputSize(); i++) for (int i = 0; i < InputSize(); i++)
PyList_SetItem(inputs, i, String(Input(i).name().c_str())); PyList_SetItem(inputs, i, CS2Bytes(Input(i).name()));
outputs = PyList_New(OutputSize()); outputs = PyList_New(OutputSize());
for (int i = 0; i < OutputSize(); i++) for (int i = 0; i < OutputSize(); i++)
PyList_SetItem(outputs, i, String(Output(i)->name().c_str())); PyList_SetItem(outputs, i, CS2Bytes(Output(i)->name()));
if (!AllowRun()) return;
// backward compatibility: param_str
PyObject_SetAttr(self, Bytes("param_str"), CS2Bytes(param_str));
PyObject_SetAttr(self, Bytes("param_str_"), CS2Bytes(param_str));
// setup // backward compatibility: self.setup(inputs, outputs)
if (PyObject_HasAttr(self, String("setup"))) if (PyObject_HasAttr(self, Bytes("setup"))) {
PyObject_CallMethod(self, "setup", "OO", inputs, outputs); PyObject_CallMethod(self, "setup", "OO", inputs, outputs);
}
} }
template <class Context> template <class Context>
void RunOp<Context>::RunOnDevice() { void RunOp<Context>::RunOnDevice() {
// init phase // reset phase
PyObject_SetAttr(self, String("phase"), String(phase().c_str())); PyObject_SetAttr(self, Bytes("phase"), CS2Bytes(phase()));
// reshape // backward compatibility: reshape(inputs, outputs)
if (PyObject_HasAttr(self, String("reshape"))) if (PyObject_HasAttr(self, Bytes("reshape"))) {
PyObject_CallMethod(self, "reshape", "OO", inputs, outputs); PyObject_CallMethod(self, "reshape", "OO", inputs, outputs);
}
// run // overloaded run inferfaces
if (PyObject_HasAttr(self, String("forward"))) { if (PyObject_HasAttr(self, Bytes("forward"))) {
PyObject_CallMethod(self, "forward", "OO", inputs, outputs); PyObject_CallMethod(self, "forward", "OO", inputs, outputs);
} else if (PyObject_HasAttr(self, String("run"))) { } else if (PyObject_HasAttr(self, Bytes("run"))) {
PyObject_CallMethod(self, "run", "OO", inputs, outputs); PyObject_CallMethod(self, "run", "OO", inputs, outputs);
} }
} }
...@@ -72,18 +79,23 @@ NO_GRADIENT(Run); ...@@ -72,18 +79,23 @@ NO_GRADIENT(Run);
template <class Context> template <class Context>
void TemplateGradientOp<Context>::RunOnDevice() { void TemplateGradientOp<Context>::RunOnDevice() {
// init phase // reset phase
PyObject_SetAttr(this->self, String("phase"), String(phase().c_str())); PyObject_SetAttr(this->self,
Bytes("phase"), CS2Bytes(phase()));
// reshape
if (PyObject_HasAttr(this->self, String("reshape"))) // backward compatibility: reshape(inputs, outputs)
PyObject_CallMethod(this->self, "reshape", "OO", this->inputs, this->outputs); if (PyObject_HasAttr(this->self, Bytes("reshape"))) {
PyObject_CallMethod(this->self, "reshape",
// run "OO", this->inputs, this->outputs);
if (PyObject_HasAttr(this->self, String("backward"))) { }
PyObject_CallMethod(this->self, "forward", "OO", this->inputs, this->outputs);
} else if (PyObject_HasAttr(this->self, String("grad"))) { // overloaded run inferfaces
PyObject_CallMethod(this->self, "grad", "OO", this->inputs, this->outputs); if (PyObject_HasAttr(this->self, Bytes("backward"))) {
PyObject_CallMethod(this->self, "forward",
"OO", this->inputs, this->outputs);
} else if (PyObject_HasAttr(this->self, Bytes("grad"))) {
PyObject_CallMethod(this->self, "grad",
"OO", this->inputs, this->outputs);
} }
} }
......
...@@ -235,7 +235,7 @@ void ConvOpBase<Context>::Reshape() { ...@@ -235,7 +235,7 @@ void ConvOpBase<Context>::Reshape() {
weight_shape.push_back(conv_in_channels / group); weight_shape.push_back(conv_in_channels / group);
weight_shape.push_back(conv_out_channels); weight_shape.push_back(conv_out_channels);
} }
bias_shape.assign(1, num_output); bias_shape = { num_output };
// determine the bottom and top shape // determine the bottom and top shape
bottom_shape = Input(0).dims(); bottom_shape = Input(0).dims();
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!