Commit 6b82cb26 by Ting PAN

Fix the missing alias in SimulateGC

1 parent 4eab1c68
Showing with 619 additions and 605 deletions
...@@ -5,6 +5,18 @@ ...@@ -5,6 +5,18 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Workspace
---------
============================== =============================================================================
List Brief
============================== =============================================================================
`Workspace(object)`_ A wrapper for the C implemented workspace.
`get_default_workspace`_ Return the current default workspace.
`reset_default_workspace`_ Reset the global default workspace.
============================== =============================================================================
Tensor Tensor
------ ------
...@@ -41,20 +53,14 @@ List Brief ...@@ -41,20 +53,14 @@ List Brief
`RunGraph`_ Run the specific graph. `RunGraph`_ Run the specific graph.
============================== ============================================================================= ============================== =============================================================================
Misc I/O
---- ---
============================== ============================================================================= ============================== =============================================================================
List Brief List Brief
============================== ============================================================================= ============================== =============================================================================
`Snapshot`_ Snapshot tensors into a binary file. `Snapshot`_ Snapshot tensors into a binary file.
`Restore`_ Restore tensors from a binary file. `Restore`_ Restore tensors from a binary file.
`SwitchWorkspace`_ Switch to the specific Workspace.
`MoveWorkspace`_ Move the source workspace into the target workspace.
`ResetWorkspace`_ Reset the specific workspace.
`ClearWorkspace`_ Clear the specific workspace.
`LogMetaGraph`_ Log the meta graph.
`ExportMetaGraph`_ Export the meta graph into a file under specific folder.
============================== ============================================================================= ============================== =============================================================================
API Reference API Reference
...@@ -63,12 +69,15 @@ API Reference ...@@ -63,12 +69,15 @@ API Reference
.. automodule:: dragon.core.workspace .. automodule:: dragon.core.workspace
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance:
.. _SwitchWorkspace: #dragon.core.workspace.SwitchWorkspace .. autoclass:: Workspace
.. _MoveWorkspace: #dragon.core.workspace.MoveWorkspace :members:
.. _ResetWorkspace: #dragon.core.workspace.ResetWorkspace
.. _ClearWorkspace: #dragon.core.workspace.ClearWorkspace .. automethod:: __init__
.. _Workspace(object): #dragon.core.workspace.Workspace
.. _get_default_workspace: #dragon.core.workspace.get_default_workspace
.. _reset_default_workspace: #dragon.core.workspace.reset_default_workspace
.. _CreateGraph: #dragon.core.workspace.CreateGraph .. _CreateGraph: #dragon.core.workspace.CreateGraph
.. _HasTensor: #dragon.core.workspace.HasTensor .. _HasTensor: #dragon.core.workspace.HasTensor
.. _GetTensorName: #dragon.core.workspace.GetTensorName .. _GetTensorName: #dragon.core.workspace.GetTensorName
...@@ -81,8 +90,5 @@ API Reference ...@@ -81,8 +90,5 @@ API Reference
.. _RunGraph: #dragon.core.workspace.RunGraph .. _RunGraph: #dragon.core.workspace.RunGraph
.. _Snapshot: #dragon.core.workspace.Snapshot .. _Snapshot: #dragon.core.workspace.Snapshot
.. _Restore: #dragon.core.workspace.Restore .. _Restore: #dragon.core.workspace.Restore
.. _LogMetaGraph: #dragon.core.workspace.LogMetaGraph
.. _ExportMetaGraph: #dragon.core.workspace.ExportMetaGraph
.. _theano.function(*args, **kwargs): ../vm/theano/compile.html#dragon.vm.theano.compile.function.function .. _theano.function(*args, **kwargs): ../vm/theano/compile.html#dragon.vm.theano.compile.function.function
\.. _config.ExportMetaGraph(prefix): ../config.html#dragon.config.ExportMetaGraph
.. _config.ExportMetaGraph(prefix): ../config.html#dragon.config.ExportMetaGraph
\ No newline at end of file
...@@ -178,7 +178,7 @@ if(WIN32) ...@@ -178,7 +178,7 @@ if(WIN32)
endif() endif()
if(UNIX) if(UNIX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -s -fPIC") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -s -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s -w -fPIC -O3 -m64 -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s -fPIC -O3 -m64 -std=c++11")
if (WITH_OMP AND (NOT APPLE)) if (WITH_OMP AND (NOT APPLE))
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
......
...@@ -39,18 +39,39 @@ class GraphBase { ...@@ -39,18 +39,39 @@ class GraphBase {
const string& exclude, const string& exclude,
int stream_id = 0) = 0; int stream_id = 0) = 0;
/*! \brief Return the name of this graph */ /*! \brief Return the graph name */
string name() const { return name_; } string name() const { return name_; }
/*! \brief Return the defined running phase */
const string& phase() const { return phase_; }
/*! \brief Return the argument map */
const Map<std::string, const Argument*>& args() { return args_; }
/*! \brief Return the specified argument */
const Argument& arg(const string& name) { return *(args_[name]); }
/*! \brief Return the stored raw def */
const GraphDef& def() const { return def_; }
/*! \brief Return the stored opt def */
const GraphDef& opt_def() const { return opt_def_; }
/*! \brief Return the parent workspace */
Workspace* ws() const { return ws_; }
protected: protected:
/*! \brief Store the name and running phase */ /*! \brief Store the name and running phase */
string name_, phase_; string name_, phase_;
/*! \brief Store the defined arguments */ /*! \brief Store the defined arguments */
Map<string, Argument> args_; Map<string, const Argument*> args_;
/*! \brief Store the parent workspace */ /*! \brief Store the parent workspace */
Workspace* ws_; Workspace* ws_;
/*! \brief Store the def */
GraphDef def_, opt_def_;
}; };
class Graph : public GraphBase { class Graph : public GraphBase {
...@@ -72,9 +93,6 @@ class Graph : public GraphBase { ...@@ -72,9 +93,6 @@ class Graph : public GraphBase {
const string& exclude, const string& exclude,
int stream_id = 0) override; int stream_id = 0) override;
/*! \brief Return the parent workspace */
Workspace* ws() const { return ws_; }
protected: protected:
/*! \brief Store the internal operators */ /*! \brief Store the internal operators */
vector<OperatorBase*> ops_; vector<OperatorBase*> ops_;
......
...@@ -48,13 +48,13 @@ class OperatorBase { ...@@ -48,13 +48,13 @@ class OperatorBase {
/*! \brief Return the number of outputs */ /*! \brief Return the number of outputs */
int YSize() { return (int)outputs_.size(); } int YSize() { return (int)outputs_.size(); }
/*! \brief Modify this operator according to the given def */ /*! \brief Modify operator according to the given def */
void UpdateFrom(const OperatorDef& def); void UpdateFrom(const OperatorDef& def);
/*! \brief Switch the internal running phase */ /*! \brief Switch the internal running phase */
void SwitchToPhase(const string& phase) { phase_ = phase; } void SwitchToPhase(const string& phase) { phase_ = phase; }
/*! \brief Run this operator on the specified stream */ /*! \brief Run operator on the specified stream */
virtual void Run(int stream_id = 0) { NOT_IMPLEMENTED; } virtual void Run(int stream_id = 0) { NOT_IMPLEMENTED; }
/*! \brief Fusion this operator into the specified graph */ /*! \brief Fusion this operator into the specified graph */
...@@ -69,18 +69,18 @@ class OperatorBase { ...@@ -69,18 +69,18 @@ class OperatorBase {
/*! \brief Return the current running phase */ /*! \brief Return the current running phase */
const string& phase() const { return phase_; } const string& phase() const { return phase_; }
/*! \brief Return the anchor name of this operator */ /*! \brief Return the resource handle */
const string& anchor() const { return anchor_; } const string& handle() const { return handle_; }
/*! \brief Return the data type of this operator */ /*! \brief Return the data type */
const string& dtype() const { return dtype_; } const string& dtype() const { return dtype_; }
/*! \brief Return the data format of this operator */ /*! \brief Return the data format */
const string& data_format() const { return data_format_; } const string& data_format() const { return data_format_; }
/*! \brief Return the unique name in this operator */ /*! \brief Return the unique name in this operator */
const string unique_name(const string& name) const { const string unique_name(const string& name) const {
return "/mnt/" + anchor_ + "/" + name; return "/mnt/" + handle_ + "/" + name;
} }
/*! \brief Return the parent workspace */ /*! \brief Return the parent workspace */
...@@ -94,7 +94,7 @@ class OperatorBase { ...@@ -94,7 +94,7 @@ class OperatorBase {
template <typename T> template <typename T>
vector<T> Args(const string& name); vector<T> Args(const string& name);
/*! \brief Return the argument map of this operator */ /*! \brief Return the argument map */
const Map<std::string, const Argument*>& args() { return args_; } const Map<std::string, const Argument*>& args() { return args_; }
/*! \brief Return the specified argument */ /*! \brief Return the specified argument */
...@@ -102,7 +102,7 @@ class OperatorBase { ...@@ -102,7 +102,7 @@ class OperatorBase {
typedef Map<string, vector<OperatorBase*>> SubGraph; typedef Map<string, vector<OperatorBase*>> SubGraph;
/*! \brief Return the recomputing subgraph of this operator */ /*! \brief Return the recomputing subgraph */
SubGraph& subgraph() { return subgraph_; } SubGraph& subgraph() { return subgraph_; }
/*! \brief Set the given recomputing subgraph */ /*! \brief Set the given recomputing subgraph */
...@@ -110,29 +110,38 @@ class OperatorBase { ...@@ -110,29 +110,38 @@ class OperatorBase {
subgraph_ = subgraph; subgraph_ = subgraph;
} }
/*! \brief Return the stored operator def */ /*! \brief Return the stored def */
const OperatorDef& def() const { return def_; } const OperatorDef& def() const { return def_; }
/*! \brief Return the debug string of the stored operator def */ /*! \brief Return the debug string of stored def */
string DebugString() const { return def_.DebugString(); } string DebugString() const { return def_.DebugString(); }
/*! \brief Return the dtype string according to given tensor */ /*! \brief Return the dtype string according to given tensor */
string DTypeString( string DTypeString(const Tensor&, const Set<string>&) const;
const Tensor& tensor,
const Set<string>& dtypes) const;
/* \brief Return the dtype string according to given type */ /* \brief Return the dtype string according to given type */
string DTypeString( string DTypeString(const string&, const Set<string>&) const;
const string& dtype,
const Set<string>& dtypes) const;
protected: protected:
/*! \brief Store the parent workspace */
Workspace* ws_; Workspace* ws_;
/*! \brief Store the def */
OperatorDef def_; OperatorDef def_;
/*! \brief Store the recomputing subgraph */
SubGraph subgraph_; SubGraph subgraph_;
string phase_, anchor_;
/*! \brief Store the phase and handle */
string phase_, handle_;
/*! \brief Store the data type and format */
string dtype_, data_format_; string dtype_, data_format_;
/*! \brief Store the pointer of inputs and outputs */
vector<Tensor*> inputs_, outputs_; vector<Tensor*> inputs_, outputs_;
/*! \brief Store the defined arguments */
Map<string, const Argument*> args_; Map<string, const Argument*> args_;
}; };
...@@ -236,7 +245,7 @@ OperatorBase* NewOperator( ...@@ -236,7 +245,7 @@ OperatorBase* NewOperator(
using OperatorBase::name; \ using OperatorBase::name; \
using OperatorBase::type; \ using OperatorBase::type; \
using OperatorBase::phase; \ using OperatorBase::phase; \
using OperatorBase::anchor; \ using OperatorBase::handle; \
using OperatorBase::dtype; \ using OperatorBase::dtype; \
using OperatorBase::data_format; \ using OperatorBase::data_format; \
using OperatorBase::unique_name; \ using OperatorBase::unique_name; \
...@@ -432,7 +441,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(TensorTypes, RunImpl); ...@@ -432,7 +441,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(TensorTypes, RunImpl);
<< arg##_desc_.size() << ")."; \ << arg##_desc_.size() << ")."; \
auto* arg##T = ws()->GetTensor( \ auto* arg##T = ws()->GetTensor( \
str::replace_first(arg##_desc_[i], \ str::replace_first(arg##_desc_[i], \
"${ANCHOR}", anchor())); \ "${HANDLE}", handle())); \
CHECK(arg##T->template IsType<type>()) \ CHECK(arg##T->template IsType<type>()) \
<< "\nThe type of " << #arg << " should be " << #type << "."; \ << "\nThe type of " << #arg << " should be " << #type << "."; \
CHECK_EQ(arg##T->count(), 1) \ CHECK_EQ(arg##T->count(), 1) \
......
...@@ -48,15 +48,16 @@ class GradientMakerBase { ...@@ -48,15 +48,16 @@ class GradientMakerBase {
virtual Gradient Make() { virtual Gradient Make() {
auto new_defs = MakeDef(); auto new_defs = MakeDef();
if (def.has_uid()) { if (def.has_uid()) {
// Attach the anchor to name if having UID // Attach the handle to name if having UID
for (int i = 0; i < new_defs.size(); i++) for (int i = 0; i < new_defs.size(); i++)
new_defs[i].set_name(def.name()); new_defs[i].set_name(def.name());
} else { } else {
// Otherwise, just put it into the arguments // Otherwise, just put it into the arguments
Argument anchor; Argument arg;
anchor.set_name("anchor"); anchor.set_s(def.name()); arg.set_name("handle");
arg.set_s(def.name());
for (int i = 0; i < new_defs.size(); i++) for (int i = 0; i < new_defs.size(); i++)
new_defs[i].add_arg()->CopyFrom(anchor); new_defs[i].add_arg()->CopyFrom(arg);
} }
return Gradient(new_defs, g_inputs_, defaults()); return Gradient(new_defs, g_inputs_, defaults());
}; };
......
...@@ -45,13 +45,13 @@ class Workspace { ...@@ -45,13 +45,13 @@ class Workspace {
void Clear(); void Clear();
/*! \brief Merge from a external workspace */ /*! \brief Merge from a external workspace */
void MergeFrom(Workspace* ws); void MergeFrom(Workspace*);
/*! \brief Query the real name of specified tensor */ /*! \brief Query the real name of specified tensor */
string GetTensorName(const string& name) const; string GetTensorName(const string&) const;
/*! \brief Try to serach the specified tensor in this workspace */ /*! \brief Try to serach the specified tensor in this workspace */
Tensor* TryGetTensor(const string& name, bool use_remote = true) const; Tensor* TryGetTensor(const string&, bool = true) const;
/*! \brief Whether the specified tensor is in this workspace */ /*! \brief Whether the specified tensor is in this workspace */
bool HasTensor(const string& name, bool use_remote = true) const { bool HasTensor(const string& name, bool use_remote = true) const {
...@@ -59,22 +59,22 @@ class Workspace { ...@@ -59,22 +59,22 @@ class Workspace {
} }
/*! \brief Create the specified tensor */ /*! \brief Create the specified tensor */
Tensor* CreateTensor(const string& name); Tensor* CreateTensor(const string&);
/*! \brief Return the specified tensor */ /*! \brief Return the specified tensor */
Tensor* GetTensor(const string& name, bool use_remote = true) const; Tensor* GetTensor(const string&, bool = true) const;
/*! \brief Reset the specified tensor */ /*! \brief Reset the specified tensor */
void ResetTensor(const string& name); void ResetTensor(const string&);
/* \brief Whether the specified filler is in this workspace */ /* \brief Whether the specified filler is in this workspace */
bool HasFiller(const string& name, bool use_remote = true) const; bool HasFiller(const string&, bool = true) const;
/*! \brief Create the specified filler */ /*! \brief Create the specified filler */
void CreateFiller(const TensorFillerProto& filler); void CreateFiller(const TensorFillerProto&);
/*! \brief Return the specified filler */ /*! \brief Return the specified filler */
const TensorFillerProto* GetFiller(const string& name) const; const TensorFillerProto* GetFiller(const string&) const;
/*! \brief Create temporal data segments */ /*! \brief Create temporal data segments */
template <class Context> template <class Context>
...@@ -103,16 +103,16 @@ class Workspace { ...@@ -103,16 +103,16 @@ class Workspace {
} }
/*! \brief Create a operator in this workspace */ /*! \brief Create a operator in this workspace */
OperatorBase* CreateOperator(const OperatorDef& def); OperatorBase* CreateOperator(const OperatorDef&);
/*! \brief Run the specified persistent operator */ /*! \brief Run the specified persistent operator */
void RunOperator(const OperatorDef& def); void RunOperator(const OperatorDef&);
/*! \brief Try to run the operator in a adaptive mode */ /*! \brief Try to run the operator in a adaptive mode */
void RunOperatorOnce(const OperatorDef& def); void RunOperatorOnce(const OperatorDef&);
/*! \brief Create a Graph in this workspace */ /*! \brief Create a Graph in this workspace */
GraphBase* CreateGraph(const GraphDef& def); GraphBase* CreateGraph(const GraphDef&);
/*! \brief Run the specifed graph by name and rules */ /*! \brief Run the specifed graph by name and rules */
void RunGraph( void RunGraph(
......
...@@ -28,7 +28,7 @@ class CollectiveUpdateOp final ...@@ -28,7 +28,7 @@ class CollectiveUpdateOp final
Workspace* ws) Workspace* ws)
: MPIOpBase<Context>(def, ws), : MPIOpBase<Context>(def, ws),
mode_(OpArg<string>("mode", "")) { mode_(OpArg<string>("mode", "")) {
if (mode_.find("NCCL") != string::npos) InitNCCL(); if (str::find(mode_, "NCCL")) InitNCCL();
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPI_FUNCTIONS; USE_MPI_FUNCTIONS;
...@@ -37,7 +37,7 @@ class CollectiveUpdateOp final ...@@ -37,7 +37,7 @@ class CollectiveUpdateOp final
/* TODO(PhyscalX): Temporarily disable it, /* TODO(PhyscalX): Temporarily disable it,
to avoid a unhandled error. */ to avoid a unhandled error. */
#ifdef WITH_NCCL #ifdef WITH_NCCL
if (mode_.find("NCCL") != string::npos) { if (str::find(mode_, "NCCL")) {
/* ncclCommDestroy(nccl_comm); */ /* ncclCommDestroy(nccl_comm); */
} }
#endif #endif
......
...@@ -30,7 +30,7 @@ inline void LoadCaffeModel( ...@@ -30,7 +30,7 @@ inline void LoadCaffeModel(
const auto& layer_name = layer.name(); const auto& layer_name = layer.name();
auto prefix = layer_name + "/param:"; auto prefix = layer_name + "/param:";
for (int j = 0; j < layer.blobs_size(); j++) { for (int j = 0; j < layer.blobs_size(); j++) {
auto tensor_name = prefix + std::to_string(j); auto tensor_name = prefix + str::to(j);
if (!ws->HasTensor(tensor_name)) { if (!ws->HasTensor(tensor_name)) {
LOG(WARNING) LOG(WARNING)
<< "Tensor(" << tensor_name << ") " << "Tensor(" << tensor_name << ") "
......
...@@ -23,6 +23,11 @@ namespace dragon { ...@@ -23,6 +23,11 @@ namespace dragon {
namespace str { namespace str {
template <typename T>
inline std::string to(T val) {
return std::to_string(val);
}
inline std::vector<std::string> split( inline std::vector<std::string> split(
const std::string& str, const std::string& str,
const std::string& c) { const std::string& c) {
...@@ -37,6 +42,12 @@ inline std::vector<std::string> split( ...@@ -37,6 +42,12 @@ inline std::vector<std::string> split(
return ret; return ret;
} }
inline bool find(
const std::string& str,
const std::string& pattern) {
return str.find(pattern) != std::string::npos;
}
inline std::string replace_first( inline std::string replace_first(
const std::string& str, const std::string& str,
const std::string& pattern, const std::string& pattern,
......
...@@ -214,13 +214,6 @@ PYBIND11_MODULE(libdragon, m) { ...@@ -214,13 +214,6 @@ PYBIND11_MODULE(libdragon, m) {
return self->GetTensor(name); return self->GetTensor(name);
}, pybind11::return_value_policy::reference_internal) }, pybind11::return_value_policy::reference_internal)
/*! \brief Return the filler type of a tensor */
.def("GetFillerType", [](
Workspace* self,
const string& name) {
return self->GetFiller(name)->type();
})
/* \brief Set an alias for the tensor */ /* \brief Set an alias for the tensor */
.def("SetTensorAlias", []( .def("SetTensorAlias", [](
Workspace* self, Workspace* self,
...@@ -311,12 +304,14 @@ PYBIND11_MODULE(libdragon, m) { ...@@ -311,12 +304,14 @@ PYBIND11_MODULE(libdragon, m) {
<< "\nFailed to parse the GraphDef."; << "\nFailed to parse the GraphDef.";
auto* graph = self->CreateGraph(graph_def); auto* graph = self->CreateGraph(graph_def);
if (verbose) { if (verbose) {
bool could_be_serialized = true;
const auto& def = graph->opt_def();
for (auto& op : def.op())
if (op.type() == "GivenTensorFill")
could_be_serialized = false;
if (could_be_serialized) {
// It is not a good design to print the debug string // It is not a good design to print the debug string
auto* T = self->CreateTensor( std::cout << def.DebugString() << std::endl;
"/graph_def/optimized/" + graph->name());
if (T->count() > 0) {
auto* data = T->mutable_data<string, CPUContext>();
std::cout << data[0] << std::endl;
} }
} }
// Return the graph name may be different from the def // Return the graph name may be different from the def
......
...@@ -72,4 +72,4 @@ void AddProtoMethods(pybind11::module& m) { ...@@ -72,4 +72,4 @@ void AddProtoMethods(pybind11::module& m) {
} // namespace dragon } // namespace dragon
#endif DRAGON_PYTHON_PY_PROTO_H_ #endif // DRAGON_PYTHON_PY_PROTO_H_
\ No newline at end of file \ No newline at end of file
...@@ -38,22 +38,22 @@ from dragon.core import proto_utils as _proto_utils ...@@ -38,22 +38,22 @@ from dragon.core import proto_utils as _proto_utils
class TensorPool(object): class TensorPool(object):
"""We apply the TensorPool to manage the reused tensors. """A wrapper to manage the reused tensors.
Tensors with the same scope in the pool will be reused by turns, Tensors with the same scope will be reused by turns,
which speeds up the whole system by reducing the unnecessary deconstructing. and thus, unnecessary deconstructing is reduced.
Heuristically, we have used 5 pools with different scopes: Heuristically, five pools with different scopes are used:
* scope(Leaf): A Pool to reuse leaf tensors. * *${LEAF}*: A pool to reuse leaf tensors.
* scope(NumPy): A pool to reuse leaf tensors from numpy. * *${NUMPY}*: A pool to reuse leaf tensors from numpy.
* scope(Join): A pool to reuse RT(runtime) tensors required by forward-backward. * *${JOIN}*: A pool to reuse RT(runtime) tensors required by forward-backward.
* scope(Detach): A pool to reuse RT(runtime) tensors required by forward only. * *${DETACH}*: A pool to reuse RT(runtime) tensors required by forward-pass only.
* scope(Reference): A pool to reuse reshaped tensors(sharing contents). * *${REFERENCE}*: A pool to reuse reshaped tensors(sharing contents).
""" """
def __init__(self): def __init__(self):
...@@ -61,6 +61,19 @@ class TensorPool(object): ...@@ -61,6 +61,19 @@ class TensorPool(object):
self._scope2keys = defaultdict(deque) self._scope2keys = defaultdict(deque)
def get(self, scope='${DETACH}'): def get(self, scope='${DETACH}'):
"""Return a unique name under the specified scope.
Parameters
----------
scope : str, optional, default='${DETACH}'
The optional
Returns
-------
str
The unique name can be used.
"""
try: try:
return self._scope2keys[scope].popleft() return self._scope2keys[scope].popleft()
except: except:
...@@ -71,6 +84,14 @@ class TensorPool(object): ...@@ -71,6 +84,14 @@ class TensorPool(object):
return self._scope2keys[scope].popleft() return self._scope2keys[scope].popleft()
def put(self, name): def put(self, name):
"""Collect a unique name.
Parameters
----------
name : str
The name to collect.
"""
if '${POOL}' in name: if '${POOL}' in name:
scope, _ = name[8:].split('/') scope, _ = name[8:].split('/')
self._scope2keys[scope].append(name) self._scope2keys[scope].append(name)
...@@ -79,13 +100,14 @@ class TensorPool(object): ...@@ -79,13 +100,14 @@ class TensorPool(object):
class OperatorPool(object): class OperatorPool(object):
"""Operators whose gradients is required will hold a resource handle, """A wrapper to manage the resource handle of operators.
which is also called ``Anchor`` in the backend.
Operators whose gradients is required will hold a resource handle.
We apply this pool to collect the handles according to the type of operator, We collect the handles according to the type of operator,
as the mem size of temporal resources varies greatly. as the size of resources varies greatly.
The resource handle will be released after the gradient flow automatically. Handle will be released after the backward-pass automatically.
""" """
def __init__(self): def __init__(self):
...@@ -93,6 +115,19 @@ class OperatorPool(object): ...@@ -93,6 +115,19 @@ class OperatorPool(object):
self._type2keys = defaultdict(deque) self._type2keys = defaultdict(deque)
def get(self, op_type): def get(self, op_type):
"""Return a unique handle according to the op type.
Parameters
----------
op_type : str
The type of the operator.
Returns
-------
str
The unique handle can be used.
"""
try: try:
return self._type2keys[op_type].popleft() return self._type2keys[op_type].popleft()
except: except:
...@@ -102,13 +137,21 @@ class OperatorPool(object): ...@@ -102,13 +137,21 @@ class OperatorPool(object):
domain='Operator', zero_based=False)) domain='Operator', zero_based=False))
return self._type2keys[op_type].popleft() return self._type2keys[op_type].popleft()
def put(self, op_name): def put(self, handle):
op_type, _ = op_name[8:].split('_') """Collect a unique handle.
self._type2keys[op_type].append(op_name)
Parameters
----------
name : str
The name to collect.
"""
op_type, _ = handle[8:].split('_')
self._type2keys[op_type].append(handle)
class Workspace(_C.Workspace): class Workspace(_C.Workspace):
"""A wrapper for the C implemented workspace. """A wrapper for the C++ implemented workspace.
This class is a fusion of *Workspace*, *Pool* and *tf.Graph*. This class is a fusion of *Workspace*, *Pool* and *tf.Graph*.
...@@ -116,20 +159,41 @@ class Workspace(_C.Workspace): ...@@ -116,20 +159,41 @@ class Workspace(_C.Workspace):
""" """
def __init__(self, name=''): def __init__(self, name=''):
"""Construct a Workspace instance.
Parameters
----------
name : str, optional, default=''
The optional workspace name.
Returns
-------
Workspace
A new workspace.
"""
super(Workspace, self).__init__(name) super(Workspace, self).__init__(name)
self._ref_objects = [] self._ref_objects = []
self._collections = {} self._collections = {}
self.tensor_pool = TensorPool() self.tensor_pool = TensorPool()
self.operator_pool = OperatorPool() self.operator_pool = OperatorPool()
def get_collection_ref(self, name):
coll_list = self._collections.get(name, None)
if coll_list is None:
coll_list = []
self._collections[name] = coll_list
return coll_list
def get_collection(self, name, scope=None): def get_collection(self, name, scope=None):
"""Return the specified collection.
Parameters
----------
name : str
The key to the collection.
scope : str, optional
The optional regex keyword.
Returns
-------
list
The collection list.
"""
coll_list = self._collections.get(name, None) coll_list = self._collections.get(name, None)
if coll_list is None: if coll_list is None:
return [] return []
...@@ -143,13 +207,61 @@ class Workspace(_C.Workspace): ...@@ -143,13 +207,61 @@ class Workspace(_C.Workspace):
filter_coll_list.append(item) filter_coll_list.append(item)
return filter_coll_list return filter_coll_list
def get_collection_ref(self, name):
"""Return the reference of specified collection.
Parameters
----------
name : str
The key to the collection.
Returns
-------
list
The collection list.
"""
coll_list = self._collections.get(name, None)
if coll_list is None:
coll_list = []
self._collections[name] = coll_list
return coll_list
def add_to_collection(self, name, value): def add_to_collection(self, name, value):
"""Add the value to the specified collection.
Parameters
----------
name : str
The key to the collection.
value : object
The value object.
Returns
-------
None
"""
if name not in self._collections: if name not in self._collections:
self._collections[name] = [value] self._collections[name] = [value]
else: else:
self._collections[name].append(value) self._collections[name].append(value)
def add_to_collections(self, names, value): def add_to_collections(self, names, value):
"""Add the value to the specified collections.
Parameters
----------
name : sequence of str
The key to the collections.
value : object
The value object.
Returns
-------
None
"""
for name in names: for name in names:
self.add_to_collection(name, value) self.add_to_collection(name, value)
...@@ -246,12 +358,21 @@ def CreateGraph(graph_def): ...@@ -246,12 +358,21 @@ def CreateGraph(graph_def):
The graph name to run. The graph name to run.
""" """
LogMetaGraph(graph_def)
ExportMetaGraph(graph_def)
options = _cfg.GetGlobalOptions() options = _cfg.GetGlobalOptions()
if options['log_meta_graph']: print(graph_def)
if options['export_meta_graph']:
if not os.path.exists(options['export_meta_graph']):
try:
os.makedirs(options['export_meta_graph'])
except Exception:
raise ValueError('The given prefix is invalid.')
path = os.path.join(
options['export_meta_graph'],
graph_def.name + '.metatxt')
with open(path, 'w') as f: f.write(str(graph_def))
_logging.info('Export meta graph to: {}'.format(path))
return get_default_workspace().CreateGraph( return get_default_workspace().CreateGraph(
_stringify_proto(graph_def), _stringify_proto(graph_def), options['log_optimized_graph'])
options['log_optimized_graph'])
def RunOperator(op_def, verbose=False): def RunOperator(op_def, verbose=False):
...@@ -332,28 +453,6 @@ def CreateFiller(filler_def): ...@@ -332,28 +453,6 @@ def CreateFiller(filler_def):
get_default_workspace().CreateFiller(filler_def) get_default_workspace().CreateFiller(filler_def)
def GetFillerType(tensor):
"""Get the filler type of specific tensor.
It is useful if you want to tag some tensors,
e.g. tag with ``numpy``, and get to initialize them lazily.
Parameters
----------
tensor : Tensor or str
The tensor to query.
Returns
-------
str
The filler type.
"""
tensor = _stringify_tensor(tensor)
return get_default_workspace().GetFillerType(tensor)
def GetTensorName(tensor): def GetTensorName(tensor):
"""Query the name represented in current workspace. """Query the name represented in current workspace.
...@@ -585,54 +684,6 @@ def Backward( ...@@ -585,54 +684,6 @@ def Backward(
) )
def LogMetaGraph(graph_def):
"""Log the meta graph.
Parameters
----------
graph_def : GraphDef
The definition of meta graph.
Returns
-------
None
"""
options = _cfg.GetGlobalOptions()
if options['log_meta_graph']: print(graph_def)
def ExportMetaGraph(graph_def):
"""Export the meta graph into a file under specific folder.
You can set the exporting prefix by `config.ExportMetaGraph(prefix)`_.
Parameters
----------
graph_def : GraphDef
The definition of meta graph.
Returns
-------
None
"""
options = _cfg.GetGlobalOptions()
if options['export_meta_graph']:
if not os.path.exists(options['export_meta_graph']):
try:
os.makedirs(options['export_meta_graph'])
except Exception:
raise ValueError('The given prefix is invalid.')
path = os.path.join(
options['export_meta_graph'],
graph_def.name + '.metatxt')
with open(path, 'w') as f: f.write(str(graph_def))
_logging.info('Export meta graph into: {}'.format(path))
def Snapshot( def Snapshot(
tensors, tensors,
filename, filename,
...@@ -663,10 +714,6 @@ def Snapshot( ...@@ -663,10 +714,6 @@ def Snapshot(
------- -------
None None
Notes
-----
""" """
file_path = prefix + filename + suffix file_path = prefix + filename + suffix
...@@ -709,7 +756,6 @@ def Restore(binary_file, format='pickle'): ...@@ -709,7 +756,6 @@ def Restore(binary_file, format='pickle'):
""" """
assert os.path.exists(binary_file), \ assert os.path.exists(binary_file), \
'Binary file({}) does not exist.'.format(binary_file) 'Binary file({}) does not exist.'.format(binary_file)
if format == 'pickle': if format == 'pickle':
try: try:
state_dict = pickle.load(open(binary_file, 'rb')) state_dict = pickle.load(open(binary_file, 'rb'))
......
...@@ -55,9 +55,9 @@ class BlobFetcher(multiprocessing.Process): ...@@ -55,9 +55,9 @@ class BlobFetcher(multiprocessing.Process):
im, labels = self.Q_in.get() im, labels = self.Q_in.get()
im_blob = numpy.zeros(shape=([self._batch_size] + list(im.shape)), dtype='uint8') im_blob = numpy.zeros(shape=([self._batch_size] + list(im.shape)), dtype='uint8')
label_blob = numpy.zeros((self._batch_size, len(labels)), dtype='int64') label_blob = numpy.zeros((self._batch_size, len(labels)), dtype='int64')
for ix in range(self._batch_size): for i in range(self._batch_size):
im_blob[ix, :, :, :], label_blob[ix, :] = im, labels im_blob[i, :, :, :], label_blob[i, :] = im, labels
if ix != self._batch_size - 1: im, labels = self.Q_in.get() if i != self._batch_size - 1: im, labels = self.Q_in.get()
return im_blob, label_blob return im_blob, label_blob
def run(self): def run(self):
......
...@@ -14,8 +14,8 @@ from __future__ import division ...@@ -14,8 +14,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import time import time
import multiprocessing
from multiprocessing import Queue
from dragon.core import mpi as _mpi from dragon.core import mpi as _mpi
from dragon.core import logging as _logging from dragon.core import logging as _logging
...@@ -39,14 +39,10 @@ class DataBatch(object): ...@@ -39,14 +39,10 @@ class DataBatch(object):
---------- ----------
source : str source : str
The path of database. The path of database.
multiple_nodes: boolean, optional, default=False
Whether to split data for multiple parallel nodes.
shuffle : bool, optional, default=False shuffle : bool, optional, default=False
Whether to shuffle the data. Whether to shuffle the data.
num_chunks : int, optional, default=2048 num_chunks : int, optional, default=2048
The number of chunks to split. The number of chunks to split.
chunk_size : int, optional, default=-1
The size(MB) of each chunk.
padding : int, optional, default=0 padding : int, optional, default=0
The zero-padding size. The zero-padding size.
fill_value : int, optional, default=127 fill_value : int, optional, default=127
...@@ -77,13 +73,12 @@ class DataBatch(object): ...@@ -77,13 +73,12 @@ class DataBatch(object):
""" """
super(DataBatch, self).__init__() super(DataBatch, self).__init__()
# Init mpi # Init mpi
global_rank = 0; local_rank = 0; group_size = 1 global_rank, local_rank, group_size = 0, 0, 1
if _mpi.Is_Init() and kwargs.get( if _mpi.Is_Init() and kwargs.get(
'phase', 'TRAIN') == 'TRAIN': 'phase', 'TRAIN') == 'TRAIN':
rank, group = _mpi.AllowParallel() rank, group = _mpi.AllowParallel()
if rank != -1: # DataParallel if rank != -1: # DataParallel
global_rank = _mpi.Rank() global_rank, group_size = _mpi.Rank(), len(group)
group_size = len(group)
for i, node in enumerate(group): for i, node in enumerate(group):
if global_rank == node: local_rank = i if global_rank == node: local_rank = i
kwargs['group_size'] = group_size kwargs['group_size'] = group_size
...@@ -109,39 +104,31 @@ class DataBatch(object): ...@@ -109,39 +104,31 @@ class DataBatch(object):
if kwargs.get('crop_size', 0) > 0 and \ if kwargs.get('crop_size', 0) > 0 and \
kwargs.get('phase', 'TRAIN') == 'TRAIN': kwargs.get('phase', 'TRAIN') == 'TRAIN':
self._num_transformers += 1 self._num_transformers += 1
self._num_transformers = min(self._num_transformers, self._max_transformers) self._num_transformers = min(
self._num_transformers, self._max_transformers)
self._batch_size = kwargs.get('batch_size', 128) self._batch_size = kwargs.get('batch_size', 128)
self._partition = kwargs.get('partition', False) self._partition = kwargs.get('partition', False)
if self._partition: if self._partition: self._batch_size //= kwargs['group_size']
self._batch_size = int(self._batch_size / kwargs['group_size'])
# Init queues # Init queues
self.Q_level_1 = multiprocessing.Queue( self.Q1 = Queue(self._prefetch * self._num_readers * self._batch_size)
self._prefetch * self._num_readers * self._batch_size) self.Q2 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_2 = multiprocessing.Queue( self.Q3 = Queue(self._prefetch * self._num_readers)
self._prefetch * self._num_readers * self._batch_size)
self.Q_level_3 = multiprocessing.Queue(
self._prefetch * self._num_readers)
# Init readers # Init readers
self._readers = [] self._readers = []
for i in range(self._num_readers): for i in range(self._num_readers):
self._readers.append(DataReader(**kwargs)) self._readers.append(DataReader(**kwargs))
self._readers[-1].Q_out = self.Q_level_1 self._readers[-1].Q_out = self.Q1
for i in range(self._num_readers): for i in range(self._num_readers):
num_parts = self._num_readers part_idx, num_parts = i, self._num_readers
part_idx = i
if self._readers[i]._multiple_nodes or \
self._readers[i]._use_shuffle:
num_parts *= group_size num_parts *= group_size
part_idx += local_rank * self._num_readers part_idx += local_rank * self._num_readers
self._readers[i]._num_parts = num_parts self._readers[i]._num_parts = num_parts
self._readers[i]._part_idx = part_idx self._readers[i]._part_idx = part_idx
self._readers[i]._random_seed += part_idx self._readers[i]._rng_seed += part_idx
self._readers[i].start() self._readers[i].start()
time.sleep(0.1) time.sleep(0.1)
...@@ -149,9 +136,8 @@ class DataBatch(object): ...@@ -149,9 +136,8 @@ class DataBatch(object):
self._transformers = [] self._transformers = []
for i in range(self._num_transformers): for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs) transformer = DataTransformer(**kwargs)
transformer._random_seed += (i + local_rank * self._num_transformers) transformer._rng_seed += (i + local_rank * self._num_transformers)
transformer.Q_in = self.Q_level_1 transformer.Q_in, transformer.Q_out = self.Q1, self.Q2
transformer.Q_out = self.Q_level_2
transformer.start() transformer.start()
self._transformers.append(transformer) self._transformers.append(transformer)
time.sleep(0.1) time.sleep(0.1)
...@@ -160,8 +146,7 @@ class DataBatch(object): ...@@ -160,8 +146,7 @@ class DataBatch(object):
self._fetchers = [] self._fetchers = []
for i in range(self._num_fetchers): for i in range(self._num_fetchers):
fetcher = BlobFetcher(**kwargs) fetcher = BlobFetcher(**kwargs)
fetcher.Q_in = self.Q_level_2 fetcher.Q_in, fetcher.Q_out = self.Q2, self.Q3
fetcher.Q_out = self.Q_level_3
fetcher.start() fetcher.start()
self._fetchers.append(fetcher) self._fetchers.append(fetcher)
time.sleep(0.1) time.sleep(0.1)
...@@ -189,4 +174,4 @@ class DataBatch(object): ...@@ -189,4 +174,4 @@ class DataBatch(object):
The batch, representing data and labels respectively. The batch, representing data and labels respectively.
""" """
return self.Q_level_3.get() return self.Q3.get()
\ No newline at end of file \ No newline at end of file
...@@ -34,28 +34,19 @@ class DataReader(multiprocessing.Process): ...@@ -34,28 +34,19 @@ class DataReader(multiprocessing.Process):
---------- ----------
source : str source : str
The path of database. The path of database.
multiple_nodes: boolean, optional, default=False
Whether to split data for multiple parallel nodes.
shuffle : bool, optional, default=False shuffle : bool, optional, default=False
Whether to shuffle the data. Whether to shuffle the data.
num_chunks : int, optional, default=2048 num_chunks : int, optional, default=2048
The number of chunks to split. The number of chunks to split.
chunk_size : int, optional, default=-1
The size(MB) of each chunk.
""" """
super(DataReader, self).__init__() super(DataReader, self).__init__()
self._source = kwargs.get('source', '') self._source = kwargs.get('source', '')
self._multiple_nodes = kwargs.get('multiple_nodes', False)
self._use_shuffle = kwargs.get('shuffle', False) self._use_shuffle = kwargs.get('shuffle', False)
self._use_instance_chunk = kwargs.get('instance_chunk', False)
self._num_chunks = kwargs.get('num_chunks', 2048) self._num_chunks = kwargs.get('num_chunks', 2048)
self._chunk_size = kwargs.get('chunk_size', -1)
self._part_idx, self._num_parts = 0, 1 self._part_idx, self._num_parts = 0, 1
self._cur_idx, self._cur_chunk_idx = 0, 0 self._cursor, self._chunk_cursor = 0, 0
self._random_seed = _cfg.GetRandomSeed() self._rng_seed = _cfg.GetRandomSeed()
self.Q_out = None self.Q_out = None
self.daemon = True self.daemon = True
...@@ -70,13 +61,13 @@ class DataReader(multiprocessing.Process): ...@@ -70,13 +61,13 @@ class DataReader(multiprocessing.Process):
""" """
return self._db.value() return self._db.value()
def redirect(self, target_idx): def redirect(self, target):
"""Redirect to the target position. """Redirect to the target position.
Parameters Parameters
---------- ----------
target_idx : int target : int
The key of instance in ``LMDB``. The key of the record.
Returns Returns
------- -------
...@@ -84,17 +75,17 @@ class DataReader(multiprocessing.Process): ...@@ -84,17 +75,17 @@ class DataReader(multiprocessing.Process):
Notes Notes
----- -----
The redirection reopens the ``LMDB``. The redirection reopens the database.
You can drop caches by ``echo 3 > /proc/sys/vm/drop_caches``. You can drop caches by ``echo 3 > /proc/sys/vm/drop_caches``.
This will disturb getting stuck when ``Database Size`` >> ``RAM Size``. This will disturb getting stuck when *Database Size* >> *RAM Size*.
""" """
self._db.close() self._db.close()
self._db.open(self._source) self._db.open(self._source)
self._cur_idx = target_idx self._cursor = target
self._db.set(str(self._cur_idx).zfill(self._zfill)) self._db.set(str(target).zfill(self._zfill))
def reset(self): def reset(self):
"""Reset the cursor and environment. """Reset the cursor and environment.
...@@ -104,21 +95,18 @@ class DataReader(multiprocessing.Process): ...@@ -104,21 +95,18 @@ class DataReader(multiprocessing.Process):
None None
""" """
if self._multiple_nodes or self._use_shuffle: if self._num_parts > 1 or self._use_shuffle:
if self._use_shuffle: self._chunk_cursor = 0
self._perm = numpy.random.permutation( self._part_idx = (self._part_idx + 1) % self._num_parts
self._num_shuffle_parts) if self._use_shuffle: self._perm = numpy.random.permutation(self._perm_size)
self._cur_chunk_idx = 0 self._head = self._part_idx * self._perm_size + self._perm[self._chunk_cursor]
self._start_idx = int(self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx]) self._tail = self._head * self._chunk_size
self._start_idx = int(self._start_idx * self._chunk_size) if self._head >= self._num_entries: self.next_chunk()
if self._start_idx >= self._num_entries: self.next_chunk() self._tail = self._head + self._chunk_size
self._end_idx = self._start_idx + self._chunk_size self._tail = min(self._num_entries, self._tail)
self._end_idx = min(self._num_entries, self._end_idx)
else: else:
self._start_idx = 0 self._head, self._tail = 0, self._num_entries
self._end_idx = self._num_entries self.redirect(self._head)
self.redirect(self._start_idx)
def next_record(self): def next_record(self):
"""Step the cursor of records. """Step the cursor of records.
...@@ -128,8 +116,8 @@ class DataReader(multiprocessing.Process): ...@@ -128,8 +116,8 @@ class DataReader(multiprocessing.Process):
None None
""" """
self._cur_idx += 1
self._db.next() self._db.next()
self._cursor += 1
def next_chunk(self): def next_chunk(self):
"""Step the cursor of shuffling chunks. """Step the cursor of shuffling chunks.
...@@ -139,16 +127,17 @@ class DataReader(multiprocessing.Process): ...@@ -139,16 +127,17 @@ class DataReader(multiprocessing.Process):
None None
""" """
self._cur_chunk_idx += 1 self._chunk_cursor += 1
if self._cur_chunk_idx >= self._num_shuffle_parts: self.reset() if self._chunk_cursor >= self._perm_size: self.reset()
else: else:
self._start_idx = self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx] self._head = self._part_idx * self._perm_size + self._perm[self._chunk_cursor]
self._start_idx = self._start_idx * self._chunk_size self._head = self._head * self._chunk_size
if self._start_idx >= self._num_entries: self.next_chunk() if self._head >= self._num_entries:
self.next_chunk()
else: else:
self._end_idx = self._start_idx + self._chunk_size self._tail = self._head + self._chunk_size
self._end_idx = min(self._num_entries, self._end_idx) self._tail = min(self._num_entries, self._tail)
self.redirect(self._start_idx) self.redirect(self._head)
def run(self): def run(self):
"""Start the process. """Start the process.
...@@ -158,44 +147,42 @@ class DataReader(multiprocessing.Process): ...@@ -158,44 +147,42 @@ class DataReader(multiprocessing.Process):
None None
""" """
# fix seed # Fix seed
numpy.random.seed(self._random_seed) numpy.random.seed(self._rng_seed)
# init db # Init db
self._db = _db.LMDB() self._db = _db.LMDB()
self._db.open(self._source) self._db.open(self._source)
self._zfill = self._db.zfill() self._zfill = self._db.zfill()
self._num_entries = self._db.num_entries() self._num_entries = self._db.num_entries()
self._epoch_size = int(self._num_entries / self._num_parts + 1)
epoch_size = self._num_entries // self._num_parts + 1
if self._use_shuffle: if self._use_shuffle:
if self._chunk_size == 1: if self._num_chunks <= 0:
# Each chunk has at most 1 record (Naive Shuffle) # Each chunk has at most 1 record (Record-Wise)
self._chunk_size, self._num_shuffle_parts = \ self._chunk_size, self._perm_size = 1, epoch_size
1, int(self._num_entries / self._num_parts) + 1
else: else:
if self._use_shuffle and self._chunk_size == -1: # Search a optimal chunk size (Chunk-Wise)
# Search a optimal chunk size by chunks (Chunk Shuffle) min_size, max_size = \
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20))) 1, self._db._total_size * 1.0 \
min_chunk_size = 1 / ((self._num_chunks * (1 << 20)))
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2 while min_size * 2 < max_size: min_size *= 2
self._chunk_size = min_chunk_size self._perm_size = int(math.ceil(
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 / self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20))) (self._num_parts * min_size << 20)))
self._chunk_size = int(self._num_entries / self._num_shuffle_parts / self._num_parts + 1) self._chunk_size = int(
limit = (self._num_parts - 0.5) * self._num_shuffle_parts * self._chunk_size self._num_entries * 1.0 /
(self._perm_size * self._num_parts) + 1)
limit = (self._num_parts - 0.5) * self._perm_size * self._chunk_size
if self._num_entries <= limit: if self._num_entries <= limit:
# Roll back to naive shuffle # Roll back to Record-Wise shuffle
self._chunk_size, self._num_shuffle_parts = \ self._chunk_size, self._perm_size = 1, epoch_size
1, int(self._num_entries / self._num_parts) + 1
else: else:
# Each chunk has at most K records # One chunk has at most K records
# Note that if ``shuffle`` and ``multiple_nodes`` are all *False*, self._chunk_size, self._perm_size = epoch_size, 1
# ``chunk_size`` and ``num_shuffle_parts`` are meaningless
self._chunk_size = int(self._num_entries / self._num_parts) + 1
self._num_shuffle_parts = 1
self._perm = numpy.arange(self._num_shuffle_parts) self._perm = numpy.arange(self._perm_size)
# Init env # Init env
self.reset() self.reset()
...@@ -204,7 +191,7 @@ class DataReader(multiprocessing.Process): ...@@ -204,7 +191,7 @@ class DataReader(multiprocessing.Process):
while True: while True:
self.Q_out.put(self.element()) self.Q_out.put(self.element())
self.next_record() self.next_record()
if self._cur_idx >= self._end_idx: if self._cursor >= self._tail:
if self._multiple_nodes or \ if self._num_parts > 1 or self._use_shuffle:
self._use_shuffle: self.next_chunk() self.next_chunk()
else: self.reset() else: self.reset()
\ No newline at end of file
...@@ -74,7 +74,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -74,7 +74,7 @@ class DataTransformer(multiprocessing.Process):
self._max_rand_scale = kwargs.get('max_random_scale', 1.0) self._max_rand_scale = kwargs.get('max_random_scale', 1.0)
self._force_color = kwargs.get('force_color', False) self._force_color = kwargs.get('force_color', False)
self._phase = kwargs.get('phase', 'TRAIN') self._phase = kwargs.get('phase', 'TRAIN')
self._random_seed = _cfg.GetRandomSeed() self._rng_seed = _cfg.GetRandomSeed()
self.Q_in = self.Q_out = None self.Q_in = self.Q_out = None
self.daemon = True self.daemon = True
...@@ -186,7 +186,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -186,7 +186,7 @@ class DataTransformer(multiprocessing.Process):
""" """
# Fix the random seed # Fix the random seed
numpy.random.seed(self._random_seed) numpy.random.seed(self._rng_seed)
# Run! # Run!
while True: while True:
......
...@@ -73,7 +73,7 @@ def fetch_initializer(initializer): ...@@ -73,7 +73,7 @@ def fetch_initializer(initializer):
def fetch_argument(op_def, desc, ws): def fetch_argument(op_def, desc, ws):
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
desc = desc.decode('utf-8') desc = desc.decode('utf-8')
desc = desc.replace('${ANCHOR}', op_def.name) desc = desc.replace('${HANDLE}', op_def.name)
argument_value = ws.FetchTensor(desc) argument_value = ws.FetchTensor(desc)
if argument_value.size == 1: if argument_value.size == 1:
return argument_value.flatten()[0] return argument_value.flatten()[0]
......
...@@ -24,6 +24,7 @@ import dragon ...@@ -24,6 +24,7 @@ import dragon
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from dragon import config as _cfg
from dragon.core import scope as _scope from dragon.core import scope as _scope
from dragon.core import logging as _logging from dragon.core import logging as _logging
from dragon.core import proto_utils as _proto_utils from dragon.core import proto_utils as _proto_utils
...@@ -310,6 +311,7 @@ class Module(object): ...@@ -310,6 +311,7 @@ class Module(object):
return self._module_key return self._module_key
def _gen_module_def(self): def _gen_module_def(self):
rng_seed = _cfg.GetGlobalOptions()['random_seed']
self._module_def = \ self._module_def = \
_proto_utils.MakeCXXOperatorDef( _proto_utils.MakeCXXOperatorDef(
name='runtime', name='runtime',
...@@ -318,7 +320,9 @@ class Module(object): ...@@ -318,7 +320,9 @@ class Module(object):
device_option=_proto_utils. device_option=_proto_utils.
GetDeviceOption( GetDeviceOption(
self._device.type, self._device.type,
self._device.index), self._device.index,
rng_seed=rng_seed,
),
**self.op_meta['arguments'] **self.op_meta['arguments']
) )
......
...@@ -36,10 +36,10 @@ class Indexing(BaseModule): ...@@ -36,10 +36,10 @@ class Indexing(BaseModule):
'op_type': 'Crop', 'op_type': 'Crop',
'arguments': { 'arguments': {
'starts_desc': [ 'starts_desc': [
'${{ANCHOR}}/starts[{}]'.format(n) '${{HANDLE}}/starts[{}]'.format(n)
for n in range(self.nstarts)], for n in range(self.nstarts)],
'sizes_desc': [ 'sizes_desc': [
'${{ANCHOR}}/sizes[{}]'.format(n) '${{HANDLE}}/sizes[{}]'.format(n)
for n in range(self.nsizes)], for n in range(self.nsizes)],
}, },
} }
...@@ -231,7 +231,7 @@ class Reshape(BaseModule): ...@@ -231,7 +231,7 @@ class Reshape(BaseModule):
'op_type': 'Reshape', 'op_type': 'Reshape',
'arguments': { 'arguments': {
'dims_desc': [ 'dims_desc': [
'${{ANCHOR}}/dims[{}]'.format(n) '${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim) for n in range(self.ndim)
], ],
}, },
...@@ -298,7 +298,7 @@ class Permute(BaseModule): ...@@ -298,7 +298,7 @@ class Permute(BaseModule):
self.op_meta = { self.op_meta = {
'op_type': 'Transpose', 'op_type': 'Transpose',
'arguments': { 'arguments': {
'perm_desc': ['${{ANCHOR}}/perm[{}]'.format(n) 'perm_desc': ['${{HANDLE}}/perm[{}]'.format(n)
for n in range(self.nperm)], for n in range(self.nperm)],
}, },
} }
...@@ -326,7 +326,7 @@ class Repeat(BaseModule): ...@@ -326,7 +326,7 @@ class Repeat(BaseModule):
'op_type': 'Tile', 'op_type': 'Tile',
'arguments': { 'arguments': {
'multiples_desc': [ 'multiples_desc': [
'${{ANCHOR}}/multiples[{}]'.format(n) '${{HANDLE}}/multiples[{}]'.format(n)
for n in range(self.ntimes) for n in range(self.ntimes)
], ],
}, },
......
...@@ -66,10 +66,10 @@ class Assign(BaseModule): ...@@ -66,10 +66,10 @@ class Assign(BaseModule):
'op_type': 'Assign', 'op_type': 'Assign',
'arguments': { 'arguments': {
'starts_desc': [ 'starts_desc': [
'${{ANCHOR}}/starts[{}]'.format(n) '${{HANDLE}}/starts[{}]'.format(n)
for n in range(self.nstarts)], for n in range(self.nstarts)],
'sizes_desc': [ 'sizes_desc': [
'${{ANCHOR}}/sizes[{}]'.format(n) '${{HANDLE}}/sizes[{}]'.format(n)
for n in range(self.nsizes)], for n in range(self.nsizes)],
}, },
} }
......
...@@ -45,7 +45,7 @@ class Fill(_InitModule): ...@@ -45,7 +45,7 @@ class Fill(_InitModule):
'dtype': self.dtype, 'dtype': self.dtype,
'value': float(self.value), 'value': float(self.value),
'dims_desc': [ 'dims_desc': [
'${{ANCHOR}}/dims[{}]'.format(n) '${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim) for n in range(self.ndim)
], ],
}, },
...@@ -67,7 +67,7 @@ class RandomNormal(_InitModule): ...@@ -67,7 +67,7 @@ class RandomNormal(_InitModule):
'mean': float(self.mean), 'mean': float(self.mean),
'std': float(self.std), 'std': float(self.std),
'dims_desc': [ 'dims_desc': [
'${{ANCHOR}}/dims[{}]'.format(n) '${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim) for n in range(self.ndim)
], ],
}, },
...@@ -89,7 +89,7 @@ class RandomUniform(_InitModule): ...@@ -89,7 +89,7 @@ class RandomUniform(_InitModule):
'low': float(self.low), 'low': float(self.low),
'high': float(self.high), 'high': float(self.high),
'dims_desc': [ 'dims_desc': [
'${{ANCHOR}}/dims[{}]'.format(n) '${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim) for n in range(self.ndim)
], ],
}, },
......
...@@ -32,7 +32,7 @@ class Resize2d(BaseModule): ...@@ -32,7 +32,7 @@ class Resize2d(BaseModule):
'fx': self.fx, 'fy': self.fy, 'fx': self.fx, 'fy': self.fy,
'data_format': 'NCHW', 'data_format': 'NCHW',
'dsize_desc': [ 'dsize_desc': [
'${{ANCHOR}}/dsize[{}]'.format(n) '${{HANDLE}}/dsize[{}]'.format(n)
for n in range(2) for n in range(2)
], ],
}, },
......
...@@ -28,19 +28,18 @@ class BlobFetcher(Process): ...@@ -28,19 +28,18 @@ class BlobFetcher(Process):
Parameters Parameters
---------- ----------
batch_size : int batch_size : int, optional, default=128
The size of a training batch. The size of a mini-batch.
partition : boolean partition : bool, optional, default=False
Whether to partition batch. Default is ``False``. Whether to partition batch for parallelism.
prefetch : int prefetch : int, optional, default=5
The prefetch count. Default is ``5``. The prefetch count.
""" """
super(BlobFetcher, self).__init__() super(BlobFetcher, self).__init__()
self._batch_size = kwargs.get('batch_size', 100) self._batch_size = kwargs.get('batch_size', 100)
self._partition = kwargs.get('partition', False) self._partition = kwargs.get('partition', False)
if self._partition: if self._partition: self._batch_size //= kwargs['group_size']
self._batch_size = int(self._batch_size / kwargs['group_size'])
self.Q_in = self.Q_out = None self.Q_in = self.Q_out = None
self.daemon = True self.daemon = True
...@@ -53,15 +52,12 @@ class BlobFetcher(Process): ...@@ -53,15 +52,12 @@ class BlobFetcher(Process):
The blob of image and labels. The blob of image and labels.
""" """
# fill blobs
im, labels = self.Q_in.get() im, labels = self.Q_in.get()
im_blob = np.zeros(shape=([self._batch_size] + list(im.shape)), dtype=im.dtype) im_blob = np.zeros(shape=([self._batch_size] + list(im.shape)), dtype=im.dtype)
label_blob = np.zeros((self._batch_size, len(labels)), dtype=np.int64) label_blob = np.zeros((self._batch_size, len(labels)), dtype=np.int64)
for ix in range(0, self._batch_size): for i in range(0, self._batch_size):
im_blob[ix, :, :, :], label_blob[ix, :] = im, labels im_blob[i, :, :, :], label_blob[i, :] = im, labels
if ix != self._batch_size - 1: im, labels = self.Q_in.get() if i != self._batch_size - 1: im, labels = self.Q_in.get()
# mean subtraction & numerical scale
im_blob = im_blob.astype(np.float32) im_blob = im_blob.astype(np.float32)
return im_blob, label_blob return im_blob, label_blob
...@@ -73,5 +69,4 @@ class BlobFetcher(Process): ...@@ -73,5 +69,4 @@ class BlobFetcher(Process):
None None
""" """
while True: while True: self.Q_out.put(self.get())
self.Q_out.put(self.get()) \ No newline at end of file
\ No newline at end of file
...@@ -17,8 +17,8 @@ import time ...@@ -17,8 +17,8 @@ import time
import pprint import pprint
from multiprocessing import Queue from multiprocessing import Queue
import dragon.core.mpi as mpi from dragon.core import mpi as _mpi
import dragon.core.logging as logging from dragon.core import logging as _logging
from dragon.utils.vision import DataReader from dragon.utils.vision import DataReader
from .data_transformer import DataTransformer from .data_transformer import DataTransformer
...@@ -39,52 +39,48 @@ class DataBatch(object): ...@@ -39,52 +39,48 @@ class DataBatch(object):
---------- ----------
source : str source : str
The path of database. The path of database.
multiple_nodes: boolean multiple_nodes: boolean, optional, default=False
Whether to split data for multiple parallel nodes. Default is ``False``. Whether to split data for multiple parallel nodes.
shuffle : boolean shuffle : bool, optional, default=False
Whether to shuffle the data. Default is ``False``. Whether to shuffle the data.
num_chunks : int num_chunks : int, optional, default=2048
The number of chunks to split. Default is ``2048``. The number of chunks to split.
chunk_size : int padding : int, optional, default=0
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``). The zero-padding size.
mean_values : list fill_value : int, optional, default=127
The mean value of each image channel. The value to fill when padding is valid.
scale : float crop_size : int, optional, default=0
The scale performed after mean subtraction. Default is ``1.0``. The cropping size.
padding : int cutout_size : int, optional, default=0
The zero-padding size. Default is ``0`` (Disabled). The square size to cutout.
fill_value : int mirror : bool, optional, default=False
The value to fill when padding is valid. Default is ``127``. Whether to mirror(flip horizontally) images.
crop_size : int color_augmentation : bool, optional, default=False
The crop size. Default is ``0`` (Disabled). Whether to use color distortion.1
mirror : boolean min_random_scale : float, optional, default=1.
Whether to flip(horizontally) images. Default is ``False``. The min scale of the input images.
color_augmentation : boolean max_random_scale : float, optional, default=1.
Whether to distort colors. Default is ``False``. The max scale of the input images.
min_random_scale : float force_gray : bool, optional, default=False
The min scale of the input images. Default is ``1.0``. Set not to duplicate channel for gray.
max_random_scale : float phase : {'TRAIN', 'TEST'}, optional
The max scale of the input images. Default is ``1.0``. The optional running phase.
force_color : boolean batch_size : int, optional, default=128
Set to duplicate channels for gray. Default is ``False``. The size of a mini-batch.
phase : str partition : bool, optional, default=False
The phase of this operator, ``TRAIN`` or ``TEST``. Default is ``TRAIN``. Whether to partition batch for parallelism.
batch_size : int prefetch : int, optional, default=5
The size of a training batch. The prefetch count.
partition : boolean
Whether to partition batch. Default is ``False``.
prefetch : int
The prefetch count. Default is ``5``.
""" """
super(DataBatch, self).__init__() super(DataBatch, self).__init__()
# init mpi # Init mpi
global_rank = 0; local_rank = 0; group_size = 1 global_rank, local_rank, group_size = 0, 0, 1
if mpi.Is_Init(): if _mpi.Is_Init() and kwargs.get(
idx, group = mpi.AllowParallel() 'phase', 'TRAIN') == 'TRAIN':
if idx != -1: # data parallel rank, group = _mpi.AllowParallel()
global_rank = mpi.Rank() if rank != -1: # DataParallel
group_size = len(group) global_rank, group_size = _mpi.Rank(), len(group)
for i, node in enumerate(group): for i, node in enumerate(group):
if global_rank == node: local_rank = i if global_rank == node: local_rank = i
kwargs['group_size'] = group_size kwargs['group_size'] = group_size
...@@ -106,73 +102,66 @@ class DataBatch(object): ...@@ -106,73 +102,66 @@ class DataBatch(object):
if kwargs.get('max_random_scale', 1.0) - \ if kwargs.get('max_random_scale', 1.0) - \
kwargs.get('min_random_scale', 1.0) != 0: kwargs.get('min_random_scale', 1.0) != 0:
self._num_transformers += 1 self._num_transformers += 1
self._num_transformers = min(self._num_transformers, self._max_transformers) self._num_transformers = min(
self._num_transformers, self._max_transformers)
self._batch_size = kwargs.get('batch_size', 100) self._batch_size = kwargs.get('batch_size', 128)
self._partition = kwargs.get('partition', False) self._partition = kwargs.get('partition', False)
if self._partition: if self._partition: self._batch_size //= kwargs['group_size']
self._batch_size = int(self._batch_size / kwargs['group_size'])
# init queues # init queues
self.Q_level_1 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q1 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_2 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q2 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_3 = Queue(self._prefetch * self._num_readers) self.Q3 = Queue(self._prefetch * self._num_readers)
# init readers # init readers
self._readers = [] self._readers = []
for i in range(self._num_readers): for i in range(self._num_readers):
self._readers.append(DataReader(**kwargs)) self._readers.append(DataReader(**kwargs))
self._readers[-1].Q_out = self.Q_level_1 self._readers[-1].Q_out = self.Q1
for i in range(self._num_readers): for i in range(self._num_readers):
num_parts = self._num_readers part_idx, num_parts = i, self._num_readers
part_idx = i if self._readers[i]._multi_nodes or \
if self._readers[i]._multiple_nodes or \
self._readers[i]._use_shuffle: self._readers[i]._use_shuffle:
num_parts *= group_size num_parts *= group_size
part_idx += local_rank * self._num_readers part_idx += local_rank * self._num_readers
self._readers[i]._num_parts = num_parts self._readers[i]._num_parts = num_parts
self._readers[i]._part_idx = part_idx self._readers[i]._part_idx = part_idx
self._readers[i]._random_seed += part_idx self._readers[i]._rng_seed += part_idx
self._readers[i].start() self._readers[i].start()
time.sleep(0.1) time.sleep(0.1)
# init transformers # Init transformers
self._transformers = [] self._transformers = []
for i in range(self._num_transformers): for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs) transformer = DataTransformer(**kwargs)
transformer._random_seed += (i + local_rank * self._num_transformers) transformer._rng_seed += (i + local_rank * self._num_transformers)
transformer.Q_in = self.Q_level_1 transformer.Q_in, transformer.Q_out = self.Q1, self.Q2
transformer.Q_out = self.Q_level_2
transformer.start() transformer.start()
self._transformers.append(transformer) self._transformers.append(transformer)
time.sleep(0.1) time.sleep(0.1)
# init blob fetchers # Init blob fetchers
self._fetchers = [] self._fetchers = []
for i in range(self._num_fetchers): for i in range(self._num_fetchers):
fetcher = BlobFetcher(**kwargs) fetcher = BlobFetcher(**kwargs)
fetcher.Q_in = self.Q_level_2 fetcher.Q_in, fetcher.Q_out = self.Q2, self.Q3
fetcher.Q_out = self.Q_level_3
fetcher.start() fetcher.start()
self._fetchers.append(fetcher) self._fetchers.append(fetcher)
time.sleep(0.1) time.sleep(0.1)
# prevent to echo multiple nodes
if local_rank == 0: self.echo()
def cleanup(): def cleanup():
def terminate(processes): def terminate(processes):
for process in processes: for process in processes:
process.terminate() process.terminate()
process.join() process.join()
terminate(self._fetchers) terminate(self._fetchers)
if local_rank == 0: logging.info('Terminating BlobFetcher ......') if local_rank == 0: _logging.info('Terminate BlobFetcher.')
terminate(self._transformers) terminate(self._transformers)
if local_rank == 0: logging.info('Terminating DataTransformer ......') if local_rank == 0: _logging.info('Terminate DataTransformer.')
terminate(self._readers) terminate(self._readers)
if local_rank == 0: logging.info('Terminating DataReader......') if local_rank == 0: _logging.info('Terminate DataReader.')
import atexit import atexit
atexit.register(cleanup) atexit.register(cleanup)
...@@ -185,22 +174,4 @@ class DataBatch(object): ...@@ -185,22 +174,4 @@ class DataBatch(object):
The batch, representing data and labels respectively. The batch, representing data and labels respectively.
""" """
return self.Q_level_3.get() return self.Q3.get()
\ No newline at end of file
def echo(self):
"""Print I/O Information.
Returns
-------
None
"""
print('---------------------------------------------------------')
print('BatchFetcher({} Threads), Using config:'.format(
self._num_readers + self._num_transformers + self._num_fetchers))
params = {'queue_size': self._prefetch,
'n_readers': self._num_readers,
'n_transformers': self._num_transformers,
'n_fetchers': self._num_fetchers}
pprint.pprint(params)
print('---------------------------------------------------------')
...@@ -13,9 +13,8 @@ from __future__ import absolute_import ...@@ -13,9 +13,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy
import numpy.random as npr import multiprocessing
from multiprocessing import Process
import dragon.config as config import dragon.config as config
import dragon.vm.caffe.proto.caffe_pb2 as pb import dragon.vm.caffe.proto.caffe_pb2 as pb
...@@ -26,7 +25,7 @@ except ImportError as e: ...@@ -26,7 +25,7 @@ except ImportError as e:
print('Failed to import cv2. Error: {0}'.format(str(e))) print('Failed to import cv2. Error: {0}'.format(str(e)))
class DataTransformer(Process): class DataTransformer(multiprocessing.Process):
"""DataTransformer is deployed to queue transformed images from `DataReader`_. """DataTransformer is deployed to queue transformed images from `DataReader`_.
Nearly all common image augmentation methods are supported. Nearly all common image augmentation methods are supported.
...@@ -37,11 +36,11 @@ class DataTransformer(Process): ...@@ -37,11 +36,11 @@ class DataTransformer(Process):
Parameters Parameters
---------- ----------
transform : lambda transform : lambda, optional
The transforms. The transforms.
color_space : str color_space : {'RGB', 'BGR'}, optional
The color space. The color space.
pack : boolean pack : boolean, optional, default=False
Pack the images automatically. Pack the images automatically.
""" """
...@@ -49,7 +48,7 @@ class DataTransformer(Process): ...@@ -49,7 +48,7 @@ class DataTransformer(Process):
self.transform = transform self.transform = transform
self.color_space = color_space self.color_space = color_space
self.pack = pack self.pack = pack
self._random_seed = config.GetRandomSeed() self._rng_seed = config.GetRandomSeed()
self.Q_in = self.Q_out = None self.Q_in = self.Q_out = None
self.daemon = True self.daemon = True
...@@ -70,7 +69,7 @@ class DataTransformer(Process): ...@@ -70,7 +69,7 @@ class DataTransformer(Process):
# Decode # Decode
datum = pb.Datum() datum = pb.Datum()
datum.ParseFromString(serialized) datum.ParseFromString(serialized)
im = np.fromstring(datum.data, np.uint8) im = numpy.fromstring(datum.data, numpy.uint8)
if datum.encoded is True: if datum.encoded is True:
im = cv2.imdecode(im, -1) im = cv2.imdecode(im, -1)
else: else:
...@@ -94,7 +93,7 @@ class DataTransformer(Process): ...@@ -94,7 +93,7 @@ class DataTransformer(Process):
None None
""" """
npr.seed(self._random_seed) numpy.random.seed(self._rng_seed)
while True: while True:
serialized = self.Q_in.get() serialized = self.Q_in.get()
im, label = self.get(serialized) im, label = self.get(serialized)
...@@ -103,5 +102,5 @@ class DataTransformer(Process): ...@@ -103,5 +102,5 @@ class DataTransformer(Process):
self.Q_out.put((im[ix], label)) self.Q_out.put((im[ix], label))
else: else:
if len(im.shape) == 3 and self.pack: if len(im.shape) == 3 and self.pack:
im = np.expand_dims(im, axis=0) im = numpy.expand_dims(im, axis=0)
self.Q_out.put((im, label)) self.Q_out.put((im, label))
\ No newline at end of file
...@@ -21,7 +21,7 @@ class _DataLoaderIter(object): ...@@ -21,7 +21,7 @@ class _DataLoaderIter(object):
self.loader = loader self.loader = loader
def __len__(self): def __len__(self):
return len(self.loader.batch.Q_level_3.qsize()) return len(self.loader.batch.Q3.qsize())
def __next__(self): def __next__(self):
return self.loader.batch.get() return self.loader.batch.get()
...@@ -33,28 +33,28 @@ class _DataLoaderIter(object): ...@@ -33,28 +33,28 @@ class _DataLoaderIter(object):
class DataLoader(object): class DataLoader(object):
def __init__(self, def __init__(
dataset, batch_size=1, shuffle=False, self,
partition=False, multiple_nodes=False, dataset,
num_chunks=2048, chunk_size=-1): batch_size=1,
shuffle=False,
num_chunks=2048,
phase='TRAIN',
):
"""A MPI-Aware DataLoader. Forked from ``dragon.utils.vision``. """A MPI-Aware DataLoader. Forked from ``dragon.utils.vision``.
Parameters Parameters
---------- ----------
dataset : torch.utils.data.dataset.Dataset dataset : torch.utils.data.dataset.Dataset
The dataset. The dataset.
batch_size : int batch_size : int, optional, default=1
The batch size. Divided by n mpi-nodes if ``partition`` is True. The batch size.
shuffle : boolean shuffle : boolean, optional, default=False
Whether to shuffle the data. Whether to shuffle the data.
partition : boolean num_chunks : int, optional, default=2048
Whether to partition batch. Default is ``False``. The number of chunks to split.
multiple_nodes: boolean phase : {'TRAIN', 'TEST'}, optional
Whether to split data for multiple parallel nodes. The optional running phase.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
""" """
self.dataset = dataset self.dataset = dataset
...@@ -65,12 +65,10 @@ class DataLoader(object): ...@@ -65,12 +65,10 @@ class DataLoader(object):
n_transformers = dataset.transform.n_transformers n_transformers = dataset.transform.n_transformers
self.batch = _Batch(**{ self.batch = _Batch(**{
'source': dataset.database, 'source': dataset.database,
'multiple_nodes': multiple_nodes,
'shuffle': shuffle, 'shuffle': shuffle,
'num_chunks': num_chunks, 'num_chunks': num_chunks,
'chunk_size': chunk_size, 'phase': phase,
'batch_size': batch_size, 'batch_size': batch_size,
'partition': partition,
'transform': dataset.transform, 'transform': dataset.transform,
'color_space': dataset.color_space, 'color_space': dataset.color_space,
'num_transformers': n_transformers, 'num_transformers': n_transformers,
...@@ -80,7 +78,7 @@ class DataLoader(object): ...@@ -80,7 +78,7 @@ class DataLoader(object):
return _DataLoaderIter(self) return _DataLoaderIter(self)
def __len__(self): def __len__(self):
return self.batch.Q_level_3.qsize() return self.batch.Q3.qsize()
def next(self): def next(self):
return self.batch.get() return self.batch.get()
......
...@@ -8,59 +8,57 @@ namespace dragon { ...@@ -8,59 +8,57 @@ namespace dragon {
/* Default constructor of <GraphBase> */ /* Default constructor of <GraphBase> */
GraphBase::GraphBase(const GraphDef& def, Workspace* ws) GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
: name_(def.name()), ws_(ws) { : def_(def), ws_(ws), name_(def.name()), phase_("TEST") {
for (auto arg : def.arg()) { // Scan the defined arguments
for (auto& arg : def_.arg()) {
CHECK_GT(arg.name().size(), 0); CHECK_GT(arg.name().size(), 0);
CHECK_EQ(args_.count(arg.name()), 0); CHECK_EQ(args_.count(arg.name()), 0);
args_[arg.name()] = arg; args_[arg.name()] = &arg;
if (arg.name() == "phase") phase_ = arg.s();
} }
Set<string> known_tensors; // Collect outputs
Set<string> outputs;
// Topo-check for a graph
for (const auto& op : def.op()) { for (const auto& op : def.op()) {
// Check inputs
for (const auto& in : op.input()) for (const auto& in : op.input())
CHECK(known_tensors.count(in) || ws_->HasTensor(in)) CHECK(outputs.count(in) || ws_->HasTensor(in))
<< "\nInput: " << in << " for op: " << "\nInput: " << in << " for op: "
<< op.name() << " is unknown."; << op.name() << " is unknown.";
// Add outputs for (const auto& out : op.output()) outputs.insert(out);
for (const auto& out : op.output()) known_tensors.insert(out);
} }
// Check for all solving targets // Check targets
Set<string> objective_targets; Set<string> targets;
for (const auto& target : def.output()) { for (const auto& target : def.output()) {
CHECK(known_tensors.count(target) || CHECK(outputs.count(target) || ws_->HasTensor(target))
ws_->HasTensor(target))
<< "\nTarget: " << target << "\nTarget: " << target
<< " does not exist in computional graph."; << " does not exist in the graph.";
objective_targets.insert(target); targets.insert(target);
} }
// Check for all gradients // Check gradients
for (const auto& gradient : def.gradient()) { for (const auto& gradient : def.gradient()) {
const auto& cost = gradient.cost(); const auto& cost = gradient.cost();
const auto& wrt = gradient.wrt(); const auto& wrt = gradient.wrt();
CHECK(known_tensors.count(cost) || ws_->HasTensor(cost)) CHECK(outputs.count(cost) || ws_->HasTensor(cost))
<< "\nTarget: " << cost << "\nTarget: " << cost
<< "_grad does not exist in computional graph."; << "does not exist in the graph.";
CHECK(known_tensors.count(wrt) || ws_->HasTensor(wrt)) CHECK(outputs.count(wrt) || ws_->HasTensor(wrt))
<< "\nTarget: " << wrt << "\nTarget: " << wrt
<< "_grad does not exist in computional graph."; << "does not exist in the graph.";
CHECK_GT(objective_targets.count(cost), 0) CHECK_GT(targets.count(cost), 0)
<< "\nTo solve d(" << cost << ")/d(" << wrt << "), " << "\nTo solve d(" << cost << ")/d(" << wrt << "),\n"
<< "must set " << cost << cost << " should be set as a target.";
<< "\nas a objective tensor to solve before derivating it.";
} }
} }
/* Create a graph from the optimized def */ /* Create a graph from the optimized def */
bool Graph::Create(const GraphDef& def, Workspace* ws) { bool Graph::Create(const GraphDef& def, Workspace* ws) {
this->opt_def_ = def; // Store for debugging
bool has_device_option = def.has_device_option(); bool has_device_option = def.has_device_option();
for (int i = 0; i < def.op_size(); i++) { for (int i = 0; i < def.op_size(); i++) {
OperatorDef op_def(def.op(i)); auto op_def(def.op(i));
LOG(DEBUG) << "Create Operator " << op_def.name() LOG(DEBUG) << "Create Operator " << op_def.name()
<< ": " << op_def.type(); << ": " << op_def.type();
// Inherit device option if necessary // Inherit device option if necessary
...@@ -75,8 +73,7 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) { ...@@ -75,8 +73,7 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) {
arg.set_name("do_sync"); arg.set_name("do_sync");
arg.set_i(1); op_def.add_arg()->CopyFrom(arg); arg.set_i(1); op_def.add_arg()->CopyFrom(arg);
} }
OperatorBase* op = NewOperator(op_def, ws); ops_.push_back(NewOperator(op_def, ws));
ops_.push_back(op);
} }
return true; return true;
} }
...@@ -89,14 +86,14 @@ Graph::Graph(const GraphDef& def, Workspace* ws) ...@@ -89,14 +86,14 @@ Graph::Graph(const GraphDef& def, Workspace* ws)
GraphDef opt_def = def; GraphDef opt_def = def;
GraphOptimizer graph_optim(ws); GraphOptimizer graph_optim(ws);
GraphGradientMaker gradient_maker; GraphGradientMaker gradient_maker;
Map< string, vec32_t > subgraph_indices; Map<string, vec32_t> subgraph_indices;
int opt = 3; // defaults: O3 int opt = 3; // defaults: O3
if (this->args_.count("optimization_level")) if (args().count("optimization_level"))
opt = this->args_["optimization_level"].i(); opt = arg("optimization_level").i();
if (opt >= 1) opt_def = graph_optim.PruneNodes(def); if (opt >= 1) opt_def = graph_optim.PruneNodes(def);
if (opt >= 2) opt_def = graph_optim.AddInplace(opt_def); if (opt >= 2) opt_def = graph_optim.AddInplace(opt_def);
if (opt >= 3) { if (opt >= 3) {
if (this->args_["phase"].s() == "TRAIN") { if (phase() == "TRAIN") {
opt_def = graph_optim.MirrorStage( opt_def = graph_optim.MirrorStage(
opt_def, subgraph_indices); opt_def, subgraph_indices);
opt_def = gradient_maker.Share(opt_def); opt_def = gradient_maker.Share(opt_def);
...@@ -105,23 +102,10 @@ Graph::Graph(const GraphDef& def, Workspace* ws) ...@@ -105,23 +102,10 @@ Graph::Graph(const GraphDef& def, Workspace* ws)
} }
} }
// Try to store the final graph as a tensor for visualization
bool could_be_serialized = true;
for (auto& op : opt_def.op())
if (op.type() == "GivenTensorFill")
could_be_serialized = false;
if (could_be_serialized) {
ws_->CreateTensor(
"/graph_def/optimized/" + opt_def.name())
->Reshape({ 1 })
->mutable_data<string, CPUContext>()[0]
= opt_def.DebugString();
}
// Create // Create
Create(opt_def, ws); Create(opt_def, ws);
// Recomputing-aware // Recomputation and SubGraph
if (subgraph_indices.size() > 0) { if (subgraph_indices.size() > 0) {
Map<string, vector<OperatorBase*>> subgraph; Map<string, vector<OperatorBase*>> subgraph;
for (const auto& it : subgraph_indices) { for (const auto& it : subgraph_indices) {
...@@ -141,11 +125,13 @@ bool Graph::Run( ...@@ -141,11 +125,13 @@ bool Graph::Run(
int stream_id) { int stream_id) {
LOG(DEBUG) << "Run Graph: " << name(); LOG(DEBUG) << "Run Graph: " << name();
for (auto op : ops_) { for (auto op : ops_) {
if (!include.empty()) if (!include.empty() &&
if (op->type().find(include) == string::npos) continue; !str::find(op->type(), include)
if (!exclude.empty()) ) continue;
if (op->type().find(exclude) != string::npos) continue; if (!exclude.empty() &&
op->SwitchToPhase(this->args_["phase"].s()); str::find(op->type(), exclude)
) continue;
op->SwitchToPhase(phase());
LOG(DEBUG) << "$ Before Operator: " << op->name(); LOG(DEBUG) << "$ Before Operator: " << op->name();
op->Run(stream_id); op->Run(stream_id);
LOG(DEBUG) << "$ After Operator: " << op->name(); LOG(DEBUG) << "$ After Operator: " << op->name();
......
...@@ -41,7 +41,7 @@ bool GraphGradientMaker::CheckGrad( ...@@ -41,7 +41,7 @@ bool GraphGradientMaker::CheckGrad(
string GraphGradientMaker::GetOperatorName() { string GraphGradientMaker::GetOperatorName() {
if (op_prefix_.empty()) return "runtime"; if (op_prefix_.empty()) return "runtime";
return op_prefix_ + std::to_string(cur_op_idx_++) + op_suffix_; return op_prefix_ + str::to(cur_op_idx_++) + op_suffix_;
} }
void GraphGradientMaker::Make( void GraphGradientMaker::Make(
...@@ -91,7 +91,7 @@ void GraphGradientMaker::Make( ...@@ -91,7 +91,7 @@ void GraphGradientMaker::Make(
if (g_output.empty()) g_output = "NULL"; if (g_output.empty()) g_output = "NULL";
g_outputs.emplace_back(g_output); g_outputs.emplace_back(g_output);
} }
Gradient grad = MakeGradientForOp(op, g_outputs); auto grad = MakeGradientForOp(op, g_outputs);
// Process the RAW grad ops // Process the RAW grad ops
vector<OperatorDef> gather_ops; vector<OperatorDef> gather_ops;
...@@ -101,11 +101,11 @@ void GraphGradientMaker::Make( ...@@ -101,11 +101,11 @@ void GraphGradientMaker::Make(
// Rename if necessary // Rename if necessary
if (terms_.size() > 0) { if (terms_.size() > 0) {
for (int i = 0; i < g_op.input_size(); ++i) { for (int i = 0; i < g_op.input_size(); ++i) {
string* input = g_op.mutable_input(i); auto* input = g_op.mutable_input(i);
if (terms_.count(*input)) *input = terms_[*input]; if (terms_.count(*input)) *input = terms_[*input];
} }
for (int i = 0; i < g_op.output_size(); ++i) { for (int i = 0; i < g_op.output_size(); ++i) {
string* output = g_op.mutable_output(i); auto* output = g_op.mutable_output(i);
if (terms_.count(*output)) *output = terms_[*output]; if (terms_.count(*output)) *output = terms_[*output];
} }
for (int i = 0; i < grad.g_inputs.size(); ++i) { for (int i = 0; i < grad.g_inputs.size(); ++i) {
...@@ -115,7 +115,7 @@ void GraphGradientMaker::Make( ...@@ -115,7 +115,7 @@ void GraphGradientMaker::Make(
} }
// Split & gather grads for multi-used input // Split & gather grads for multi-used input
for (int i = 0; i < g_op.output_size(); ++i) { for (int i = 0; i < g_op.output_size(); ++i) {
string* output = g_op.mutable_output(i); auto* output = g_op.mutable_output(i);
int original_idx = -1; int original_idx = -1;
for (int j = 0; j < grad.g_inputs.size(); ++j) for (int j = 0; j < grad.g_inputs.size(); ++j)
if (g_op.output(i) == grad.g_inputs[j]) original_idx = j; if (g_op.output(i) == grad.g_inputs[j]) original_idx = j;
...@@ -126,11 +126,11 @@ void GraphGradientMaker::Make( ...@@ -126,11 +126,11 @@ void GraphGradientMaker::Make(
if (g_op.output(i) == input) output_in_inputs = true; if (g_op.output(i) == input) output_in_inputs = true;
if (output_in_inputs) continue; if (output_in_inputs) continue;
// Found a split branch // Found a split branch
string original_name = op.input(original_idx); auto original_name = op.input(original_idx);
if (inputs_count[original_name] > 1) { if (inputs_count[original_name] > 1) {
// Split // Split
string split_name = *output + "_autosplit_" auto split_name = *output + "_autosplit_"
+ std::to_string(grads_count[*output]++); + str::to(grads_count[*output]++);
if (!is_skip) all_split_grads.insert(split_name); if (!is_skip) all_split_grads.insert(split_name);
// Gather // Gather
if (grads_count[*output] == inputs_count[original_name]) { if (grads_count[*output] == inputs_count[original_name]) {
...@@ -142,7 +142,7 @@ void GraphGradientMaker::Make( ...@@ -142,7 +142,7 @@ void GraphGradientMaker::Make(
gather_op.mutable_device_option() gather_op.mutable_device_option()
->CopyFrom(g_op.device_option()); ->CopyFrom(g_op.device_option());
for (int j = 0; j < grads_count[*output]; j++) { for (int j = 0; j < grads_count[*output]; j++) {
string key = *output + "_autosplit_" + std::to_string(j); auto key = *output + "_autosplit_" + str::to(j);
if (all_split_grads.count(key)) gather_op.add_input(key); if (all_split_grads.count(key)) gather_op.add_input(key);
} }
gather_ops.emplace_back(gather_op); gather_ops.emplace_back(gather_op);
...@@ -219,7 +219,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -219,7 +219,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < input_def.op_size(); ++i) {
const OperatorDef& op = input_def.op(i); const OperatorDef& op = input_def.op(i);
// Ignore the non-gradient ops // Ignore the non-gradient ops
if (op.type().find("Gradient") == string::npos) continue; if (!str::find(op.type(), "Gradient")) continue;
if (op.type() == "GradientGather") { if (op.type() == "GradientGather") {
invalid_ops.insert(i); invalid_ops.insert(i);
if (ignore_grads_.count(op.output(0))) { if (ignore_grads_.count(op.output(0))) {
...@@ -237,7 +237,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -237,7 +237,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
} }
} }
for (const auto& input : op.input()) for (const auto& input : op.input())
if (input.find("grad") != string::npos) if (str::find(input, "grad"))
ref_count[input] += 1; ref_count[input] += 1;
} }
...@@ -247,7 +247,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -247,7 +247,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
if (invalid_ops.count(i)) continue; if (invalid_ops.count(i)) continue;
const OperatorDef& op = input_def.op(i); const OperatorDef& op = input_def.op(i);
output_def.add_op()->CopyFrom(op); output_def.add_op()->CopyFrom(op);
if (op.type().find("Gradient") == string::npos) continue; if (!str::find(op.type(), "Gradient")) continue;
for (const auto& output : op.output()) { for (const auto& output : op.output()) {
const auto& find_iter = ssa_map.find(output); const auto& find_iter = ssa_map.find(output);
if (find_iter != ssa_map.end()) { if (find_iter != ssa_map.end()) {
...@@ -277,7 +277,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -277,7 +277,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
auto get_temporary_grad = [&]() mutable { auto get_temporary_grad = [&]() mutable {
if (grads_pool.empty()) { if (grads_pool.empty()) {
return "/share/buffer/grad:" + return "/share/buffer/grad:" +
std::to_string(temporary_idx++); str::to(temporary_idx++);
} else { } else {
/*! /*!
* LIFO is more memory efficent than FIFO usually, * LIFO is more memory efficent than FIFO usually,
...@@ -295,7 +295,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -295,7 +295,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
for (int i = 0; i < output_def.op_size(); ++i) { for (int i = 0; i < output_def.op_size(); ++i) {
OperatorDef* op = output_def.mutable_op(i); OperatorDef* op = output_def.mutable_op(i);
// Ignore the non-gradient ops // Ignore the non-gradient ops
if (op->type().find("Gradient") == string::npos) continue; if (!str::find(op->type(), "Gradient")) continue;
// GC to store the grads that have finished lifecycle // GC to store the grads that have finished lifecycle
vector<string> GC; vector<string> GC;
// Inplace-aware // Inplace-aware
......
...@@ -13,7 +13,7 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) { ...@@ -13,7 +13,7 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) {
dag_.clear(); colored_.clear(); dag_.clear(); colored_.clear();
// Build DAG // Build DAG
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < input_def.op_size(); ++i) {
const OperatorDef& op = input_def.op(i); const auto& op = input_def.op(i);
for (const auto& v : op.output()) { for (const auto& v : op.output()) {
vector<string> sp_u; vector<string> sp_u;
if (!op.input_size()) sp_u.resize(op.output_size()); if (!op.input_size()) sp_u.resize(op.output_size());
...@@ -36,8 +36,8 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) { ...@@ -36,8 +36,8 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) {
// Forward dyeing through connected path for all gradients // Forward dyeing through connected path for all gradients
for (const auto& gradient : input_def.gradient()) { for (const auto& gradient : input_def.gradient()) {
string u = gradient.cost() + "_grad"; auto u = gradient.cost() + "_grad";
string v = gradient.wrt() + "_grad"; auto v = gradient.wrt() + "_grad";
if (ws_->HasTensor(u)) u = ws_->GetTensor(u)->name(); if (ws_->HasTensor(u)) u = ws_->GetTensor(u)->name();
if (ws_->HasTensor(v)) v = ws_->GetTensor(v)->name(); if (ws_->HasTensor(v)) v = ws_->GetTensor(v)->name();
visited_.clear(); visited_.clear();
...@@ -111,7 +111,7 @@ GraphDef GraphOptimizer::AddInplace(const GraphDef& input_def) { ...@@ -111,7 +111,7 @@ GraphDef GraphOptimizer::AddInplace(const GraphDef& input_def) {
dag_.clear(); renamed_.clear(); dag_.clear(); renamed_.clear();
// Build DAG // Build DAG
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < input_def.op_size(); ++i) {
const OperatorDef& op = input_def.op(i); const auto& op = input_def.op(i);
for (const auto& v : op.output()) { for (const auto& v : op.output()) {
vector<string> sp_u; vector<string> sp_u;
if (!op.input_size()) sp_u.resize(op.output_size()); if (!op.input_size()) sp_u.resize(op.output_size());
...@@ -142,7 +142,7 @@ GraphDef GraphOptimizer::AddInplace(const GraphDef& input_def) { ...@@ -142,7 +142,7 @@ GraphDef GraphOptimizer::AddInplace(const GraphDef& input_def) {
// Rename to create in-place // Rename to create in-place
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < input_def.op_size(); ++i) {
const OperatorDef& op = input_def.op(i); const auto& op = input_def.op(i);
for (int j = 0; j < op.input_size(); ++j) { for (int j = 0; j < op.input_size(); ++j) {
if (whitelist.count(op.input(j)) == 0 && if (whitelist.count(op.input(j)) == 0 &&
renamed_.count(op.input(j)) && renamed_.count(op.input(j)) &&
...@@ -172,7 +172,7 @@ GraphDef GraphOptimizer::AddInplace(const GraphDef& input_def) { ...@@ -172,7 +172,7 @@ GraphDef GraphOptimizer::AddInplace(const GraphDef& input_def) {
GraphDef GraphOptimizer::MirrorStage( GraphDef GraphOptimizer::MirrorStage(
const GraphDef& input_def, const GraphDef& input_def,
Map<string, vec32_t >& op_indices) { Map<string, vec32_t>& op_indices) {
GraphDef output_def(input_def); GraphDef output_def(input_def);
Map<string, set<int>> fake_op_indices; Map<string, set<int>> fake_op_indices;
Map<string, string> rename_map; Map<string, string> rename_map;
...@@ -180,7 +180,7 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -180,7 +180,7 @@ GraphDef GraphOptimizer::MirrorStage(
// Check mirror stage // Check mirror stage
for (const auto& op : input_def.op()) { for (const auto& op : input_def.op()) {
if (op.type().find("Gradient") != string::npos) continue; if (str::find(op.type(), "Gradient")) continue;
bool mirror_stage = false; bool mirror_stage = false;
for (auto& arg : op.arg()) for (auto& arg : op.arg())
if (arg.name() == "mirror_stage") if (arg.name() == "mirror_stage")
...@@ -194,8 +194,8 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -194,8 +194,8 @@ GraphDef GraphOptimizer::MirrorStage(
// Allocate the temporal buffers // Allocate the temporal buffers
string v2_name, version_name; string v2_name, version_name;
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < input_def.op_size(); ++i) {
const OperatorDef& op = input_def.op(i); const auto& op = input_def.op(i);
OperatorDef* op_v2 = output_def.mutable_op(i); auto* op_v2 = output_def.mutable_op(i);
vector<string> used_buffers; vector<string> used_buffers;
for (int j = 0; j < op.input_size(); ++j) { for (int j = 0; j < op.input_size(); ++j) {
const auto& it = rename_map.find(op.input(j)); const auto& it = rename_map.find(op.input(j));
...@@ -217,15 +217,16 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -217,15 +217,16 @@ GraphDef GraphOptimizer::MirrorStage(
continue; continue;
} }
for (int k = 0; k < GRAPH_TEPORAL_OUTPUT_MAX_SIZE; ++k) { for (int k = 0; k < GRAPH_TEPORAL_OUTPUT_MAX_SIZE; ++k) {
v2_name = "/share/buffer/symbol:" + std::to_string(k); v2_name = "/share/buffer/symbol:" + str::to(k);
for (const auto& buffer : used_buffers) for (const auto& buffer : used_buffers)
if (buffer.find(v2_name) != string::npos) { v2_name.clear(); } if (str::find(buffer, v2_name)) { v2_name.clear(); }
if (!v2_name.empty()) { used_buffers.emplace_back(v2_name); break; } if (!v2_name.empty()) { used_buffers.emplace_back(v2_name); break; }
} }
CHECK(!v2_name.empty()) << "\nNo enough buffers for outputs."; CHECK(!v2_name.empty()) << "\nNo enough buffers for outputs.";
ws_->CreateTensor(v2_name)->set_version(0); ws_->CreateTensor(v2_name)->set_version(0);
version_name = "/ver:" + std::to_string(versions[v2_name]++); version_name = "/ver:" + str::to(versions[v2_name]++);
*op_v2->mutable_output(j) = rename_map[op.output(j)] = *op_v2->mutable_output(j) =
rename_map[op.output(j)] =
v2_name + version_name; v2_name + version_name;
} }
} }
...@@ -233,8 +234,8 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -233,8 +234,8 @@ GraphDef GraphOptimizer::MirrorStage(
// Plan the minimum recomputing ops for temporal buffers // Plan the minimum recomputing ops for temporal buffers
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < input_def.op_size(); ++i) {
const OperatorDef& input_op = input_def.op(i); const auto& input_op = input_def.op(i);
const OperatorDef& output_op = output_def.op(i); const auto& output_op = output_def.op(i);
/* ---------------------------------------------------------- /* ----------------------------------------------------------
* *
...@@ -244,7 +245,7 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -244,7 +245,7 @@ GraphDef GraphOptimizer::MirrorStage(
* *
* ---------------------------------------------------------- */ * ---------------------------------------------------------- */
set<int> minimum_ops = {i}; set<int> minimum_ops = { i };
for (int j = 0; j < input_op.input_size(); ++j) { for (int j = 0; j < input_op.input_size(); ++j) {
if (input_op.input(j) != output_op.input(j)) { if (input_op.input(j) != output_op.input(j)) {
for (auto idx : fake_op_indices[input_op.input(j)]) for (auto idx : fake_op_indices[input_op.input(j)])
...@@ -270,86 +271,91 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -270,86 +271,91 @@ GraphDef GraphOptimizer::MirrorStage(
/* Allocate the buffer for outputs (-O3) */ /* Allocate the buffer for outputs (-O3) */
GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) { GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) {
GraphDef output_def(input_def); Set<string> blacklist;
Set<string> targets;
Map<string, int> ref_count; Map<string, int> ref_count;
Map<string, string> rename_map; Map<string, string> rename_map;
Set<string> dimension_ops = { static Set<string> dim_ops = {
"Shape", "Reshape", "Flatten", "Shape",
"ExpandDims", "Squeeze", "Squeeze",
}, blacklist_outputs, star_ops = { "Reshape",
"Reshape", "Crop", "NNResize", "Flatten",
"ExpandDims",
}, star_ops = {
"Crop",
"Reshape",
"NNResize",
"BilinearResize", "BilinearResize",
}; };
// Prepare the Outputs Pool // Prepare the pool
int temporary_idx = 0; int temporary_idx = 0;
std::deque<string> outputs_pool; std::deque<string> pool;
auto get_temporary_output = [&]() mutable { auto get_temporary_output = [&]() mutable {
if (outputs_pool.empty()) { if (pool.empty()) {
return "/share/buffer/output:" + return "/share/buffer/output:" +
std::to_string(temporary_idx++); str::to(temporary_idx++);
} else { } else {
string temporary_output = outputs_pool.back(); auto temporary_output = pool.back();
outputs_pool.pop_back(); pool.pop_back();
return temporary_output; return temporary_output;
} }
}; };
// Count the references // Count the references
for (const auto& op : input_def.op()) { for (const auto& op : input_def.op()) {
for (auto& input : op.input()) for (const auto& input : op.input())
ref_count[input] += 1; ref_count[input] += 1;
if (dimension_ops.count(op.type())) if (dim_ops.count(op.type()))
blacklist_outputs.insert(op.input(0)); blacklist.insert(op.input(0));
if (star_ops.count(op.type())) { if (star_ops.count(op.type())) {
for (const auto& arg : op.arg()) for (const auto& arg : op.arg())
if (arg.name() == "shape_like") if (arg.name() == "shape_like")
blacklist_outputs.insert(arg.s()); blacklist.insert(
ws_->GetTensorName(arg.s())
);
} }
} }
// We should preserve the targets // We should preserve the targets
for (auto& e : input_def.output()) targets.insert(e); for (auto& e : input_def.output()) blacklist.insert(e);
// Rewritten the inputs and outputs // Rewritten the inputs and outputs
string ori_name; bool inplace_flag; auto output_def(input_def);
string name; bool inplace_flag;
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < input_def.op_size(); ++i) {
vector<string> GC; vector<string> GC;
const OperatorDef& op = input_def.op(i); const auto& op = input_def.op(i);
OperatorDef* op_v2 = output_def.mutable_op(i); auto* op_v2 = output_def.mutable_op(i);
// Ignore the init operators // Ignore the init operators
if (op.input_size() == 0) continue; if (op.input_size() == 0) continue;
// Analyze the inputs // Analyze the inputs
for (int j = 0; j < op.input_size(); ++j) { for (int j = 0; j < op.input_size(); ++j) {
ori_name = op.input(j); name = op.input(j);
if (rename_map.count(ori_name)) { if (rename_map.count(name)) {
*op_v2->mutable_input(j) = *op_v2->mutable_input(j) =
rename_map[ori_name]; rename_map[name];
} }
ref_count[ori_name]--; ref_count[name]--;
if (ref_count[ori_name] == 0 && op_v2->input(j).find( if (ref_count[name] == 0 && str::find(
"/share/buffer/output:") != string::npos) op_v2->input(j), "/share/buffer/output:"))
GC.push_back(op_v2->input(j)); GC.push_back(op_v2->input(j));
} }
// Allocate the buffers // Allocate the buffers
if (!dimension_ops.count(op.type())) { if (!dim_ops.count(op.type())) {
for (int j = 0; j < op.output_size(); ++j) { for (int j = 0; j < op.output_size(); ++j) {
name = op.output(j);
inplace_flag = false; inplace_flag = false;
ori_name = op.output(j); if (blacklist.count(name)) continue;
if (targets.count(ori_name) ||
blacklist_outputs.count(ori_name))
continue;
for (const auto& input : op.input()) for (const auto& input : op.input())
if (ori_name == input) inplace_flag = true; if (name == input) inplace_flag = true;
if (inplace_flag) { if (inplace_flag) {
*op_v2->mutable_output(j) = op_v2->input(j); *op_v2->mutable_output(j) = op_v2->input(j);
} else { } else {
rename_map[ori_name] = rename_map[name] =
*op_v2->mutable_output(j) = *op_v2->mutable_output(j) =
get_temporary_output(); get_temporary_output();
} }
...@@ -357,7 +363,7 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) { ...@@ -357,7 +363,7 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) {
} }
// Update the pool from GC // Update the pool from GC
for (auto& e : GC) outputs_pool.emplace_back(e); for (const auto& e : GC) pool.emplace_back(e);
} }
return output_def; return output_def;
...@@ -377,8 +383,8 @@ void GraphOptimizer::ForwardPruneTraversal( ...@@ -377,8 +383,8 @@ void GraphOptimizer::ForwardPruneTraversal(
} }
visited_[u] = false; visited_[u] = false;
for (int i = 0; i < dag_[u].childs.size(); ++i) { for (int i = 0; i < dag_[u].childs.size(); ++i) {
string v = dag_[u].childs[i]; auto v = dag_[u].childs[i];
vector<string> new_path(path); auto new_path(path);
new_path.push_back(v); new_path.push_back(v);
if (v == leaf) { if (v == leaf) {
for (const auto& node : new_path) for (const auto& node : new_path)
...@@ -394,7 +400,7 @@ void GraphOptimizer::ForwardPruneTraversal( ...@@ -394,7 +400,7 @@ void GraphOptimizer::ForwardPruneTraversal(
void GraphOptimizer::BackwardPruneTraversal(const string& v) { void GraphOptimizer::BackwardPruneTraversal(const string& v) {
colored_[v] = true; colored_[v] = true;
for (int i = 0; i < dag_[v].parents.size(); ++i) { for (int i = 0; i < dag_[v].parents.size(); ++i) {
string u = dag_[v].parents[i]; auto u = dag_[v].parents[i];
if (colored_.count(u)) continue; if (colored_.count(u)) continue;
BackwardPruneTraversal(u); BackwardPruneTraversal(u);
} }
......
...@@ -153,7 +153,7 @@ const Map<string, string> MixedMemory::info() const { ...@@ -153,7 +153,7 @@ const Map<string, string> MixedMemory::info() const {
<< "but got invalid mem pointer."; << "but got invalid mem pointer.";
} }
s2s["device_type"] = _state_; s2s["device_type"] = _state_;
s2s["device_id"] = std::to_string(ptr_device_); s2s["device_id"] = str::to(ptr_device_);
return s2s; return s2s;
} }
......
...@@ -9,17 +9,15 @@ namespace dragon { ...@@ -9,17 +9,15 @@ namespace dragon {
OperatorBase::OperatorBase( OperatorBase::OperatorBase(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: def_(def), ws_(ws), : def_(def), ws_(ws), handle_(def.name()),
anchor_(def.name()), dtype_("float32"), data_format_("NCHW") {
dtype_("float32"),
data_format_("NCHW") {
// Scan the defined arguments // Scan the defined arguments
for (auto& arg : def_.arg()) { for (auto& arg : def_.arg()) {
CHECK_GT(arg.name().size(), 0); CHECK_GT(arg.name().size(), 0);
CHECK_EQ(args_.count(arg.name()), 0); CHECK_EQ(args_.count(arg.name()), 0);
args_[arg.name()] = &arg; args_[arg.name()] = &arg;
if (arg.name() == "anchor") { if (arg.name() == "handle") {
anchor_ = arg.s(); handle_ = arg.s();
} else if (arg.name() == "dtype") { } else if (arg.name() == "dtype") {
dtype_ = arg.s(); dtype_ = arg.s();
} else if (arg.name() == "data_format") { } else if (arg.name() == "data_format") {
...@@ -28,20 +26,18 @@ OperatorBase::OperatorBase( ...@@ -28,20 +26,18 @@ OperatorBase::OperatorBase(
} }
// Set the inputs and outputs // Set the inputs and outputs
string tensor_name; size_t ver_pos; string name; size_t ver_pos;
for (auto& e : def.input()) { for (const auto& e : def.input()) {
tensor_name = e; name = e;
if ((ver_pos = e.find("/ver:")) != string::npos) if ((ver_pos = e.find("/ver:")) != string::npos)
tensor_name = e.substr(0, ver_pos); name = e.substr(0, ver_pos);
auto* tensor = ws->GetTensor(tensor_name); inputs_.push_back(ws->GetTensor(name));
inputs_.push_back(tensor);
} }
for (auto& e : def.output()) { for (const auto& e : def.output()) {
tensor_name = e; name = e;
if ((ver_pos = e.find("/ver:")) != string::npos) if ((ver_pos = e.find("/ver:")) != string::npos)
tensor_name = e.substr(0, ver_pos); name = e.substr(0, ver_pos);
auto* tensor = ws->CreateTensor(tensor_name); outputs_.push_back(ws->CreateTensor(name));
outputs_.push_back(tensor);
} }
} }
...@@ -90,10 +86,10 @@ string OperatorBase::DTypeString( ...@@ -90,10 +86,10 @@ string OperatorBase::DTypeString(
return ss.str(); return ss.str();
} }
/* Modify this operator according to the given def */ /* Modify operator according to the given def */
void OperatorBase::UpdateFrom(const OperatorDef& def) { void OperatorBase::UpdateFrom(const OperatorDef& def) {
anchor_ = def.name(); handle_ = def.name();
inputs_.resize(def.input_size()); inputs_.resize(def.input_size());
outputs_.resize(def.output_size()); outputs_.resize(def.output_size());
for (int i = 0; i < inputs_.size(); i++) for (int i = 0; i < inputs_.size(); i++)
...@@ -102,7 +98,7 @@ void OperatorBase::UpdateFrom(const OperatorDef& def) { ...@@ -102,7 +98,7 @@ void OperatorBase::UpdateFrom(const OperatorDef& def) {
outputs_[i] = ws()->CreateTensor(def.output(i)); outputs_[i] = ws()->CreateTensor(def.output(i));
} }
/*! Create a operator instance from the factory */ /* Create an operator from the factory */
OperatorBase* TryCreateOperator( OperatorBase* TryCreateOperator(
const string& key, const string& key,
...@@ -127,7 +123,7 @@ OperatorBase* TryCreateOperator( ...@@ -127,7 +123,7 @@ OperatorBase* TryCreateOperator(
} }
} }
/* New a operator from the raw def */ /* New an operator from the raw def */
OperatorBase* NewOperator( OperatorBase* NewOperator(
const OperatorDef& def, const OperatorDef& def,
...@@ -164,7 +160,7 @@ Gradient MakeGradientForOp( ...@@ -164,7 +160,7 @@ Gradient MakeGradientForOp(
for (int i = 0; i < grad.ops.size(); ++i) { for (int i = 0; i < grad.ops.size(); ++i) {
grad.ops[i].set_uid( grad.ops[i].set_uid(
reference_def.uid() + "/grad" + reference_def.uid() + "/grad" +
(i > 0 ? (":" + std::to_string(i)) : "") (i > 0 ? (":" + str::to(i)) : "")
); );
} }
} }
......
...@@ -241,9 +241,9 @@ string Workspace::GetDummyName( ...@@ -241,9 +241,9 @@ string Workspace::GetDummyName(
while (1) { while (1) {
index = dmap[required_name]++; index = dmap[required_name]++;
accepted_name = index ? base_name + "_" + accepted_name = index ? base_name + "_" +
std::to_string(index) + suffix : str::to(index) + suffix :
zero_based ? required_name : zero_based ? required_name :
base_name + "_" + std::to_string( base_name + "_" + str::to(
dmap[required_name]++) + suffix; dmap[required_name]++) + suffix;
if (remote_workspaces_.empty()) break; if (remote_workspaces_.empty()) break;
if (!HasTensor(accepted_name)) break; if (!HasTensor(accepted_name)) break;
......
...@@ -18,18 +18,18 @@ void _IndexSelect( ...@@ -18,18 +18,18 @@ void _IndexSelect(
const T* x, const T* x,
T* y, T* y,
CPUContext* ctx) { CPUContext* ctx) {
int64_t x_offset, select_idx; int64_t x_ofs, select_idx;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
for (int i = 0; i < num_indices; ++i) { for (int i = 0; i < num_indices; ++i) {
select_idx = indices[i]; select_idx = indices[i];
select_idx = select_idx >= 0 ? select_idx = select_idx >= 0 ?
select_idx : select_idx + axis_dim; select_idx : select_idx + axis_dim;
x_offset = ( x_ofs = (
n * axis_dim + select_idx n * axis_dim + select_idx
) * inner_dim; ) * inner_dim;
math::Copy( math::Copy(
inner_dim, inner_dim,
x + x_offset, x + x_ofs,
y, ctx y, ctx
); y += inner_dim; ); y += inner_dim;
} }
...@@ -48,7 +48,7 @@ void _IndexSelectGrad( ...@@ -48,7 +48,7 @@ void _IndexSelectGrad(
const T* dy, const T* dy,
T* dx, T* dx,
CPUContext* ctx) { CPUContext* ctx) {
int64_t x_offset, select_idx; int64_t x_ofs, select_idx;
auto nelements = outer_dim * axis_dim * inner_dim; auto nelements = outer_dim * axis_dim * inner_dim;
math::Set(nelements, cast::to<T>(0.f), dx, ctx); math::Set(nelements, cast::to<T>(0.f), dx, ctx);
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
...@@ -56,13 +56,13 @@ void _IndexSelectGrad( ...@@ -56,13 +56,13 @@ void _IndexSelectGrad(
select_idx = indices[i]; select_idx = indices[i];
select_idx = select_idx >= 0 ? select_idx = select_idx >= 0 ?
select_idx : select_idx + axis_dim; select_idx : select_idx + axis_dim;
x_offset = ( x_ofs = (
n * axis_dim + select_idx n * axis_dim + select_idx
) * inner_dim; ) * inner_dim;
math::Add( math::Add(
inner_dim, inner_dim,
dy, dx + x_offset, dy, dx + x_ofs,
dx + x_offset, ctx dx + x_ofs, ctx
); dy += inner_dim; ); dy += inner_dim;
} }
} }
......
...@@ -59,7 +59,7 @@ ONNXImporterReturns ONNXBackend::ConvPoolNodeImporter( ...@@ -59,7 +59,7 @@ ONNXImporterReturns ONNXBackend::ConvPoolNodeImporter(
auto mode = attributes.get<string>("auto_pad"); auto mode = attributes.get<string>("auto_pad");
auto* padding = attributes.AddRewrittenAttribute("padding"); auto* padding = attributes.AddRewrittenAttribute("padding");
// SAME, SAME_LOWER, or SAME_UPPER // SAME, SAME_LOWER, or SAME_UPPER
if (mode.find("SAME") != string::npos) padding->set_s(mode); if (str::find(mode, "SAME")) padding->set_s(mode);
else padding->set_s("VALID"); // Use explicit pads else padding->set_s("VALID"); // Use explicit pads
attributes.remove("auto_pad"); attributes.remove("auto_pad");
...@@ -77,7 +77,7 @@ ONNXImporterReturns ONNXBackend::ConvPoolNodeImporter( ...@@ -77,7 +77,7 @@ ONNXImporterReturns ONNXBackend::ConvPoolNodeImporter(
// Determine the op type // Determine the op type
OperatorDef* op_def = returns.GetOp(0); OperatorDef* op_def = returns.GetOp(0);
auto ks = attributes.get<ONNX_INTS>("kernel_shape"); auto ks = attributes.get<ONNX_INTS>("kernel_shape");
*(op_def->mutable_type()) += (std::to_string(ks.size()) + "d"); *(op_def->mutable_type()) += (str::to(ks.size()) + "d");
return returns; return returns;
} }
...@@ -329,7 +329,7 @@ ONNXImporterReturns ONNXBackend::LpNormNodeImporter( ...@@ -329,7 +329,7 @@ ONNXImporterReturns ONNXBackend::LpNormNodeImporter(
// Determine the "p", i.e. op type // Determine the "p", i.e. op type
auto p = attributes.get<int64_t>("p", 2); auto p = attributes.get<int64_t>("p", 2);
node.set_op_type("L" + std::to_string(p) + "Norm"); node.set_op_type("L" + str::to(p) + "Norm");
attributes.remove("p"); attributes.remove("p");
auto* num_axes = attributes.AddRewrittenAttribute("num_axes"); auto* num_axes = attributes.AddRewrittenAttribute("num_axes");
......
...@@ -24,7 +24,7 @@ void CuDNNDropoutOp<Context>::RunImpl() { ...@@ -24,7 +24,7 @@ void CuDNNDropoutOp<Context>::RunImpl() {
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
auto* states_tensor = ws()->CreateTensor( auto* states_tensor = ws()->CreateTensor(
"/share/cudnn/dropout:" + "/share/cudnn/dropout:" +
std::to_string(rng_seed_) + "/states"); str::to(rng_seed_) + "/states");
if (states_tensor->count() > 0) { if (states_tensor->count() > 0) {
auto* states = states_tensor->template auto* states = states_tensor->template
mutable_data<uint8_t, Context>(); mutable_data<uint8_t, Context>();
...@@ -95,7 +95,7 @@ void CuDNNDropoutGradientOp<Context>::RunImpl() { ...@@ -95,7 +95,7 @@ void CuDNNDropoutGradientOp<Context>::RunImpl() {
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
auto* states_tensor = ws()->CreateTensor( auto* states_tensor = ws()->CreateTensor(
"/share/cudnn/dropout:" + "/share/cudnn/dropout:" +
std::to_string(rng_seed_) + "/states"); str::to(rng_seed_) + "/states");
if (states_tensor->count() > 0) { if (states_tensor->count() > 0) {
auto* states = states_tensor->template auto* states = states_tensor->template
mutable_data<uint8_t, Context>(); mutable_data<uint8_t, Context>();
......
...@@ -68,7 +68,7 @@ void DropoutGradientOp<Context>::RunImpl() { ...@@ -68,7 +68,7 @@ void DropoutGradientOp<Context>::RunImpl() {
dx, ctx() dx, ctx()
); );
} else { } else {
LOG(FATAL) << "Incorrect Op phase: " << phase(); LOG(FATAL) << "Unknown Phase: " << phase();
} }
} }
......
...@@ -27,7 +27,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -27,7 +27,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
auto* states_tensor = ws()->CreateTensor( auto* states_tensor = ws()->CreateTensor(
"/share/cudnn/dropout:" + "/share/cudnn/dropout:" +
std::to_string(rng_seed_) + "/states"); str::to(rng_seed_) + "/states");
if (states_tensor->count() > 0) { if (states_tensor->count() > 0) {
auto* states = states_tensor->template auto* states = states_tensor->template
mutable_data<uint8_t, Context>(); mutable_data<uint8_t, Context>();
......
...@@ -15,7 +15,7 @@ void ConvOpBase<Context>::ComputeOutShape() { ...@@ -15,7 +15,7 @@ void ConvOpBase<Context>::ComputeOutShape() {
if (!Transposed()) { if (!Transposed()) {
auto idm = x_shape_[axis_ + i]; auto idm = x_shape_[axis_ + i];
auto dk = dilation_[i] * (kshape_[i] - 1) + 1; auto dk = dilation_[i] * (kshape_[i] - 1) + 1;
if (padding_.find("SAME") == string::npos) { if (!str::find(padding_, "SAME")) {
// Explicit pads // Explicit pads
auto odm = ( auto odm = (
idm + pad_l_[i] + pad_r_[i] - dk idm + pad_l_[i] + pad_r_[i] - dk
...@@ -38,7 +38,7 @@ void ConvOpBase<Context>::ComputeOutShape() { ...@@ -38,7 +38,7 @@ void ConvOpBase<Context>::ComputeOutShape() {
} else { } else {
auto idm = x_shape_[axis_ + i]; auto idm = x_shape_[axis_ + i];
auto dk = dilation_[i] * (kshape_[i] - 1) + 1; auto dk = dilation_[i] * (kshape_[i] - 1) + 1;
if (padding_.find("SAME") == string::npos) { if (!str::find(padding_, "SAME")) {
// Explicit pads // Explicit pads
auto odm = stride_[i] * (idm - 1 auto odm = stride_[i] * (idm - 1
) + dk - pad_l_[i] - pad_r_[i]; ) + dk - pad_l_[i] - pad_r_[i];
......
...@@ -17,7 +17,7 @@ namespace dragon { ...@@ -17,7 +17,7 @@ namespace dragon {
for (int i = 0; i < 2; i++) \ for (int i = 0; i < 2; i++) \
kshape_[i] = X(0).dim(i + 2); \ kshape_[i] = X(0).dim(i + 2); \
} \ } \
if (padding_.find("SAME") != string::npos) { \ if (str::find(padding_, "SAME")) { \
for (int i = 0; i < 2; i++) { \ for (int i = 0; i < 2; i++) { \
auto idm = X(0).dim(i + 2); \ auto idm = X(0).dim(i + 2); \
int64_t odm = (idm + stride_[i] - 1 \ int64_t odm = (idm + stride_[i] - 1 \
...@@ -38,7 +38,7 @@ namespace dragon { ...@@ -38,7 +38,7 @@ namespace dragon {
for (int i = 0; i < 2; i++) \ for (int i = 0; i < 2; i++) \
kshape_[i] = X(0).dim(i + 1); \ kshape_[i] = X(0).dim(i + 1); \
} \ } \
if (padding_.find("SAME") != string::npos) { \ if (str::find(padding_, "SAME")) { \
for (int i = 0; i < 2; i++) { \ for (int i = 0; i < 2; i++) { \
auto idm = X(0).dim(i + 1); \ auto idm = X(0).dim(i + 1); \
int64_t odm = (idm + stride_[i] - 1 \ int64_t odm = (idm + stride_[i] - 1 \
...@@ -55,7 +55,7 @@ namespace dragon { ...@@ -55,7 +55,7 @@ namespace dragon {
} else { \ } else { \
LOG(FATAL) << "Unknown DataFormat: " << data_format(); \ LOG(FATAL) << "Unknown DataFormat: " << data_format(); \
} \ } \
if (padding_.find("SAME") == string::npos) { \ if (!str::find(padding_, "SAME")) { \
/*! Case 1: infer output shape with explicit pads */ \ /*! Case 1: infer output shape with explicit pads */ \
if (ceil_mode_) { \ if (ceil_mode_) { \
pool_h_ = ceil( \ pool_h_ = ceil( \
......
...@@ -35,7 +35,7 @@ LogSeverity StrToLogSeverity(std::string level) { ...@@ -35,7 +35,7 @@ LogSeverity StrToLogSeverity(std::string level) {
} }
std::string GenLogHashKey(const char* file, int line) { std::string GenLogHashKey(const char* file, int line) {
return std::string(file) + std::to_string(line); return std::string(file) + str::to(line);
} }
int EveryNRegister( int EveryNRegister(
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!