Commit a7a7e4fc by Ting PAN

Fix the skipped algorithm finding in cached CUDNN convolution

Summary:
This commit enforces the algorithm finding even if the backward of filter or data
will not be executed. Otherwise, the empty algorithm will be encountered between
two cached operation with the same arguments and input shape.
1 parent 218796ed
Showing with 275 additions and 283 deletions
......@@ -18,6 +18,9 @@ dragon/core
`class Operator <core/Operator.html>`_
: The base operator class with context.
`class OpSchema <core/OpSchema.html>`_
: Class to record the schema of operator.
`class Tensor <core/Tensor.html>`_
: The base tensor class, manage memory or not.
......@@ -37,6 +40,7 @@ dragon/core
core/CUDAContext
core/Graph
core/Operator
core/OpSchema
core/Tensor
core/TypeMeta
core/UnifiedMemory
......
OpSchema
========
.. doxygenclass:: dragon::OpSchema
Constructors
------------
.. doxygenfunction:: dragon::OpSchema::OpSchema()
.. doxygenfunction:: dragon::OpSchema::OpSchema(const string &op_type, const string &file, const int line)
Public Functions
----------------
AllowInplace
############
.. doxygenfunction:: dragon::OpSchema::AllowInplace(set<pair<int, int>> inplace)
AllowInplace
############
.. doxygenfunction:: dragon::OpSchema::AllowInplace(std::function<bool(int, int)> inplace)
NumInputs
#########
.. doxygenfunction:: dragon::OpSchema::NumInputs(int n)
NumInputs
#########
.. doxygenfunction:: dragon::OpSchema::NumInputs(int min_num, int max_num)
NumOutputs
##########
.. doxygenfunction:: dragon::OpSchema::NumOutputs(int n)
NumOutputs
##########
.. doxygenfunction:: dragon::OpSchema::NumOutputs(int min_num, int max_num)
Verify
######
.. doxygenfunction:: dragon::OpSchema::Verify
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
......@@ -470,7 +470,6 @@ zero\_
.. _torch.div(...): div.html
.. _torch.eq(...): eq.html
.. _torch.exp(...): exp.html
.. _torch.expand(...): expand.html
.. _torch.floor(...): floor.html
.. _torch.ge(...): ge.html
.. _torch.gt(...): gt.html
......
......@@ -79,12 +79,12 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
Map<string, vec32_t> subgraph_indices;
int opt = 3; // default: O3
if (args().count("optimization")) opt = arg("optimization").i();
if (opt >= 1) def_v2 = graph_optimizer.PruneNodes(def);
if (opt >= 2) graph_optimizer.AddInplace(def_v2, output_aliases_);
if (opt >= 1) def_v2 = graph_optimizer.EliminateUnused(def);
if (opt >= 2) graph_optimizer.PlanInplace(def_v2, output_aliases_);
if (opt >= 3) {
if (phase() == "TRAIN") {
def_v2 = graph_optimizer.MirrorStage(def_v2, subgraph_indices);
def_v2 = gradient_maker.Share(def_v2);
def_v2 = graph_optimizer.PlanCheckpoint(def_v2, subgraph_indices);
def_v2 = gradient_maker.Optimize(def_v2);
} else {
def_v2 = graph_optimizer.SimulateGC(def_v2);
}
......@@ -98,8 +98,8 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
Map<string, vector<OperatorBase*>> subgraph;
for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>();
for (const auto& idx : subgraph_indices[it.first])
subgraph[it.first].push_back(cached_ops_[idx]);
for (auto op_idx : subgraph_indices[it.first])
subgraph[it.first].push_back(cached_ops_[op_idx]);
}
for (auto* op : cached_ops_) {
op->set_subgraph(subgraph);
......
......@@ -19,15 +19,15 @@ namespace dragon {
class DRAGON_API GraphGradientMaker {
public:
/*! \brief Generate graph def from the op defs */
/*! \brief Generate graph from the executed ops */
void Make(
const vector<OperatorDef*>& op_defs,
const vector<OperatorDef*>& ops,
const vector<string>& targets,
const vector<string>& input_grads,
GraphDef& graph_def);
GraphDef& graph);
/*! \brief Rewrite graph def to share the intermediate grads */
GraphDef Share(const GraphDef& input_def);
/*! \brief Eliminate the unused and make sharing of outputs */
GraphDef Optimize(const GraphDef& graph);
/*! \brief Add an empty gradient */
void add_empty_grad(const string& name) {
......@@ -47,14 +47,14 @@ class DRAGON_API GraphGradientMaker {
private:
/*! \brief Check the missing grads */
bool CheckGrad(
const OperatorDef& op_def,
const OperatorDef& op,
const Set<string>& targets,
vector<pair<string, int>>& gen_grads);
/*! \brief Return a dummy operator name */
string GetOperatorName() {
if (op_prefix_.empty()) return "GradientOp";
return op_prefix_ + str::to(op_index_++);
return op_prefix_ + str::to(op_idx_++);
}
/*! \brief The mapping from input to grad */
......@@ -70,7 +70,7 @@ class DRAGON_API GraphGradientMaker {
string op_prefix_;
/*! \brief The counter of op name */
int64_t op_index_ = 0;
int64_t op_idx_ = 0;
};
} // namespace dragon
......
......@@ -32,45 +32,42 @@ class GraphOptimizer {
/*! \brief Default constructor */
GraphOptimizer(Workspace* ws) : ws_(ws) {}
/*! \brief Build the DAG resources for given def */
void BuildDAG(const GraphDef& input_def);
/*! \brief Build the DAG */
void BuildDAG(const GraphDef& graph);
/*! \brief Prune the redundant nodes (-O1) */
GraphDef PruneNodes(const GraphDef& input_def);
/*! \brief Eliminate the unused outputs and operators */
GraphDef EliminateUnused(const GraphDef& graph);
/*! \brief Add the inplace for outputs (-O2) */
void AddInplace(
const GraphDef& input_def,
/*! \brief Plan the inplace for inputs */
void PlanInplace(
const GraphDef& graph,
Map<string, Set<string>>& output_aliases);
/*! \brief Plan the recomputing for inputs (-O3) */
GraphDef MirrorStage(
const GraphDef& input_def,
Map<string, vec32_t>& op_indices);
/*! \brief Plan the checkpoint for inputs */
GraphDef PlanCheckpoint(
const GraphDef& graph,
Map<string, vec32_t>& subgraph_indices);
/*! \brief Allocate the buffer for outputs (-O3) */
GraphDef SimulateGC(const GraphDef& input_def);
/*! \brief Allocate the shared buffer for outputs */
GraphDef SimulateGC(const GraphDef& graph);
protected:
/*! \brief Pass from gradients to remove unused nodes */
void ForwardPrunePass(
const string& u,
const string& leaf,
const std::deque<string>& path);
/*! \brief Remote the unused nodes from a sink to all sources */
void EliminateUnusedNode(const string& sink);
/*! \brief Pass from targets to remove unused nodes */
void BackwardPrunePass(const string& v);
/*! \brief Remote the unused nodes from a source to a sink */
void EliminateUnusedNode(const string& source, const string& sink);
/* \brief Store the workspace of parent graph */
/* \brief The graph workspace */
Workspace* ws_;
/* \brief Store the DAG */
Map<string, Node> dag_;
/* \brief The graph nodes */
Map<string, Node> nodes_;
/* \brief Store the traversal flags */
Map<string, bool> visited_, colored_;
/* \brief The traversal flags */
Map<string, bool> visited_, used_;
/* \brief Store the count of references */
/* \brief The reference count */
Map<string, int> reference_count_;
private:
......
......@@ -173,10 +173,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) {
OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) {
auto* schema = OpSchemaRegistry::Schema(def.type());
if (schema != nullptr) {
CHECK(schema->Verify(def))
<< "\nOperator failed to pass the schema checking.";
}
if (schema != nullptr) CHECK(schema->Verify(def));
OperatorDef mutable_def(def);
// Heuristically make each random seed slightly different
static unsigned int seed_offset = 0;
......
......@@ -14,15 +14,13 @@ bool OpSchema::Verify(const OperatorDef& def) const {
<< " is not in range [min=" << min_output_
<< ", max=" << max_output_ << "]";
}
if (CheckInplace != nullptr) {
for (int i = 0; i < def.input_size(); ++i) {
if (def.input(i).empty()) continue;
for (int j = 0; j < def.output_size(); ++j) {
if (def.output(j).empty()) continue;
if (def.input(i) == def.output(j) && !CheckInplace(i, j)) {
LOG(FATAL) << header << "Input(" << i << ") and Output(" << j << ") "
<< "can not be set to inplace.";
}
for (int i = 0; i < def.input_size(); ++i) {
if (def.input(i).empty()) continue;
for (int j = 0; j < def.output_size(); ++j) {
if (def.output(j).empty()) continue;
if (def.input(i) == def.output(j) && !CheckInplace(i, j)) {
LOG(FATAL) << header << "Input(" << i << ") and Output(" << j << ") "
<< "can not be set to inplace.";
}
}
}
......@@ -49,7 +47,12 @@ OpSchema& OpSchema::NumOutputs(int min_num, int max_num) {
return *this;
}
OpSchema& OpSchema::Inplace(set<pair<int, int>> inplace) {
OpSchema& OpSchema::AllowInplace(std::function<bool(int, int)> inplace) {
CheckInplace = inplace;
return *this;
}
OpSchema& OpSchema::AllowInplace(set<pair<int, int>> inplace) {
CheckInplace = [inplace](int in, int out) -> bool {
return (inplace.count(std::make_pair(in, out)) > 0);
};
......
......@@ -20,6 +20,9 @@
namespace dragon {
/*!
* \brief Class to record the schema of operator.
*/
class DRAGON_API OpSchema {
public:
/*! \brief Default constructor */
......@@ -27,15 +30,12 @@ class DRAGON_API OpSchema {
Init();
}
/*! \brief Constructor with defined spec */
/*! \brief Constructor with the defined spec */
OpSchema(const string& op_type, const string& file, const int line)
: op_type_(op_type), file_(file), line_(line) {
Init();
}
/*! \brief Check if the in-place setting is matched */
std::function<bool(int, int)> CheckInplace = nullptr;
/*! \brief Set a fixed number of inputs */
OpSchema& NumInputs(int n);
......@@ -48,12 +48,18 @@ class DRAGON_API OpSchema {
/*! \brief Set the min and max number of outputs */
OpSchema& NumOutputs(int min_num, int max_num);
/*! \brief Set the in-place setting */
OpSchema& Inplace(set<pair<int, int>> inplace);
/*! \brief Set the rule to allow inplace with a group of indices */
OpSchema& AllowInplace(set<pair<int, int>> inplace);
/*! \brief Verify if the def matches the schema */
/*! \brief Set the rule to allow inplace with a function */
OpSchema& AllowInplace(std::function<bool(int, int)> inplace);
/*! \brief Check if the given def matches this schema */
bool Verify(const OperatorDef& def) const;
/*! \brief Check if the inplace is allowed */
std::function<bool(int, int)> CheckInplace = [](int, int) { return false; };
private:
/*! \brief Initialize the default settings */
void Init() {
......
......@@ -242,7 +242,7 @@ PYBIND11_MODULE(libdragon_python, m) {
maker.Make(op_defs, targets, input_grads, graph_def);
py::gil_scoped_release g;
if (!retain_grads) {
graph_def = maker.Share(graph_def);
graph_def = maker.Optimize(graph_def);
}
for (const auto& op_def : graph_def.op()) {
if (verbose) {
......
......@@ -129,7 +129,7 @@ OPERATOR_SCHEMA(DropBlock2d)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
DEPLOY_CPU(DropBlock2dGradient);
#ifdef USE_CUDA
......@@ -142,7 +142,7 @@ OPERATOR_SCHEMA(DropBlock2dGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(DropBlock2d, SimpleGradientMaker);
......
......@@ -95,7 +95,7 @@ OPERATOR_SCHEMA(DropPath)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DropPathGradient)
/* dY */
......@@ -103,7 +103,7 @@ OPERATOR_SCHEMA(DropPathGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(DropPath, SimpleGradientMaker);
......
......@@ -84,7 +84,7 @@ OPERATOR_SCHEMA(Dropout)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DropoutGradient)
/* dY */
......@@ -92,7 +92,7 @@ OPERATOR_SCHEMA(DropoutGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Dropout, SimpleGradientMaker);
......
......@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Elu)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(EluGradient)
/* Y, dY */
......@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(EluGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Elu, InplaceGradientMaker);
......
......@@ -73,7 +73,7 @@ OPERATOR_SCHEMA(Relu)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReluGradient)
/* Y, dY */
......@@ -81,7 +81,7 @@ OPERATOR_SCHEMA(ReluGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Relu, InplaceGradientMaker);
......
......@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Selu)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SeluGradient)
/* Y, dY */
......@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(SeluGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Selu, InplaceGradientMaker);
......
......@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Sigmoid)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SigmoidGradient)
/* Y, dY */
......@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(SigmoidGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Sigmoid, InplaceGradientMaker);
......
......@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(Softmax)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SoftmaxGradient)
/* Y, dY */
......@@ -68,7 +68,7 @@ OPERATOR_SCHEMA(SoftmaxGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Softmax, InplaceGradientMaker);
......
......@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Tanh)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(TanhGradient)
/* Y, dY */
......@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(TanhGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Tanh, InplaceGradientMaker);
......
......@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(ExpandDims)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ExpandDimsGradient)
/* dY */
......@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(ExpandDimsGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(ExpandDims, SimpleGradientMaker);
......
......@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Flatten)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(FlattenGradient)
/* dY */
......@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(FlattenGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Flatten, SimpleGradientMaker);
......
......@@ -69,7 +69,7 @@ OPERATOR_SCHEMA(Reshape)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReshapeGradient)
/* dY */
......@@ -77,7 +77,7 @@ OPERATOR_SCHEMA(ReshapeGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Reshape, SimpleGradientMaker);
......
......@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(Squeeze)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SqueezeGradient)
/* dY */
......@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(SqueezeGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Squeeze, SimpleGradientMaker);
......
......@@ -193,7 +193,7 @@ DEPLOY_CPU(Collective);
DEPLOY_CUDA(Collective);
#endif
OPERATOR_SCHEMA(Collective);
OPERATOR_SCHEMA(Collective).AllowInplace([](int, int) -> bool { return true; });
} // namespace dragon
......
......@@ -122,7 +122,7 @@ OPERATOR_SCHEMA(GradientAdd)
/* Y */
.NumOutputs(1)
/* X1 => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(StopGradient)
/* X */
......@@ -130,7 +130,7 @@ OPERATOR_SCHEMA(StopGradient)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
NO_GRADIENT(StopGradient);
......
......@@ -107,16 +107,16 @@ OPERATOR_SCHEMA(Add)
.NumInputs(2)
/* Y */
.NumOutputs(1)
/* A => Y */
.Inplace({{0, 0}, {1, 0}});
/* A => Y, B => Y */
.AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(AddGradient)
/* dY */
.NumInputs(1)
/* dA, dB */
.NumOutputs(2)
/* dY => dA */
.Inplace({{0, 0}, {0, 1}});
/* dY => dA, dY => dB */
.AllowInplace({{0, 0}, {0, 1}});
REGISTER_GRADIENT(Add, SimpleGradientMaker);
......
......@@ -151,7 +151,7 @@ OPERATOR_SCHEMA(Affine)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(AffineGradient)
/* X, W, dY */
......@@ -159,7 +159,7 @@ OPERATOR_SCHEMA(AffineGradient)
/* dX, dW, dB */
.NumOutputs(3)
/* dY => dX */
.Inplace({{2, 0}});
.AllowInplace({{2, 0}});
namespace {
......
......@@ -6,49 +6,28 @@ namespace dragon {
template <class Context>
template <typename T>
void AxpbyOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) {
CHECK_EQ(X->count(), Y->count());
auto* x = X->template data<T, Context>();
auto* y = Y->template mutable_data<T, Context>();
void AxpbyOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
auto* x = X.template data<T, Context>();
auto* y = Y->ReshapeLike(X)->template mutable_data<T, Context>();
if (beta_ == 1.f) {
if (alpha_ == 1.f) {
math::Add(X->count(), x, y, y, ctx());
math::Add(X.count(), x, y, y, ctx());
} else {
math::Axpy(X->count(), alpha_, x, y, ctx());
math::Axpy(X.count(), alpha_, x, y, ctx());
}
} else {
if (alpha_ == 0.f) {
math::Scale(X->count(), beta_, y, y, ctx());
math::Scale(X.count(), beta_, y, y, ctx());
} else {
math::Axpby(X->count(), alpha_, x, beta_, y, ctx());
math::Axpby(X.count(), alpha_, x, beta_, y, ctx());
}
}
}
template <class Context>
void AxpbyOp<Context>::RunOnDevice() {
for (int i = 0; i < InputSize(); i++) {
auto &X = Input(i), *Y = Output(i);
Y->ReshapeLike(X);
if (XIsType(X, int8_t)) {
DoRunWithType<int8_t>(&X, Y);
} else if (XIsType(X, uint8_t)) {
DoRunWithType<uint8_t>(&X, Y);
} else if (XIsType(X, int)) {
DoRunWithType<int>(&X, Y);
} else if (XIsType(X, int64_t)) {
DoRunWithType<int64_t>(&X, Y);
} else if (XIsType(X, float16)) {
DoRunWithType<float16>(&X, Y);
} else if (XIsType(X, float)) {
DoRunWithType<float>(&X, Y);
} else if (XIsType(X, double)) {
DoRunWithType<double>(&X, Y);
} else
LOG(FATAL) << MessageForUnsupported(
types::to_string(X.meta()),
{"int8", "uint8", "int32", "int64", "float16", "float32", "float64"});
}
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Axpby);
......@@ -57,10 +36,12 @@ DEPLOY_CUDA(Axpby);
#endif
OPERATOR_SCHEMA(Axpby)
/* X1, ... */
.NumInputs(1, INT_MAX)
/* Y1, ... */
.NumOutputs(1, INT_MAX);
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
NO_GRADIENT(Axpby);
......
......@@ -207,16 +207,16 @@ OPERATOR_SCHEMA(Div)
.NumInputs(2)
/* Y */
.NumOutputs(1)
/* A => Y */
.Inplace({{0, 0}, {1, 0}});
/* A => Y, B => Y */
.AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(DivGradient)
/* A, B, dY */
.NumInputs(3)
/* dA, dB */
.NumOutputs(2)
/* dY => dA */
.Inplace({{2, 0}, {2, 1}});
/* dY => dA, dY => dB */
.AllowInplace({{2, 0}, {2, 1}});
REGISTER_GRADIENT(Div, GenericGradientMaker);
......
......@@ -172,15 +172,15 @@ DEPLOY_CUDA(Greater);
DEPLOY_CUDA(GreaterEqual);
#endif
OPERATOR_SCHEMA(Ceil).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Floor).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Round).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Sign).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Sqrt).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Rsqrt).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Exp).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Log).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Invert).NumInputs(1).NumOutputs(1).Inplace({{0, 0}});
OPERATOR_SCHEMA(Ceil).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Floor).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Round).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sign).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sqrt).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Rsqrt).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Exp).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Log).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Invert).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sin).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Cos).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Square).NumInputs(1).NumOutputs(1);
......
......@@ -43,7 +43,7 @@ class AxpbyOp final : public Operator<Context> {
void RunOnDevice() override;
template <typename T>
void DoRunWithType(Tensor* X, Tensor* Y);
void DoRunWithType();
protected:
float alpha_, beta_;
......
......@@ -31,7 +31,7 @@ OPERATOR_SCHEMA(ExpGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Exp, InplaceGradientMaker);
......
......@@ -189,16 +189,16 @@ OPERATOR_SCHEMA(Mul)
.NumInputs(2)
/* Y */
.NumOutputs(1)
/* A => Y */
.Inplace({{0, 0}, {1, 0}});
/* A => Y, B => Y */
.AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(MulGradient)
/* A, B, dY */
.NumInputs(3)
/* dA, dB */
.NumOutputs(2)
/* dY => dA */
.Inplace({{2, 0}, {2, 1}});
/* dY => dA, dY => dB */
.AllowInplace({{2, 0}, {2, 1}});
REGISTER_GRADIENT(Mul, GenericGradientMaker);
......
......@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Neg)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(NegGradient)
/* dY */
......@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(NegGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Neg, SimpleGradientMaker);
......
......@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(Reciprocal)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReciprocalGradient)
/* Y, dY */
......@@ -61,7 +61,7 @@ OPERATOR_SCHEMA(ReciprocalGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Reciprocal, InplaceGradientMaker);
......
......@@ -32,7 +32,7 @@ OPERATOR_SCHEMA(RsqrtGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Rsqrt, InplaceGradientMaker);
......
......@@ -30,7 +30,7 @@ OPERATOR_SCHEMA(SignGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Sign, SimpleGradientMaker);
......
......@@ -37,7 +37,7 @@ OPERATOR_SCHEMA(SqrtGradient)
/* dX */
.NumOutputs(1)
/* dY => dX */
.Inplace({{1, 0}});
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(Sqrt, InplaceGradientMaker);
......
......@@ -112,16 +112,16 @@ OPERATOR_SCHEMA(Sub)
.NumInputs(2)
/* Y */
.NumOutputs(1)
/* A => Y */
.Inplace({{0, 0}, {1, 0}});
/* A => Y, B => Y */
.AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(SubGradient)
/* dY */
.NumInputs(1)
/* dA, dB */
.NumOutputs(2)
/* dY => dA */
.Inplace({{0, 0}, {0, 1}});
/* dY => dA, dY => dB */
.AllowInplace({{0, 0}, {0, 1}});
REGISTER_GRADIENT(Sub, SimpleGradientMaker);
......
......@@ -89,7 +89,7 @@ OPERATOR_SCHEMA(BiasAdd)
/* Y */
.NumOutputs(1)
/* X => Y */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(BiasAddGradient)
/* dY */
......@@ -97,7 +97,7 @@ OPERATOR_SCHEMA(BiasAddGradient)
/* dX, dB */
.NumOutputs(2)
/* dY => dX */
.Inplace({{0, 0}});
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(BiasAdd, SimpleGradientMaker);
......
......@@ -289,7 +289,7 @@ template <class Context>
template <typename T>
void CuDNNConv2dGradientOp<Context>::ResetDesc() {
auto &X = Input(0), &W = Input(1), &dY = Input(-1);
auto *dX = Output(0), *dW = Output(1);
// auto *dX = Output(0), *dW = Output(1);
bool input_changed = (X.dims() != input_dims_);
bool filter_changed = (W.dims() != filter_dims_);
if (input_changed || filter_changed) {
......@@ -328,8 +328,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
exhaustive_search_data_ = true;
exhaustive_search_filter_ = true;
} else {
if (dW->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0)
{
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos];
......@@ -353,20 +353,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> "
<< "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
}
if (dX->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0)
{
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS;
cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos];
......@@ -390,18 +378,27 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardData> "
<< "under the current desc and workspace limit.";
}
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_data_algo_));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
}
}
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size
}
......
......@@ -287,7 +287,6 @@ template <class Context>
template <typename T>
void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
auto &X = Input(0), &W = Input(1), &dY = Input(-1);
auto *dX = Output(0), *dW = Output(1);
bool input_changed = (X.dims() != input_dims_);
bool filter_changed = (W.dims() != filter_dims_);
if (input_changed || filter_changed) {
......@@ -324,8 +323,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
exhaustive_search_data_ = true;
exhaustive_search_filter_ = true;
} else {
if (dW->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0)
{
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos];
......@@ -349,20 +348,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> "
<< "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
}
if (dX->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0)
{
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS;
cudnnConvolutionFwdAlgoPerf_t stats[num_algos];
......@@ -386,18 +373,27 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionForward> "
<< "under the current desc and workspace limit.";
}
#else
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_data_algo_));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
}
}
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size
}
......
......@@ -50,11 +50,11 @@ def set_optimization(level=1):
* level = ``0``: Do nothing.
* level = ``1``: Prune the redundant nodes.
* level = ``1``: Eliminate the unused outputs and operators.
* level = ``2``: Add the inplace to outputs.
* level = ``2``: Apply inplace to the inputs if available.
* level = ``3``: Allocate the buffer for outputs.
* level = ``3``: Allocate shared buffer for the outputs.
Parameters
----------
......
......@@ -78,7 +78,7 @@ class GradientMaker(object):
if not is_skip:
for input, grad_input in zip(op_def.input, grad_inputs):
inputs_to_grads[input] = grad_input
# Add def for ``GradientGenerateOp``
# Add ``GradientGenerateOp``
if len(gen_grads) > 0:
inputs, outputs, values = [], [], []
for name, i in gen_grads:
......@@ -94,7 +94,7 @@ class GradientMaker(object):
device_option=op_def.device_option
if op_def.HasField('device_option') else None)
backward_defs.append(gen_op)
# Add def for ``GradientOp``
# Add ``GradientOp``
for grad_def in grad_defs:
grad_def.name = OpDef.get_name()
backward_defs.append(grad_def)
......
......@@ -130,7 +130,7 @@ def affine(inputs, axis=1, num_axes=1, **kwargs):
return op_lib.blend(**args)
@OpSchema.num_inputs(1, 2147483647)
@OpSchema.num_inputs(1)
def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
r"""Compute the element-wise addition from input to output.
......@@ -140,10 +140,10 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
Parameters
----------
inputs : Union[dragon.Tensor, Sequence[dragon.Tensor]]
The input tensor(s).
outputs : Union[dragon.Tensor, Sequence[dragon.Tensor]], optional
The output tensor(s).
inputs : dragon.Tensor
The input tensor.
outputs : dragon.Tensor, optional
The output tensor.
alpha : number, optional, default=1.
The value to :math:`\alpha`.
beta : number, optional, default=1.
......@@ -151,23 +151,17 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
Returns
-------
Union[dragon.Tensor, Sequence[dragon.Tensor]]
The output tensor(s).
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
args['alpha'], args['beta'] = float(alpha), float(beta)
if types.is_tensor(inputs):
inputs = [inputs]
if outputs is not None and types.is_tensor(outputs):
args['outputs'] = [outputs]
op_lib = math_ops_lib.Axpby
if context.executing_eagerly():
return op_lib \
.instantiate(
alpha=args['alpha'],
beta=args['beta'],
).apply(inputs, args['outputs'])
.instantiate(alpha=args['alpha'], beta=args['beta']) \
.apply([inputs], [outputs])
else:
return op_lib.blend(**args)
......
......@@ -65,8 +65,7 @@ class _BatchNorm(Module):
.format(**self.__dict__)
def forward(self, input):
training = self.training or \
not self.track_running_stats
training = self.training or not self.track_running_stats
return F.batch_norm(
input, *self.inputs,
training=training,
......
......@@ -146,10 +146,8 @@ def cat(seq, dim=0, out=None):
"""
return _functions.Concat \
.instantiate(
seq[0].device,
axis=dim,
).apply(seq, out)
.instantiate(seq[0].device, axis=dim) \
.apply(seq, out)
def channel_normalize(
......@@ -618,10 +616,7 @@ def nonzero(input, out=None):
The output tensor.
"""
return _functions.NonZero \
.instantiate(
input.device,
).apply(input, out)
return _functions.NonZero.instantiate(input.device).apply(input, out)
def one_hot(input, depth):
......@@ -647,8 +642,7 @@ def one_hot(input, depth):
The output tensor.
"""
return _functions.OneHot \
.instantiate(input.device, depth=depth).apply(input)
return _functions.OneHot.instantiate(input.device, depth=depth).apply(input)
def permute(input, dims):
......@@ -715,18 +709,14 @@ def reshape(input, shape, out=None):
"""
shape = nest.flatten(shape)
return _functions.Reshape \
.instantiate(
input.device,
ndim=len(shape),
).apply(input, shape, out)
.instantiate(input.device, ndim=len(shape)) \
.apply(input, shape, out)
def slice(input, starts, sizes):
return _functions.Slice \
.instantiate(
input.device,
ndim=len(starts),
).apply(input, starts, sizes)
.instantiate(input.device, ndim=len(starts)) \
.apply(input, starts, sizes)
def split(tensor, split_size_or_sections, dim=0):
......@@ -1015,9 +1005,8 @@ def where(condition, x, y):
"""
return _functions.Where \
.instantiate(
utils.unify_devices([condition, x, y]),
).apply(condition, x, y)
.instantiate(utils.unify_devices([condition, x, y])) \
.apply(condition, x, y)
def _arg_reduce(input, op_type, dim=None, keepdim=False, out=None):
......
......@@ -567,10 +567,6 @@ def expand(self, *sizes):
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.expand(...)`_
"""
return array_funcs.expand(self, sizes)
......
......@@ -51,6 +51,5 @@ class GradAccumulate(function.Function):
'arguments': {'alpha': 1., 'beta': float(self.momentum)},
}
def forward(self, grads):
outputs = [grad.id + '[accum]' for grad in grads]
return self.dispatch(grads, outputs, no_grad=True)
def forward(self, grad):
return self.dispatch([grad], [grad.id + '[accum]'], no_grad=True)
......@@ -14,18 +14,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.util import nest
from dragon.vm.torch.core.ops.training import _functions
def accumulate_grad(grads, momentum=1):
"""Accumulate the gradients."""
grads = nest.flatten(grads)
if len(grads) == 0:
return
def accumulate_grad(grad, momentum=1):
"""Accumulate the gradient."""
return _functions.GradAccumulate \
.instantiate(grads[0].device, momentum=momentum) \
.apply(grads)
.instantiate(grad.device, momentum=momentum).apply(grad)
def update_param(
......
......@@ -97,15 +97,13 @@ class Optimizer(object):
The momentum to the accumulated value.
"""
grads = []
current_ws = workspace.get_workspace()
for group in self.param_groups:
group['_internal/grad_accum'] = True
for param in group['params']:
grad = self._steal_grad(current_ws, param)
if grad is not None:
grads.append(grad)
training_funcs.accumulate_grad(grads, momentum)
training_funcs.accumulate_grad(grad)
def add_param_group(self, param_group):
"""Add a new param group into the optimizer.
......
......@@ -776,10 +776,6 @@ class Tensor(object):
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.expand(...)`_
"""
def expand_as(self, other):
......@@ -795,10 +791,6 @@ class Tensor(object):
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.expand(...)`_
"""
return self.expand(*other.size())
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!