Commit 3d2abe69 by Ting PAN

Mix Static/Dynamic Arguments

1 parent 04fdadb0
...@@ -68,13 +68,13 @@ class OperatorBase { ...@@ -68,13 +68,13 @@ class OperatorBase {
template <class Context> template <class Context>
class Operator : public OperatorBase { class Operator : public OperatorBase {
public: public:
Operator(const OperatorDef& op_def, Workspace* ws) Operator(const OperatorDef& op_def, Workspace* ws)
: OperatorBase(op_def, ws), ctx_(op_def.device_option()) { : OperatorBase(op_def, ws), ctx_(op_def.device_option()) {
allow_run_ = true; allow_run_ = true;
allow_run_ &= _MPICheck(); allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && output(0)->name() == "ignore")); allow_run_ &= (!(OutputSize() == 1 && output(0)->name() == "ignore"));
allow_share_grads_ = (!op_def.debug_mode()); allow_share_grads_ = (!op_def.debug_mode());
allow_share_grads_ &= op_def.share_grads(); allow_share_grads_ &= op_def.share_grads();
allow_share_grads_ &= (type().find("Gradient") != string::npos); allow_share_grads_ &= (type().find("Gradient") != string::npos);
} }
...@@ -167,6 +167,53 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -167,6 +167,53 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
} \ } \
} }
#define DECLARE_ARGUMENT_WITH_DESC(type, argument) \
type argument##_value; \
string argument##_desc; \
type argument()
#define DECLARE_ARGUMENTS_WITH_DESC(type, argument) \
vector<type> argument##_value; \
vector<string> argument##_desc; \
type argument(int idx)
#define GET_ARGUMENT_WITH_DESC(type, argument, default_value) \
argument##_value = OperatorBase::GetSingleArg<type>(#argument, default_value); \
argument##_desc = OperatorBase::GetSingleArg<string>(string(#argument) + "_desc", "")
#define GET_ARGUMENTS_WITH_DESC(type, argument) \
argument##_value = OperatorBase::GetRepeatedArg<type>(#argument); \
argument##_desc = OperatorBase::GetRepeatedArg<string>(string(#argument) + "_desc")
#define DEFINE_ARGUMENT_WITH_DESC(type, classname, argument) \
template <class Context> \
type classname<Context>::argument() { \
if (argument##_desc.empty()) return argument##_value; \
Tensor* argument##_tensor = ws()->GetTensor(argument##_desc); \
CHECK(argument##_tensor->IsType<type>()) \
<< "\nThe type of " << #argument << " should be " << #type << "."; \
CHECK_EQ(argument##_tensor->count(), 1) \
<< "\nThe argument of " << #argument << " should be a scalar"; \
return argument##_tensor->template data<type, CPUContext>()[0]; \
}
#define DEFINE_ARGUMENTS_WITH_DESC(type, classname, argument) \
template <class Context> \
type classname<Context>::argument(int idx) { \
if (argument##_desc.empty()) { \
CHECK_LT(idx, argument##_value.size()); \
return argument##_value[idx]; \
} \
CHECK_LT(idx, argument##_desc.size()); \
Tensor* argument##_tensor = ws()->GetTensor(argument##_desc[idx]); \
CHECK(argument##_tensor->IsType<type>()) \
<< "\nThe type of " << #argument << " should be " << #type; \
CHECK_EQ(argument##_tensor->count(), 1) \
<< "\nThe argument of " << #argument << " at pos(" \
<< idx << ") should be a scalar."; \
return argument##_tensor->template data<type, CPUContext>()[0]; \
}
#define DISABLE_SHARE_GRADIENT \ #define DISABLE_SHARE_GRADIENT \
this->allow_share_grads_ = false this->allow_share_grads_ = false
...@@ -202,4 +249,4 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -202,4 +249,4 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
INSTANTIATE_CUDNN_OPERATOR(name); INSTANTIATE_CUDNN_OPERATOR(name);
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_OPERATOR_H_ #endif // DRAGON_CORE_OPERATOR_H_
\ No newline at end of file
...@@ -17,32 +17,26 @@ class DropoutOp final : public Operator<Context> { ...@@ -17,32 +17,26 @@ class DropoutOp final : public Operator<Context> {
public: public:
DropoutOp(const OperatorDef& op_def, Workspace* ws) DropoutOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
prob(OperatorBase::GetSingleArg<float>("prob", 0.5)) { use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
bool use_scale = OperatorBase::GetSingleArg<bool>("scale", true); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
threshold = static_cast<unsigned int>(UINT_MAX * prob);
if (use_scale) scale = 1.0 / (1.0 - prob);
else scale = 1.0;
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float prob, scale; DECLARE_ARGUMENT_WITH_DESC(float, prob);
unsigned int threshold; bool use_scale;
Tensor* mask; Tensor* mask;
}; };
template <class Context> template <class Context>
class DropoutGradientOp final : public Operator<Context> { class DropoutGradientOp final : public Operator<Context> {
public: public:
DropoutGradientOp(const OperatorDef& op_def, Workspace* ws) DropoutGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
prob(OperatorBase::GetSingleArg<float>("prob", 0.5)) { use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
bool use_scale = OperatorBase::GetSingleArg<bool>("scale", true); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
threshold = static_cast<unsigned int>(UINT_MAX * prob);
if (use_scale) scale = 1.0 / (1.0 - prob);
else scale = 1.0;
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
...@@ -50,11 +44,14 @@ class DropoutGradientOp final : public Operator<Context> { ...@@ -50,11 +44,14 @@ class DropoutGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float prob, scale; DECLARE_ARGUMENT_WITH_DESC(float, prob);
unsigned int threshold; bool use_scale;
Tensor* mask; Tensor* mask;
}; };
DEFINE_ARGUMENT_WITH_DESC(float, DropoutOp, prob);
DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
\ No newline at end of file
...@@ -17,14 +17,15 @@ class InitializeOp: public Operator<Context> { ...@@ -17,14 +17,15 @@ class InitializeOp: public Operator<Context> {
public: public:
InitializeOp(const OperatorDef& op_def, Workspace* ws) InitializeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
dims_desc(OperatorBase::GetRepeatedArg<string>("dims")), shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) {
shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) {} GET_ARGUMENTS_WITH_DESC(int, dims);
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<string> dims_desc; DECLARE_ARGUMENTS_WITH_DESC(int, dims);
string shape_desc; string shape_desc;
TensorFiller filler; TensorFiller filler;
}; };
...@@ -116,6 +117,8 @@ public: ...@@ -116,6 +117,8 @@ public:
} }
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims);
} // namespace } // namespace
#endif // DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #endif // DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
\ No newline at end of file
...@@ -16,21 +16,26 @@ class ArangeOp final : public Operator<Context> { ...@@ -16,21 +16,26 @@ class ArangeOp final : public Operator<Context> {
public: public:
ArangeOp(const OperatorDef& op_def, Workspace* ws) ArangeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
start_desc(OperatorBase::GetSingleArg<string>("start", "")), dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) {
stop_desc(OperatorBase::GetSingleArg<string>("stop", "")), GET_ARGUMENT_WITH_DESC(int, start, 0);
step_desc(OperatorBase::GetSingleArg<string>("step", "")), GET_ARGUMENT_WITH_DESC(int, stop, 0);
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) {} GET_ARGUMENT_WITH_DESC(int, step, 1);
}
void Reshape();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
string start_desc, stop_desc, step_desc, dtype; DECLARE_ARGUMENT_WITH_DESC(int, start);
TIndex start, stop, step, count; DECLARE_ARGUMENT_WITH_DESC(int, stop);
DECLARE_ARGUMENT_WITH_DESC(int, step);
string dtype;
}; };
DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, start);
DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, stop);
DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, step);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_ARANGE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_ARANGE_OP_H_
\ No newline at end of file
...@@ -16,15 +16,16 @@ class RepeatOp : public Operator<Context> { ...@@ -16,15 +16,16 @@ class RepeatOp : public Operator<Context> {
public: public:
RepeatOp(const OperatorDef& op_def, Workspace* ws) RepeatOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
repeats_desc(OperatorBase::GetSingleArg<string>("repeats", "")) {} GET_ARGUMENT_WITH_DESC(int, repeats, 1);
}
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunWithType();
protected: protected:
TIndex axis, outer_dim, dim, inner_dim, reps; DECLARE_ARGUMENT_WITH_DESC(int, repeats);
string repeats_desc; TIndex axis, outer_dim, dim, inner_dim;
}; };
template <class Context> template <class Context>
...@@ -32,17 +33,21 @@ class RepeatGradientOp : public Operator<Context> { ...@@ -32,17 +33,21 @@ class RepeatGradientOp : public Operator<Context> {
public: public:
RepeatGradientOp(const OperatorDef& op_def, Workspace* ws) RepeatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
repeats_desc(OperatorBase::GetSingleArg<string>("repeats", "")) {} GET_ARGUMENT_WITH_DESC(int, repeats, 1);
}
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunWithType();
protected: protected:
DECLARE_ARGUMENT_WITH_DESC(int, repeats);
TIndex axis, outer_dim, dim, inner_dim, reps; TIndex axis, outer_dim, dim, inner_dim, reps;
string repeats_desc;
}; };
DEFINE_ARGUMENT_WITH_DESC(int, RepeatOp, repeats);
DEFINE_ARGUMENT_WITH_DESC(int, RepeatGradientOp, repeats);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
...@@ -15,14 +15,15 @@ template <class Context> ...@@ -15,14 +15,15 @@ template <class Context>
class TileOp : public Operator<Context> { class TileOp : public Operator<Context> {
public: public:
TileOp(const OperatorDef& op_def, Workspace* ws) TileOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws) {
multiples_desc(OperatorBase::GetRepeatedArg<string>("multiples")) {} GET_ARGUMENTS_WITH_DESC(int, multiples);
}
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
protected: protected:
vector<string> multiples_desc; DECLARE_ARGUMENTS_WITH_DESC(int, multiples);
TIndex axis, multiple, outer_dim, ex_inner_dim; TIndex axis, multiple, outer_dim, ex_inner_dim;
Tensor* dest, *source; Tensor* dest, *source;
}; };
...@@ -31,8 +32,8 @@ template <class Context> ...@@ -31,8 +32,8 @@ template <class Context>
class TileGradientOp : public Operator<Context> { class TileGradientOp : public Operator<Context> {
public: public:
TileGradientOp(const OperatorDef& op_def, Workspace* ws) TileGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws) {
multiples_desc(OperatorBase::GetRepeatedArg<string>("multiples")) { GET_ARGUMENTS_WITH_DESC(int, multiples);
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
...@@ -40,11 +41,14 @@ class TileGradientOp : public Operator<Context> { ...@@ -40,11 +41,14 @@ class TileGradientOp : public Operator<Context> {
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
protected: protected:
vector<string> multiples_desc; DECLARE_ARGUMENTS_WITH_DESC(int, multiples);
TIndex axis, multiple, outer_dim, ex_inner_dim; TIndex axis, multiple, outer_dim, ex_inner_dim;
Tensor* dest, *source; Tensor* dest, *source;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, TileOp, multiples);
DEFINE_ARGUMENTS_WITH_DESC(int, TileGradientOp, multiples);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_TILE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
\ No newline at end of file
...@@ -16,10 +16,11 @@ class BilinearResizeOp : public Operator<Context> { ...@@ -16,10 +16,11 @@ class BilinearResizeOp : public Operator<Context> {
public: public:
BilinearResizeOp(const OperatorDef& op_def, Workspace* ws) BilinearResizeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
dsize_desc(OperatorBase::GetRepeatedArg<string>("dsize")),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)), fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)), fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
GET_ARGUMENTS_WITH_DESC(int, dsize);
if (data_format == "NCHW") spatial_axis = 2; if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
...@@ -28,11 +29,10 @@ class BilinearResizeOp : public Operator<Context> { ...@@ -28,11 +29,10 @@ class BilinearResizeOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<string> dsize_desc; DECLARE_ARGUMENTS_WITH_DESC(int, dsize);
float fy, fx; float fy, fx;
string data_format; string data_format, shape_like_desc;
TIndex n, c, h, w, out_h, out_w, spatial_axis; TIndex n, c, h, w, out_h, out_w, spatial_axis;
vector<TIndex> dims;
}; };
template <class Context> template <class Context>
...@@ -50,6 +50,8 @@ class BilinearResizeGradientOp : public Operator<Context> { ...@@ -50,6 +50,8 @@ class BilinearResizeGradientOp : public Operator<Context> {
TIndex n, c, h, w, out_h, out_w; TIndex n, c, h, w, out_h, out_w;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, BilinearResizeOp, dsize);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_VISION_BILINEAR_RESIZE_OP_H_ #endif // DRAGON_OPERATORS_VISION_BILINEAR_RESIZE_OP_H_
\ No newline at end of file
...@@ -21,8 +21,9 @@ class ConvOpBase : public Operator<Context> { ...@@ -21,8 +21,9 @@ class ConvOpBase : public Operator<Context> {
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")), data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")), padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
num_output(OperatorBase::GetSingleArg<int>("num_output", 1)), num_output(OperatorBase::GetSingleArg<int>("num_output", 1)),
group(OperatorBase::GetSingleArg<int>("group", 1)), group(OperatorBase::GetSingleArg<int>("group", 1)) {
output_dims_desc(OperatorBase::GetRepeatedArg<string>("output_shape")) { output_dims_value = OperatorBase::GetRepeatedArg<int>("output_shape");
output_dims_desc = OperatorBase::GetRepeatedArg<string>("output_shape_desc");
if (data_format == "NCHW") spatial_axis = 2; if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
...@@ -41,7 +42,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -41,7 +42,7 @@ class ConvOpBase : public Operator<Context> {
TIndex conv_in_channels, conv_out_channels; TIndex conv_in_channels, conv_out_channels;
TIndex conv_out_spatial_dim, kernel_dim; TIndex conv_out_spatial_dim, kernel_dim;
TIndex col_offset, output_offset, weight_offset, x_offset, y_offset; TIndex col_offset, output_offset, weight_offset, x_offset, y_offset;
vector<string> output_dims_desc; DECLARE_ARGUMENTS_WITH_DESC(int, output_dims);
bool is_1x1; bool is_1x1;
void Setup(); void Setup();
...@@ -87,6 +88,8 @@ class ConvOpBase : public Operator<Context> { ...@@ -87,6 +88,8 @@ class ConvOpBase : public Operator<Context> {
} }
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, ConvOpBase, output_dims);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_ #endif // DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_
\ No newline at end of file
...@@ -16,10 +16,11 @@ class NNResizeOp : public Operator<Context> { ...@@ -16,10 +16,11 @@ class NNResizeOp : public Operator<Context> {
public: public:
NNResizeOp(const OperatorDef& op_def, Workspace* ws) NNResizeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
dsize_desc(OperatorBase::GetRepeatedArg<string>("dsize")),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)), fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)), fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
GET_ARGUMENTS_WITH_DESC(int, dsize);
if (data_format == "NCHW") spatial_axis = 2; if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
...@@ -29,9 +30,9 @@ class NNResizeOp : public Operator<Context> { ...@@ -29,9 +30,9 @@ class NNResizeOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<string> dsize_desc; DECLARE_ARGUMENTS_WITH_DESC(int, dsize);
float fy, fx; float fy, fx;
string data_format; string data_format, shape_like_desc;
TIndex n, c, h, w, out_h, out_w, spatial_axis; TIndex n, c, h, w, out_h, out_w, spatial_axis;
}; };
...@@ -50,6 +51,8 @@ class NNResizeGradientOp : public Operator<Context> { ...@@ -50,6 +51,8 @@ class NNResizeGradientOp : public Operator<Context> {
TIndex n, c, h, w, out_h, out_w; TIndex n, c, h, w, out_h, out_w;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, NNResizeOp, dsize);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_VISION_NN_RESIZE_OP_H_ #endif // DRAGON_OPERATORS_VISION_NN_RESIZE_OP_H_
\ No newline at end of file
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
# config
from dragon.config import *
import dragon.config as config
# core # core
from dragon.core.tensor import Tensor from dragon.core.tensor import Tensor
import dragon.core.workspace as workspace import dragon.core.workspace as workspace
......
...@@ -36,4 +36,38 @@ def CheckInputs(inputs, *args): ...@@ -36,4 +36,38 @@ def CheckInputs(inputs, *args):
def ParseArguments(locals): def ParseArguments(locals):
__all__ = locals __all__ = locals
kwargs = __all__['kwargs']; del __all__['kwargs'] kwargs = __all__['kwargs']; del __all__['kwargs']
return dict(__all__, **kwargs) return dict(__all__, **kwargs)
\ No newline at end of file
def AddArgumentWithDesc(arguments, property, name, as_target=True):
if isinstance(property, Tensor):
if as_target:
if not 'extra_inputs' in arguments:
arguments['extra_inputs'] = []
arguments['extra_inputs'].extend([property])
arguments[name] = None
arguments[name + '_desc'] = property.name
return arguments
def AddArgumentsWithDesc(arguments, properties, name, type, as_target=True):
if not isinstance(properties, (list, tuple)): properties = [properties]
# check whether to use desc
tensor_in_properties = False
for property in properties:
if isinstance(property, Tensor):
tensor_in_properties = True
if tensor_in_properties:
properties_t = []
for property in properties:
if isinstance(property, Tensor):
if as_target:
if not 'extra_inputs' in arguments:
arguments['extra_inputs'] = []
arguments['extra_inputs'].extend([property])
properties_t.append(property.name)
else:
properties_t.append(Tensor.Convert(property, dtype=type).name)
arguments[name] = None
arguments[name + '_desc'] = properties_t
return arguments
\ No newline at end of file
...@@ -201,7 +201,7 @@ def Dropout(inputs, prob=0.5, scale=True, **kwargs): ...@@ -201,7 +201,7 @@ def Dropout(inputs, prob=0.5, scale=True, **kwargs):
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
prob : float prob : float or Tensor
The prob of dropping. Default is ``0.5``. The prob of dropping. Default is ``0.5``.
scale : boolean scale : boolean
Whether to scale the output during training. Whether to scale the output during training.
...@@ -214,6 +214,7 @@ def Dropout(inputs, prob=0.5, scale=True, **kwargs): ...@@ -214,6 +214,7 @@ def Dropout(inputs, prob=0.5, scale=True, **kwargs):
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments = AddArgumentWithDesc(arguments, prob, 'prob', as_target=False)
output = Tensor.CreateOperator(nout=1, op_type='Dropout', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Dropout', **arguments)
......
...@@ -16,9 +16,9 @@ def _wrap_input_shape(arguments, shape): ...@@ -16,9 +16,9 @@ def _wrap_input_shape(arguments, shape):
arguments['extra_inputs'] = shape arguments['extra_inputs'] = shape
arguments['shape'] = shape.name arguments['shape'] = shape.name
elif isinstance(shape, (list, tuple)): elif isinstance(shape, (list, tuple)):
arguments['extra_inputs'] = [Tensor.Convert(dim, dtype='int32') for dim in shape] arguments['dims'] = shape
arguments['dims'] = [dim.name for dim in arguments['extra_inputs']]
arguments['shape'] = None arguments['shape'] = None
AddArgumentsWithDesc(arguments, shape, 'dims', 'int32', as_target=True)
else: else:
raise TypeError('Unsupported type of shape: {}'.format(type(shape))) raise TypeError('Unsupported type of shape: {}'.format(type(shape)))
return arguments return arguments
......
...@@ -455,8 +455,7 @@ def Repeat(inputs, axis=-1, repeats=1, **kwargs): ...@@ -455,8 +455,7 @@ def Repeat(inputs, axis=-1, repeats=1, **kwargs):
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['extra_inputs'] = [Tensor.Convert(repeats, dtype='int32')] arguments = AddArgumentWithDesc(arguments, repeats, 'repeats', as_target=True)
arguments['repeats'] = arguments['extra_inputs'][0].name
output = Tensor.CreateOperator(nout=1, op_type='Repeat', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Repeat', **arguments)
...@@ -492,8 +491,7 @@ def Tile(inputs, multiples, **kwargs): ...@@ -492,8 +491,7 @@ def Tile(inputs, multiples, **kwargs):
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['extra_inputs'] = [Tensor.Convert(multiple, dtype='int32') for multiple in multiples] arguments = AddArgumentsWithDesc(arguments, multiples, 'multiples', 'int32', as_target=True)
arguments['multiples'] = [multiple.name for multiple in arguments['extra_inputs']]
output = Tensor.CreateOperator(nout=1, op_type='Tile', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Tile', **arguments)
...@@ -779,14 +777,11 @@ def Arange(start, stop=None, step=1, dtype='FLOAT32', **kwargs): ...@@ -779,14 +777,11 @@ def Arange(start, stop=None, step=1, dtype='FLOAT32', **kwargs):
""" """
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['extra_inputs'] = [Tensor.Convert(start, dtype='int32'),
Tensor.Convert(step, dtype='int32')]
arguments['start'] = arguments['extra_inputs'][0].name
arguments['step'] = arguments['extra_inputs'][1].name
if stop is not None:
arguments['extra_inputs'].append(Tensor.Convert(stop, dtype='int32'))
arguments['stop'] = arguments['extra_inputs'][-1].name
arguments['dtype'] = arguments['dtype'].upper() arguments['dtype'] = arguments['dtype'].upper()
arguments = AddArgumentWithDesc(arguments, start, 'start', as_target=True)
arguments = AddArgumentWithDesc(arguments, step, 'step', as_target=True)
if stop is not None:
arguments = AddArgumentWithDesc(arguments, stop, 'stop', as_target=True)
output = Tensor.CreateOperator([], nout=1, op_type='Arange', **arguments) output = Tensor.CreateOperator([], nout=1, op_type='Arange', **arguments)
......
...@@ -139,7 +139,7 @@ def Conv2dTranspose(inputs, num_output, kernel_size, ...@@ -139,7 +139,7 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
The dilation multiple(s) of deconvolution. Default is ``1``. The dilation multiple(s) of deconvolution. Default is ``1``.
group : int group : int
The group size of deconvolution. Default is ``1``. The group size of deconvolution. Default is ``1``.
output_shape : list of int or None output_shape : list or None
The deterministic output shape for **SAME** padding. The deterministic output shape for **SAME** padding.
padding : str padding : str
The padding algorithm. ``VALID`` or ``SAME``. The padding algorithm. ``VALID`` or ``SAME``.
...@@ -170,12 +170,8 @@ def Conv2dTranspose(inputs, num_output, kernel_size, ...@@ -170,12 +170,8 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format)) raise ValueError('Unsupported data format: {}'.format(data_format))
arguments['output_shape'] = None
if output_shape is not None: if output_shape is not None:
if not isinstance(output_shape, list): AddArgumentsWithDesc(arguments, output_shape, 'output_shape', 'int32', as_target=True)
raise TypeError('The output shape should be a list.')
arguments['extra_inputs'] = [Tensor.Convert(dim, dtype='int32') for dim in output_shape]
arguments['output_shape'] = [dim.name for dim in arguments['extra_inputs']]
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
...@@ -400,7 +396,8 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, ...@@ -400,7 +396,8 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0,
return output return output
def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): def NNResize(inputs, dsize, shape_like=None,
fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
"""Resize the image with Nearest-Neighbor method. """Resize the image with Nearest-Neighbor method.
Set ``dsize`` to None if you want to use ``fy`` and ``fx``. Set ``dsize`` to None if you want to use ``fy`` and ``fx``.
...@@ -411,6 +408,8 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -411,6 +408,8 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
The input tensor. The input tensor.
dsize : tuple, list, Tensor or None dsize : tuple, list, Tensor or None
The output size, formats as (h, w). The output size, formats as (h, w).
shape_like : Tensor or None
The tensor for guiding the shape of resizing.
fy : float fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded). The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float fx : float
...@@ -433,11 +432,15 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -433,11 +432,15 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
if dsize is not None: if dsize is not None:
if len(dsize) != 2: if len(dsize) != 2:
raise ValueError('The dsize should be a list with 2 elements.') raise ValueError('The dsize should be a list with 2 elements.')
arguments['extra_inputs'] = [Tensor.Convert(size, dtype='int32') for size in dsize] AddArgumentsWithDesc(arguments, dsize, 'dsize', 'int32', as_target=True)
arguments['dsize'] = [size.name for size in arguments['extra_inputs']]
if dsize is None and (fy == -1.0 or fx == -1.0): if shape_like is not None:
raise RuntimeError('The dsize or fy/fx should be specified either.') if not isinstance(shape_like, Tensor):
raise TypeError('The shape_like should be a Tensor.')
arguments['shape_like'] = shape_like.name
if dsize is None and shape_like is None and (fy == -1.0 or fx == -1.0):
raise RuntimeError('The dsize, shape_like or fy/fx should be specified either.')
output = Tensor.CreateOperator(nout=1, op_type='NNResize', **arguments) output = Tensor.CreateOperator(nout=1, op_type='NNResize', **arguments)
...@@ -449,6 +452,8 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -449,6 +452,8 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
for size in dsize: for size in dsize:
if isinstance(size, Tensor): if isinstance(size, Tensor):
possible_to_infer_shape = False possible_to_infer_shape = False
if shape_like is not None:
possible_to_infer_shape = False
if possible_to_infer_shape: if possible_to_infer_shape:
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
...@@ -464,7 +469,8 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -464,7 +469,8 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
return output return output
def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): def BilinearResize(inputs, dsize, shape_like=None,
fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
"""Resize the image with Bi-linear method. """Resize the image with Bi-linear method.
Set ``dsize`` to None if you want to use ``fy`` and ``fx``. Set ``dsize`` to None if you want to use ``fy`` and ``fx``.
...@@ -475,6 +481,8 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -475,6 +481,8 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
The input tensor. The input tensor.
dsize : tuple, list, Tensor or None dsize : tuple, list, Tensor or None
The output size, formats as (h, w). The output size, formats as (h, w).
shape_like : Tensor or None
The tensor for guiding the shape of resizing.
fy : float fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded). The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float fx : float
...@@ -497,11 +505,15 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -497,11 +505,15 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
if dsize is not None: if dsize is not None:
if len(dsize) != 2: if len(dsize) != 2:
raise ValueError('The dsize should be a list with 2 elements.') raise ValueError('The dsize should be a list with 2 elements.')
arguments['extra_inputs'] = [Tensor.Convert(size, dtype='int32') for size in dsize] AddArgumentsWithDesc(arguments, dsize, 'dsize', 'int32', as_target=True)
arguments['dsize'] = [size.name for size in arguments['extra_inputs']]
if shape_like is not None:
if not isinstance(shape_like, Tensor):
raise TypeError('The shape_like should be a Tensor.')
arguments['shape_like'] = shape_like.name
if dsize is None and (fy == -1.0 or fx == -1.0): if dsize is None and shape_like is None and (fy == -1.0 or fx == -1.0):
raise RuntimeError('The dsize or fy/fx should be specified either.') raise RuntimeError('The dsize, shape_like or fy/fx should be specified either.')
output = Tensor.CreateOperator(nout=1, op_type='BilinearResize', **arguments) output = Tensor.CreateOperator(nout=1, op_type='BilinearResize', **arguments)
...@@ -513,6 +525,8 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -513,6 +525,8 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
for size in dsize: for size in dsize:
if isinstance(size, Tensor): if isinstance(size, Tensor):
possible_to_infer_shape = False possible_to_infer_shape = False
if shape_like is not None:
possible_to_infer_shape = False
if possible_to_infer_shape: if possible_to_infer_shape:
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
......
...@@ -20,7 +20,7 @@ class BaseUpdater(object): ...@@ -20,7 +20,7 @@ class BaseUpdater(object):
BaseUpdater is designed to preprocess the gradients. BaseUpdater is designed to preprocess the gradients.
""" """
def __init__(self, scale_gradient = 1.0, clip_gradient = -1.0, def __init__(self, scale_gradient = 1.0, clip_gradient = -1.0,
l2_decay = -1.0, slot=''): l2_decay = -1.0, slot='', verbose=True):
"""Construct a Updater to optimize the objectives. """Construct a Updater to optimize the objectives.
Parameters Parameters
...@@ -42,6 +42,7 @@ class BaseUpdater(object): ...@@ -42,6 +42,7 @@ class BaseUpdater(object):
self._tuples = [] self._tuples = []
self._type = None self._type = None
self._prefix = '' self._prefix = ''
self._verbose = verbose
def append(self, pair, lr_mult=1.0, decay_mult=1.0): def append(self, pair, lr_mult=1.0, decay_mult=1.0):
"""Append an ``UpdatePair`` into the updater. """Append an ``UpdatePair`` into the updater.
...@@ -117,7 +118,7 @@ class SGDUpdater(BaseUpdater): ...@@ -117,7 +118,7 @@ class SGDUpdater(BaseUpdater):
'momentum': momentum}, 'momentum': momentum},
**self._hyper_params) **self._hyper_params)
self._type = 'SGDUpdate' self._type = 'SGDUpdate'
self.echo() if self._verbose: self.echo()
class NesterovUpdater(BaseUpdater): class NesterovUpdater(BaseUpdater):
...@@ -140,7 +141,7 @@ class NesterovUpdater(BaseUpdater): ...@@ -140,7 +141,7 @@ class NesterovUpdater(BaseUpdater):
'momentum': momentum}, 'momentum': momentum},
**self._hyper_params) **self._hyper_params)
self._type = 'NesterovUpdate' self._type = 'NesterovUpdate'
self.echo() if self._verbose: self.echo()
class RMSPropUpdater(BaseUpdater): class RMSPropUpdater(BaseUpdater):
...@@ -166,7 +167,7 @@ class RMSPropUpdater(BaseUpdater): ...@@ -166,7 +167,7 @@ class RMSPropUpdater(BaseUpdater):
'eps': eps}, 'eps': eps},
**self._hyper_params) **self._hyper_params)
self._type = 'RMSPropUpdate' self._type = 'RMSPropUpdate'
self.echo() if self._verbose: self.echo()
class AdamUpdater(BaseUpdater): class AdamUpdater(BaseUpdater):
...@@ -195,4 +196,4 @@ class AdamUpdater(BaseUpdater): ...@@ -195,4 +196,4 @@ class AdamUpdater(BaseUpdater):
'eps': eps}, 'eps': eps},
**self._hyper_params) **self._hyper_params)
self._type = 'AdamUpdate' self._type = 'AdamUpdate'
self.echo() if self._verbose: self.echo()
\ No newline at end of file \ No newline at end of file
...@@ -264,9 +264,10 @@ class NNResizeLayer(Layer): ...@@ -264,9 +264,10 @@ class NNResizeLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(NNResizeLayer, self).Setup(bottom) super(NNResizeLayer, self).Setup(bottom)
input = bottom[0] if isinstance(bottom, list) else bottom input = bottom[0] if isinstance(bottom, list) else bottom
if isinstance(bottom, list) and len(bottom) > 1: if self._param['dsize'] is None:
dshape = ops.Shape(bottom[1]) if len(bottom) != 2:
self._param['dsize'] = (dshape[2], dshape[3]) raise ValueError('The second bottom should be provided to determine the shape.')
self._param['shape_like'] = bottom[1]
return ops.NNResize(input, **self._param) return ops.NNResize(input, **self._param)
...@@ -296,7 +297,8 @@ class BilinearResizeLayer(Layer): ...@@ -296,7 +297,8 @@ class BilinearResizeLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(BilinearResizeLayer, self).Setup(bottom) super(BilinearResizeLayer, self).Setup(bottom)
input = bottom[0] if isinstance(bottom, list) else bottom input = bottom[0] if isinstance(bottom, list) else bottom
if isinstance(bottom, list) and len(bottom) > 1: if self._param['dsize'] is None:
dshape = ops.Shape(bottom[1]) if len(bottom) != 2:
self._param['dsize'] = (dshape[2], dshape[3]) raise ValueError('The second bottom should be provided to determine the shape.')
self._param['shape_like'] = bottom[1]
return ops.BilinearResize(input, **self._param) return ops.BilinearResize(input, **self._param)
\ No newline at end of file
...@@ -36,7 +36,7 @@ find_packages('dragon') ...@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules() find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.1.8', version='0.2.1.9',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon', url='https://github.com/neopenx/Dragon',
author='Ting Pan', author='Ting Pan',
......
...@@ -9,18 +9,18 @@ void DropoutOp<Context>::RunWithType() { ...@@ -9,18 +9,18 @@ void DropoutOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
uint32_t* Mdata = mask->template mutable_data<uint32_t, Context>(); uint32_t* Mdata = mask->template mutable_data<uint32_t, Context>();
float scale = use_scale ? 1.0 / (1.0 - prob()) : 1.0;
if (this->phase() == "TRAIN") { if (this->phase() == "TRAIN") {
kernel::Dropout<T, Context>(output(0)->count(), kernel::Dropout<T, Context>(output(0)->count(),
prob, prob(),
scale, scale,
Xdata, Xdata,
Mdata, Mdata,
Ydata, Ydata,
&ctx()); &ctx());
} else if (this->phase() == "TEST") { } else if (this->phase() == "TEST") {
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
if (scale == 1.0) math::Scal<T, Context>(output(0)->count(), 1.0 - prob, Ydata); if (scale == 1.0) math::Scal<T, Context>(output(0)->count(), 1.0 - prob(), Ydata);
} }
} }
...@@ -46,10 +46,10 @@ void DropoutGradientOp<Context>::RunWithType() { ...@@ -46,10 +46,10 @@ void DropoutGradientOp<Context>::RunWithType() {
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template data<uint32_t, Context>(); auto* Mdata = mask->template data<uint32_t, Context>();
float scale = use_scale ? 1.0 / (1.0 - prob()) : 1.0;
if (this->phase() == "TRAIN") { if (this->phase() == "TRAIN") {
kernel::DropoutGrad<T, Context>(output(0)->count(), kernel::DropoutGrad<T, Context>(output(0)->count(),
prob, prob(),
scale, scale,
dYdata, dYdata,
Mdata, Mdata,
...@@ -84,4 +84,4 @@ class GetDropoutGradient final : public GradientMakerBase { ...@@ -84,4 +84,4 @@ class GetDropoutGradient final : public GradientMakerBase {
}; };
REGISTER_GRADIENT(Dropout, GetDropoutGradient); REGISTER_GRADIENT(Dropout, GetDropoutGradient);
} // namepsace dragon } // namepsace dragon
\ No newline at end of file
...@@ -31,7 +31,7 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() { ...@@ -31,7 +31,7 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() {
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), Ldata); T loss = math::ASum<T, Context>(losses.count(), Ldata);
output(0)->Reshape(vector<TIndex>(1, 1)); output(0)->Reshape(vector<TIndex>(1, 1));
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, CPUContext>();
Ydata[0] = loss / normalizer; Ydata[0] = loss / normalizer;
} }
......
...@@ -12,23 +12,19 @@ void InitializeOp<Context>::RunWithType() { ...@@ -12,23 +12,19 @@ void InitializeOp<Context>::RunWithType() {
template <class Context> template <class Context>
void InitializeOp<Context>::RunOnDevice() { void InitializeOp<Context>::RunOnDevice() {
vector<TIndex> dims; vector<TIndex> output_shape;
if (shape_desc.empty()) { if (shape_desc.empty()) {
// determine the shape from dimensions // determine the shape from dimensions
for (auto& dim_desc : dims_desc) { int ndims = (int)std::max(dims_value.size(), dims_desc.size());
Tensor* dim = ws()->GetTensor(dim_desc); for (int i = 0; i < ndims; i++) output_shape.push_back(dims(i));
CHECK_EQ(dim->count(), 1) << "\nThe dimension should be a scalar.";
CHECK(dim->IsType<int>()) << "\nThe type of dimension should be int32.";
dims.push_back(dim->template data<int, CPUContext>()[0]);
}
} else { } else {
// determine the shape from given shape // determine the shape from given shape
Tensor* shape = ws()->GetTensor(shape_desc); Tensor* shape = ws()->GetTensor(shape_desc);
CHECK(shape->IsType<int>()) << "\nThe type of shape should be int32."; CHECK(shape->IsType<int>()) << "\nThe type of shape should be int32.";
auto* shape_data = shape->template data<int, CPUContext>(); auto* shape_data = shape->template data<int, CPUContext>();
for (int i = 0; i < shape->count(); i++) dims.push_back(shape_data[i]); for (int i = 0; i < shape->count(); i++) output_shape.push_back(shape_data[i]);
} }
output(0)->Reshape(dims); output(0)->Reshape(output_shape);
RunWithType<float>(); RunWithType<float>();
} }
......
...@@ -4,39 +4,18 @@ ...@@ -4,39 +4,18 @@
namespace dragon { namespace dragon {
template <class Context>
void ArangeOp<Context>::Reshape() {
// parse start & step & stop
Tensor* t = ws()->GetTensor(start_desc);
CHECK_EQ(t->count(), 1) << "\nThe start should be a scalar";
CHECK(t->IsType<int>()) << "\nThe type of start should be int32.";
start = t->template data<int, CPUContext>()[0];
t = ws()->GetTensor(step_desc);
CHECK_EQ(t->count(), 1) << "\nThe step should be a scalar";
CHECK(t->IsType<int>()) << "\nThe type of step should be int32.";
step = t->template data<int, CPUContext>()[0];
if (!stop_desc.empty()) {
t = ws()->GetTensor(stop_desc);
CHECK_EQ(t->count(), 1) << "\nThe stop should be a scalar";
CHECK(t->IsType<int>()) << "\nThe type of stop should be int32.";
stop = t->template data<int, CPUContext>()[0];
} else { stop = start; start = 0; }
count = (stop - start - 1) / step + 1;
output(0)->Reshape(vector<TIndex>(1, count));
}
template <class Context> template <typename T> template <class Context> template <typename T>
void ArangeOp<Context>::RunWithType() { void ArangeOp<Context>::RunWithType() {
TIndex start_ = start(), step_ = step(), stop_ = stop(), count;
if (stop_ == 0) { stop_ = start_; start_ = 0; }
count = (stop_ - start_ - 1) / step_ + 1;
output(0)->Reshape(vector<TIndex>(1, count));
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::Arange<T, Context>(count, start, step, Ydata); kernel::Arange<T, Context>(count, start_, step_, Ydata);
} }
template <class Context> template <class Context>
void ArangeOp<Context>::RunOnDevice() { void ArangeOp<Context>::RunOnDevice() {
Reshape();
if (dtype == "FLOAT32") RunWithType<float>(); if (dtype == "FLOAT32") RunWithType<float>();
else if (dtype == "INT32") RunWithType<int>(); else if (dtype == "INT32") RunWithType<int>();
else LOG(FATAL) << "Unsupported data types"; else LOG(FATAL) << "Unsupported data types";
......
...@@ -12,7 +12,7 @@ void RepeatOp<Context>::RunWithType() { ...@@ -12,7 +12,7 @@ void RepeatOp<Context>::RunWithType() {
outer_dim, outer_dim,
dim, dim,
inner_dim, inner_dim,
reps, repeats(),
Xdata, Xdata,
Ydata, Ydata,
&ctx()); &ctx());
...@@ -20,20 +20,16 @@ void RepeatOp<Context>::RunWithType() { ...@@ -20,20 +20,16 @@ void RepeatOp<Context>::RunWithType() {
template <class Context> template <class Context>
void RepeatOp<Context>::RunOnDevice() { void RepeatOp<Context>::RunOnDevice() {
// parse repeats from desc
Tensor* repeats = ws()->GetTensor(repeats_desc);
CHECK(repeats->IsType<int>()) << "\nThe type of repeats should be int32.";
reps = repeats->template data<int, CPUContext>()[0];
if (axis == -1) { if (axis == -1) {
outer_dim = inner_dim = 1; outer_dim = inner_dim = 1;
dim = input(0).count(); dim = input(0).count();
output(0)->Reshape(vector<TIndex>(1, dim * reps)); output(0)->Reshape(vector<TIndex>(1, dim * repeats()));
} else { } else {
outer_dim = input(0).count(0, axis); outer_dim = input(0).count(0, axis);
dim = input(0).dim(axis); dim = input(0).dim(axis);
inner_dim = input(0).count(axis + 1); inner_dim = input(0).count(axis + 1);
vector<TIndex> dims = input(0).dims(); vector<TIndex> dims = input(0).dims();
dims[axis] *= reps; dims[axis] *= repeats();
output(0)->Reshape(dims); output(0)->Reshape(dims);
} }
...@@ -55,7 +51,7 @@ void RepeatGradientOp<Context>::RunWithType() { ...@@ -55,7 +51,7 @@ void RepeatGradientOp<Context>::RunWithType() {
outer_dim, outer_dim,
dim, dim,
inner_dim, inner_dim,
reps, repeats(),
dYdata, dYdata,
dXdata, dXdata,
&ctx()); &ctx());
...@@ -63,10 +59,6 @@ void RepeatGradientOp<Context>::RunWithType() { ...@@ -63,10 +59,6 @@ void RepeatGradientOp<Context>::RunWithType() {
template <class Context> template <class Context>
void RepeatGradientOp<Context>::RunOnDevice() { void RepeatGradientOp<Context>::RunOnDevice() {
// parse repeats from desc
Tensor* repeats = ws()->GetTensor(repeats_desc);
CHECK(repeats->IsType<int>()) << "\nThe type of repeats should be int32.";
reps = repeats->template data<int, CPUContext>()[0];
if (axis == -1) { if (axis == -1) {
outer_dim = inner_dim = 1; outer_dim = inner_dim = 1;
dim = input(0).count(); dim = input(0).count();
...@@ -98,4 +90,4 @@ class GetRepeatGradient final : public GradientMakerBase { ...@@ -98,4 +90,4 @@ class GetRepeatGradient final : public GradientMakerBase {
}; };
REGISTER_GRADIENT(Repeat, GetRepeatGradient); REGISTER_GRADIENT(Repeat, GetRepeatGradient);
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -25,15 +25,9 @@ void TileOp<Context>::TileRunWithType() { ...@@ -25,15 +25,9 @@ void TileOp<Context>::TileRunWithType() {
template <class Context> template <class Context>
void TileOp<Context>::RunOnDevice() { void TileOp<Context>::RunOnDevice() {
// parse tasks from desc
CHECK_EQ(multiples_desc.size(), input(0).ndim())
<< "\nThe num of dimensions of input is " << input(0).ndim()
<< ", but provided " << multiples_desc.size() << " multiples.";
vector< pair<int, int> > process_axes; vector< pair<int, int> > process_axes;
for (int i = 0; i < multiples_desc.size(); i++) { for (int i = 0; i < input(0).ndim(); i++)
int mult = ws()->GetTensor(multiples_desc[i])->template data<int, CPUContext>()[0]; if (multiples(i) > 1) process_axes.push_back({ multiples(i), i });
if (mult > 1) process_axes.push_back({ mult, i });
}
std::sort(process_axes.begin(), process_axes.end()); std::sort(process_axes.begin(), process_axes.end());
// do nothing // do nothing
...@@ -90,15 +84,9 @@ void TileGradientOp<Context>::TileRunWithType() { ...@@ -90,15 +84,9 @@ void TileGradientOp<Context>::TileRunWithType() {
template <class Context> template <class Context>
void TileGradientOp<Context>::RunOnDevice() { void TileGradientOp<Context>::RunOnDevice() {
// parse tasks from desc
CHECK_EQ(multiples_desc.size(), input(-1).ndim())
<< "\nThe num of dimensions of input is " << input(-1).ndim()
<< ", but provided " << multiples_desc.size() << " multiples.";
vector< pair<int, int> > process_axes; vector< pair<int, int> > process_axes;
for (int i = 0; i < multiples_desc.size(); i++) { for (int i = 0; i < input(0).ndim(); i++)
int mult = ws()->GetTensor(multiples_desc[i])->template data<int, CPUContext>()[0]; if (multiples(i) > 1) process_axes.push_back({ multiples(i), i });
if (mult > 1) process_axes.push_back({ mult, i });
}
std::sort(process_axes.begin(), process_axes.end()); std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end()); std::reverse(process_axes.begin(), process_axes.end());
......
...@@ -8,19 +8,19 @@ namespace dragon { ...@@ -8,19 +8,19 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void BilinearResizeOp<Context>::RunWithType() { void BilinearResizeOp<Context>::RunWithType() {
if (data_format == "NCHW") { if (data_format == "NCHW") {
n = dims[0]; n = input(0).dim(0);
c = dims[1]; c = input(0).dim(1);
h = input(0).dim(2); h = input(0).dim(2);
w = input(0).dim(3); w = input(0).dim(3);
out_h = dims[2]; out_h = output(0)->dim(2);
out_w = dims[3]; out_w = output(0)->dim(3);
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
n = dims[0]; n = input(0).dim(0);
h = input(0).dim(1); h = input(0).dim(1);
w = input(0).dim(2); w = input(0).dim(2);
out_h = dims[1]; c = input(0).dim(3);
out_w = dims[2]; out_h = output(0)->dim(1);
c = dims[3]; out_w = output(0)->dim(2);
} }
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
...@@ -33,14 +33,14 @@ void BilinearResizeOp<Context>::RunWithType() { ...@@ -33,14 +33,14 @@ void BilinearResizeOp<Context>::RunWithType() {
template <class Context> template <class Context>
void BilinearResizeOp<Context>::RunOnDevice() { void BilinearResizeOp<Context>::RunOnDevice() {
dims = input(0).dims(); vector<TIndex> dims = input(0).dims();
if (dsize_desc.size() > 0) { if (dsize_desc.size() > 0 || dsize_value.size() > 0) {
CHECK_EQ(dsize_desc.size(), 2) << "\nThe dsize should be a scalar with 2 elements."; for (int i = 0; i < 2; i++)
for (int i = 0; i < 2; i++) { dims[spatial_axis + i] = dsize(i);
Tensor* dsize = ws()->GetTensor(dsize_desc[i]); } else if (!shape_like_desc.empty()) {
CHECK(dsize->IsType<int>()) << "\nThe type of dsize should be int32."; Tensor* shape_like_tensor = ws()->GetTensor(shape_like_desc);
dims[spatial_axis + i] = dsize->template data<int, CPUContext>()[0]; for (int i = 0; i < 2; i++)
} dims[spatial_axis + i] = shape_like_tensor->dim(spatial_axis + i);
} else { } else {
CHECK(fy != -1.0 && fx != -1.0) CHECK(fy != -1.0 && fx != -1.0)
<< "\nThe fx and fy should be set."; << "\nThe fx and fy should be set.";
......
...@@ -29,14 +29,13 @@ void ConvOpBase<Context>::ComputeOutputShape() { ...@@ -29,14 +29,13 @@ void ConvOpBase<Context>::ComputeOutputShape() {
const TIndex output_dim = stride[i] * (input_dim - 1) + dilated_kernel - 2 * pad[i]; const TIndex output_dim = stride[i] * (input_dim - 1) + dilated_kernel - 2 * pad[i];
output_shape.push_back(output_dim); output_shape.push_back(output_dim);
} else { } else {
CHECK(output_dims_desc.size() > 0) CHECK(output_dims_desc.size() > 0 || output_dims_value.size() > 0)
<< "\nThe output shape must be specified if using SAME padding algorithm."; << "\nThe output shape must be specified if using SAME padding algorithm.";
CHECK_EQ((int)output_dims_desc.size(), num_spatial_axes + 2) int given_ndim = (int)std::max(output_dims_desc.size(), output_dims_value.size());
CHECK_EQ(given_ndim, num_spatial_axes + 2)
<< "\nThe len of output shape should be " << num_spatial_axes + 2 << "\nThe len of output shape should be " << num_spatial_axes + 2
<< ", but got " << output_dims_desc.size() << "."; << ", but got " << output_dims_desc.size() << ".";
Tensor* t = ws()->GetTensor(output_dims_desc[spatial_axis + i]); TIndex output_dim = output_dims(spatial_axis + i);
CHECK(t->IsType<int>()) << "\nThe type of output shape should be int32.";
TIndex output_dim = t->template data<int, CPUContext>()[0];
TIndex padding_needed = stride[i] * (input_dim - 1) + dilated_kernel - output_dim; TIndex padding_needed = stride[i] * (input_dim - 1) + dilated_kernel - output_dim;
CHECK_GE(padding_needed, 0) CHECK_GE(padding_needed, 0)
<< "\nThe output shape is incorrect." << "\nThe output shape is incorrect."
......
...@@ -34,13 +34,13 @@ void NNResizeOp<Context>::RunWithType() { ...@@ -34,13 +34,13 @@ void NNResizeOp<Context>::RunWithType() {
template <class Context> template <class Context>
void NNResizeOp<Context>::RunOnDevice() { void NNResizeOp<Context>::RunOnDevice() {
vector<TIndex> dims = input(0).dims(); vector<TIndex> dims = input(0).dims();
if (dsize_desc.size() > 0) { if (dsize_desc.size() > 0 || dsize_value.size() > 0) {
CHECK_EQ(dsize_desc.size(), 2) << "\nThe dsize should be a scalar with 2 elements."; for (int i = 0; i < 2; i++)
for (int i = 0; i < 2; i++) { dims[spatial_axis + i] = dsize(i);
Tensor* dsize = ws()->GetTensor(dsize_desc[i]); } else if (!shape_like_desc.empty()) {
CHECK(dsize->IsType<int>()) << "\nThe type of dsize should be int32."; Tensor* shape_like_tensor = ws()->GetTensor(shape_like_desc);
dims[spatial_axis + i] = dsize->template data<int, CPUContext>()[0]; for (int i = 0; i < 2; i++)
} dims[spatial_axis + i] = shape_like_tensor->dim(spatial_axis + i);
} else { } else {
CHECK(fy != -1.0 && fx != -1.0) CHECK(fy != -1.0 && fx != -1.0)
<< "\nThe fx and fy should be set."; << "\nThe fx and fy should be set.";
......
...@@ -84,6 +84,8 @@ class ProcessAgent(Process): ...@@ -84,6 +84,8 @@ class ProcessAgent(Process):
continue continue
prediction, value = self.predict(self.env.current_state) prediction, value = self.predict(self.env.current_state)
action = self.select_action(prediction) action = self.select_action(prediction)
reward, done = self.env.step(action) reward, done = self.env.step(action)
reward_sum += reward reward_sum += reward
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!