Commit 4bef6a6b by Ting PAN

Add ND-Crop & ND-Pad support

1 parent b7365da6
Showing with 1094 additions and 333 deletions
......@@ -19,12 +19,12 @@
namespace dragon {
class CPUObject{
class CPUObject {
public:
unique_ptr<std::mt19937> rand_generator;
};
class CPUContext{
class CPUContext {
public:
CPUContext(): random_seed_(3) { generator(); }
CPUContext(unsigned int random_seed): random_seed_(random_seed) { generator(); }
......
......@@ -21,11 +21,13 @@ class GraphBase {
string op_type;
};
GraphBase(const GraphDef& graph_def, Workspace* ws);
GraphBase(const GraphDef& meta_graph, Workspace* ws);
virtual bool Create(const GraphDef& graph_def, Workspace* ws) = 0;
virtual bool Create(const GraphDef& optimized_graph, Workspace* ws) = 0;
virtual bool Run(const string& include, const string& exclude) = 0;
inline string name() const { return name_; }
protected:
string name_, phase_;
Map<string, Argument> args_;
......@@ -34,15 +36,15 @@ class GraphBase {
class Graph final : public GraphBase {
public:
Graph(const GraphDef& graph_def, Workspace* ws);
Graph(const GraphDef& meta_graph, Workspace* ws);
bool Create(const GraphDef& graph_def, Workspace* ws) override;
bool Create(const GraphDef& optimized_graph, Workspace* ws) override;
bool Run(const string& include, const string& exclude) override;
GraphDef Prune(const GraphDef& graph_def);
GraphDef Share(const GraphDef& graph_def);
GraphDef MakeUpdate(const GraphDef& graph_def);
void RecomputingAware(const GraphDef& graph_def, Workspace* ws);
GraphDef Prune(const GraphDef& meta_graph);
GraphDef MakeUpdate(const GraphDef& meta_graph);
GraphDef Share(const GraphDef& optimized_graph);
void RecomputingAware(const GraphDef& optimized_graph, Workspace* ws);
inline Workspace* ws() const { return ws_; }
......@@ -58,7 +60,7 @@ class Graph final : public GraphBase {
Set<string> targets_;
};
GraphBase* NewGraph(const GraphDef& graph_def, Workspace* ws);
GraphBase* NewGraph(const GraphDef& meta_graph, Workspace* ws);
DECLARE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*);
} // namespace dragon
......
......@@ -13,7 +13,7 @@
namespace dragon {
class MixedMemory{
class MixedMemory {
public:
enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED };
MixedMemory()
......
......@@ -22,7 +22,7 @@ namespace dragon {
class Workspace;
class OperatorBase{
class OperatorBase {
public:
OperatorBase(const OperatorDef& op_def, Workspace* ws);
......
......@@ -14,7 +14,7 @@
namespace dragon {
class OpSchema{
class OpSchema {
public:
OpSchema()
: op_type_("unknown"), file_("unknown"), line_(0) { Init(); }
......
......@@ -17,8 +17,9 @@ namespace dragon {
#define WORKSPACE_GRAD_BUFFER_SIZE 1
#define WORKSPACE_MAX_CORRUPTED_SIZE 2
class Workspace{
class Workspace {
public:
typedef Map<string, Workspace*> WorkspaceMap;
typedef Map<string, unique_ptr<Tensor> > TensorMap;
typedef Map<string, stack<string> > BufferMap;
typedef Map<string, unique_ptr<mutex> > LockMap;
......@@ -26,7 +27,7 @@ class Workspace{
typedef Map<string, TensorFiller> FillerMap;
typedef Map<string, string> RenameMap;
Workspace() { init(); }
Workspace(const string& name) : name_(name) { init(); }
~Workspace();
void init() {
......@@ -35,16 +36,35 @@ class Workspace{
CreateBuffer("Grad", WORKSPACE_GRAD_BUFFER_SIZE);
}
const string& name() { return name_; }
/******************** Workspace ********************/
inline Workspace* MoveWorkspace(Workspace* ws) {
CHECK(ws) << "The given Workspace is invalid.";
if (workspace_map_.count(ws->name()))
return workspace_map_[ws->name()];
return workspace_map_[ws->name()] = ws;
}
/******************** Tensor ********************/
inline string GetTensorName(const string& name) {
if (rename_map_.count(name)) return rename_map_[name];
else return name;
if (rename_map_.count(name) > 0) {
return rename_map_[name];
} else { return name; }
}
inline bool HasTensor(const string& name) {
inline bool HasTensor(const string& name, bool use_remote=true) {
// search local workspace
string query = GetTensorName(name);
return tensor_map_.count(query) > 0;
bool result = tensor_map_.count(query) > 0;
if (!use_remote) return result;
// search remote workspace
for (auto& it : workspace_map_)
result |= it.second->HasTensor(query);
return result;
}
inline Tensor* CreateTensor(const string& name) {
......@@ -54,11 +74,21 @@ class Workspace{
return tensor_map_[query].get();
}
inline Tensor* GetTensor(const string& name) {
inline Tensor* GetTensor(const string& name, bool use_remote=true) {
string query = GetTensorName(name);
CHECK(HasTensor(query))
<< "Tensor(" << name << ") does not exist.";
// search local workspace
if (tensor_map_.count(query) > 0)
return tensor_map_[query].get();
if (use_remote) {
// search remote workspace
for (auto& it : workspace_map_) {
if (it.second->HasTensor(query))
return it.second->GetTensor(query);
}
}
LOG(FATAL) << "Tensor(" << name << ") does not exist "
<< "in current workspace and it's sub-workspace.";
return nullptr;
}
inline void LockTensor(const string& name) {
......@@ -76,15 +106,23 @@ class Workspace{
}
inline void ReleaseTensor(const string& name) {
CHECK(HasTensor(name)) << "\nTensor(" << name << ") does not "
<< "belong to workspace, could not release it.";
CHECK(HasTensor(name, false))
<< "\nTensor(" << name << ") does not "
<< "belong to current workspace, could not release it.";
string query = GetTensorName(name);
tensor_map_[query]->Reset();
}
inline vector<string> GetTensors() {
vector<string> names;
for (auto& it : tensor_map_) names.push_back(it.first);
// search local workspace
for (auto& it : tensor_map_)
names.push_back(it.first);
// serach remote workspace
for (auto& it : workspace_map_) {
vector<string> sub_names = it.second->GetTensors();
names.insert(names.end(), sub_names.begin(), sub_names.end());
}
return names;
}
......@@ -118,7 +156,7 @@ class Workspace{
if (!buffer_map_[category].empty()) {
string name = buffer_map_[category].top();
buffer_map_[category].pop();
return GetTensor(name);
return tensor_map_[name].get();
}
LOG(FATAL) << "Buffers of [" << category << "] "
<< "are not enough, add more if necessary.";
......@@ -142,9 +180,11 @@ class Workspace{
/******************** Graph ********************/
GraphBase* CreateGraph(const GraphDef& graph_def);
GraphBase* CreateGraph(const GraphDef& meta_graph);
inline bool RunGraph(const string& graph_name,
const string& include, const string& exclude) {
const string& include,
const string& exclude) {
if (!graph_map_.count(graph_name)) {
LOG(ERROR) << "Graph(" << graph_name << ") does not exist.";
return false;
......@@ -166,6 +206,8 @@ class Workspace{
}
private:
string name_;
WorkspaceMap workspace_map_;
TensorMap tensor_map_;
BufferMap buffer_map_;
LockMap lock_map_;
......
......@@ -16,29 +16,24 @@ class CropOp: public Operator<Context> {
public:
CropOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 2)),
offsets_param(OperatorBase::GetRepeatedArg<int>("offsets")),
starts(OperatorBase::GetRepeatedArg<int>("starts")),
ends(OperatorBase::GetRepeatedArg<int>("ends")),
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
CHECK(shape.size() * shape_like.size() == 0)
<< "\nCan not set shape and shape_like both.";
CHECK(shape.size() + shape_like.size() != 0)
<< "\nMust set shape and shape_like either.";
}
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {}
void ComputeOutputShape();
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename T> void RecursiveRunWithType(vector<TIndex> idxs,
const vector<TIndex>& offsets,
int cur_dim,
Tensor* x,
Tensor* y);
protected:
TIndex axis;
vector<int> offsets_param, shape;
vector<TIndex> output_shape, offsets;
TIndex start_axis;
string shape_like;
vector<int> starts, ends, offsets, shape;
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
};
template <class Context>
......@@ -46,29 +41,24 @@ class CropGradientOp final : public Operator<Context > {
public:
CropGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 2)),
offsets_param(OperatorBase::GetRepeatedArg<int>("offsets")),
starts(OperatorBase::GetRepeatedArg<int>("starts")),
ends(OperatorBase::GetRepeatedArg<int>("ends")),
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
CHECK(shape.size() * shape_like.size() == 0)
<< "\ncan not set shape and shape_like both.";
CHECK(shape.size() + shape_like.size() != 0)
<< "\nmust set shape and shape_like either.";
}
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {}
void ComputeOutputShape();
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename T> void RecursiveRunWithType(vector<TIndex> idxs,
const vector<TIndex>& offsets,
int cur_dim,
Tensor* dy,
Tensor* dx);
protected:
TIndex axis;
vector<int> offsets_param, shape;
vector<TIndex> output_shape, offsets;
TIndex start_axis;
string shape_like;
vector<int> starts, ends, offsets, shape;
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
};
} // namespace dragon
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
#define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class PadOp final : public Operator<Context> {
public:
PadOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
pad_l(OperatorBase::GetRepeatedArg<int>("pad_l")),
pad_r(OperatorBase::GetRepeatedArg<int>("pad_r")),
mode(OperatorBase::GetSingleArg<string>("mode", "CONSTANT")),
value(OperatorBase::GetSingleArg<float>("value", 0.0f)) {
if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length.";
for (int i = 0; i < pad_l.size(); i++) {
int padding_size = pad_l[i] + pad_r[i];
if (padding_size > 0)
process_axes.push_back({ padding_size, i });
}
std::sort(process_axes.begin(), process_axes.end());
}
void RunOnDevice() override;
template <typename T> void ConstRunWithType();
template <typename T> void ReflectRunWithType();
template <typename T> void EdgeRunWithType();
protected:
vector<int> pad_l, pad_r;
string mode;
float value;
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
};
template <class Context>
class PadGradientOp final : public Operator<Context> {
public:
PadGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
pad_l(OperatorBase::GetRepeatedArg<int>("pad_l")),
pad_r(OperatorBase::GetRepeatedArg<int>("pad_r")),
mode(OperatorBase::GetSingleArg<string>("mode", "CONSTANT")) {
if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length.";
for (int i = 0; i < pad_l.size(); i++) {
int padding_size = pad_l[i] + pad_r[i];
if (padding_size > 0)
process_axes.push_back({ padding_size, i });
}
std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end());
}
void RunOnDevice() override;
template <typename T> void ConstRunWithType();
template <typename T> void ReflectRunWithType();
template <typename T> void EdgeRunWithType();
protected:
vector<int> pad_l, pad_r;
string mode;
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
\ No newline at end of file
......@@ -19,7 +19,8 @@ class TileOp : public Operator<Context> {
multiples(OperatorBase::GetRepeatedArg<int>("multiples")) {
for (int i = 0; i < multiples.size(); i++)
if (multiples[i] > 1)
process_axes.push_back({ i, multiples[i] });
process_axes.push_back({ multiples[i], i });
std::sort(process_axes.begin(), process_axes.end());
}
void RunOnDevice() override;
......@@ -38,9 +39,11 @@ class TileGradientOp : public Operator<Context> {
TileGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
multiples(OperatorBase::GetRepeatedArg<int>("multiples")) {
for (int i = (int)multiples.size() - 1; i >= 0; i--)
for (int i = 0; i < multiples.size(); i++)
if (multiples[i] > 1)
process_axes.push_back({ i, multiples[i] });
process_axes.push_back({ multiples[i], i });
std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end());
}
void RunOnDevice() override;
......
......@@ -34,7 +34,7 @@ static const int CUDA_NUM_THREADS = 1024;
#define CUDA_CHECK(condition) \
do { \
cudaError_t error = condition; \
CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
CHECK_EQ(error, cudaSuccess) << "\n" << cudaGetErrorString(error); \
} while (0)
#define CUBLAS_CHECK(condition) \
......@@ -53,7 +53,7 @@ static const int CUDA_NUM_THREADS = 1024;
#define NCCL_CHECK(condition) \
do { \
ncclResult_t status = condition; \
CHECK_EQ(status, ncclSuccess) << " " << ncclGetErrorString(status); \
CHECK_EQ(status, ncclSuccess) << "\n" << ncclGetErrorString(status); \
} while (0)
#endif // WITH_MPI_NCCL
......
......@@ -28,7 +28,7 @@ class Tensor;
#define CUDNN_CHECK(condition) \
do { \
cudnnStatus_t status = condition; \
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "\n" \
<< cudnnGetErrorString(status); \
} while (0)
......
......@@ -319,20 +319,80 @@ void ConcatGrad(const int count,
/******************** ndarray.crop ********************/
template <typename T, class Context>
void Crop2D(vector<TIndex> idxs,
const vector<TIndex>& offsets,
const int cur_dim,
Tensor* x,
Tensor* y,
Context* context);
void Crop1D(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int start,
const T* x,
T* y);
template <typename T, class Context>
void Crop2DGrad(vector<TIndex> idxs,
const vector<TIndex>& offsets,
const int cur_dim,
Tensor* dy,
Tensor* dx,
Context* context);
void Crop1DGrad(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int start,
const int end,
const T* dy,
T* dx);
/******************** ndarray.pad ********************/
template <typename T, class Context>
void ConstPad1D(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float value,
const T* x,
T* y);
template <typename T, class Context>
void ReflectPad1D(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const T* x,
T* y);
template <typename T, class Context>
void EdgePad1D(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const T* x,
T* y);
template <typename T, class Context>
void ConstPad1DGrad(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const T* dy,
T* dx);
template <typename T, class Context>
void ReflectPad1DGrad(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const T* dy,
T* dx);
template <typename T, class Context>
void EdgePad1DGrad(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const T* dy,
T* dx);
/******************** ndarray.one_hot ********************/
......
......@@ -121,7 +121,7 @@ bool SwitchWorkspaceInternal(const string& name, const bool create_if_missing) {
g_workspace = g_workspaces[name].get();
return true;
} else if (create_if_missing) {
unique_ptr<Workspace> new_workspace(new Workspace());
unique_ptr<Workspace> new_workspace(new Workspace(name));
g_workspace = new_workspace.get();
g_workspaces[name] = std::move(new_workspace);
g_current_workspace = name;
......@@ -171,7 +171,7 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
CHECK(g_workspaces.count(target_workspace))
<< "\nWorkspace(" << target_workspace << ") does not exist, can not be reset.";
LOG(INFO) << "Reset the Workspace(" << target_workspace << ")";
g_workspaces[target_workspace].reset(new Workspace());
g_workspaces[target_workspace].reset(new Workspace(target_workspace));
g_workspace = g_workspaces[target_workspace].get();
Py_RETURN_TRUE;
}
......
......@@ -34,6 +34,16 @@ option['debug_mode'] = False
# Set it by the memonger
option['share_grads'] = False
# Whether to log the meta graphs
option['log_meta_graph'] = False
# The prefix of exporting directory
# An empty string leads to invalid exporting
option['export_meta_graph'] = ''
# Whether to log the optimized graphs
option['log_optimized_graph'] = False
def EnableCPU():
"""Enable CPU mode globally.
......@@ -150,6 +160,67 @@ def SetDebugMode(enabled=True):
option['debug_mode'] = enabled
def LogMetaGraph(enabled=True):
"""Enable to log meta graph globally.
The meta graph is a describer generated by the VM frontend.
Parameters
----------
enabled : boolean
Whether to enable logging.
Returns
-------
None
"""
global option
option['log_meta_graph'] = enabled
def LogOptimizedGraph(enabled=True):
"""Enable to log optimized graph globally.
The optimized graph is a describer optimized by the VM backend.
Parameters
----------
enabled : boolean
Whether to enable logging.
Returns
-------
None
"""
global option
option['log_optimized_graph'] = enabled
def ExportMetaGraph(prefix=''):
"""Enable to export all runnable meta graphs into text files.
These text files will be saved as the following format:
``prefix/Graph_xxx.metatxt``
Note that an empty prefix will leads to invalid exporting.
Parameters
----------
prefix : str
The prefix of the exporting.
Returns
-------
None
"""
global option
option['export_meta_graph'] = prefix
def SetLoggingLevel(level):
"""Set the minimum level of Logging.
......
......@@ -6,10 +6,10 @@
from collections import defaultdict
TENSOR_SCOPE = ''
PHASE_SCOPE = ''
DEVICE_SCOPE = ''
ENGINE_SCOPE = ''
_TENSOR_SCOPE = ''
_PHASE_SCOPE = ''
_DEVICE_SCOPE = ''
_ENGINE_SCOPE = ''
SEPARATOR = '/'
......@@ -50,8 +50,8 @@ def GetTensorIdx():
"""
global _SCOPE_TENSOR_IDX
_SCOPE_TENSOR_IDX[TENSOR_SCOPE] += 1
return _SCOPE_TENSOR_IDX[TENSOR_SCOPE] - 1
_SCOPE_TENSOR_IDX[_TENSOR_SCOPE] += 1
return _SCOPE_TENSOR_IDX[_TENSOR_SCOPE] - 1
def GetOperatorName(name=None):
......@@ -107,13 +107,19 @@ class TensorScope(object):
self.prefix = prefix + SEPARATOR
def __enter__(self):
global TENSOR_SCOPE
TENSOR_SCOPE += self.prefix
global _TENSOR_SCOPE
_TENSOR_SCOPE += self.prefix
return self.prefix.split(SEPARATOR)[0]
def __exit__(self, type, value, traceback):
global TENSOR_SCOPE
assert TENSOR_SCOPE.endswith(self.prefix)
TENSOR_SCOPE = TENSOR_SCOPE[:-len(self.prefix)]
global _TENSOR_SCOPE
assert _TENSOR_SCOPE.endswith(self.prefix)
_TENSOR_SCOPE = _TENSOR_SCOPE[:-len(self.prefix)]
def set_tensor_scope(name_scope):
global _TENSOR_SCOPE
_TENSOR_SCOPE = name_scope
class PhaseScope(object):
......@@ -135,13 +141,13 @@ class PhaseScope(object):
self.phase = phase
def __enter__(self):
global PHASE_SCOPE
PHASE_SCOPE = self.phase
global _PHASE_SCOPE
_PHASE_SCOPE = self.phase
def __exit__(self, type, value, traceback):
global PHASE_SCOPE
assert PHASE_SCOPE == self.phase
PHASE_SCOPE = ''
global _PHASE_SCOPE
assert _PHASE_SCOPE == self.phase
_PHASE_SCOPE = ''
class DeviceScope(object):
......@@ -163,11 +169,11 @@ class DeviceScope(object):
self.id = id
def __enter__(self):
global DEVICE_SCOPE, ENGINE_SCOPE
DEVICE_SCOPE = '/' + self.device + ':' + str(self.id)
ENGINE_SCOPE = self.engine
global _DEVICE_SCOPE, _ENGINE_SCOPE
_DEVICE_SCOPE = '/' + self.device + ':' + str(self.id)
_ENGINE_SCOPE = self.engine
def __exit__(self, type, value, traceback):
global DEVICE_SCOPE, ENGINE_SCOPE
DEVICE_SCOPE = ''
ENGINE_SCOPE = ''
\ No newline at end of file
global _DEVICE_SCOPE, _ENGINE_SCOPE
_DEVICE_SCOPE = ''
_ENGINE_SCOPE = ''
\ No newline at end of file
......@@ -227,9 +227,9 @@ class Tensor(object):
@name.setter
def name(self, value):
from .scope import TENSOR_SCOPE
if value is None: self._name = TENSOR_SCOPE + GetTensorName()
else: self._name = TENSOR_SCOPE + value
from .scope import _TENSOR_SCOPE
if value is None: self._name = _TENSOR_SCOPE + GetTensorName()
else: self._name = _TENSOR_SCOPE + value
@property
def grad_wrts(self):
......@@ -399,27 +399,45 @@ class Tensor(object):
ws.FeedTensor(tensor, np.array(indices, dtype=np.float32))
return tensor
if not isinstance(item, tuple):
# 1D At
if isinstance(item, int):
output = self.CreateOperator(inputs=[self, wrapper_indices([item])], nout=1, op_type='At')
if self.shape is not None:
output.shape = self.shape[:]
output.shape[0] = 1
return output
else:
# ND Crop
item = (item, )
starts = []
ends = []
output_dims = []
for it in item:
if isinstance(it, slice):
# handle start
if it.start is None: starts.append(0)
else: starts.append(it.start)
# handle stop
if it.stop is None: ends.append(0)
else: ends.append(it.stop)
# handle step
if it.step is not None:
raise NotImplementedError('Cropping with step has not been implemented yet. ')
output_dims.append(min(ends[-1] - starts[-1], 1))
elif isinstance(it, int):
starts.append(it)
ends.append(it + 1)
output_dims.append(1)
else:
raise TypeError('Unsupported type of indices: {}'.format(type(type(it))))
output = self.CreateOperator(inputs=self, nout=1, op_type='Crop', starts=starts, ends=ends)
elif isinstance(item, slice):
indices = [i for i in xrange(item.start, item.stop, item.step
if item.step is not None else 1)]
outputs = []
for idx in indices:
output = self.CreateOperator(inputs=[self, wrapper_indices([idx])], nout=1, op_type='At')
if self.shape is not None:
output.shape = self.shape[:]
output.shape[0] = 1
outputs.append(output)
return outputs
output.shape = output_dims[:]
elif isinstance(item, Tensor):
return self.CreateOperator(inputs=[self, item], nout=1, op_type='At')
return output
def __add__(self, other):
"""Calculate x + y.
......@@ -926,13 +944,13 @@ class Tensor(object):
outputs_name = [output.name for output in outputs]
op_idx, op_name = GetOperatorName(name)
device_option = None
from dragon.core.scope import DEVICE_SCOPE, ENGINE_SCOPE
if DEVICE_SCOPE != '':
from dragon.core.scope import _DEVICE_SCOPE, _ENGINE_SCOPE
if _DEVICE_SCOPE != '':
supports = {'/cpu': 0, '/gpu': 1}
device_option = pb.DeviceOption()
device_option.device_type = supports[DEVICE_SCOPE.split(':')[0]]
device_option.gpu_id = int(DEVICE_SCOPE.split(':')[1])
device_option.engine = ENGINE_SCOPE
device_option.device_type = supports[_DEVICE_SCOPE.split(':')[0]]
device_option.gpu_id = int(_DEVICE_SCOPE.split(':')[1])
device_option.engine = _ENGINE_SCOPE
op_def = MakeOperatorDef(op_type, inputs_name, outputs_name, op_name,
device_option=device_option, **kwargs)
expressions[op_idx] = op_def
......
......@@ -31,9 +31,9 @@ __all__ = [
'CreateFiller',
'Snapshot',
'Restore',
'PrintRawGraphDef',
'PrintOptimizedGraph',
'WriteOptimizedGraph'
'LogMetaGraph',
'LogOptimizedGraph',
'ExportMetaGraph'
]
_DATA_TYPES = {
......@@ -55,7 +55,7 @@ def _stringify_proto(obj):
def SwitchWorkspace(workspace, create_if_missing=True):
"""Switch to the specific Workspace.
"""Switch to the specific workspace.
Parameters
----------
......@@ -76,35 +76,27 @@ def SwitchWorkspace(workspace, create_if_missing=True):
SwitchWorkspaceCC(workspace, create_if_missing)
def CreateGraph(graph_def):
"""Create the graph in the backend.
def CreateGraph(meta_graph):
"""Create the graph in the VM backend.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of the raw graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
Returns
-------
None
Notes
-----
Uncomment `PrintRawGraphDef`_ will print the raw prototxt.
Uncomment `PrintOptimizedGraph`_ will print the optimized prototxt.
Uncomment `WriteOptimizedGraph`_ will generate the optimized prototxt file.
References
----------
The wrapper of ``CreateGraphCC``.
"""
#PrintRawGraphDef(graph_def)
CreateGraphCC(_stringify_proto(graph_def))
#PrintOptimizedGraph(graph_def)
#WriteOptimizedGraph(graph_def)
LogMetaGraph(meta_graph)
ExportMetaGraph(meta_graph)
CreateGraphCC(_stringify_proto(meta_graph))
LogOptimizedGraph(meta_graph)
def HasTensor(tensor):
......@@ -248,12 +240,12 @@ def FeedTensor(tensor, ndarray, force_cpu=False, dtype=None):
dev = None
if force_cpu is True: dev = utils.MakeDeviceOption(0, 0)
else:
from dragon.core.scope import DEVICE_SCOPE
if DEVICE_SCOPE != '':
from dragon.core.scope import _DEVICE_SCOPE
if _DEVICE_SCOPE != '':
supports = {'/cpu': 0, '/gpu': 1}
dev = pb.DeviceOption()
dev.device_type = supports[DEVICE_SCOPE.split(':')[0]]
dev.gpu_id = int(DEVICE_SCOPE.split(':')[1])
dev.device_type = supports[_DEVICE_SCOPE.split(':')[0]]
dev.gpu_id = int(_DEVICE_SCOPE.split(':')[1])
else:
from dragon.config import option
if option['device'] == 'CUDA':
......@@ -267,17 +259,21 @@ def FeedTensor(tensor, ndarray, force_cpu=False, dtype=None):
auto_dtype = np.float32 if dtype is None else dtype
else:
auto_dtype = ndarray.dtype if dtype is None else dtype
if hasattr(tensor, 'dtype') and tensor.dtype is not None:
if tensor.dtype not in _DATA_TYPES:
raise TypeError('Unsupported data types: {}.'.format(tensor.dtype))
preset_dtype = _DATA_TYPES[tensor.dtype]
if dtype is not None:
if dtype != preset_dtype:
raise TypeError('The preset data type is {}, but force to {}.'
.format(preset_dtype, dtype))
raise TypeError('The preset data type is {}, but force to {}.'.
format(preset_dtype, dtype))
auto_dtype = preset_dtype
ndarray = np.array(ndarray, dtype=auto_dtype)
if hasattr(tensor, 'shape'): tensor.shape = list(ndarray.shape)
FeedTensorCC(name, ndarray, _stringify_proto(dev))
stages = {
'forward': {'include': '', 'exclude': 'Gradient'},
'backward': {'include': 'Gradient', 'exclude': 'Generate'},
......@@ -285,6 +281,7 @@ stages = {
'external_grads': {'include': '', 'exclude': 'Generate'}
}
def RunGraph(graph_name, inputs=(), outputs=[], stage=None, return_outputs=True):
"""Run the specific graph.
......@@ -329,37 +326,39 @@ def RunGraph(graph_name, inputs=(), outputs=[], stage=None, return_outputs=True)
else: return [outputs[i].get_value() for i in xrange(len(outputs))]
def PrintRawGraphDef(graph_def):
"""Print the raw prototxt.
def LogMetaGraph(meta_graph):
"""Log the meta graph.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of the raw graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
Returns
-------
None
"""
logger.info(graph_def)
"""
from dragon.config import option
if option['log_meta_graph']:
logger.info(meta_graph)
def GetOptimizedGraph(graph_def):
"""Return the optimized prototxt.
def GetOptimizedGraph(meta_graph):
"""Return the optimized graph.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of the raw graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
Returns
-------
graph_def : dragon_pb2.GraphDef
The definition of the optimized graph.
The definition of optimized graph.
"""
graph_name = graph_def.name
graph_name = meta_graph.name
graph_tensor = 'GraphDef_' + graph_name
if not HasTensorCC(graph_tensor):
......@@ -371,40 +370,52 @@ def GetOptimizedGraph(graph_def):
return opt_graph_def
def PrintOptimizedGraph(graph_def):
"""Print the optimized prototxt.
def LogOptimizedGraph(meta_graph):
"""Log the optimized graph.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of the raw graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
Returns
-------
None
"""
from dragon.config import option
if option['log_optimized_graph']:
optimized_graph = GetOptimizedGraph(meta_graph)
logger.info(optimized_graph)
opt_graph_def = GetOptimizedGraph(graph_def)
logger.info(opt_graph_def)
def ExportMetaGraph(meta_graph):
"""Export the meta graph into a file under specific folder.
def WriteOptimizedGraph(graph_def):
"""Generate the optimized prototxt file under ``__main__`` folder.
You can set the exporting prefix by `config.ExportMetaGraph(prefix)`_.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of the raw graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
Returns
-------
None
"""
opt_graph_def = GetOptimizedGraph(graph_def)
with open(opt_graph_def.name + '.txt', 'w') as f:
f.write(str(opt_graph_def))
logger.info('write serialized graph to: {}'.format(opt_graph_def.name + '.txt'))
"""
from dragon.config import option
if option['export_meta_graph']:
if not os.path.exists(option['export_meta_graph']):
try:
os.makedirs(option['export_meta_graph'])
except Exception:
raise ValueError('The given prefix is invalid.')
filepath = os.path.join(option['export_meta_graph'],
meta_graph.name + '.metatxt')
with open(filepath, 'w') as f:
f.write(str(meta_graph))
logger.info('Export meta graph into: {}'.format(filepath))
def Snapshot(tensors, filename, prefix='', suffix='.bin', format='default'):
......
......@@ -18,6 +18,9 @@ List Brief
`SetGPU`_ Set the global id GPU.
`GetGPU`_ Get the global id of GPU.
`SetDebugMode`_ Enable Debug mode globally.
`LogMetaGraph`_ Enable to log meta graph globally.
`LogOptimizedGraph`_ Enable to log optimized graph globally.
`ExportMetaGraph`_ Enable to export all runnable meta graphs into text files.
`SetLoggingLevel`_ Set the minimum level of Logging.
==================== =============================================================================
......@@ -34,4 +37,7 @@ API Reference
.. _SetGPU: #dragon.config.SetGPU
.. _GetGPU: #dragon.config.GetGPU
.. _SetDebugMode: #dragon.config.SetDebugMode
.. _LogMetaGraph: #dragon.config.LogMetaGraph
.. _LogOptimizedGraph: #dragon.config.LogOptimizedGraph
.. _ExportMetaGraph: #dragon.config.ExportMetaGraph
.. _SetLoggingLevel: #dragon.config.SetLoggingLevel
\ No newline at end of file
......@@ -18,7 +18,7 @@ List Brief
`Tensor.get_value`_ Fetch the values from C++ backend.
`Tensor.copy`_ Return a Tensor with same content.
`Tensor.reshape`_ Reshape the dimensions of input.
`Tensor.dimshuffle`_ Shuffle the dimen`sions.
`Tensor.dimshuffle`_ Shuffle the dimensions.
`Tensor.CreateOperator`_ Construct a new Tensor with specific operator descriptor.
`Tensor.Fill`_ Fill self with the specific type of filler.
`Tensor.PrintExpressions`_ Return the stringified internal expressions.
......
......@@ -37,9 +37,9 @@ List Brief
`Snapshot`_ Snapshot tensors into a binary file.
`Restore`_ Restore tensors from a binary file.
`SwitchWorkspace`_ Switch to the specific Workspace.
`PrintRawGraphDef`_ Print the raw prototxt.
`PrintOptimizedGraph`_ Print the optimized prototxt.
`WriteOptimizedGraph`_ Generate the optimized prototxt into a file.
`LogMetaGraph`_ Log the meta graph.
`LogOptimizedGraph`_ Log the optimized graph.
`ExportMetaGraph`_ Export the meta graph into a file under specific folder.
============================== =============================================================================
API Reference
......@@ -60,8 +60,9 @@ API Reference
.. _RunGraph: #dragon.core.workspace.RunGraph
.. _Snapshot: #dragon.core.workspace.Snapshot
.. _Restore: #dragon.core.workspace.Restore
.. _PrintRawGraphDef: #dragon.core.workspace.PrintRawGraphDef
.. _PrintOptimizedGraph: #dragon.core.workspace.PrintOptimizedGraph
.. _WriteOptimizedGraph: #dragon.core.workspace.WriteOptimizedGraph
.. _LogMetaGraph: #dragon.core.workspace.LogMetaGraph
.. _LogOptimizedGraph: #dragon.core.workspace.LogOptimizedGraph
.. _ExportMetaGraph: #dragon.core.workspace.ExportMetaGraph
.. _theano.function(*args, **kwargs): ../vm/theano/compile.html#dragon.vm.theano.compile.function.function
.. _config.ExportMetaGraph(prefix): ../config.html#dragon.config.ExportMetaGraph
\ No newline at end of file
......@@ -122,7 +122,6 @@ List Brief
=============== ======================================================================
`At`_ 1D At interface of NDArray.
`RandomPick`_ 1D RandomPick interface of NDArray.
`Crop`_ 2D Crop interface interface of NDArray.
`Reduce`_ The general reduce operator.
`Sum`_ Compute the sum along the given axis.
`Mean`_ Compute the mean along the given axis.
......@@ -134,6 +133,8 @@ List Brief
`Repeat`_ Repeat the input along the given axis.
`Transpose`_ Transpose the input according to the given permutations.
`Tile`_ Tile the input according to the given multiples.
`Pad`_ Pad the input according to the given paddings.
`Crop`_ Crop the input according to the given starts and ends.
`Flatten`_ Flatten the input along the given axes.
`Reshape`_ Reshape the dimensions of input.
`ExpandDims`_ ExpandDims interface of NDArray.
......@@ -257,6 +258,7 @@ List Brief
.. _Transpose: operators/ndarray.html#dragon.operators.ndarray.Transpose
.. _Repeat: operators/ndarray.html#dragon.operators.ndarray.Repeat
.. _Tile: operators/ndarray.html#dragon.operators.ndarray.Tile
.. _Pad: operators/ndarray.html#dragon.operators.ndarray.Pad
.. _Flatten: operators/ndarray.html#dragon.operators.ndarray.Flatten
.. _Reshape: operators/ndarray.html#dragon.operators.ndarray.Reshape
.. _ExpandDims: operators/ndarray.html#dragon.operators.ndarray.ExpandDims
......
......@@ -67,6 +67,7 @@ List Brief
`PermuteLayer`_ The implementation of ``PermuteLayer``.
`FlattenLayer`_ The implementation of ``FlattenLayer``.
`SoftmaxLayer`_ The implementation of ``SoftmaxLayer``.
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``.
`BatchRenormLayer`_ The implementation of ``BatchRenormLayer``.
`InstanceNormLayer`_ The implementation of ``InstanceNormLayer``.
......@@ -170,6 +171,7 @@ API Reference
.. _PermuteLayer: #dragon.vm.caffe.layers.common.PermuteLayer
.. _FlattenLayer: #dragon.vm.caffe.layers.common.FlattenLayer
.. _SoftmaxLayer: #dragon.vm.caffe.layers.common.SoftmaxLayer
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
.. _BatchRenormLayer: #dragon.vm.caffe.layers.common.BatchRenormLayer
.. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer
......@@ -258,6 +260,8 @@ API Reference
.. _FlattenParameter.axis: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L748
.. _FlattenParameter.end_axis: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L753
.. _SoftmaxParameter.axis: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L1142
.. _ArgMaxParameter.top_k: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L485
.. _ArgMaxParameter.axis: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L490
.. _BatchNormParameter.use_global_stats: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L511
.. _BatchNormParameter.moving_average_fraction: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L520
.. _BatchNormParameter.eps: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L523
......
......@@ -83,8 +83,11 @@ def RandomPick(inputs, max_samples=1, axis=0, **kwargs):
return outputs
def Crop(inputs, shape, shape_like=None, axis=2, offsets=(), **kwargs):
"""2D Crop interface interface of NDArray.
def Crop(inputs, starts, ends, start_axis=None,
offsets=None, shape=None, shape_like=None, **kwargs):
"""Crop the input according to the given starts and ends.
Set ``starts`` and ``ends`` to None, if want to use ``start_axis``, ``offsets`` and ``shape``.
Set ``shape`` to None, if you want to use ``shape_like``.
......@@ -92,23 +95,44 @@ def Crop(inputs, shape, shape_like=None, axis=2, offsets=(), **kwargs):
----------
inputs : Tensor
The input tensor.
shape : list or None
The shape of cropping.
starts : int, list of int or None
The starts.
ends : int, list of int or None
The ends.
start_axis : int or None
The axis to start. Default is ``None`` (Disabled).
offsets : int, list of int or None
The offsets. Ignore the axes before ``start_axis``.
shape : list, tuple or None
The referring shape. Use ``-1`` to represent the unknown dimensions.
shape_like : Tensor or None
The shape of cropping. Default is ``None`` (Use ``shape``).
axis : int
The start axis of cropping.
offsets : int or list of int
The offsets. A single value or list of values.
The referring shape. Default is ``None`` (Disabled).
Returns
-------
Tensor
The output tensor.
Examples
--------
>>> x = Tensor('x', dtype='float32').Variable()
>>> x.set_value(np.arange(1, 25).reshape((1, 2, 3, 4)))
>>> y = Crop(x, starts=[0, 1, 0, 2], ends=[1, 2, 0, 0])
>>> y = x[0:1, 1:2, :, 2:] # the same as above
>>> y = Crop(x, None, None, start_axis=1, offsets=(1, 0, 2), shape=(-1, 1, 3, 2)) # the same as above
"""
CheckInputs(inputs, 1)
arguments = ParseArguments(locals())
if starts is not None:
if not isinstance(starts, (list, tuple)):
arguments['starts'] = [starts]
if ends is not None:
if not isinstance(ends, (list, tuple)):
arguments['ends'] = [ends]
if offsets is not None:
if not isinstance(offsets, (list, tuple)):
arguments['offsets'] = [offsets]
if shape is None: arguments['shape'] = []
if shape_like is not None:
if not isinstance(shape_like, Tensor):
......@@ -471,6 +495,51 @@ def Tile(inputs, multiples, **kwargs):
return output
def Pad(inputs, paddings, mode='CONSTANT', value=0, **kwargs):
"""Pad the input according to the given paddings.
Parameters
----------
input : Tensor
The input tensor.
paddings : list or tuple
The paddings, 1D/2D list or tuple.
mode : str
The padding mode, ``CONSTANT``, ``REFLECT`` or ``EDGE``.
value : basic numerical type
The value to use on the ``CONSTANT`` mode.
Returns
-------
Tensor
The output tensor.
"""
CheckInputs(inputs, 1)
arguments = ParseArguments(locals())
pad_l = []; pad_r = []
for padding in paddings:
if isinstance(padding, (list, tuple)):
if len(padding) != 2:
raise ValueError('The padding should be a list or tuple of length 2.')
pad_l.append(int(padding[0]))
pad_r.append(int(padding[1]))
else:
pad_l.append(int(padding))
pad_r.append(int(padding))
arguments['paddings'] = None
arguments['pad_l'] = pad_l
arguments['pad_r'] = pad_r
arguments['value'] = float(arguments['value'])
output = Tensor.CreateOperator(nout=1, op_type='Pad', **arguments)
return output
def OneHot(inputs, depth, on_value=1, off_value=0, **kwargs):
"""Generate the one-hot representation of inputs.
......
......@@ -103,6 +103,7 @@ Concat = ndarray.Concat
Transpose = ndarray.Transpose
Repeat = ndarray.Repeat
Tile = ndarray.Tile
Pad = ndarray.Pad
OneHot = ndarray.OneHot
Flatten = ndarray.Flatten
Reshape = ndarray.Reshape
......
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......@@ -44,6 +44,7 @@ from .common import InnerProductLayer, \
EltwiseLayer, \
ScaleLayer, \
SoftmaxLayer, \
ArgMaxLayer, \
PermuteLayer, \
FlattenLayer, \
ConcatLayer, \
......
# --------------------------------------------------------
# Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......@@ -190,12 +190,13 @@ class CropLayer(Layer):
def __init__(self, LayerParameter):
super(CropLayer, self).__init__(LayerParameter)
param = LayerParameter.crop_param
self._param = {'axis': param.axis,
self._param = {'start_axis': param.axis,
'offsets': [int(element) for element in param.offset]}
def Setup(self, bottom):
super(CropLayer, self).Setup(bottom)
self._param['shape_like'] = bottom[1]
self._param['starts'] = self._param['ends'] = None
return ops.Crop(bottom[0], **self._param)
......@@ -285,6 +286,30 @@ class SoftmaxLayer(Layer):
return ops.Softmax(input, **self._param)
class ArgMaxLayer(Layer):
"""The implementation of ``ArgMaxLayer``.
Parameters
----------
top_k : int
The top k results to keep. Refer `ArgMaxParameter.top_k`_.
axis : int
The axis to perform argmax. Refer `ArgMaxParameter.axis`_.
"""
def __init__(self, LayerParameter):
super(ArgMaxLayer, self).__init__(LayerParameter)
param = LayerParameter.argmax_param
self._param = {'top_k': param.top_k,
'axis': param.axis,
'keep_dims': True}
def Setup(self, bottom):
super(ArgMaxLayer, self).Setup(bottom)
input = bottom[0] if isinstance(bottom, list) else bottom
return ops.Argmax(input, **self._param)
class BatchNormLayer(Layer):
"""The implementation of ``BatchNormLayer``.
......
# --------------------------------------------------------
# Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......@@ -45,7 +45,7 @@ class ConvolutionLayer(Layer):
'group': int(param.group)}
if param.HasField('kernel_h'):
assert param.HasField('kernel_w')
self._param['kernel'] = [param.kernel_h, param.kernel_w]
self._param['kernel_size'] = [param.kernel_h, param.kernel_w]
if param.HasField('stride_h'):
assert param.HasField('stride_w')
self._param['stride'] = [param.stride_h, param.stride_w]
......
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Caffe for Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
from collections import OrderedDict, Counter
from .proto import caffe_pb2
import six
......
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Caffe for Dragon
# Caffe @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......@@ -17,13 +17,13 @@ from dragon.core.gradient_maker import GraphGradientMaker
from dragon.core.scope import GetOperatorName, GetTensorName
from dragon.core.tensor import Tensor
def GraphDef_Grad(graph_def, targets):
def GraphDef_Grad(meta_graph, targets):
"""Inject the gradient targets into GraphDef.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
targets : list
The solving targets.
......@@ -45,18 +45,18 @@ def GraphDef_Grad(graph_def, targets):
g_target = pb.GradientTarget()
g_target.cost = str(pair[0])
g_target.wrt = str(pair[1])
graph_def.g_target.extend([g_target])
meta_graph.g_target.extend([g_target])
def GraphDef_Phase(graph_def, targets):
def GraphDef_Phase(meta_graph, targets):
"""Inject the phase into GraphDef.
If existing gradients, we assume it should be ``TRAIN``, and vice versa.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
targets : list
The solving targets.
......@@ -66,26 +66,25 @@ def GraphDef_Phase(graph_def, targets):
"""
phase = 'TEST'
from dragon.core.scope import PHASE_SCOPE
global PHASE_SCOPE
if PHASE_SCOPE != '': phase = PHASE_SCOPE.upper()
from dragon.core.scope import _PHASE_SCOPE
if _PHASE_SCOPE != '': phase = _PHASE_SCOPE.upper()
else:
for target in targets:
if len(target.grad_wrts) > 0:
phase = 'TRAIN'
break
graph_def.arg.extend([MakeArgument('phase', phase)])
meta_graph.arg.extend([MakeArgument('phase', phase)])
def GraphDef_Update(graph_def, updater):
def GraphDef_Update(meta_graph, updater):
"""Inject the update targets into GraphDef.
The ``updater`` should generate update targets before.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
updater : BaseUpdater
The updater.
......@@ -96,7 +95,7 @@ def GraphDef_Update(graph_def, updater):
"""
if updater is None: return
updater._prefix = graph_def.name + '_'
updater._prefix = meta_graph.name + '_'
extra_arguments = updater._extra_kwargs
extra_arguments['domain'] = updater._prefix
parallel_arguments = {}
......@@ -114,7 +113,7 @@ def GraphDef_Update(graph_def, updater):
= mpi.CreateGroup(root=group[0], incl=group)
parallel_arguments['root'] = group[0]
for k, v in parallel_arguments.items():
graph_def.arg.add().CopyFrom(MakeArgument(k, v))
meta_graph.arg.add().CopyFrom(MakeArgument(k, v))
for tuple in updater._tuples:
tensors = tuple[0]; arguments = tuple[1]
......@@ -126,16 +125,16 @@ def GraphDef_Update(graph_def, updater):
u_target.tensor.append(tensor)
for k, v in kwargs.items():
u_target.arg.add().CopyFrom(MakeArgument(k, v))
graph_def.u_target.extend([u_target])
meta_graph.u_target.extend([u_target])
def GraphDef_Opt(graph_def):
def GraphDef_Opt(meta_graph):
"""Inject the optimization options into GraphDef.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
Returns
-------
......@@ -149,17 +148,17 @@ def GraphDef_Opt(graph_def):
"""
from dragon.config import option
graph_def.debug_mode = option['debug_mode']
graph_def.share_grads = option['share_grads']
meta_graph.debug_mode = option['debug_mode']
meta_graph.share_grads = option['share_grads']
def GraphDef_Device(graph_def):
def GraphDef_Device(meta_graph):
"""Inject the device option into GraphDef.
Parameters
----------
graph_def : dragon_pb2.GraphDef
The definition of graph.
meta_graph : dragon_pb2.GraphDef
The definition of meta graph.
Returns
-------
......@@ -182,7 +181,7 @@ def GraphDef_Device(graph_def):
device_option.gpu_id = option['gpu_id']
device_option.random_seed = option['random_seed']
if option['use_cudnn']: device_option.engine = 'CUDNN'
graph_def.device_option.CopyFrom(device_option)
meta_graph.device_option.CopyFrom(device_option)
def function(inputs=None, outputs=None, givens=None, updater=None):
......@@ -239,22 +238,22 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
all_exprs = {}; all_extra_targets = set()
if not isinstance(outputs, list): outputs = [outputs]
graph_def = pb.GraphDef()
meta_graph = pb.GraphDef()
graph_def.name = 'Graph_' + str(ws.CURRENT_GRAPH_IDX)
meta_graph.name = 'Graph_' + str(ws.CURRENT_GRAPH_IDX)
ws.CURRENT_GRAPH_IDX += 1
# extract operators and targets from expressions
existing_grads = False
for output in outputs:
graph_def.target.extend([output.name])
meta_graph.target.extend([output.name])
if sys.version_info >= (3, 0):
all_exprs = OrderedDict(all_exprs, **output.expressions)
else:
all_exprs = dict(all_exprs, **output.expressions)
all_extra_targets = all_extra_targets.union(output.extra_targets)
if len(output.grad_wrts) > 0: existing_grads = True
for extra_target in all_extra_targets: graph_def.target.extend([extra_target])
for extra_target in all_extra_targets: meta_graph.target.extend([extra_target])
# we should sort out the topology of these operators before using
all_exprs = sorted(all_exprs.items(), key=lambda d:d[0])
......@@ -284,24 +283,25 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
# handle grads
if existing_grads:
targets = [output.name for output in outputs]
targets.extend(all_extra_targets)
forward_ops, grad_ops = GraphGradientMaker.Make(forward_ops, targets)
else: grad_ops = []
graph_def.op.extend(forward_ops + grad_ops)
meta_graph.op.extend(forward_ops + grad_ops)
if len(outputs) > 0:
GraphDef_Device(graph_def)
GraphDef_Opt(graph_def)
GraphDef_Grad(graph_def, outputs)
GraphDef_Phase(graph_def, outputs)
GraphDef_Device(meta_graph)
GraphDef_Opt(meta_graph)
GraphDef_Grad(meta_graph, outputs)
GraphDef_Phase(meta_graph, outputs)
elif updater is not None:
GraphDef_Device(graph_def)
GraphDef_Opt(graph_def)
GraphDef_Update(graph_def, updater)
GraphDef_Device(meta_graph)
GraphDef_Opt(meta_graph)
GraphDef_Update(meta_graph, updater)
# call c api to create graph
ws.CreateGraph(graph_def)
ws.CreateGraph(meta_graph)
# return a lambda point to run this graph
return lambda *args, **kwargs: \
ws.RunGraph(graph_def.name, (inputs, args), outputs, **kwargs)
\ No newline at end of file
ws.RunGraph(meta_graph.name, (inputs, args), outputs, **kwargs)
\ No newline at end of file
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......@@ -38,7 +38,11 @@ def grad(cost, wrt, **kwargs):
for w in wrt:
cost.grad_wrts.append(w.name)
w.grad_objs.append(cost.name)
grads.append(Tensor(w.name + '_grad'))
w_grad = Tensor(w.name + '_grad')
w_grad.extra_targets.add(cost.name)
w_grad.expressions = cost.expressions
w_grad.grad_wrts.append(w.name)
grads.append(w_grad)
if len(grads) == 1: return grads[0]
return grads
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
# --------------------------------------------------------
# Dragon
# Theano @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
......
......@@ -4,20 +4,21 @@
namespace dragon {
GraphBase* Workspace::CreateGraph(const GraphDef& graph_def) {
CHECK(graph_def.has_name());
if (graph_map_.count(graph_def.name()))
return graph_map_[graph_def.name()].get();
LOG(DEBUG) << "Create Graph: " << graph_def.name();
graph_map_[graph_def.name()] = unique_ptr<GraphBase>(NewGraph(graph_def, this));
return graph_map_[graph_def.name()].get();
GraphBase* Workspace::CreateGraph(const GraphDef& meta_graph) {
CHECK(meta_graph.has_name())
<< "The name of given meta graph should not be empty.";
if (graph_map_.count(meta_graph.name()))
return graph_map_[meta_graph.name()].get();
LOG(DEBUG) << "Create Graph: " << meta_graph.name();
graph_map_[meta_graph.name()] = unique_ptr<GraphBase>(NewGraph(meta_graph, this));
return graph_map_[meta_graph.name()].get();
}
Workspace::~Workspace() {
for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) {
string name = "_t_mirror_stage_buffer_" + dragon_cast<string, int>(i);
if (HasTensor(name)) {
MixedMemory* mem = GetTensor(name)->memory();
if (tensor_map_.count(name) > 0) {
MixedMemory* mem = tensor_map_[name]->memory();
if (mem != nullptr) delete mem;
}
}
......
......@@ -72,5 +72,3 @@ DEPLOY_CUDNN(SoftmaxGradient);
} // namespace dragon
#endif // WITH_CUDNN
\ No newline at end of file
#include "operators/cast/float2half_op.h"
#include "core/workspace.h"
#include "utils/op_kernel.h"
namespace dragon {
#ifdef WITH_CUDA_FP16
template <class Context>
void FloatToHalfOp<Context>::RunOnDevice() {
CHECK(input(0).template IsType<float>())
<< "The type of input should be float32.";
output(0)->ReshapeLike(input(0));
// cast
auto* Xdata = input(0).template data<float, Context>();
auto* Ydata = output(0)->template mutable_data<float16, Context>();
kernel::Float2Half<float, Context>(output(0)->count(), Xdata, Ydata);
// release & share
input(0).Reset();
input(0).ReshapeLike(*output(0));
input(0).Share(*output(0));
}
#ifdef WITH_CUDA
DEPLOY_CUDA(FloatToHalf);
#endif
OPERATOR_SCHEMA(FloatToHalf).NumInputs(1).NumOutputs(1);
NO_GRADIENT(FloatToHalf);
#endif
} // namespace dragon
\ No newline at end of file
#include "operators/ndarray/pad_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void PadOp<Context>::ConstRunWithType() {
auto* Xdata = source->template data<T, Context>();
auto* Ydata = dest->template mutable_data<T, Context>();
kernel::ConstPad1D<T, Context>(dest->count(),
dim,
dim + pad_l[axis] + pad_r[axis],
inner_dim,
pad_l[axis],
value,
Xdata,
Ydata);
}
template <class Context> template <typename T>
void PadOp<Context>::ReflectRunWithType() {
auto* Xdata = source->template data<T, Context>();
auto* Ydata = dest->template mutable_data<T, Context>();
kernel::ReflectPad1D<T, Context>(dest->count(),
dim,
dim + pad_l[axis] + pad_r[axis],
inner_dim,
pad_l[axis],
Xdata,
Ydata);
}
template <class Context> template <typename T>
void PadOp<Context>::EdgeRunWithType() {
auto* Xdata = source->template data<T, Context>();
auto* Ydata = dest->template mutable_data<T, Context>();
kernel::EdgePad1D<T, Context>(dest->count(),
dim,
dim + pad_l[axis] + pad_r[axis],
inner_dim,
pad_l[axis],
Xdata,
Ydata);
}
template <class Context>
void PadOp<Context>::RunOnDevice() {
CHECK_EQ(input(0).ndim(), pad_l.size())
<< "\nThe padding is performed on " << pad_l.size() << " dimensions, "
<< "but the num of dimensions of input is " << input(0).ndim() << ".";
// do nothing
if (process_axes.size() == 0) {
output(0)->ReshapeLike(input(0));
output(0)->Share(input(0));
return;
}
// select source & dest
source = &input(0);
if (process_axes.size() % 2 == 1) dest = output(0);
else dest = ws()->GetBuffer();
for (auto& task : process_axes) {
axis = task.second;
vector<TIndex> dims = source->dims();
inner_dim = source->count(axis + 1);
dim = source->dim(axis);
dims[axis] += (pad_l[axis] + pad_r[axis]);
dest->Reshape(dims);
if (mode == "CONSTANT") {
if (input(0).template IsType<float>()) ConstRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (mode == "REFLECT") {
CHECK_LE(pad_l[axis], dim + 1)
<< "\nThe dimension of axis " << axis << " is " << dim << ","
<< "\nwhile the excepted bounds of pad_l for reflecting are (0, " << dim + 1 << "].";
CHECK_LE(pad_r[axis], dim - 1)
<< "\nThe dimension of axis " << axis << " is " << dim << ","
<< "\nwhile the excepted bounds of pad_r for reflecting are (0, " << dim - 1 << "].";
if (input(0).template IsType<float>()) ReflectRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (mode == "EDGE") {
if (input(0).template IsType<float>()) EdgeRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else {
LOG(FATAL) << "Unsupported padding mode: " << mode << " .";
}
// allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest);
if (process_axes.size() % 2 == 1) {
if (dest == &input(0)) dest = ws()->GetBuffer();
} else {
if (dest == &input(0)) dest = output(0);
}
}
ws()->ReleaseBuffer(dest);
}
DEPLOY_CPU(Pad);
#ifdef WITH_CUDA
DEPLOY_CUDA(Pad);
#endif
OPERATOR_SCHEMA(Pad).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T>
void PadGradientOp<Context>::ConstRunWithType() {
auto* dYdata = source->template data<T, Context>();
auto* dXdata = dest->template mutable_data<T, Context>();
math::Set<T, Context>(dest->count(), 0, dXdata);
kernel::ConstPad1DGrad<T, Context>(dest->count(),
dim - pad_l[axis] - pad_r[axis],
dim,
inner_dim,
pad_l[axis],
dYdata,
dXdata);
}
template <class Context> template <typename T>
void PadGradientOp<Context>::ReflectRunWithType() {
auto* dYdata = source->template data<T, Context>();
auto* dXdata = dest->template mutable_data<T, Context>();
math::Set<T, Context>(dest->count(), 0, dXdata);
kernel::ReflectPad1DGrad<T, Context>(source->count(),
dim - pad_l[axis] - pad_r[axis],
dim,
inner_dim,
pad_l[axis],
dYdata,
dXdata);
}
template <class Context> template <typename T>
void PadGradientOp<Context>::EdgeRunWithType() {
auto* dYdata = source->template data<T, Context>();
auto* dXdata = dest->template mutable_data<T, Context>();
math::Set<T, Context>(dest->count(), 0, dXdata);
kernel::EdgePad1DGrad<T, Context>(source->count(),
dim - pad_l[axis] - pad_r[axis],
dim,
inner_dim,
pad_l[axis],
dYdata,
dXdata);
}
template <class Context>
void PadGradientOp<Context>::RunOnDevice() {
CHECK_EQ(input(0).ndim(), pad_l.size())
<< "\nThe padding is performed on " << pad_l.size() << " dimensions, "
<< "but the number of dimensions of input is " << input(0).ndim() << ".";
// do nothing
if (process_axes.size() == 0) {
output(0)->ReshapeLike(input(-1));
output(0)->Share(input(-1));
return;
}
// select source & buffer
source = &input(-1);
if (process_axes.size() % 2 == 1) dest = output(0);
else dest = ws()->GetBuffer();
for (auto& task : process_axes) {
axis = task.second;
vector<TIndex> dims = source->dims();
inner_dim = source->count(axis + 1);
dim = source->dim(axis);
dims[axis] -= (pad_l[axis] + pad_r[axis]);
dest->Reshape(dims);
if (mode == "CONSTANT") {
if (input(0).template IsType<float>()) ConstRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (mode == "REFLECT") {
CHECK_LE(pad_l[axis], dim + 1)
<< "\nThe dimension of axis " << axis << " is " << dim << ","
<< "\nwhile the excepted bounds of pad_l for reflecting are (0, " << dim + 1 << "].";
CHECK_LE(pad_r[axis], dim - 1)
<< "\nThe dimension of axis " << axis << " is " << dim << ","
<< "\nwhile the excepted bounds of pad_r for reflecting are (0, " << dim - 1 << "].";
if (input(0).template IsType<float>()) ReflectRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (mode == "EDGE") {
if (input(0).template IsType<float>()) EdgeRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else {
LOG(FATAL) << "Unsupported padding mode: " << mode << " .";
}
// allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest);
if (process_axes.size() % 2 == 1) {
if (dest == &input(-1)) dest = ws()->GetBuffer();
} else {
if (dest == &input(-1)) dest = output(0);
}
}
ws()->ReleaseBuffer(dest);
}
DEPLOY_CPU(PadGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(PadGradient);
#endif
OPERATOR_SCHEMA(PadGradient).NumInputs(1).NumOutputs(1);
class GetPadGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetPadGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(Pad, GetPadGradient);
} // namespace dragon
\ No newline at end of file
......@@ -21,9 +21,6 @@ void TileOp<Context>::TileRunWithType() {
Xdata,
Ydata,
&ctx());
// swap source & dest
std::swap(source, dest);
}
template <class Context>
......@@ -32,8 +29,8 @@ void TileOp<Context>::RunOnDevice() {
// do nothing
if (process_axes.size() == 0) {
output(0)->ReshapeLike(input(-1));
output(0)->Share(input(-1));
output(0)->ReshapeLike(input(0));
output(0)->Share(input(0));
return;
}
......@@ -43,11 +40,11 @@ void TileOp<Context>::RunOnDevice() {
else dest = ws()->GetBuffer();
for (auto& task : process_axes) {
axis = task.first; multiple = task.second;
axis = task.second; multiple = task.first;
if (input(0).template IsType<float>()) TileRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
// allow buffer to protect X if num axes >= 2
// allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest);
if (process_axes.size() % 2 == 1) {
if (dest == &input(0)) dest = ws()->GetBuffer();
} else {
......@@ -80,9 +77,6 @@ void TileGradientOp<Context>::TileRunWithType() {
dYdata,
dXdata,
&ctx());
// swap source & dest
std::swap(source, dest);
}
template <class Context>
......@@ -102,11 +96,11 @@ void TileGradientOp<Context>::RunOnDevice() {
else dest = ws()->GetBuffer();
for (auto& task : process_axes) {
axis = task.first; multiple = task.second;
axis = task.second; multiple = task.first;
if (input(0).template IsType<float>()) TileRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
// allow buffer to protect dY if num axes >= 2
// allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest);
if (process_axes.size() % 2 == 1) {
if (dest == &input(-1)) dest = ws()->GetBuffer();
} else {
......
......@@ -842,43 +842,156 @@ template <> void ConcatGrad<float16, CPUContext>(const int count,
/******************** ndarray.crop ********************/
template<> void Crop2D<float, CPUContext>(vector<TIndex> idxs,
const vector<TIndex>& offsets,
const int cur_dim,
Tensor* x,
Tensor* y,
CPUContext* context) {
// run as Crop1D
auto* Xdata = x->data<float, CPUContext>();
auto* Ydata = y->mutable_data<float, CPUContext>();
template<> void Crop1D<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int start,
const float* x,
float* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
y[idx] = x[(o * dim + ex_d + start) * inner_dim + i];
}
}
for (int i = 0; i < y->dim(cur_dim); ++i) {
vector<TIndex> idx_off(cur_dim + 1, 0);
for (int j = 0; j < cur_dim; j++) idx_off[j] = idxs[j] + offsets[j];
idx_off[cur_dim] = offsets[cur_dim];
context->Copy<float, CPUContext, CPUContext>(y->dim(cur_dim),
Ydata + y->offset(idxs),
Xdata + x->offset(idx_off));
template<> void Crop1DGrad<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int start,
const int end,
const float* dy,
float* dx) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int d = (idx / inner_dim) % dim;
const int o = idx / inner_dim / dim;
if (d >= start && d < end)
dx[idx] = dy[(o * ex_dim + d - start) * inner_dim + i];
}
}
template<> void Crop2DGrad<float, CPUContext>(vector<TIndex> idxs,
const vector<TIndex>& offsets,
const int cur_dim,
Tensor* dy,
Tensor* dx,
CPUContext* context) {
// run as Crop1D
auto* dYdata = dy->data<float, CPUContext>();
auto* dXdata = dx->mutable_data<float, CPUContext>();
/******************** ndarray.pad ********************/
template <> void ConstPad1D<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float value,
const float* x,
float* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
const int d = ex_d - pad_l;
y[idx] = (d < 0 || d >= dim) ? value : x[(o * dim + d) * inner_dim + i];
}
}
template <> void ReflectPad1D<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float* x,
float* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
int d = ex_d - pad_l;
d = std::max(d, -d);
d = std::min(d, 2 * dim - d - 2);
y[idx] = x[(o * dim + d) * inner_dim + i];
}
}
template <> void EdgePad1D<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float* x,
float* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
const int d = std::min(dim - 1, std::max(ex_d - pad_l, 0));
y[idx] = x[(o * dim + d) * inner_dim + i];
}
}
for (int i = 0; i < dy->dim(cur_dim); ++i) {
vector<TIndex> idx_off(cur_dim + 1, 0);
for (int j = 0; j < cur_dim; j++) idx_off[j] = idxs[j] + offsets[j];
idx_off[cur_dim] = offsets[cur_dim];
context->Copy<float, CPUContext, CPUContext>(dy->dim(cur_dim),
dXdata + dx->offset(idx_off),
dXdata + dy->offset(idxs));
template <> void ConstPad1DGrad<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float* dy,
float* dx) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % dim + pad_l;
const int o = idx / inner_dim / dim;
dx[idx] = dy[(o * ex_dim + ex_d) * inner_dim + i];
}
}
template <> void ReflectPad1DGrad<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float* dy,
float* dx) {
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
int d = ex_d - pad_l;
d = std::max(d, -d);
d = std::min(d, 2 * dim - d - 2);
dx[(o * dim + d) * inner_dim + i] += dy[idx];
}
}
template <> void EdgePad1DGrad<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float* dy,
float* dx) {
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
const int d = std::min(dim - 1, std::max(ex_d - pad_l, 0));
dx[(o * dim + d) * inner_dim + i] += dy[idx];
}
}
......@@ -1692,7 +1805,7 @@ template<> void ROIPooling<float, CPUContext>(const float spatial_scale,
Mdata += mask->offset(0, 1);
} // end c
// offset roi region
Rdata += roi->offset(1);
Rdata += 5;
} // end n
}
......
......@@ -5,16 +5,15 @@
-----
Dragon is a **C**(Computation)**G**(Graph)**V**(Virtual)**M**(Machine) based distributed deep learning framework.
Our goal is to reduce the unnecessary structures or interfaces. Therefore, in addition to feed or fetch, the last thing is designing a objective function through available operators.
Our goal is to reduce the unnecessary structures or interfaces. Therefore, in addition to feed or fetch, the last thing is designing a objective function through all available operators.
Besides, we demonstrate a cross-frameworks frontend(**Deep Learning VirtualBox**) is feasible, and further more, will get benefit from all participating crucial interfaces especially when one is not reasonable.
Besides, we demonstrate that a cross-frameworks frontend(**Deep Learning VirtualBox**) is feasible, and further more, will get benefit from all participating crucial interfaces especially when one is not reasonable.
## News
Dragon 0.2.1 Released - The preliminary documentation, and massive known bugs are fixed.
## License and Citation
Dragon is released under the [BSD 2-Clause license](https://github.com/neopenx/Dragon/blob/master/LICENSE).
Please cite Dragon in your publications if it helps your research:
......@@ -24,4 +23,5 @@ Please cite Dragon in your publications if it helps your research:
Journal = {arXiv preprint arXiv:1707.08265},
Title = {Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework},
Year = {2017}
}
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!