Commit a7a7e4fc by Ting PAN

Fix the skipped algorithm finding in cached CUDNN convolution

Summary:
This commit enforces the algorithm finding even if the backward of filter or data
will not be executed. Otherwise, the empty algorithm will be encountered between
two cached operation with the same arguments and input shape.
1 parent 218796ed
Showing with 250 additions and 258 deletions
...@@ -18,6 +18,9 @@ dragon/core ...@@ -18,6 +18,9 @@ dragon/core
`class Operator <core/Operator.html>`_ `class Operator <core/Operator.html>`_
: The base operator class with context. : The base operator class with context.
`class OpSchema <core/OpSchema.html>`_
: Class to record the schema of operator.
`class Tensor <core/Tensor.html>`_ `class Tensor <core/Tensor.html>`_
: The base tensor class, manage memory or not. : The base tensor class, manage memory or not.
...@@ -37,6 +40,7 @@ dragon/core ...@@ -37,6 +40,7 @@ dragon/core
core/CUDAContext core/CUDAContext
core/Graph core/Graph
core/Operator core/Operator
core/OpSchema
core/Tensor core/Tensor
core/TypeMeta core/TypeMeta
core/UnifiedMemory core/UnifiedMemory
......
OpSchema
========
.. doxygenclass:: dragon::OpSchema
Constructors
------------
.. doxygenfunction:: dragon::OpSchema::OpSchema()
.. doxygenfunction:: dragon::OpSchema::OpSchema(const string &op_type, const string &file, const int line)
Public Functions
----------------
AllowInplace
############
.. doxygenfunction:: dragon::OpSchema::AllowInplace(set<pair<int, int>> inplace)
AllowInplace
############
.. doxygenfunction:: dragon::OpSchema::AllowInplace(std::function<bool(int, int)> inplace)
NumInputs
#########
.. doxygenfunction:: dragon::OpSchema::NumInputs(int n)
NumInputs
#########
.. doxygenfunction:: dragon::OpSchema::NumInputs(int min_num, int max_num)
NumOutputs
##########
.. doxygenfunction:: dragon::OpSchema::NumOutputs(int n)
NumOutputs
##########
.. doxygenfunction:: dragon::OpSchema::NumOutputs(int min_num, int max_num)
Verify
######
.. doxygenfunction:: dragon::OpSchema::Verify
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
...@@ -470,7 +470,6 @@ zero\_ ...@@ -470,7 +470,6 @@ zero\_
.. _torch.div(...): div.html .. _torch.div(...): div.html
.. _torch.eq(...): eq.html .. _torch.eq(...): eq.html
.. _torch.exp(...): exp.html .. _torch.exp(...): exp.html
.. _torch.expand(...): expand.html
.. _torch.floor(...): floor.html .. _torch.floor(...): floor.html
.. _torch.ge(...): ge.html .. _torch.ge(...): ge.html
.. _torch.gt(...): gt.html .. _torch.gt(...): gt.html
......
...@@ -79,12 +79,12 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { ...@@ -79,12 +79,12 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
Map<string, vec32_t> subgraph_indices; Map<string, vec32_t> subgraph_indices;
int opt = 3; // default: O3 int opt = 3; // default: O3
if (args().count("optimization")) opt = arg("optimization").i(); if (args().count("optimization")) opt = arg("optimization").i();
if (opt >= 1) def_v2 = graph_optimizer.PruneNodes(def); if (opt >= 1) def_v2 = graph_optimizer.EliminateUnused(def);
if (opt >= 2) graph_optimizer.AddInplace(def_v2, output_aliases_); if (opt >= 2) graph_optimizer.PlanInplace(def_v2, output_aliases_);
if (opt >= 3) { if (opt >= 3) {
if (phase() == "TRAIN") { if (phase() == "TRAIN") {
def_v2 = graph_optimizer.MirrorStage(def_v2, subgraph_indices); def_v2 = graph_optimizer.PlanCheckpoint(def_v2, subgraph_indices);
def_v2 = gradient_maker.Share(def_v2); def_v2 = gradient_maker.Optimize(def_v2);
} else { } else {
def_v2 = graph_optimizer.SimulateGC(def_v2); def_v2 = graph_optimizer.SimulateGC(def_v2);
} }
...@@ -98,8 +98,8 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { ...@@ -98,8 +98,8 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
Map<string, vector<OperatorBase*>> subgraph; Map<string, vector<OperatorBase*>> subgraph;
for (const auto& it : subgraph_indices) { for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>(); subgraph[it.first] = vector<OperatorBase*>();
for (const auto& idx : subgraph_indices[it.first]) for (auto op_idx : subgraph_indices[it.first])
subgraph[it.first].push_back(cached_ops_[idx]); subgraph[it.first].push_back(cached_ops_[op_idx]);
} }
for (auto* op : cached_ops_) { for (auto* op : cached_ops_) {
op->set_subgraph(subgraph); op->set_subgraph(subgraph);
......
...@@ -19,15 +19,15 @@ namespace dragon { ...@@ -19,15 +19,15 @@ namespace dragon {
class DRAGON_API GraphGradientMaker { class DRAGON_API GraphGradientMaker {
public: public:
/*! \brief Generate graph def from the op defs */ /*! \brief Generate graph from the executed ops */
void Make( void Make(
const vector<OperatorDef*>& op_defs, const vector<OperatorDef*>& ops,
const vector<string>& targets, const vector<string>& targets,
const vector<string>& input_grads, const vector<string>& input_grads,
GraphDef& graph_def); GraphDef& graph);
/*! \brief Rewrite graph def to share the intermediate grads */ /*! \brief Eliminate the unused and make sharing of outputs */
GraphDef Share(const GraphDef& input_def); GraphDef Optimize(const GraphDef& graph);
/*! \brief Add an empty gradient */ /*! \brief Add an empty gradient */
void add_empty_grad(const string& name) { void add_empty_grad(const string& name) {
...@@ -47,14 +47,14 @@ class DRAGON_API GraphGradientMaker { ...@@ -47,14 +47,14 @@ class DRAGON_API GraphGradientMaker {
private: private:
/*! \brief Check the missing grads */ /*! \brief Check the missing grads */
bool CheckGrad( bool CheckGrad(
const OperatorDef& op_def, const OperatorDef& op,
const Set<string>& targets, const Set<string>& targets,
vector<pair<string, int>>& gen_grads); vector<pair<string, int>>& gen_grads);
/*! \brief Return a dummy operator name */ /*! \brief Return a dummy operator name */
string GetOperatorName() { string GetOperatorName() {
if (op_prefix_.empty()) return "GradientOp"; if (op_prefix_.empty()) return "GradientOp";
return op_prefix_ + str::to(op_index_++); return op_prefix_ + str::to(op_idx_++);
} }
/*! \brief The mapping from input to grad */ /*! \brief The mapping from input to grad */
...@@ -70,7 +70,7 @@ class DRAGON_API GraphGradientMaker { ...@@ -70,7 +70,7 @@ class DRAGON_API GraphGradientMaker {
string op_prefix_; string op_prefix_;
/*! \brief The counter of op name */ /*! \brief The counter of op name */
int64_t op_index_ = 0; int64_t op_idx_ = 0;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -32,45 +32,42 @@ class GraphOptimizer { ...@@ -32,45 +32,42 @@ class GraphOptimizer {
/*! \brief Default constructor */ /*! \brief Default constructor */
GraphOptimizer(Workspace* ws) : ws_(ws) {} GraphOptimizer(Workspace* ws) : ws_(ws) {}
/*! \brief Build the DAG resources for given def */ /*! \brief Build the DAG */
void BuildDAG(const GraphDef& input_def); void BuildDAG(const GraphDef& graph);
/*! \brief Prune the redundant nodes (-O1) */ /*! \brief Eliminate the unused outputs and operators */
GraphDef PruneNodes(const GraphDef& input_def); GraphDef EliminateUnused(const GraphDef& graph);
/*! \brief Add the inplace for outputs (-O2) */ /*! \brief Plan the inplace for inputs */
void AddInplace( void PlanInplace(
const GraphDef& input_def, const GraphDef& graph,
Map<string, Set<string>>& output_aliases); Map<string, Set<string>>& output_aliases);
/*! \brief Plan the recomputing for inputs (-O3) */ /*! \brief Plan the checkpoint for inputs */
GraphDef MirrorStage( GraphDef PlanCheckpoint(
const GraphDef& input_def, const GraphDef& graph,
Map<string, vec32_t>& op_indices); Map<string, vec32_t>& subgraph_indices);
/*! \brief Allocate the buffer for outputs (-O3) */ /*! \brief Allocate the shared buffer for outputs */
GraphDef SimulateGC(const GraphDef& input_def); GraphDef SimulateGC(const GraphDef& graph);
protected: protected:
/*! \brief Pass from gradients to remove unused nodes */ /*! \brief Remote the unused nodes from a sink to all sources */
void ForwardPrunePass( void EliminateUnusedNode(const string& sink);
const string& u,
const string& leaf,
const std::deque<string>& path);
/*! \brief Pass from targets to remove unused nodes */ /*! \brief Remote the unused nodes from a source to a sink */
void BackwardPrunePass(const string& v); void EliminateUnusedNode(const string& source, const string& sink);
/* \brief Store the workspace of parent graph */ /* \brief The graph workspace */
Workspace* ws_; Workspace* ws_;
/* \brief Store the DAG */ /* \brief The graph nodes */
Map<string, Node> dag_; Map<string, Node> nodes_;
/* \brief Store the traversal flags */ /* \brief The traversal flags */
Map<string, bool> visited_, colored_; Map<string, bool> visited_, used_;
/* \brief Store the count of references */ /* \brief The reference count */
Map<string, int> reference_count_; Map<string, int> reference_count_;
private: private:
......
...@@ -173,10 +173,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) { ...@@ -173,10 +173,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) {
OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) { OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) {
auto* schema = OpSchemaRegistry::Schema(def.type()); auto* schema = OpSchemaRegistry::Schema(def.type());
if (schema != nullptr) { if (schema != nullptr) CHECK(schema->Verify(def));
CHECK(schema->Verify(def))
<< "\nOperator failed to pass the schema checking.";
}
OperatorDef mutable_def(def); OperatorDef mutable_def(def);
// Heuristically make each random seed slightly different // Heuristically make each random seed slightly different
static unsigned int seed_offset = 0; static unsigned int seed_offset = 0;
......
...@@ -14,7 +14,6 @@ bool OpSchema::Verify(const OperatorDef& def) const { ...@@ -14,7 +14,6 @@ bool OpSchema::Verify(const OperatorDef& def) const {
<< " is not in range [min=" << min_output_ << " is not in range [min=" << min_output_
<< ", max=" << max_output_ << "]"; << ", max=" << max_output_ << "]";
} }
if (CheckInplace != nullptr) {
for (int i = 0; i < def.input_size(); ++i) { for (int i = 0; i < def.input_size(); ++i) {
if (def.input(i).empty()) continue; if (def.input(i).empty()) continue;
for (int j = 0; j < def.output_size(); ++j) { for (int j = 0; j < def.output_size(); ++j) {
...@@ -25,7 +24,6 @@ bool OpSchema::Verify(const OperatorDef& def) const { ...@@ -25,7 +24,6 @@ bool OpSchema::Verify(const OperatorDef& def) const {
} }
} }
} }
}
return true; return true;
} }
...@@ -49,7 +47,12 @@ OpSchema& OpSchema::NumOutputs(int min_num, int max_num) { ...@@ -49,7 +47,12 @@ OpSchema& OpSchema::NumOutputs(int min_num, int max_num) {
return *this; return *this;
} }
OpSchema& OpSchema::Inplace(set<pair<int, int>> inplace) { OpSchema& OpSchema::AllowInplace(std::function<bool(int, int)> inplace) {
CheckInplace = inplace;
return *this;
}
OpSchema& OpSchema::AllowInplace(set<pair<int, int>> inplace) {
CheckInplace = [inplace](int in, int out) -> bool { CheckInplace = [inplace](int in, int out) -> bool {
return (inplace.count(std::make_pair(in, out)) > 0); return (inplace.count(std::make_pair(in, out)) > 0);
}; };
......
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
namespace dragon { namespace dragon {
/*!
* \brief Class to record the schema of operator.
*/
class DRAGON_API OpSchema { class DRAGON_API OpSchema {
public: public:
/*! \brief Default constructor */ /*! \brief Default constructor */
...@@ -27,15 +30,12 @@ class DRAGON_API OpSchema { ...@@ -27,15 +30,12 @@ class DRAGON_API OpSchema {
Init(); Init();
} }
/*! \brief Constructor with defined spec */ /*! \brief Constructor with the defined spec */
OpSchema(const string& op_type, const string& file, const int line) OpSchema(const string& op_type, const string& file, const int line)
: op_type_(op_type), file_(file), line_(line) { : op_type_(op_type), file_(file), line_(line) {
Init(); Init();
} }
/*! \brief Check if the in-place setting is matched */
std::function<bool(int, int)> CheckInplace = nullptr;
/*! \brief Set a fixed number of inputs */ /*! \brief Set a fixed number of inputs */
OpSchema& NumInputs(int n); OpSchema& NumInputs(int n);
...@@ -48,12 +48,18 @@ class DRAGON_API OpSchema { ...@@ -48,12 +48,18 @@ class DRAGON_API OpSchema {
/*! \brief Set the min and max number of outputs */ /*! \brief Set the min and max number of outputs */
OpSchema& NumOutputs(int min_num, int max_num); OpSchema& NumOutputs(int min_num, int max_num);
/*! \brief Set the in-place setting */ /*! \brief Set the rule to allow inplace with a group of indices */
OpSchema& Inplace(set<pair<int, int>> inplace); OpSchema& AllowInplace(set<pair<int, int>> inplace);
/*! \brief Verify if the def matches the schema */ /*! \brief Set the rule to allow inplace with a function */
OpSchema& AllowInplace(std::function<bool(int, int)> inplace);
/*! \brief Check if the given def matches this schema */
bool Verify(const OperatorDef& def) const; bool Verify(const OperatorDef& def) const;
/*! \brief Check if the inplace is allowed */
std::function<bool(int, int)> CheckInplace = [](int, int) { return false; };
private: private:
/*! \brief Initialize the default settings */ /*! \brief Initialize the default settings */
void Init() { void Init() {
......
...@@ -242,7 +242,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -242,7 +242,7 @@ PYBIND11_MODULE(libdragon_python, m) {
maker.Make(op_defs, targets, input_grads, graph_def); maker.Make(op_defs, targets, input_grads, graph_def);
py::gil_scoped_release g; py::gil_scoped_release g;
if (!retain_grads) { if (!retain_grads) {
graph_def = maker.Share(graph_def); graph_def = maker.Optimize(graph_def);
} }
for (const auto& op_def : graph_def.op()) { for (const auto& op_def : graph_def.op()) {
if (verbose) { if (verbose) {
......
...@@ -129,7 +129,7 @@ OPERATOR_SCHEMA(DropBlock2d) ...@@ -129,7 +129,7 @@ OPERATOR_SCHEMA(DropBlock2d)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
DEPLOY_CPU(DropBlock2dGradient); DEPLOY_CPU(DropBlock2dGradient);
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -142,7 +142,7 @@ OPERATOR_SCHEMA(DropBlock2dGradient) ...@@ -142,7 +142,7 @@ OPERATOR_SCHEMA(DropBlock2dGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(DropBlock2d, SimpleGradientMaker); REGISTER_GRADIENT(DropBlock2d, SimpleGradientMaker);
......
...@@ -95,7 +95,7 @@ OPERATOR_SCHEMA(DropPath) ...@@ -95,7 +95,7 @@ OPERATOR_SCHEMA(DropPath)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DropPathGradient) OPERATOR_SCHEMA(DropPathGradient)
/* dY */ /* dY */
...@@ -103,7 +103,7 @@ OPERATOR_SCHEMA(DropPathGradient) ...@@ -103,7 +103,7 @@ OPERATOR_SCHEMA(DropPathGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(DropPath, SimpleGradientMaker); REGISTER_GRADIENT(DropPath, SimpleGradientMaker);
......
...@@ -84,7 +84,7 @@ OPERATOR_SCHEMA(Dropout) ...@@ -84,7 +84,7 @@ OPERATOR_SCHEMA(Dropout)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DropoutGradient) OPERATOR_SCHEMA(DropoutGradient)
/* dY */ /* dY */
...@@ -92,7 +92,7 @@ OPERATOR_SCHEMA(DropoutGradient) ...@@ -92,7 +92,7 @@ OPERATOR_SCHEMA(DropoutGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Dropout, SimpleGradientMaker); REGISTER_GRADIENT(Dropout, SimpleGradientMaker);
......
...@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Elu) ...@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Elu)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(EluGradient) OPERATOR_SCHEMA(EluGradient)
/* Y, dY */ /* Y, dY */
...@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(EluGradient) ...@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(EluGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Elu, InplaceGradientMaker); REGISTER_GRADIENT(Elu, InplaceGradientMaker);
......
...@@ -73,7 +73,7 @@ OPERATOR_SCHEMA(Relu) ...@@ -73,7 +73,7 @@ OPERATOR_SCHEMA(Relu)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReluGradient) OPERATOR_SCHEMA(ReluGradient)
/* Y, dY */ /* Y, dY */
...@@ -81,7 +81,7 @@ OPERATOR_SCHEMA(ReluGradient) ...@@ -81,7 +81,7 @@ OPERATOR_SCHEMA(ReluGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Relu, InplaceGradientMaker); REGISTER_GRADIENT(Relu, InplaceGradientMaker);
......
...@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Selu) ...@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Selu)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SeluGradient) OPERATOR_SCHEMA(SeluGradient)
/* Y, dY */ /* Y, dY */
...@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(SeluGradient) ...@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(SeluGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Selu, InplaceGradientMaker); REGISTER_GRADIENT(Selu, InplaceGradientMaker);
......
...@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Sigmoid) ...@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Sigmoid)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SigmoidGradient) OPERATOR_SCHEMA(SigmoidGradient)
/* Y, dY */ /* Y, dY */
...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(SigmoidGradient) ...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(SigmoidGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Sigmoid, InplaceGradientMaker); REGISTER_GRADIENT(Sigmoid, InplaceGradientMaker);
......
...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(Softmax) ...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(Softmax)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SoftmaxGradient) OPERATOR_SCHEMA(SoftmaxGradient)
/* Y, dY */ /* Y, dY */
...@@ -68,7 +68,7 @@ OPERATOR_SCHEMA(SoftmaxGradient) ...@@ -68,7 +68,7 @@ OPERATOR_SCHEMA(SoftmaxGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Softmax, InplaceGradientMaker); REGISTER_GRADIENT(Softmax, InplaceGradientMaker);
......
...@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Tanh) ...@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Tanh)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(TanhGradient) OPERATOR_SCHEMA(TanhGradient)
/* Y, dY */ /* Y, dY */
...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(TanhGradient) ...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(TanhGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Tanh, InplaceGradientMaker); REGISTER_GRADIENT(Tanh, InplaceGradientMaker);
......
...@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(ExpandDims) ...@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(ExpandDims)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ExpandDimsGradient) OPERATOR_SCHEMA(ExpandDimsGradient)
/* dY */ /* dY */
...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(ExpandDimsGradient) ...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(ExpandDimsGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(ExpandDims, SimpleGradientMaker); REGISTER_GRADIENT(ExpandDims, SimpleGradientMaker);
......
...@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Flatten) ...@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Flatten)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(FlattenGradient) OPERATOR_SCHEMA(FlattenGradient)
/* dY */ /* dY */
...@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(FlattenGradient) ...@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(FlattenGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Flatten, SimpleGradientMaker); REGISTER_GRADIENT(Flatten, SimpleGradientMaker);
......
...@@ -69,7 +69,7 @@ OPERATOR_SCHEMA(Reshape) ...@@ -69,7 +69,7 @@ OPERATOR_SCHEMA(Reshape)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReshapeGradient) OPERATOR_SCHEMA(ReshapeGradient)
/* dY */ /* dY */
...@@ -77,7 +77,7 @@ OPERATOR_SCHEMA(ReshapeGradient) ...@@ -77,7 +77,7 @@ OPERATOR_SCHEMA(ReshapeGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Reshape, SimpleGradientMaker); REGISTER_GRADIENT(Reshape, SimpleGradientMaker);
......
...@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(Squeeze) ...@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(Squeeze)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SqueezeGradient) OPERATOR_SCHEMA(SqueezeGradient)
/* dY */ /* dY */
...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(SqueezeGradient) ...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(SqueezeGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Squeeze, SimpleGradientMaker); REGISTER_GRADIENT(Squeeze, SimpleGradientMaker);
......
...@@ -193,7 +193,7 @@ DEPLOY_CPU(Collective); ...@@ -193,7 +193,7 @@ DEPLOY_CPU(Collective);
DEPLOY_CUDA(Collective); DEPLOY_CUDA(Collective);
#endif #endif
OPERATOR_SCHEMA(Collective); OPERATOR_SCHEMA(Collective).AllowInplace([](int, int) -> bool { return true; });
} // namespace dragon } // namespace dragon
......
...@@ -122,7 +122,7 @@ OPERATOR_SCHEMA(GradientAdd) ...@@ -122,7 +122,7 @@ OPERATOR_SCHEMA(GradientAdd)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X1 => Y */ /* X1 => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(StopGradient) OPERATOR_SCHEMA(StopGradient)
/* X */ /* X */
...@@ -130,7 +130,7 @@ OPERATOR_SCHEMA(StopGradient) ...@@ -130,7 +130,7 @@ OPERATOR_SCHEMA(StopGradient)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
NO_GRADIENT(StopGradient); NO_GRADIENT(StopGradient);
......
...@@ -107,16 +107,16 @@ OPERATOR_SCHEMA(Add) ...@@ -107,16 +107,16 @@ OPERATOR_SCHEMA(Add)
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* A => Y */ /* A => Y, B => Y */
.Inplace({{0, 0}, {1, 0}}); .AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(AddGradient) OPERATOR_SCHEMA(AddGradient)
/* dY */ /* dY */
.NumInputs(1) .NumInputs(1)
/* dA, dB */ /* dA, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dA */ /* dY => dA, dY => dB */
.Inplace({{0, 0}, {0, 1}}); .AllowInplace({{0, 0}, {0, 1}});
REGISTER_GRADIENT(Add, SimpleGradientMaker); REGISTER_GRADIENT(Add, SimpleGradientMaker);
......
...@@ -151,7 +151,7 @@ OPERATOR_SCHEMA(Affine) ...@@ -151,7 +151,7 @@ OPERATOR_SCHEMA(Affine)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(AffineGradient) OPERATOR_SCHEMA(AffineGradient)
/* X, W, dY */ /* X, W, dY */
...@@ -159,7 +159,7 @@ OPERATOR_SCHEMA(AffineGradient) ...@@ -159,7 +159,7 @@ OPERATOR_SCHEMA(AffineGradient)
/* dX, dW, dB */ /* dX, dW, dB */
.NumOutputs(3) .NumOutputs(3)
/* dY => dX */ /* dY => dX */
.Inplace({{2, 0}}); .AllowInplace({{2, 0}});
namespace { namespace {
......
...@@ -6,49 +6,28 @@ namespace dragon { ...@@ -6,49 +6,28 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void AxpbyOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) { void AxpbyOp<Context>::DoRunWithType() {
CHECK_EQ(X->count(), Y->count()); auto &X = Input(0), *Y = Output(0);
auto* x = X->template data<T, Context>(); auto* x = X.template data<T, Context>();
auto* y = Y->template mutable_data<T, Context>(); auto* y = Y->ReshapeLike(X)->template mutable_data<T, Context>();
if (beta_ == 1.f) { if (beta_ == 1.f) {
if (alpha_ == 1.f) { if (alpha_ == 1.f) {
math::Add(X->count(), x, y, y, ctx()); math::Add(X.count(), x, y, y, ctx());
} else { } else {
math::Axpy(X->count(), alpha_, x, y, ctx()); math::Axpy(X.count(), alpha_, x, y, ctx());
} }
} else { } else {
if (alpha_ == 0.f) { if (alpha_ == 0.f) {
math::Scale(X->count(), beta_, y, y, ctx()); math::Scale(X.count(), beta_, y, y, ctx());
} else { } else {
math::Axpby(X->count(), alpha_, x, beta_, y, ctx()); math::Axpby(X.count(), alpha_, x, beta_, y, ctx());
} }
} }
} }
template <class Context> template <class Context>
void AxpbyOp<Context>::RunOnDevice() { void AxpbyOp<Context>::RunOnDevice() {
for (int i = 0; i < InputSize(); i++) { DispatchHelper<MathTensorTypes>::Call(this, Input(0));
auto &X = Input(i), *Y = Output(i);
Y->ReshapeLike(X);
if (XIsType(X, int8_t)) {
DoRunWithType<int8_t>(&X, Y);
} else if (XIsType(X, uint8_t)) {
DoRunWithType<uint8_t>(&X, Y);
} else if (XIsType(X, int)) {
DoRunWithType<int>(&X, Y);
} else if (XIsType(X, int64_t)) {
DoRunWithType<int64_t>(&X, Y);
} else if (XIsType(X, float16)) {
DoRunWithType<float16>(&X, Y);
} else if (XIsType(X, float)) {
DoRunWithType<float>(&X, Y);
} else if (XIsType(X, double)) {
DoRunWithType<double>(&X, Y);
} else
LOG(FATAL) << MessageForUnsupported(
types::to_string(X.meta()),
{"int8", "uint8", "int32", "int64", "float16", "float32", "float64"});
}
} }
DEPLOY_CPU(Axpby); DEPLOY_CPU(Axpby);
...@@ -57,10 +36,12 @@ DEPLOY_CUDA(Axpby); ...@@ -57,10 +36,12 @@ DEPLOY_CUDA(Axpby);
#endif #endif
OPERATOR_SCHEMA(Axpby) OPERATOR_SCHEMA(Axpby)
/* X1, ... */ /* X */
.NumInputs(1, INT_MAX) .NumInputs(1)
/* Y1, ... */ /* Y */
.NumOutputs(1, INT_MAX); .NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
NO_GRADIENT(Axpby); NO_GRADIENT(Axpby);
......
...@@ -207,16 +207,16 @@ OPERATOR_SCHEMA(Div) ...@@ -207,16 +207,16 @@ OPERATOR_SCHEMA(Div)
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* A => Y */ /* A => Y, B => Y */
.Inplace({{0, 0}, {1, 0}}); .AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(DivGradient) OPERATOR_SCHEMA(DivGradient)
/* A, B, dY */ /* A, B, dY */
.NumInputs(3) .NumInputs(3)
/* dA, dB */ /* dA, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dA */ /* dY => dA, dY => dB */
.Inplace({{2, 0}, {2, 1}}); .AllowInplace({{2, 0}, {2, 1}});
REGISTER_GRADIENT(Div, GenericGradientMaker); REGISTER_GRADIENT(Div, GenericGradientMaker);
......
...@@ -172,15 +172,15 @@ DEPLOY_CUDA(Greater); ...@@ -172,15 +172,15 @@ DEPLOY_CUDA(Greater);
DEPLOY_CUDA(GreaterEqual); DEPLOY_CUDA(GreaterEqual);
#endif #endif
OPERATOR_SCHEMA(Ceil).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Ceil).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Floor).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Floor).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Round).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Round).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sign).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Sign).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sqrt).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Sqrt).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Rsqrt).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Rsqrt).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Exp).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Exp).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Log).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Log).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Invert).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Invert).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sin).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Sin).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Cos).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Cos).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Square).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Square).NumInputs(1).NumOutputs(1);
......
...@@ -43,7 +43,7 @@ class AxpbyOp final : public Operator<Context> { ...@@ -43,7 +43,7 @@ class AxpbyOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> template <typename T>
void DoRunWithType(Tensor* X, Tensor* Y); void DoRunWithType();
protected: protected:
float alpha_, beta_; float alpha_, beta_;
......
...@@ -31,7 +31,7 @@ OPERATOR_SCHEMA(ExpGradient) ...@@ -31,7 +31,7 @@ OPERATOR_SCHEMA(ExpGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Exp, InplaceGradientMaker); REGISTER_GRADIENT(Exp, InplaceGradientMaker);
......
...@@ -189,16 +189,16 @@ OPERATOR_SCHEMA(Mul) ...@@ -189,16 +189,16 @@ OPERATOR_SCHEMA(Mul)
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* A => Y */ /* A => Y, B => Y */
.Inplace({{0, 0}, {1, 0}}); .AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(MulGradient) OPERATOR_SCHEMA(MulGradient)
/* A, B, dY */ /* A, B, dY */
.NumInputs(3) .NumInputs(3)
/* dA, dB */ /* dA, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dA */ /* dY => dA, dY => dB */
.Inplace({{2, 0}, {2, 1}}); .AllowInplace({{2, 0}, {2, 1}});
REGISTER_GRADIENT(Mul, GenericGradientMaker); REGISTER_GRADIENT(Mul, GenericGradientMaker);
......
...@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Neg) ...@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Neg)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(NegGradient) OPERATOR_SCHEMA(NegGradient)
/* dY */ /* dY */
...@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(NegGradient) ...@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(NegGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Neg, SimpleGradientMaker); REGISTER_GRADIENT(Neg, SimpleGradientMaker);
......
...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(Reciprocal) ...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(Reciprocal)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReciprocalGradient) OPERATOR_SCHEMA(ReciprocalGradient)
/* Y, dY */ /* Y, dY */
...@@ -61,7 +61,7 @@ OPERATOR_SCHEMA(ReciprocalGradient) ...@@ -61,7 +61,7 @@ OPERATOR_SCHEMA(ReciprocalGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Reciprocal, InplaceGradientMaker); REGISTER_GRADIENT(Reciprocal, InplaceGradientMaker);
......
...@@ -32,7 +32,7 @@ OPERATOR_SCHEMA(RsqrtGradient) ...@@ -32,7 +32,7 @@ OPERATOR_SCHEMA(RsqrtGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Rsqrt, InplaceGradientMaker); REGISTER_GRADIENT(Rsqrt, InplaceGradientMaker);
......
...@@ -30,7 +30,7 @@ OPERATOR_SCHEMA(SignGradient) ...@@ -30,7 +30,7 @@ OPERATOR_SCHEMA(SignGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Sign, SimpleGradientMaker); REGISTER_GRADIENT(Sign, SimpleGradientMaker);
......
...@@ -37,7 +37,7 @@ OPERATOR_SCHEMA(SqrtGradient) ...@@ -37,7 +37,7 @@ OPERATOR_SCHEMA(SqrtGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Sqrt, InplaceGradientMaker); REGISTER_GRADIENT(Sqrt, InplaceGradientMaker);
......
...@@ -112,16 +112,16 @@ OPERATOR_SCHEMA(Sub) ...@@ -112,16 +112,16 @@ OPERATOR_SCHEMA(Sub)
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* A => Y */ /* A => Y, B => Y */
.Inplace({{0, 0}, {1, 0}}); .AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(SubGradient) OPERATOR_SCHEMA(SubGradient)
/* dY */ /* dY */
.NumInputs(1) .NumInputs(1)
/* dA, dB */ /* dA, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dA */ /* dY => dA, dY => dB */
.Inplace({{0, 0}, {0, 1}}); .AllowInplace({{0, 0}, {0, 1}});
REGISTER_GRADIENT(Sub, SimpleGradientMaker); REGISTER_GRADIENT(Sub, SimpleGradientMaker);
......
...@@ -89,7 +89,7 @@ OPERATOR_SCHEMA(BiasAdd) ...@@ -89,7 +89,7 @@ OPERATOR_SCHEMA(BiasAdd)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(BiasAddGradient) OPERATOR_SCHEMA(BiasAddGradient)
/* dY */ /* dY */
...@@ -97,7 +97,7 @@ OPERATOR_SCHEMA(BiasAddGradient) ...@@ -97,7 +97,7 @@ OPERATOR_SCHEMA(BiasAddGradient)
/* dX, dB */ /* dX, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(BiasAdd, SimpleGradientMaker); REGISTER_GRADIENT(BiasAdd, SimpleGradientMaker);
......
...@@ -289,7 +289,7 @@ template <class Context> ...@@ -289,7 +289,7 @@ template <class Context>
template <typename T> template <typename T>
void CuDNNConv2dGradientOp<Context>::ResetDesc() { void CuDNNConv2dGradientOp<Context>::ResetDesc() {
auto &X = Input(0), &W = Input(1), &dY = Input(-1); auto &X = Input(0), &W = Input(1), &dY = Input(-1);
auto *dX = Output(0), *dW = Output(1); // auto *dX = Output(0), *dW = Output(1);
bool input_changed = (X.dims() != input_dims_); bool input_changed = (X.dims() != input_dims_);
bool filter_changed = (W.dims() != filter_dims_); bool filter_changed = (W.dims() != filter_dims_);
if (input_changed || filter_changed) { if (input_changed || filter_changed) {
...@@ -328,8 +328,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -328,8 +328,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
exhaustive_search_data_ = true; exhaustive_search_data_ = true;
exhaustive_search_filter_ = true; exhaustive_search_filter_ = true;
} else { } else {
if (dW->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
{
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos]; cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos];
...@@ -353,20 +353,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -353,20 +353,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found) CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> " << "\nNo algorithms available for <cudnnConvolutionBackwardFilter> "
<< "under the current desc and workspace limit."; << "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
if (dX->has_name()) { {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS;
cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos]; cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos];
...@@ -390,7 +378,17 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -390,7 +378,17 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found) CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardData> " << "\nNo algorithms available for <cudnnConvolutionBackwardData> "
<< "under the current desc and workspace limit."; << "under the current desc and workspace limit.";
}
#else #else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
filter_desc_, filter_desc_,
...@@ -402,7 +400,6 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -402,7 +400,6 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
&bwd_data_algo_)); &bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0) #endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
}
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size
} }
} }
......
...@@ -287,7 +287,6 @@ template <class Context> ...@@ -287,7 +287,6 @@ template <class Context>
template <typename T> template <typename T>
void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
auto &X = Input(0), &W = Input(1), &dY = Input(-1); auto &X = Input(0), &W = Input(1), &dY = Input(-1);
auto *dX = Output(0), *dW = Output(1);
bool input_changed = (X.dims() != input_dims_); bool input_changed = (X.dims() != input_dims_);
bool filter_changed = (W.dims() != filter_dims_); bool filter_changed = (W.dims() != filter_dims_);
if (input_changed || filter_changed) { if (input_changed || filter_changed) {
...@@ -324,8 +323,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -324,8 +323,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
exhaustive_search_data_ = true; exhaustive_search_data_ = true;
exhaustive_search_filter_ = true; exhaustive_search_filter_ = true;
} else { } else {
if (dW->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
{
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos]; cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos];
...@@ -349,20 +348,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -349,20 +348,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found) CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> " << "\nNo algorithms available for <cudnnConvolutionBackwardFilter> "
<< "under the current desc and workspace limit."; << "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
if (dX->has_name()) { {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS;
cudnnConvolutionFwdAlgoPerf_t stats[num_algos]; cudnnConvolutionFwdAlgoPerf_t stats[num_algos];
...@@ -386,7 +373,17 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -386,7 +373,17 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found) CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionForward> " << "\nNo algorithms available for <cudnnConvolutionForward> "
<< "under the current desc and workspace limit."; << "under the current desc and workspace limit.";
}
#else #else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
input_desc_, input_desc_,
...@@ -398,7 +395,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -398,7 +395,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
&bwd_data_algo_)); &bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0) #endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
}
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size
} }
} }
......
...@@ -50,11 +50,11 @@ def set_optimization(level=1): ...@@ -50,11 +50,11 @@ def set_optimization(level=1):
* level = ``0``: Do nothing. * level = ``0``: Do nothing.
* level = ``1``: Prune the redundant nodes. * level = ``1``: Eliminate the unused outputs and operators.
* level = ``2``: Add the inplace to outputs. * level = ``2``: Apply inplace to the inputs if available.
* level = ``3``: Allocate the buffer for outputs. * level = ``3``: Allocate shared buffer for the outputs.
Parameters Parameters
---------- ----------
......
...@@ -78,7 +78,7 @@ class GradientMaker(object): ...@@ -78,7 +78,7 @@ class GradientMaker(object):
if not is_skip: if not is_skip:
for input, grad_input in zip(op_def.input, grad_inputs): for input, grad_input in zip(op_def.input, grad_inputs):
inputs_to_grads[input] = grad_input inputs_to_grads[input] = grad_input
# Add def for ``GradientGenerateOp`` # Add ``GradientGenerateOp``
if len(gen_grads) > 0: if len(gen_grads) > 0:
inputs, outputs, values = [], [], [] inputs, outputs, values = [], [], []
for name, i in gen_grads: for name, i in gen_grads:
...@@ -94,7 +94,7 @@ class GradientMaker(object): ...@@ -94,7 +94,7 @@ class GradientMaker(object):
device_option=op_def.device_option device_option=op_def.device_option
if op_def.HasField('device_option') else None) if op_def.HasField('device_option') else None)
backward_defs.append(gen_op) backward_defs.append(gen_op)
# Add def for ``GradientOp`` # Add ``GradientOp``
for grad_def in grad_defs: for grad_def in grad_defs:
grad_def.name = OpDef.get_name() grad_def.name = OpDef.get_name()
backward_defs.append(grad_def) backward_defs.append(grad_def)
......
...@@ -130,7 +130,7 @@ def affine(inputs, axis=1, num_axes=1, **kwargs): ...@@ -130,7 +130,7 @@ def affine(inputs, axis=1, num_axes=1, **kwargs):
return op_lib.blend(**args) return op_lib.blend(**args)
@OpSchema.num_inputs(1, 2147483647) @OpSchema.num_inputs(1)
def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs): def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
r"""Compute the element-wise addition from input to output. r"""Compute the element-wise addition from input to output.
...@@ -140,10 +140,10 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs): ...@@ -140,10 +140,10 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
Parameters Parameters
---------- ----------
inputs : Union[dragon.Tensor, Sequence[dragon.Tensor]] inputs : dragon.Tensor
The input tensor(s). The input tensor.
outputs : Union[dragon.Tensor, Sequence[dragon.Tensor]], optional outputs : dragon.Tensor, optional
The output tensor(s). The output tensor.
alpha : number, optional, default=1. alpha : number, optional, default=1.
The value to :math:`\alpha`. The value to :math:`\alpha`.
beta : number, optional, default=1. beta : number, optional, default=1.
...@@ -151,23 +151,17 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs): ...@@ -151,23 +151,17 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
Returns Returns
------- -------
Union[dragon.Tensor, Sequence[dragon.Tensor]] dragon.Tensor
The output tensor(s). The output tensor.
""" """
args = parse_args(locals()) args = parse_args(locals())
args['alpha'], args['beta'] = float(alpha), float(beta) args['alpha'], args['beta'] = float(alpha), float(beta)
if types.is_tensor(inputs):
inputs = [inputs]
if outputs is not None and types.is_tensor(outputs):
args['outputs'] = [outputs]
op_lib = math_ops_lib.Axpby op_lib = math_ops_lib.Axpby
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(alpha=args['alpha'], beta=args['beta']) \
alpha=args['alpha'], .apply([inputs], [outputs])
beta=args['beta'],
).apply(inputs, args['outputs'])
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -65,8 +65,7 @@ class _BatchNorm(Module): ...@@ -65,8 +65,7 @@ class _BatchNorm(Module):
.format(**self.__dict__) .format(**self.__dict__)
def forward(self, input): def forward(self, input):
training = self.training or \ training = self.training or not self.track_running_stats
not self.track_running_stats
return F.batch_norm( return F.batch_norm(
input, *self.inputs, input, *self.inputs,
training=training, training=training,
......
...@@ -146,10 +146,8 @@ def cat(seq, dim=0, out=None): ...@@ -146,10 +146,8 @@ def cat(seq, dim=0, out=None):
""" """
return _functions.Concat \ return _functions.Concat \
.instantiate( .instantiate(seq[0].device, axis=dim) \
seq[0].device, .apply(seq, out)
axis=dim,
).apply(seq, out)
def channel_normalize( def channel_normalize(
...@@ -618,10 +616,7 @@ def nonzero(input, out=None): ...@@ -618,10 +616,7 @@ def nonzero(input, out=None):
The output tensor. The output tensor.
""" """
return _functions.NonZero \ return _functions.NonZero.instantiate(input.device).apply(input, out)
.instantiate(
input.device,
).apply(input, out)
def one_hot(input, depth): def one_hot(input, depth):
...@@ -647,8 +642,7 @@ def one_hot(input, depth): ...@@ -647,8 +642,7 @@ def one_hot(input, depth):
The output tensor. The output tensor.
""" """
return _functions.OneHot \ return _functions.OneHot.instantiate(input.device, depth=depth).apply(input)
.instantiate(input.device, depth=depth).apply(input)
def permute(input, dims): def permute(input, dims):
...@@ -715,18 +709,14 @@ def reshape(input, shape, out=None): ...@@ -715,18 +709,14 @@ def reshape(input, shape, out=None):
""" """
shape = nest.flatten(shape) shape = nest.flatten(shape)
return _functions.Reshape \ return _functions.Reshape \
.instantiate( .instantiate(input.device, ndim=len(shape)) \
input.device, .apply(input, shape, out)
ndim=len(shape),
).apply(input, shape, out)
def slice(input, starts, sizes): def slice(input, starts, sizes):
return _functions.Slice \ return _functions.Slice \
.instantiate( .instantiate(input.device, ndim=len(starts)) \
input.device, .apply(input, starts, sizes)
ndim=len(starts),
).apply(input, starts, sizes)
def split(tensor, split_size_or_sections, dim=0): def split(tensor, split_size_or_sections, dim=0):
...@@ -1015,9 +1005,8 @@ def where(condition, x, y): ...@@ -1015,9 +1005,8 @@ def where(condition, x, y):
""" """
return _functions.Where \ return _functions.Where \
.instantiate( .instantiate(utils.unify_devices([condition, x, y])) \
utils.unify_devices([condition, x, y]), .apply(condition, x, y)
).apply(condition, x, y)
def _arg_reduce(input, op_type, dim=None, keepdim=False, out=None): def _arg_reduce(input, op_type, dim=None, keepdim=False, out=None):
......
...@@ -567,10 +567,6 @@ def expand(self, *sizes): ...@@ -567,10 +567,6 @@ def expand(self, *sizes):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.expand(...)`_
""" """
return array_funcs.expand(self, sizes) return array_funcs.expand(self, sizes)
......
...@@ -51,6 +51,5 @@ class GradAccumulate(function.Function): ...@@ -51,6 +51,5 @@ class GradAccumulate(function.Function):
'arguments': {'alpha': 1., 'beta': float(self.momentum)}, 'arguments': {'alpha': 1., 'beta': float(self.momentum)},
} }
def forward(self, grads): def forward(self, grad):
outputs = [grad.id + '[accum]' for grad in grads] return self.dispatch([grad], [grad.id + '[accum]'], no_grad=True)
return self.dispatch(grads, outputs, no_grad=True)
...@@ -14,18 +14,13 @@ from __future__ import absolute_import ...@@ -14,18 +14,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.util import nest
from dragon.vm.torch.core.ops.training import _functions from dragon.vm.torch.core.ops.training import _functions
def accumulate_grad(grads, momentum=1): def accumulate_grad(grad, momentum=1):
"""Accumulate the gradients.""" """Accumulate the gradient."""
grads = nest.flatten(grads)
if len(grads) == 0:
return
return _functions.GradAccumulate \ return _functions.GradAccumulate \
.instantiate(grads[0].device, momentum=momentum) \ .instantiate(grad.device, momentum=momentum).apply(grad)
.apply(grads)
def update_param( def update_param(
......
...@@ -97,15 +97,13 @@ class Optimizer(object): ...@@ -97,15 +97,13 @@ class Optimizer(object):
The momentum to the accumulated value. The momentum to the accumulated value.
""" """
grads = []
current_ws = workspace.get_workspace() current_ws = workspace.get_workspace()
for group in self.param_groups: for group in self.param_groups:
group['_internal/grad_accum'] = True group['_internal/grad_accum'] = True
for param in group['params']: for param in group['params']:
grad = self._steal_grad(current_ws, param) grad = self._steal_grad(current_ws, param)
if grad is not None: if grad is not None:
grads.append(grad) training_funcs.accumulate_grad(grad)
training_funcs.accumulate_grad(grads, momentum)
def add_param_group(self, param_group): def add_param_group(self, param_group):
"""Add a new param group into the optimizer. """Add a new param group into the optimizer.
......
...@@ -776,10 +776,6 @@ class Tensor(object): ...@@ -776,10 +776,6 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.expand(...)`_
""" """
def expand_as(self, other): def expand_as(self, other):
...@@ -795,10 +791,6 @@ class Tensor(object): ...@@ -795,10 +791,6 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.expand(...)`_
""" """
return self.expand(*other.size()) return self.expand(*other.size())
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!