// -------------------------------------------------------- // Dragon // Copyright(c) 2017 SeetaTech // Written by Ting Pan // -------------------------------------------------------- #ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #include "core/operator.h" #include "utils/filler.h" namespace dragon { template <class Context> class InitializeOp: public Operator<Context> { public: InitializeOp(const OperatorDef& op_def, Workspace* ws) : Operator<Context>(op_def, ws), static_shape(OperatorBase::GetRepeatedArg<int>("static_shape")), dynamic_shape(OperatorBase::GetSingleArg<string>("dynamic_shape", "")) {} void RunOnDevice() override; template <typename T> void RunWithType(); protected: TensorFiller filler; vector<int> static_shape; string dynamic_shape; }; template <class Context> class FillOp final : public InitializeOp<Context> { public: FillOp(const OperatorDef& op_def, Workspace* ws) : InitializeOp<Context>(op_def, ws) { this->filler.set_type("constant"); this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0)); } }; template <class Context> class RandomUniformOp final : public InitializeOp<Context> { public: RandomUniformOp(const OperatorDef& op_def, Workspace* ws) : InitializeOp<Context>(op_def, ws) { this->filler.set_type("uniform"); this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0)); this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0)); } }; template <class Context> class RandomNormalOp final : public InitializeOp<Context> { public: RandomNormalOp(const OperatorDef& op_def, Workspace* ws) : InitializeOp<Context>(op_def, ws) { this->filler.set_type("normal"); this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0)); this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0)); } }; template <class Context> class TruncatedNormalOp final : public InitializeOp<Context> { public: TruncatedNormalOp(const OperatorDef& op_def, Workspace* ws) : InitializeOp<Context>(op_def, ws) { this->filler.set_type("truncated_normal"); float mu = OperatorBase::GetSingleArg<float>("mean", 0.0); float sigma = OperatorBase::GetSingleArg<float>("std", 1.0); this->filler.set_mean(mu); this->filler.set_std(sigma); this->filler.set_low(mu - 2 * sigma); this->filler.set_high(mu + 2 * sigma); } }; template <class Context> class GlorotUniformOp final : public InitializeOp<Context> { public: GlorotUniformOp(const OperatorDef& op_def, Workspace* ws) : InitializeOp<Context>(op_def, ws) { string mode = OperatorBase::GetSingleArg<string>("mode", "fan_in"); float scale = OperatorBase::GetSingleArg<float>("scale", 3.0); this->filler.set_type("xavier"); if (mode == "fan_avg") { this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_AVG); } else if (mode == "fan_out") { this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_OUT); } else { this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_IN); } this->filler.set_scale(scale); } }; template <class Context> class GlorotNormalOp final : public InitializeOp<Context> { public: GlorotNormalOp(const OperatorDef& op_def, Workspace* ws) : InitializeOp<Context>(op_def, ws) { string mode = OperatorBase::GetSingleArg<string>("mode", "fan_in"); float scale = OperatorBase::GetSingleArg<float>("scale", 2.0); this->filler.set_type("msra"); if (mode == "fan_avg") { this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_AVG); } else if (mode == "fan_out") { this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_OUT); } else { this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_IN); } this->filler.set_scale(scale); } }; } // namespace #endif // DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_