Commit 3b990761 by Ting PAN

Merge into the DimensionOp

1 parent abae2712
Showing with 1032 additions and 556 deletions
...@@ -18,7 +18,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ...@@ -18,7 +18,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
python3-tk \ python3-tk \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \ RUN pip3 install --no-cache-dir --upgrade setuptools wheel -i https://pypi.tuna.tsinghua.edu.cn/simple && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \ pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \ numpy \
protobuf \ protobuf \
...@@ -27,6 +27,7 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \ ...@@ -27,6 +27,7 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \
six \ six \
Pillow Pillow
matplotlib \ matplotlib \
scikit-image \
pyyaml \ pyyaml \
cython cython
......
...@@ -21,7 +21,7 @@ RUN rm /etc/apt/sources.list.d/cuda.list && rm /etc/apt/sources.list.d/nvidia-ml ...@@ -21,7 +21,7 @@ RUN rm /etc/apt/sources.list.d/cuda.list && rm /etc/apt/sources.list.d/nvidia-ml
python3-tk \ python3-tk \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \ RUN pip3 install --no-cache-dir --upgrade setuptools wheel -i https://pypi.tuna.tsinghua.edu.cn/simple && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \ pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \ numpy \
protobuf \ protobuf \
...@@ -30,6 +30,7 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \ ...@@ -30,6 +30,7 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \
six \ six \
Pillow \ Pillow \
matplotlib \ matplotlib \
scikit-image \
pyyaml \ pyyaml \
cython cython
......
...@@ -52,9 +52,9 @@ using Set = std::unordered_set<Value> ; ...@@ -52,9 +52,9 @@ using Set = std::unordered_set<Value> ;
/* /*
* Define the Kernel version. * Define the Kernel version.
* *
* | Major(2) | Minor(2) | Patch(09) | * | Major(2) | Minor(2) | Patch(10) |
*/ */
#define DRAGON_VERSION 2209 #define DRAGON_VERSION 2210
/* /*
* Define the default random seed. * Define the default random seed.
......
...@@ -114,7 +114,7 @@ class Operator : public OperatorBase { ...@@ -114,7 +114,7 @@ class Operator : public OperatorBase {
virtual void MakeResource(); virtual void MakeResource();
virtual void CleanResource(); virtual void CleanResource();
void MemorySwitch() { virtual void MemorySwitch() {
for (auto* I : inputs_) for (auto* I : inputs_)
if(I->name() != "ignore") I->SwitchToDevice(); if(I->name() != "ignore") I->SwitchToDevice();
for (auto* O : outputs_) for (auto* O : outputs_)
......
...@@ -40,8 +40,9 @@ class Tensor { ...@@ -40,8 +40,9 @@ class Tensor {
capacity_ = 0; capacity_ = 0;
} }
} else { } else {
if (ex_memory_ && TIndex(ex_memory_->nbytes()) < if (ex_memory_ && !is_shared_ &&
TIndex(new_size * meta_.itemsize())) { TIndex(ex_memory_->nbytes()) <
TIndex(new_size * meta_.itemsize())) {
delete ex_memory_; delete ex_memory_;
ex_memory_ = nullptr; ex_memory_ = nullptr;
capacity_ = 0; capacity_ = 0;
...@@ -232,18 +233,18 @@ class Tensor { ...@@ -232,18 +233,18 @@ class Tensor {
return static_cast<const T*>(raw_data<Context>()); return static_cast<const T*>(raw_data<Context>());
} }
template <class DstCTX, class SrcCTX> template <class Context>
inline void Copy(const Tensor& other) { inline void CopyFrom(const Tensor& other) {
CHECK_EQ(size_, other.size_); CHECK_EQ(size_, other.size_);
auto* src = other.template raw_data<SrcCTX>(); auto* src = other.template raw_data<Context>();
auto* dst = raw_mutable_data<DstCTX>(other.meta_); auto* dst = raw_mutable_data<Context>(other.meta_);
if (dst == src) return; if (dst == src) return;
if (TypeMeta::Id<DstCTX>() == if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CPUContext>()) { TypeMeta::Id<CPUContext>()) {
CPUContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src); CPUContext::Memcpy<Context, Context>(nbytes(), dst, src);
} else if (TypeMeta::Id<DstCTX>() == } else if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CUDAContext>()) { TypeMeta::Id<CUDAContext>()) {
CUDAContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src); CUDAContext::Memcpy<Context, Context>(nbytes(), dst, src);
} }
} }
...@@ -253,6 +254,8 @@ class Tensor { ...@@ -253,6 +254,8 @@ class Tensor {
own_mem_ = false; own_mem_ = false;
} }
inline void Share(MixedMemory* mem) { Move(mem); is_shared_ = true; }
inline void Reset() { inline void Reset() {
size_ = capacity_ = 0; size_ = capacity_ = 0;
meta_ = TypeMeta(); meta_ = TypeMeta();
...@@ -271,7 +274,8 @@ class Tensor { ...@@ -271,7 +274,8 @@ class Tensor {
string name_; string name_;
shared_ptr<MixedMemory> memory_; shared_ptr<MixedMemory> memory_;
MixedMemory* ex_memory_ = nullptr; MixedMemory* ex_memory_ = nullptr;
bool is_corrupted_ = false, own_mem_ = true; bool is_corrupted_ = false, is_shared_ = false;
bool own_mem_ = true;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -49,7 +49,8 @@ inline const TypeMeta& TypeStringToMeta( ...@@ -49,7 +49,8 @@ inline const TypeMeta& TypeStringToMeta(
{ "int64", TypeMeta::Make<int64_t>() }, { "int64", TypeMeta::Make<int64_t>() },
{ "float64", TypeMeta::Make<double>() }, { "float64", TypeMeta::Make<double>() },
{ "float16", TypeMeta::Make<float16>() }, { "float16", TypeMeta::Make<float16>() },
{ "uint8", TypeMeta::Make<uint8_t>() } { "uint8", TypeMeta::Make<uint8_t>() },
{ "int8", TypeMeta::Make<char>() },
}; };
static TypeMeta unknown_type; static TypeMeta unknown_type;
return s2m_type_map.count(str_type) ? return s2m_type_map.count(str_type) ?
...@@ -65,7 +66,8 @@ inline const std::string TypeMetaToString( ...@@ -65,7 +66,8 @@ inline const std::string TypeMetaToString(
{ TypeMeta::Id<int64_t>(), "int64" }, { TypeMeta::Id<int64_t>(), "int64" },
{ TypeMeta::Id<double>(), "float64", }, { TypeMeta::Id<double>(), "float64", },
{ TypeMeta::Id<float16>(), "float16" }, { TypeMeta::Id<float16>(), "float16" },
{ TypeMeta::Id<uint8_t>(), "uint8" } { TypeMeta::Id<uint8_t>(), "uint8" },
{ TypeMeta::Id<char>(), "int8" }
}; };
return m2s_type_map.count(meta.id()) ? return m2s_type_map.count(meta.id()) ?
m2s_type_map[meta.id()] : "unknown"; m2s_type_map[meta.id()] : "unknown";
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_
#define DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_
#include "core/operator.h"
namespace dragon {
/*********************************************
* *
* Base *
* *
**********************************************/
template <class Context>
class DimOpBase : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(DimOpBase);
void MemorySwitch() override {
/* Disable the Memory Activation */
}
};
template <class Context>
class DimGradientOpBase : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(DimGradientOpBase);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
// simply copy the dY to dX
Output(0)->ReshapeLike(Input(0));
if (Output(0)->name() != Input(-1).name())
Output(0)->template CopyFrom<Context>(Input(-1));
}
};
#define DEFINE_DIMENSION_GRADIENT_OP(name) \
template <class Context> \
class name##GradientOp final : public DimGradientOpBase<Context> { \
public: \
name##GradientOp(const OperatorDef& def, Workspace* ws) \
: DimGradientOpBase<Context>(def, ws) {} \
};
/*********************************************
* *
* Reshape *
* *
**********************************************/
template <class Context>
class ReshapeOp final : public DimOpBase<Context> {
public:
ReshapeOp(const OperatorDef& def, Workspace* ws)
: DimOpBase<Context>(def, ws),
shape_like_desc(OperatorBase::Arg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
protected:
DECLARE_ARGUMENTS_WITH_DESC(int, shape);
string shape_like_desc;
vector<TIndex> require_shape, new_shape;
};
DEFINE_ARGUMENTS_WITH_DESC(int, ReshapeOp, shape);
DEFINE_DIMENSION_GRADIENT_OP(Reshape);
/*********************************************
* *
* Flatten *
* *
**********************************************/
template <class Context>
class FlattenOp final : public DimOpBase<Context> {
public:
FlattenOp(const OperatorDef& def, Workspace* ws)
: DimOpBase<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)),
keep_axes(OperatorBase::Arg<int>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
protected:
TIndex axis, num_axes, keep_axes;
};
DEFINE_DIMENSION_GRADIENT_OP(Flatten);
/*********************************************
* *
* Expand Dims *
* *
**********************************************/
template <class Context>
class ExpandDimsOp final : public DimOpBase<Context> {
public:
ExpandDimsOp(const OperatorDef& def, Workspace* ws)
: DimOpBase<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", INT_MAX)) {
if (axis == INT_MAX)
LOG(FATAL) << "Excepted a axis to insert the new dim.";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
protected:
TIndex axis;
};
DEFINE_DIMENSION_GRADIENT_OP(ExpandDims);
/*********************************************
* *
* Squeeze *
* *
**********************************************/
template <class Context>
class SqueezeOp final : public DimOpBase<Context> {
public:
SqueezeOp(const OperatorDef& def, Workspace* ws)
: DimOpBase<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
protected:
TIndex axis;
};
DEFINE_DIMENSION_GRADIENT_OP(Squeeze);
} // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
#define DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ExpandDimsOp final : public Operator<Context> {
public:
ExpandDimsOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
protected:
TIndex axis;
};
template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpandDimsGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_
#define DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class FlattenOp final : public Operator<Context> {
public:
FlattenOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)),
keep_axes(OperatorBase::Arg<int>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void SqueezeRun();
void KeepRun();
protected:
TIndex axis, num_axes, keep_axes;
};
template <class Context>
class FlattenGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(FlattenGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ReshapeOp final : public Operator<Context> {
public:
ReshapeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
shape_like_desc(OperatorBase::Arg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
protected:
DECLARE_ARGUMENTS_WITH_DESC(int, shape);
string shape_like_desc;
vector<TIndex> require_shape, new_shape;
};
template <class Context>
class ReshapeGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReshapeGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
DEFINE_ARGUMENTS_WITH_DESC(int, ReshapeOp, shape);
} // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
\ No newline at end of file
...@@ -55,7 +55,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -55,7 +55,7 @@ class ConvOpBase : public Operator<Context> {
void GradientReshape(); void GradientReshape();
virtual void ComputeOutputShape(); virtual void ComputeOutputShape();
virtual bool ReverseDimensions() = 0; virtual bool ReverseDimensions() = 0;
virtual bool HasBias() = 0; virtual bool HasBias() { NOT_IMPLEMENTED; return true; }
template <typename T> void Wx(const T* x, template <typename T> void Wx(const T* x,
const T* weights, T* y, bool skip_im2col = false); const T* weights, T* y, bool skip_im2col = false);
......
...@@ -16,12 +16,12 @@ int type_from_string(std::string type) { ...@@ -16,12 +16,12 @@ int type_from_string(std::string type) {
} }
Device::Device() Device::Device()
: device_type_(CPU), device_id_(0) {} : device_type_(0), device_id_(0) {}
Device::Device(std::string device_type, int device_id) Device::Device(std::string device_type, int device_id)
: device_type_((DeviceType)type_from_string(device_type)), device_id_(device_id) {} : device_type_(type_from_string(device_type)), device_id_(device_id) {}
Device::Device(std::string device_type) Device::Device(std::string device_type)
: device_type_((DeviceType)type_from_string(device_type)), device_id_(0) {} : device_type_(type_from_string(device_type)), device_id_(0) {}
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/io/zero_copy_stream_impl.h>
#include "dragon.h" #include "dragon.h"
#include "protos/dragon.pb.h"
#include "core/common.h" #include "core/common.h"
#include "core/workspace.h" #include "core/workspace.h"
#include "utils/caffemodel.h" #include "utils/caffemodel.h"
...@@ -35,7 +34,8 @@ Workspace* ResetWorkspace(const std::string& name) { ...@@ -35,7 +34,8 @@ Workspace* ResetWorkspace(const std::string& name) {
g_workspaces[name].reset(new Workspace(name)); g_workspaces[name].reset(new Workspace(name));
for (auto& sub_workspace : sub_workspaces[name]) { for (auto& sub_workspace : sub_workspaces[name]) {
if (g_workspaces.count(sub_workspace) > 0) if (g_workspaces.count(sub_workspace) > 0)
g_workspaces[name]->MoveWorkspace(g_workspaces[sub_workspace].get()); g_workspaces[name]->MoveWorkspace(
g_workspaces[sub_workspace].get());
} }
return g_workspaces[name].get(); return g_workspaces[name].get();
} }
...@@ -49,7 +49,9 @@ void ReleaseWorkspace(const std::string& name) { ...@@ -49,7 +49,9 @@ void ReleaseWorkspace(const std::string& name) {
g_workspaces.erase(name); g_workspaces.erase(name);
} }
void MoveWorkspace(Workspace* target_ws, Workspace* source_ws) { void MoveWorkspace(
Workspace* target_ws,
Workspace* source_ws) {
std::unique_lock<std::mutex> lock(g_mutex); std::unique_lock<std::mutex> lock(g_mutex);
CHECK(source_ws) << "\nThe given source workspace is invalid."; CHECK(source_ws) << "\nThe given source workspace is invalid.";
CHECK(target_ws) << "\nThe given target workspace is invalid."; CHECK(target_ws) << "\nThe given target workspace is invalid.";
...@@ -59,7 +61,9 @@ void MoveWorkspace(Workspace* target_ws, Workspace* source_ws) { ...@@ -59,7 +61,9 @@ void MoveWorkspace(Workspace* target_ws, Workspace* source_ws) {
<< "into the Workspace(" << target_ws->name() << ")."; << "into the Workspace(" << target_ws->name() << ").";
} }
std::string CreateGraph(const std::string& graph_file, Workspace* ws) { std::string CreateGraph(
const std::string& graph_file,
Workspace* ws) {
GraphDef meta_graph; GraphDef meta_graph;
int fd = open(graph_file.c_str(), O_RDONLY); int fd = open(graph_file.c_str(), O_RDONLY);
CHECK_NE(fd, -1) << "\nFile not found: " << graph_file; CHECK_NE(fd, -1) << "\nFile not found: " << graph_file;
...@@ -75,7 +79,10 @@ std::string CreateGraph(const std::string& graph_file, Workspace* ws) { ...@@ -75,7 +79,10 @@ std::string CreateGraph(const std::string& graph_file, Workspace* ws) {
return meta_graph.name(); return meta_graph.name();
} }
std::string CreateGraph(const std::string& graph_file, const Device& device, Workspace* ws) { std::string CreateGraph(
const std::string& graph_file,
const Device& device,
Workspace* ws) {
GraphDef meta_graph; GraphDef meta_graph;
int fd = open(graph_file.c_str(), O_RDONLY); int fd = open(graph_file.c_str(), O_RDONLY);
CHECK_NE(fd, -1) << "\nFile not found: " << graph_file; CHECK_NE(fd, -1) << "\nFile not found: " << graph_file;
...@@ -95,26 +102,29 @@ std::string CreateGraph(const std::string& graph_file, const Device& device, Wor ...@@ -95,26 +102,29 @@ std::string CreateGraph(const std::string& graph_file, const Device& device, Wor
return meta_graph.name(); return meta_graph.name();
} }
void CreateTensor(const std::string& name, Workspace* ws) { void CreateTensor(
const std::string& name,
Workspace* ws) {
ws->CreateTensor(name); ws->CreateTensor(name);
} }
template <typename T> template <typename T>
void FeedTensor(const std::string& name, void FeedTensor(
const vector<TIndex>& shape, const std::string& name,
const T* data, const vector<TIndex>& shape,
const Device& device, const T* data,
Workspace* ws) { const Device& device,
Workspace* ws) {
Tensor* tensor = ws->CreateTensor(name); Tensor* tensor = ws->CreateTensor(name);
tensor->Reshape(shape); tensor->Reshape(shape);
if (device.device_type() == CUDA) { if (device.device_type() == 1) {
CUDAContext context(device.device_id()); CUDAContext context(device.device_id());
context.SwitchToDevice(); context.SwitchToDevice();
tensor->mutable_data<T, CUDAContext>(); tensor->mutable_data<T, CUDAContext>();
context.Memcpy<CUDAContext, CPUContext>(tensor->nbytes(), context.Memcpy<CUDAContext, CPUContext>(tensor->nbytes(),
tensor->raw_mutable_data<CUDAContext>(), tensor->raw_mutable_data<CUDAContext>(),
static_cast<const void*>(data)); static_cast<const void*>(data));
} else if (device.device_type() == CPU) { } else if (device.device_type() == 0) {
CPUContext context; CPUContext context;
tensor->mutable_data<T, CPUContext>(); tensor->mutable_data<T, CPUContext>();
context.Memcpy<CPUContext, CPUContext>(tensor->nbytes(), context.Memcpy<CPUContext, CPUContext>(tensor->nbytes(),
...@@ -125,7 +135,9 @@ void FeedTensor(const std::string& name, ...@@ -125,7 +135,9 @@ void FeedTensor(const std::string& name,
} }
} }
void TransplantCaffeModel(const std::string& input_model, const std::string& output_model) { void TransplantCaffeModel(
const std::string& input_model,
const std::string& output_model) {
TensorProtos protos; TensorProtos protos;
NetParameter net_param; NetParameter net_param;
ReadProtoFromBinaryFile(input_model.c_str(), &net_param); ReadProtoFromBinaryFile(input_model.c_str(), &net_param);
...@@ -151,13 +163,16 @@ void TransplantCaffeModel(const std::string& input_model, const std::string& out ...@@ -151,13 +163,16 @@ void TransplantCaffeModel(const std::string& input_model, const std::string& out
<< ", size: " << blob.data_size(); << ", size: " << blob.data_size();
} }
} }
std::fstream output(output_model, std::ios::out | std::ios::trunc | std::ios::binary); std::fstream output(output_model,
std::ios::out | std::ios::trunc | std::ios::binary);
CHECK(protos.SerializeToOstream(&output)); CHECK(protos.SerializeToOstream(&output));
LOG(INFO) << "save the model @: " << output_model << "......"; LOG(INFO) << "save the model @: " << output_model << "......";
LOG(INFO) << "model format: DragonMoel"; LOG(INFO) << "model format: DragonMoel";
} }
void LoadDragonmodel(const std::string& model_file, Workspace* ws){ void LoadDragonmodel(
const std::string& model_file,
Workspace* ws){
TensorProtos tensors; TensorProtos tensors;
ReadProtoFromBinaryFile(model_file.c_str(), &tensors); ReadProtoFromBinaryFile(model_file.c_str(), &tensors);
LOG(INFO) << "Restore From Model @: " << model_file << "......"; LOG(INFO) << "Restore From Model @: " << model_file << "......";
...@@ -190,7 +205,9 @@ void LoadDragonmodel(const std::string& model_file, Workspace* ws){ ...@@ -190,7 +205,9 @@ void LoadDragonmodel(const std::string& model_file, Workspace* ws){
} }
} }
void LoadCaffemodel(const std::string& model_file, Workspace* ws){ void LoadCaffemodel(
const std::string& model_file,
Workspace* ws){
NetParameter net_param; NetParameter net_param;
ReadProtoFromBinaryFile(model_file.c_str(), &net_param); ReadProtoFromBinaryFile(model_file.c_str(), &net_param);
std::string scope = ""; std::string scope = "";
...@@ -231,14 +248,17 @@ void LoadCaffemodel(const std::string& model_file, Workspace* ws){ ...@@ -231,14 +248,17 @@ void LoadCaffemodel(const std::string& model_file, Workspace* ws){
} }
} }
void RunGraph(const std::string& graph_name, Workspace* ws) { void RunGraph(
const std::string& graph_name,
Workspace* ws) {
ws->RunGraph(graph_name, "", ""); ws->RunGraph(graph_name, "", "");
} }
template <typename T> template <typename T>
T* FetchTensor(const std::string& name, T* FetchTensor(
vector<TIndex>& shape, const std::string& name,
Workspace* ws){ vector<TIndex>& shape,
Workspace* ws){
if (!ws->HasTensor(name)){ if (!ws->HasTensor(name)){
LOG(FATAL) << "Tensor(" << name << ")" LOG(FATAL) << "Tensor(" << name << ")"
<< " doesn't exist, try create it before."; << " doesn't exist, try create it before.";
...@@ -251,13 +271,11 @@ T* FetchTensor(const std::string& name, ...@@ -251,13 +271,11 @@ T* FetchTensor(const std::string& name,
shape = tensor->dims(); shape = tensor->dims();
void* data = malloc(tensor->nbytes()); void* data = malloc(tensor->nbytes());
if (tensor->memory_state() == MixedMemory::STATE_AT_CUDA) { if (tensor->memory_state() == MixedMemory::STATE_AT_CUDA) {
CUDAContext::Memcpy<CPUContext, CUDAContext>(tensor->nbytes(), CUDAContext::Memcpy<CPUContext, CUDAContext>(
data, tensor->nbytes(), data, tensor->raw_data<CUDAContext>());
tensor->raw_data<CUDAContext>());
} else { } else {
CPUContext::Memcpy<CPUContext, CPUContext>(tensor->nbytes(), CPUContext::Memcpy<CPUContext, CPUContext>(
data, tensor->nbytes(), data, tensor->raw_data<CPUContext>());
tensor->raw_data<CPUContext>());
} }
return static_cast<T*>(data); return static_cast<T*>(data);
} }
...@@ -266,4 +284,30 @@ void SetLogLevel(const std::string& level) { ...@@ -266,4 +284,30 @@ void SetLogLevel(const std::string& level) {
SetLogDestination(StrToLogSeverity(level)); SetLogDestination(StrToLogSeverity(level));
} }
template float* FetchTensor<float>(
const std::string&,
std::vector<TIndex>&,
Workspace*);
template void FeedTensor<float>(
const std::string&,
const std::vector<TIndex>&,
const float*,
const Device&,
Workspace*);
template void FeedTensor<int>(
const std::string&,
const std::vector<TIndex>&,
const int*,
const Device&,
Workspace*);
template void FeedTensor<uint8_t>(
const std::string&,
const std::vector<TIndex>&,
const uint8_t*,
const Device&,
Workspace*);
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -29,18 +29,16 @@ typedef int64_t TIndex; ...@@ -29,18 +29,16 @@ typedef int64_t TIndex;
class Workspace; class Workspace;
class Device { class Device {
enum DeviceType { CPU, CUDA };
public: public:
EXPORT Device(); EXPORT Device();
EXPORT explicit Device(std::string device_type); EXPORT explicit Device(std::string device_type);
EXPORT Device(std::string device_type, int device_id); EXPORT Device(std::string device_type, int device_id);
EXPORT const DeviceType& device_type() const { return device_type_; } EXPORT const int& device_type() const { return device_type_; }
EXPORT const int device_id() const { return device_id_; } EXPORT const int device_id() const { return device_id_; }
private: private:
DeviceType device_type_; int device_type_;
int device_id_; int device_id_;
}; };
...@@ -52,53 +50,48 @@ EXPORT void ReleaseWorkspace(const std::string& name); ...@@ -52,53 +50,48 @@ EXPORT void ReleaseWorkspace(const std::string& name);
EXPORT void MoveWorkspace(Workspace* main, Workspace* sub); EXPORT void MoveWorkspace(Workspace* main, Workspace* sub);
EXPORT std::string CreateGraph(const std::string& graph_file, Workspace* ws); EXPORT std::string CreateGraph(
const std::string& graph_file,
Workspace* ws);
EXPORT std::string CreateGraph(const std::string& graph_file, const Device& device, Workspace* ws); EXPORT std::string CreateGraph(
const std::string& graph_file,
const Device& device,
Workspace* ws);
EXPORT void RunGraph(const std::string& graph_name, Workspace* ws); EXPORT void RunGraph(
const std::string& graph_name,
Workspace* ws);
EXPORT void CreateTensor(const std::string& name, Workspace* ws); EXPORT void CreateTensor(
const std::string& name,
Workspace* ws);
template <typename T> template <typename T>
void FeedTensor(const std::string& name, EXPORT void FeedTensor(
const std::vector<TIndex>& shape, const std::string& name,
const T* data, const std::vector<TIndex>& shape,
const Device& device, const T* data,
Workspace* ws); const Device& device,
Workspace* ws);
template <typename T> template <typename T>
T* FetchTensor(const std::string& name, EXPORT T* FetchTensor(
std::vector<TIndex>& shape, const std::string& name,
Workspace* ws); std::vector<TIndex>& shape,
Workspace* ws);
template EXPORT float* FetchTensor(const std::string&,
std::vector<TIndex>&, EXPORT void LoadCaffemodel(
Workspace*); const std::string& model_file,
Workspace* ws);
template EXPORT void FeedTensor(const std::string&,
const std::vector<TIndex>&, EXPORT void TransplantCaffeModel(
const float*, const std::string& input_model,
const Device&, const std::string& output_model);
Workspace*);
EXPORT void LoadDragonmodel(
template EXPORT void FeedTensor(const std::string&, const std::string& model_file,
const std::vector<TIndex>&, Workspace* ws);
const int*,
const Device&,
Workspace*);
template EXPORT void FeedTensor(const std::string&,
const std::vector<TIndex>&,
const uint8_t*,
const Device&,
Workspace*);
EXPORT void LoadCaffemodel(const std::string& model_file, Workspace* ws);
EXPORT void TransplantCaffeModel(const std::string& input_model, const std::string& output_model);
EXPORT void LoadDragonmodel(const std::string& model_file, Workspace* ws);
EXPORT void SetLogLevel(const std::string& level); EXPORT void SetLogLevel(const std::string& level);
......
...@@ -231,6 +231,7 @@ PyMethodDef* GetAllMethods() { ...@@ -231,6 +231,7 @@ PyMethodDef* GetAllMethods() {
PYFUNC(RenameTensorCC), PYFUNC(RenameTensorCC),
PYFUNC(TensorFromShapeCC), PYFUNC(TensorFromShapeCC),
PYFUNC(TensorFromPyArrayCC), PYFUNC(TensorFromPyArrayCC),
PYFUNC(TensorFromTensorCC),
PYFUNC(GetTensorNameCC), PYFUNC(GetTensorNameCC),
PYFUNC(GetTensorInfoCC), PYFUNC(GetTensorInfoCC),
PYFUNC(FeedTensorCC), PYFUNC(FeedTensorCC),
......
...@@ -152,6 +152,55 @@ PyObject* TensorFromPyArrayCC(PyObject* self, PyObject* args) { ...@@ -152,6 +152,55 @@ PyObject* TensorFromPyArrayCC(PyObject* self, PyObject* args) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
PyObject* TensorFromTensorCC(PyObject* self, PyObject* args) {
char* dst_name, *src_name;
PyObject* py_dst_ctx = nullptr, *py_src_ctx = nullptr;
if (!PyArg_ParseTuple(args, "ssOO",
&dst_name, &src_name, &py_dst_ctx, &py_src_ctx)) {
PyErr_SetString(PyExc_ValueError,
"Failed to create tensor from tensor.\n"
"Excepted the (dest, src) name and context.");
return nullptr;
}
DeviceOption dst_ctx, src_ctx;
dst_ctx.ParseFromString(PyBytes_AsStringEx(py_dst_ctx));
src_ctx.ParseFromString(PyBytes_AsStringEx(py_src_ctx));
Tensor* srcT = ws()->GetTensor(src_name);
Tensor* dstT = ws()->CreateTensor(dst_name);
dstT->ReshapeLike(*srcT);
dstT->SetMeta(srcT->meta());
if (dst_ctx.device_type() == DeviceType::CUDA) {
if (src_ctx.device_type() == DeviceType::CUDA) {
// CUDA <- CUDA
CUDAContext::Memcpy<CUDAContext, CUDAContext>(
srcT->nbytes(),
dstT->raw_mutable_data<CUDAContext>(),
srcT->raw_data<CUDAContext>());
} else {
// CUDA <- CPU
CUDAContext::Memcpy<CUDAContext, CUDAContext>(
srcT->nbytes(),
dstT->raw_mutable_data<CUDAContext>(),
srcT->raw_data<CPUContext>());
}
} else {
if (src_ctx.device_type() == DeviceType::CUDA) {
// CPU <- CUDA
CUDAContext::Memcpy<CUDAContext, CUDAContext>(
srcT->nbytes(),
dstT->raw_mutable_data<CPUContext>(),
srcT->raw_data<CUDAContext>());
} else {
// CPU <- CPU
CUDAContext::Memcpy<CUDAContext, CUDAContext>(
srcT->nbytes(),
dstT->raw_mutable_data<CPUContext>(),
srcT->raw_data<CPUContext>());
}
}
Py_RETURN_TRUE;
}
inline PyObject* TensorToPyArrayCC(PyObject* self, PyObject* args) { inline PyObject* TensorToPyArrayCC(PyObject* self, PyObject* args) {
Tensor* tensor = ws()->GetTensor(ParseName(self, args)); Tensor* tensor = ws()->GetTensor(ParseName(self, args));
CHECK_GT(tensor->count(), 0); CHECK_GT(tensor->count(), 0);
...@@ -183,7 +232,8 @@ inline PyObject* TensorToPyArrayExCC(PyObject* self, PyObject* args) { ...@@ -183,7 +232,8 @@ inline PyObject* TensorToPyArrayExCC(PyObject* self, PyObject* args) {
return nullptr; return nullptr;
} }
auto* data = const_cast<void*>(tensor->raw_data<CPUContext>()); auto* data = const_cast<void*>(tensor->raw_data<CPUContext>());
PyObject* array = PyArray_SimpleNewFromData(tensor->ndim(), dims.data(), npy_type, data); PyObject* array = PyArray_SimpleNewFromData(
tensor->ndim(), dims.data(), npy_type, data);
Py_XINCREF(array); Py_XINCREF(array);
return array; return array;
} }
...@@ -202,7 +252,8 @@ inline PyObject* ToCUDATensorCC(PyObject* self, PyObject* args) { ...@@ -202,7 +252,8 @@ inline PyObject* ToCUDATensorCC(PyObject* self, PyObject* args) {
char* cname; char* cname;
int device_id; int device_id;
if (!PyArg_ParseTuple(args, "si", &cname, &device_id)) { if (!PyArg_ParseTuple(args, "si", &cname, &device_id)) {
PyErr_SetString(PyExc_ValueError, "Excepted the tensor name and device id."); PyErr_SetString(PyExc_ValueError,
"Excepted the tensor name and device id.");
return nullptr; return nullptr;
} }
Tensor* t = ws()->GetTensor(cname); Tensor* t = ws()->GetTensor(cname);
......
...@@ -23,7 +23,8 @@ inline const int TypeMetaToNPY(const TypeMeta& meta) { ...@@ -23,7 +23,8 @@ inline const int TypeMetaToNPY(const TypeMeta& meta) {
{ TypeMeta::Id<int64_t>(), NPY_INT64 }, { TypeMeta::Id<int64_t>(), NPY_INT64 },
{ TypeMeta::Id<double>(), NPY_FLOAT64 }, { TypeMeta::Id<double>(), NPY_FLOAT64 },
{ TypeMeta::Id<float16>(), NPY_FLOAT16 }, { TypeMeta::Id<float16>(), NPY_FLOAT16 },
{ TypeMeta::Id<uint8_t>(), NPY_UINT8 } { TypeMeta::Id<uint8_t>(), NPY_UINT8 },
{ TypeMeta::Id<char>(), NPY_INT8 }
}; };
return m2npy_type_map.count(meta.id()) ? m2npy_type_map[meta.id()] : -1; return m2npy_type_map.count(meta.id()) ? m2npy_type_map[meta.id()] : -1;
} }
...@@ -35,7 +36,8 @@ inline const TypeMeta& TypeNPYToMeta(int npy_type) { ...@@ -35,7 +36,8 @@ inline const TypeMeta& TypeNPYToMeta(int npy_type) {
{ NPY_INT64, TypeMeta::Make<int64_t>() }, { NPY_INT64, TypeMeta::Make<int64_t>() },
{ NPY_FLOAT64, TypeMeta::Make<double>() }, { NPY_FLOAT64, TypeMeta::Make<double>() },
{ NPY_FLOAT16, TypeMeta::Make<float16>() }, { NPY_FLOAT16, TypeMeta::Make<float16>() },
{ NPY_UINT8, TypeMeta::Make<uint8_t>() } { NPY_UINT8, TypeMeta::Make<uint8_t>() },
{ NPY_INT8, TypeMeta::Make<char>() },
}; };
static TypeMeta unknown_type; static TypeMeta unknown_type;
return npy2m_type_map.count(npy_type) ? npy2m_type_map[npy_type] : unknown_type; return npy2m_type_map.count(npy_type) ? npy2m_type_map[npy_type] : unknown_type;
......
...@@ -24,6 +24,7 @@ from dragon.core.utils import MakeDeviceOption ...@@ -24,6 +24,7 @@ from dragon.core.utils import MakeDeviceOption
__all__ = [ __all__ = [
'FromShape', 'FromShape',
'SetShape', 'SetShape',
'FromTensor',
'FromPyArray', 'FromPyArray',
'SetPyArray', 'SetPyArray',
'ToPyArray', 'ToPyArray',
...@@ -113,6 +114,40 @@ def SetShape(tensor, shape, dtype='float32'): ...@@ -113,6 +114,40 @@ def SetShape(tensor, shape, dtype='float32'):
TensorFromShapeCC(_stringify_tensor(tensor), shape, dtype) TensorFromShapeCC(_stringify_tensor(tensor), shape, dtype)
def FromTensor(src, src_ctx=None, name=None, ctx=None):
"""Create a Tensor from a existing tensor.
Parameters
----------
src_ctx : str
The name of source tensor.
src_ctx : dragon_pb2.DeviceOption
The context of source tensor.
name : str
The optional tensor name for destination tensor.
ctx : dragon_pb2.DeviceOption
The context for destination tensor.
Returns
-------
Tensor
The tensor with the same data as source.
References
----------
The wrapper of ``TensorFromTensorCC``.
"""
if name is None: tensor = Tensor(name=name)
else: tensor = Tensor(_name=name)
if src_ctx is None: src_ctx = MakeDeviceOption(0, 0) # CPUContext
if ctx is None: ctx = MakeDeviceOption(0, 0) # CPUContext
TensorFromTensorCC(
_stringify_tensor(tensor), _stringify_tensor(src),
_stringify_proto(ctx), _stringify_proto(src_ctx))
return tensor
def FromPyArray(array, name=None): def FromPyArray(array, name=None):
"""Create a Tensor from a existing Array. """Create a Tensor from a existing Array.
...@@ -120,7 +155,7 @@ def FromPyArray(array, name=None): ...@@ -120,7 +155,7 @@ def FromPyArray(array, name=None):
Parameters Parameters
---------- ----------
array : np.ndarray array : ndarray
The array for creating the tensor. The array for creating the tensor.
name : str name : str
The optional tensor name. The optional tensor name.
...@@ -152,7 +187,7 @@ def SetPyArray(tensor, array): ...@@ -152,7 +187,7 @@ def SetPyArray(tensor, array):
---------- ----------
tensor : Tensor, str or None tensor : Tensor, str or None
The specific tensor to use. The specific tensor to use.
array : numpy.ndarray array : ndarray
The array for creating the tensor. The array for creating the tensor.
Returns Returns
...@@ -179,7 +214,7 @@ def ToPyArray(tensor): ...@@ -179,7 +214,7 @@ def ToPyArray(tensor):
Returns Returns
------- -------
numpy.ndarray ndarray
The array sharing the memory with original tensor. The array sharing the memory with original tensor.
References References
...@@ -202,7 +237,7 @@ def ToPyArrayEx(tensor): ...@@ -202,7 +237,7 @@ def ToPyArrayEx(tensor):
Returns Returns
------- -------
numpy.ndarray ndarray
The array sharing the memory with original tensor. The array sharing the memory with original tensor.
References References
......
...@@ -149,7 +149,8 @@ List Brief ...@@ -149,7 +149,8 @@ List Brief
`OneHot`_ Generate the one-hot representation of inputs. `OneHot`_ Generate the one-hot representation of inputs.
`Flatten`_ Flatten the input along the given axes. `Flatten`_ Flatten the input along the given axes.
`Reshape`_ Reshape the dimensions of input. `Reshape`_ Reshape the dimensions of input.
`ExpandDims`_ ExpandDims interface of NDArray. `Squeeze`_ Remove the dimensions with size 1.
`ExpandDims`_ Expand the new dimension with size 1 to specific axis.
`Shape`_ Get the dynamic shape of a Tensor. `Shape`_ Get the dynamic shape of a Tensor.
`Arange`_ Return a vector of elements by arange. `Arange`_ Return a vector of elements by arange.
=============== ====================================================================== =============== ======================================================================
...@@ -285,6 +286,7 @@ List Brief ...@@ -285,6 +286,7 @@ List Brief
.. _OneHot: operators/ndarray.html#dragon.operators.ndarray.OneHot .. _OneHot: operators/ndarray.html#dragon.operators.ndarray.OneHot
.. _Flatten: operators/ndarray.html#dragon.operators.ndarray.Flatten .. _Flatten: operators/ndarray.html#dragon.operators.ndarray.Flatten
.. _Reshape: operators/ndarray.html#dragon.operators.ndarray.Reshape .. _Reshape: operators/ndarray.html#dragon.operators.ndarray.Reshape
.. _Squeeze: operators/ndarray.html#dragon.operators.ndarray.Squeeze
.. _ExpandDims: operators/ndarray.html#dragon.operators.ndarray.ExpandDims .. _ExpandDims: operators/ndarray.html#dragon.operators.ndarray.ExpandDims
.. _Shape: operators/ndarray.html#dragon.operators.ndarray.Shape .. _Shape: operators/ndarray.html#dragon.operators.ndarray.Shape
.. _Arange: operators/ndarray.html#dragon.operators.ndarray.Arange .. _Arange: operators/ndarray.html#dragon.operators.ndarray.Arange
......
...@@ -97,7 +97,7 @@ class DataReader(Process): ...@@ -97,7 +97,7 @@ class DataReader(Process):
self._db.close() self._db.close()
self._db.open(self._source) self._db.open(self._source)
self._cur_idx = target_idx self._cur_idx = target_idx
self._db.set(str(self._cur_idx).zfill(self._db_zfill)) self._db.set(str(self._cur_idx).zfill(self._zfill))
def reset(self): def reset(self):
"""Reset the cursor and environment. """Reset the cursor and environment.
...@@ -112,12 +112,12 @@ class DataReader(Process): ...@@ -112,12 +112,12 @@ class DataReader(Process):
self._cur_chunk_idx = 0 self._cur_chunk_idx = 0
self._start_idx = int(self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx]) self._start_idx = int(self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx])
self._start_idx = int(self._start_idx * self._chunk_size) self._start_idx = int(self._start_idx * self._chunk_size)
if self._start_idx >= self._db_size: self.next_chunk() if self._start_idx >= self._num_entries: self.next_chunk()
self._end_idx = self._start_idx + self._chunk_size self._end_idx = self._start_idx + self._chunk_size
self._end_idx = min(self._db_size, self._end_idx) self._end_idx = min(self._num_entries, self._end_idx)
else: else:
self._start_idx = 0 self._start_idx = 0
self._end_idx = self._db_size self._end_idx = self._num_entries
self.redirect(self._start_idx) self.redirect(self._start_idx)
...@@ -145,10 +145,10 @@ class DataReader(Process): ...@@ -145,10 +145,10 @@ class DataReader(Process):
else: else:
self._start_idx = self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx] self._start_idx = self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx]
self._start_idx = self._start_idx * self._chunk_size self._start_idx = self._start_idx * self._chunk_size
if self._start_idx >= self._db_size: self.next_chunk() if self._start_idx >= self._num_entries: self.next_chunk()
else: else:
self._end_idx = self._start_idx + self._chunk_size self._end_idx = self._start_idx + self._chunk_size
self._end_idx = min(self._db_size, self._end_idx) self._end_idx = min(self._num_entries, self._end_idx)
self.redirect(self._start_idx) self.redirect(self._start_idx)
def run(self): def run(self):
...@@ -165,14 +165,14 @@ class DataReader(Process): ...@@ -165,14 +165,14 @@ class DataReader(Process):
# init db # init db
self._db = LMDB() self._db = LMDB()
self._db.open(self._source) self._db.open(self._source)
self._db_zfill = self._db.zfill() self._zfill = self._db.zfill()
self._db_size = self._db.num_entries() self._num_entries = self._db.num_entries()
self._epoch_size = int(self._db_size / self._num_parts + 1) self._epoch_size = int(self._num_entries / self._num_parts + 1)
if self._use_shuffle: if self._use_shuffle:
if self._chunk_size == 1: if self._chunk_size == 1:
# each chunk has at most 1 record [For Fully Shuffle] # each chunk has at most 1 record [For Fully Shuffle]
self._num_shuffle_parts = int(self._db_size / self._chunk_size / self._num_parts) + 1 self._num_shuffle_parts = int(self._num_entries / self._chunk_size / self._num_parts) + 1
else: else:
if self._use_shuffle and self._chunk_size == -1: if self._use_shuffle and self._chunk_size == -1:
# search a optimal chunk size by chunks [For Chunk Shuffle] # search a optimal chunk size by chunks [For Chunk Shuffle]
...@@ -182,12 +182,12 @@ class DataReader(Process): ...@@ -182,12 +182,12 @@ class DataReader(Process):
self._chunk_size = min_chunk_size self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 / self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20))) (self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._db_size / self._num_shuffle_parts / self._num_parts + 1) self._chunk_size = int(self._num_entries / self._num_shuffle_parts / self._num_parts + 1)
else: else:
# each chunk has at most K records [For Multiple Nodes] # each chunk has at most K records [For Multiple Nodes]
# note that if ``shuffle`` and ``multiple_nodes`` are all ``False``, # note that if ``shuffle`` and ``multiple_nodes`` are all ``False``,
# ``chunk_size`` and ``num_shuffle_parts`` are meaningless # ``chunk_size`` and ``num_shuffle_parts`` are meaningless
self._chunk_size = int(self._db_size / self._num_parts) + 1 self._chunk_size = int(self._num_entries / self._num_parts) + 1
self._num_shuffle_parts = 1 self._num_shuffle_parts = 1
self._perm = np.arange(self._num_shuffle_parts) self._perm = np.arange(self._num_shuffle_parts)
......
...@@ -727,11 +727,11 @@ def Reshape(inputs, shape, shape_like=None, **kwargs): ...@@ -727,11 +727,11 @@ def Reshape(inputs, shape, shape_like=None, **kwargs):
Examples Examples
-------- --------
>>> a = Tensor(shape=[1, 2, 3, 4]).Variable() >>> a = Tensor(shape=[1, 2, 3, 4]).Variable()
>>> print Reshape(a, shape=[6, 4]) >>> print(Reshape(a, shape=[6, 4]))
>>> [6, 4] >>> [6, 4]
>>> b = Reshape(a, shape=[-1, 4]) # shape will be [6, 4] in the backend >>> b = Reshape(a, shape=[-1, 4]) # shape will be [6, 4] in the backend
>>> print b.shape >>> print(b.shape)
>>> [1, 4] # fake dimension at axis 0 >>> [1, 4] # fake dimension at axis 0
""" """
...@@ -766,15 +766,58 @@ def Reshape(inputs, shape, shape_like=None, **kwargs): ...@@ -766,15 +766,58 @@ def Reshape(inputs, shape, shape_like=None, **kwargs):
return output return output
def ExpandDims(inputs, axis=-1, **kwargs): def Squeeze(inputs, axis=None, **kwargs):
"""ExpandDims interface of NDArray. """Remove the dimensions with size 1.
Set ``axis`` to remove the specific position.
Parameters
----------
inputs : Tensor
The input tensor.
axis : int or None
The specific axis to remove.
Returns
-------
Tensor
The output tensor.
Examples
--------
>>> a = Tensor(shape=[2, 1, 3, 4]).Variable()
>>> print(Squeeze(a).shape)
>>> print(Squeeze(a, axis=0).shape)
"""
CheckInputs(inputs, 1)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='Squeeze', **arguments)
if inputs.shape is not None:
output_shape = []
if axis: axis += (0 if axis >= 0 else len(inputs.shape))
for idx, dim in enumerate(inputs.shape[:]):
if dim != 1 or \
(axis and dim == 1 and idx != axis):
output_shape.append(dim)
output.shape = output_shape
return output
def ExpandDims(inputs, axis, **kwargs):
"""Expand the new dimension with size 1 to specific axis.
Negative ``axis`` is equal to ``axis = axis + num_axes + 1``.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
axis : int axis : int
The insert position of new dimension. Default is ``-1`` (Push Back). The insert axis of new dimension.
Returns Returns
------- -------
...@@ -784,9 +827,8 @@ def ExpandDims(inputs, axis=-1, **kwargs): ...@@ -784,9 +827,8 @@ def ExpandDims(inputs, axis=-1, **kwargs):
Examples Examples
-------- --------
>>> a = Tensor(shape=[1, 2, 3, 4]).Variable() >>> a = Tensor(shape=[1, 2, 3, 4]).Variable()
>>> print ExpandDims(a).shape >>> print(ExpandDims(a).shape)
>>> print(ExpandDims(a, axis=2).shape)
>>> print ExpandDims(a, axis=2).shape
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
...@@ -796,7 +838,8 @@ def ExpandDims(inputs, axis=-1, **kwargs): ...@@ -796,7 +838,8 @@ def ExpandDims(inputs, axis=-1, **kwargs):
if inputs.shape is not None: if inputs.shape is not None:
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
if axis == -1 or axis >= len(inputs.shape): axis += (0 if axis >= 0 else len(inputs.shape) + 1)
if axis < 0 or axis >= len(inputs.shape):
output.shape.append(np.long(1)) output.shape.append(np.long(1))
else: output.shape.insert(axis, np.long(1)) else: output.shape.insert(axis, np.long(1))
......
...@@ -129,6 +129,7 @@ OneHot = ndarray.OneHot ...@@ -129,6 +129,7 @@ OneHot = ndarray.OneHot
Flatten = ndarray.Flatten Flatten = ndarray.Flatten
Reshape = ndarray.Reshape Reshape = ndarray.Reshape
ExpandDims = ndarray.ExpandDims ExpandDims = ndarray.ExpandDims
Squeeze = ndarray.Squeeze
Shape = ndarray.Shape Shape = ndarray.Shape
Arange = ndarray.Arange Arange = ndarray.Arange
......
syntax = "proto2"; syntax = "proto2";
package dragon;
message TensorProto { message TensorProto {
repeated int32 dims = 1; repeated int32 dims = 1;
enum DataType { enum DataType {
......
...@@ -18,14 +18,14 @@ _sym_db = _symbol_database.Default() ...@@ -18,14 +18,14 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='dragon.proto', name='dragon.proto',
package='', package='dragon',
serialized_pb=_b('\n\x0c\x64ragon.proto\"\xf7\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x05\x12/\n\tdata_type\x18\x02 \x01(\x0e\x32\x15.TensorProto.DataType:\x05\x46LOAT\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x11\n\tbyte_data\x18\x05 \x01(\x0c\x12\x13\n\x0bstring_data\x18\x06 \x03(\x0c\x12\x0c\n\x04name\x18\x07 \x01(\t\"C\n\x08\x44\x61taType\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05INT32\x10\x02\x12\x08\n\x04\x42YTE\x10\x03\x12\n\n\x06STRING\x10\x04\x12\x0b\n\x07\x46LOAT16\x10\x0c\",\n\x0cTensorProtos\x12\x1c\n\x06protos\x18\x01 \x03(\x0b\x32\x0c.TensorProto\"\x80\x01\n\x08\x41rgument\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x05\x12\x0b\n\x03i64\x18\t \x01(\x03\x12\t\n\x01s\x18\x04 \x01(\t\x12\t\n\x01\x62\x18\x08 \x01(\x08\x12\x0e\n\x06\x66loats\x18\x05 \x03(\x02\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0f\n\x07strings\x18\x07 \x03(\t\"s\n\x0c\x44\x65viceOption\x12%\n\x0b\x64\x65vice_type\x18\x01 \x01(\x0e\x32\x0b.DeviceType:\x03\x43PU\x12\x14\n\tdevice_id\x18\x02 \x01(\x05:\x01\x30\x12\x16\n\x0brandom_seed\x18\x03 \x01(\r:\x01\x33\x12\x0e\n\x06\x65ngine\x18\x04 \x01(\t\"\x86\x01\n\x0bOperatorDef\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x16\n\x03\x61rg\x18\x05 \x03(\x0b\x32\t.Argument\x12$\n\rdevice_option\x18\x06 \x01(\x0b\x32\r.DeviceOption\"=\n\x0eGradientTarget\x12\x0c\n\x04\x63ost\x18\x01 \x01(\t\x12\x0b\n\x03wrt\x18\x02 \x01(\t\x12\x10\n\x08\x65xternal\x18\x03 \x01(\t\"R\n\x0cUpdateTarget\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0e\n\x06tensor\x18\x03 \x03(\t\x12\x16\n\x03\x61rg\x18\x04 \x03(\x0b\x32\t.Argument\"\x8d\x02\n\x0cTensorFiller\x12\x0e\n\x06tensor\x18\x01 \x01(\t\x12\x16\n\x04type\x18\x02 \x01(\t:\x08\x63onstant\x12\x10\n\x05value\x18\x03 \x01(\x02:\x01\x30\x12\x0e\n\x03low\x18\x04 \x01(\x02:\x01\x30\x12\x0f\n\x04high\x18\x05 \x01(\x02:\x01\x31\x12\x0f\n\x04mean\x18\x06 \x01(\x02:\x01\x30\x12\x0e\n\x03std\x18\x07 \x01(\x02:\x01\x31\x12\x10\n\x05scale\x18\x08 \x01(\x02:\x01\x33\x12\x39\n\rvariance_norm\x18\t \x01(\x0e\x32\x1a.TensorFiller.VarianceNorm:\x06\x46\x41N_IN\"4\n\x0cVarianceNorm\x12\n\n\x06\x46\x41N_IN\x10\x00\x12\x0b\n\x07\x46\x41N_OUT\x10\x01\x12\x0b\n\x07\x46\x41N_AVG\x10\x02\"\xd8\x01\n\x08GraphDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x18\n\x02op\x18\x02 \x03(\x0b\x32\x0c.OperatorDef\x12\x12\n\ngraph_type\x18\x03 \x01(\t\x12$\n\rdevice_option\x18\x05 \x01(\x0b\x32\r.DeviceOption\x12\x16\n\x03\x61rg\x18\x06 \x03(\x0b\x32\t.Argument\x12\x0e\n\x06target\x18\x07 \x03(\t\x12!\n\x08g_target\x18\x08 \x03(\x0b\x32\x0f.GradientTarget\x12\x1f\n\x08u_target\x18\t \x03(\x0b\x32\r.UpdateTarget*+\n\nDeviceType\x12\x07\n\x03\x43PU\x10\x00\x12\x08\n\x04\x43UDA\x10\x01\x12\n\n\x06OPENCL\x10\x02') serialized_pb=_b('\n\x0c\x64ragon.proto\x12\x06\x64ragon\"\xfe\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x05\x12\x36\n\tdata_type\x18\x02 \x01(\x0e\x32\x1c.dragon.TensorProto.DataType:\x05\x46LOAT\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x11\n\tbyte_data\x18\x05 \x01(\x0c\x12\x13\n\x0bstring_data\x18\x06 \x03(\x0c\x12\x0c\n\x04name\x18\x07 \x01(\t\"C\n\x08\x44\x61taType\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05INT32\x10\x02\x12\x08\n\x04\x42YTE\x10\x03\x12\n\n\x06STRING\x10\x04\x12\x0b\n\x07\x46LOAT16\x10\x0c\"3\n\x0cTensorProtos\x12#\n\x06protos\x18\x01 \x03(\x0b\x32\x13.dragon.TensorProto\"\x80\x01\n\x08\x41rgument\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x05\x12\x0b\n\x03i64\x18\t \x01(\x03\x12\t\n\x01s\x18\x04 \x01(\t\x12\t\n\x01\x62\x18\x08 \x01(\x08\x12\x0e\n\x06\x66loats\x18\x05 \x03(\x02\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0f\n\x07strings\x18\x07 \x03(\t\"z\n\x0c\x44\x65viceOption\x12,\n\x0b\x64\x65vice_type\x18\x01 \x01(\x0e\x32\x12.dragon.DeviceType:\x03\x43PU\x12\x14\n\tdevice_id\x18\x02 \x01(\x05:\x01\x30\x12\x16\n\x0brandom_seed\x18\x03 \x01(\r:\x01\x33\x12\x0e\n\x06\x65ngine\x18\x04 \x01(\t\"\x94\x01\n\x0bOperatorDef\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x1d\n\x03\x61rg\x18\x05 \x03(\x0b\x32\x10.dragon.Argument\x12+\n\rdevice_option\x18\x06 \x01(\x0b\x32\x14.dragon.DeviceOption\"=\n\x0eGradientTarget\x12\x0c\n\x04\x63ost\x18\x01 \x01(\t\x12\x0b\n\x03wrt\x18\x02 \x01(\t\x12\x10\n\x08\x65xternal\x18\x03 \x01(\t\"Y\n\x0cUpdateTarget\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0e\n\x06tensor\x18\x03 \x03(\t\x12\x1d\n\x03\x61rg\x18\x04 \x03(\x0b\x32\x10.dragon.Argument\"\x94\x02\n\x0cTensorFiller\x12\x0e\n\x06tensor\x18\x01 \x01(\t\x12\x16\n\x04type\x18\x02 \x01(\t:\x08\x63onstant\x12\x10\n\x05value\x18\x03 \x01(\x02:\x01\x30\x12\x0e\n\x03low\x18\x04 \x01(\x02:\x01\x30\x12\x0f\n\x04high\x18\x05 \x01(\x02:\x01\x31\x12\x0f\n\x04mean\x18\x06 \x01(\x02:\x01\x30\x12\x0e\n\x03std\x18\x07 \x01(\x02:\x01\x31\x12\x10\n\x05scale\x18\x08 \x01(\x02:\x01\x33\x12@\n\rvariance_norm\x18\t \x01(\x0e\x32!.dragon.TensorFiller.VarianceNorm:\x06\x46\x41N_IN\"4\n\x0cVarianceNorm\x12\n\n\x06\x46\x41N_IN\x10\x00\x12\x0b\n\x07\x46\x41N_OUT\x10\x01\x12\x0b\n\x07\x46\x41N_AVG\x10\x02\"\xfb\x01\n\x08GraphDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1f\n\x02op\x18\x02 \x03(\x0b\x32\x13.dragon.OperatorDef\x12\x12\n\ngraph_type\x18\x03 \x01(\t\x12+\n\rdevice_option\x18\x05 \x01(\x0b\x32\x14.dragon.DeviceOption\x12\x1d\n\x03\x61rg\x18\x06 \x03(\x0b\x32\x10.dragon.Argument\x12\x0e\n\x06target\x18\x07 \x03(\t\x12(\n\x08g_target\x18\x08 \x03(\x0b\x32\x16.dragon.GradientTarget\x12&\n\x08u_target\x18\t \x03(\x0b\x32\x14.dragon.UpdateTarget*+\n\nDeviceType\x12\x07\n\x03\x43PU\x10\x00\x12\x08\n\x04\x43UDA\x10\x01\x12\n\n\x06OPENCL\x10\x02')
) )
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DEVICETYPE = _descriptor.EnumDescriptor( _DEVICETYPE = _descriptor.EnumDescriptor(
name='DeviceType', name='DeviceType',
full_name='DeviceType', full_name='dragon.DeviceType',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
...@@ -44,8 +44,8 @@ _DEVICETYPE = _descriptor.EnumDescriptor( ...@@ -44,8 +44,8 @@ _DEVICETYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=1335, serialized_start=1427,
serialized_end=1378, serialized_end=1470,
) )
_sym_db.RegisterEnumDescriptor(_DEVICETYPE) _sym_db.RegisterEnumDescriptor(_DEVICETYPE)
...@@ -57,7 +57,7 @@ OPENCL = 2 ...@@ -57,7 +57,7 @@ OPENCL = 2
_TENSORPROTO_DATATYPE = _descriptor.EnumDescriptor( _TENSORPROTO_DATATYPE = _descriptor.EnumDescriptor(
name='DataType', name='DataType',
full_name='TensorProto.DataType', full_name='dragon.TensorProto.DataType',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
...@@ -84,14 +84,14 @@ _TENSORPROTO_DATATYPE = _descriptor.EnumDescriptor( ...@@ -84,14 +84,14 @@ _TENSORPROTO_DATATYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=197, serialized_start=212,
serialized_end=264, serialized_end=279,
) )
_sym_db.RegisterEnumDescriptor(_TENSORPROTO_DATATYPE) _sym_db.RegisterEnumDescriptor(_TENSORPROTO_DATATYPE)
_TENSORFILLER_VARIANCENORM = _descriptor.EnumDescriptor( _TENSORFILLER_VARIANCENORM = _descriptor.EnumDescriptor(
name='VarianceNorm', name='VarianceNorm',
full_name='TensorFiller.VarianceNorm', full_name='dragon.TensorFiller.VarianceNorm',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
...@@ -110,63 +110,63 @@ _TENSORFILLER_VARIANCENORM = _descriptor.EnumDescriptor( ...@@ -110,63 +110,63 @@ _TENSORFILLER_VARIANCENORM = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=1062, serialized_start=1119,
serialized_end=1114, serialized_end=1171,
) )
_sym_db.RegisterEnumDescriptor(_TENSORFILLER_VARIANCENORM) _sym_db.RegisterEnumDescriptor(_TENSORFILLER_VARIANCENORM)
_TENSORPROTO = _descriptor.Descriptor( _TENSORPROTO = _descriptor.Descriptor(
name='TensorProto', name='TensorProto',
full_name='TensorProto', full_name='dragon.TensorProto',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='dims', full_name='TensorProto.dims', index=0, name='dims', full_name='dragon.TensorProto.dims', index=0,
number=1, type=5, cpp_type=1, label=3, number=1, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='data_type', full_name='TensorProto.data_type', index=1, name='data_type', full_name='dragon.TensorProto.data_type', index=1,
number=2, type=14, cpp_type=8, label=1, number=2, type=14, cpp_type=8, label=1,
has_default_value=True, default_value=1, has_default_value=True, default_value=1,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='float_data', full_name='TensorProto.float_data', index=2, name='float_data', full_name='dragon.TensorProto.float_data', index=2,
number=3, type=2, cpp_type=6, label=3, number=3, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='int32_data', full_name='TensorProto.int32_data', index=3, name='int32_data', full_name='dragon.TensorProto.int32_data', index=3,
number=4, type=5, cpp_type=1, label=3, number=4, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='byte_data', full_name='TensorProto.byte_data', index=4, name='byte_data', full_name='dragon.TensorProto.byte_data', index=4,
number=5, type=12, cpp_type=9, label=1, number=5, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='string_data', full_name='TensorProto.string_data', index=5, name='string_data', full_name='dragon.TensorProto.string_data', index=5,
number=6, type=12, cpp_type=9, label=3, number=6, type=12, cpp_type=9, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', full_name='TensorProto.name', index=6, name='name', full_name='dragon.TensorProto.name', index=6,
number=7, type=9, cpp_type=9, label=1, number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -184,20 +184,20 @@ _TENSORPROTO = _descriptor.Descriptor( ...@@ -184,20 +184,20 @@ _TENSORPROTO = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=17, serialized_start=25,
serialized_end=264, serialized_end=279,
) )
_TENSORPROTOS = _descriptor.Descriptor( _TENSORPROTOS = _descriptor.Descriptor(
name='TensorProtos', name='TensorProtos',
full_name='TensorProtos', full_name='dragon.TensorProtos',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='protos', full_name='TensorProtos.protos', index=0, name='protos', full_name='dragon.TensorProtos.protos', index=0,
number=1, type=11, cpp_type=10, label=3, number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -214,76 +214,76 @@ _TENSORPROTOS = _descriptor.Descriptor( ...@@ -214,76 +214,76 @@ _TENSORPROTOS = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=266, serialized_start=281,
serialized_end=310, serialized_end=332,
) )
_ARGUMENT = _descriptor.Descriptor( _ARGUMENT = _descriptor.Descriptor(
name='Argument', name='Argument',
full_name='Argument', full_name='dragon.Argument',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', full_name='Argument.name', index=0, name='name', full_name='dragon.Argument.name', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='f', full_name='Argument.f', index=1, name='f', full_name='dragon.Argument.f', index=1,
number=2, type=2, cpp_type=6, label=1, number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='i', full_name='Argument.i', index=2, name='i', full_name='dragon.Argument.i', index=2,
number=3, type=5, cpp_type=1, label=1, number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='i64', full_name='Argument.i64', index=3, name='i64', full_name='dragon.Argument.i64', index=3,
number=9, type=3, cpp_type=2, label=1, number=9, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='s', full_name='Argument.s', index=4, name='s', full_name='dragon.Argument.s', index=4,
number=4, type=9, cpp_type=9, label=1, number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='b', full_name='Argument.b', index=5, name='b', full_name='dragon.Argument.b', index=5,
number=8, type=8, cpp_type=7, label=1, number=8, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False, has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='floats', full_name='Argument.floats', index=6, name='floats', full_name='dragon.Argument.floats', index=6,
number=5, type=2, cpp_type=6, label=3, number=5, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='ints', full_name='Argument.ints', index=7, name='ints', full_name='dragon.Argument.ints', index=7,
number=6, type=5, cpp_type=1, label=3, number=6, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='strings', full_name='Argument.strings', index=8, name='strings', full_name='dragon.Argument.strings', index=8,
number=7, type=9, cpp_type=9, label=3, number=7, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -300,41 +300,41 @@ _ARGUMENT = _descriptor.Descriptor( ...@@ -300,41 +300,41 @@ _ARGUMENT = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=313, serialized_start=335,
serialized_end=441, serialized_end=463,
) )
_DEVICEOPTION = _descriptor.Descriptor( _DEVICEOPTION = _descriptor.Descriptor(
name='DeviceOption', name='DeviceOption',
full_name='DeviceOption', full_name='dragon.DeviceOption',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='device_type', full_name='DeviceOption.device_type', index=0, name='device_type', full_name='dragon.DeviceOption.device_type', index=0,
number=1, type=14, cpp_type=8, label=1, number=1, type=14, cpp_type=8, label=1,
has_default_value=True, default_value=0, has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='device_id', full_name='DeviceOption.device_id', index=1, name='device_id', full_name='dragon.DeviceOption.device_id', index=1,
number=2, type=5, cpp_type=1, label=1, number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0, has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='random_seed', full_name='DeviceOption.random_seed', index=2, name='random_seed', full_name='dragon.DeviceOption.random_seed', index=2,
number=3, type=13, cpp_type=3, label=1, number=3, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=3, has_default_value=True, default_value=3,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='engine', full_name='DeviceOption.engine', index=3, name='engine', full_name='dragon.DeviceOption.engine', index=3,
number=4, type=9, cpp_type=9, label=1, number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -351,55 +351,55 @@ _DEVICEOPTION = _descriptor.Descriptor( ...@@ -351,55 +351,55 @@ _DEVICEOPTION = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=443, serialized_start=465,
serialized_end=558, serialized_end=587,
) )
_OPERATORDEF = _descriptor.Descriptor( _OPERATORDEF = _descriptor.Descriptor(
name='OperatorDef', name='OperatorDef',
full_name='OperatorDef', full_name='dragon.OperatorDef',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='input', full_name='OperatorDef.input', index=0, name='input', full_name='dragon.OperatorDef.input', index=0,
number=1, type=9, cpp_type=9, label=3, number=1, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='output', full_name='OperatorDef.output', index=1, name='output', full_name='dragon.OperatorDef.output', index=1,
number=2, type=9, cpp_type=9, label=3, number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', full_name='OperatorDef.name', index=2, name='name', full_name='dragon.OperatorDef.name', index=2,
number=3, type=9, cpp_type=9, label=1, number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='type', full_name='OperatorDef.type', index=3, name='type', full_name='dragon.OperatorDef.type', index=3,
number=4, type=9, cpp_type=9, label=1, number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='arg', full_name='OperatorDef.arg', index=4, name='arg', full_name='dragon.OperatorDef.arg', index=4,
number=5, type=11, cpp_type=10, label=3, number=5, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='device_option', full_name='OperatorDef.device_option', index=5, name='device_option', full_name='dragon.OperatorDef.device_option', index=5,
number=6, type=11, cpp_type=10, label=1, number=6, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -416,34 +416,34 @@ _OPERATORDEF = _descriptor.Descriptor( ...@@ -416,34 +416,34 @@ _OPERATORDEF = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=561, serialized_start=590,
serialized_end=695, serialized_end=738,
) )
_GRADIENTTARGET = _descriptor.Descriptor( _GRADIENTTARGET = _descriptor.Descriptor(
name='GradientTarget', name='GradientTarget',
full_name='GradientTarget', full_name='dragon.GradientTarget',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='cost', full_name='GradientTarget.cost', index=0, name='cost', full_name='dragon.GradientTarget.cost', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='wrt', full_name='GradientTarget.wrt', index=1, name='wrt', full_name='dragon.GradientTarget.wrt', index=1,
number=2, type=9, cpp_type=9, label=1, number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='external', full_name='GradientTarget.external', index=2, name='external', full_name='dragon.GradientTarget.external', index=2,
number=3, type=9, cpp_type=9, label=1, number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -460,41 +460,41 @@ _GRADIENTTARGET = _descriptor.Descriptor( ...@@ -460,41 +460,41 @@ _GRADIENTTARGET = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=697, serialized_start=740,
serialized_end=758, serialized_end=801,
) )
_UPDATETARGET = _descriptor.Descriptor( _UPDATETARGET = _descriptor.Descriptor(
name='UpdateTarget', name='UpdateTarget',
full_name='UpdateTarget', full_name='dragon.UpdateTarget',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', full_name='UpdateTarget.name', index=0, name='name', full_name='dragon.UpdateTarget.name', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='type', full_name='UpdateTarget.type', index=1, name='type', full_name='dragon.UpdateTarget.type', index=1,
number=2, type=9, cpp_type=9, label=1, number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='tensor', full_name='UpdateTarget.tensor', index=2, name='tensor', full_name='dragon.UpdateTarget.tensor', index=2,
number=3, type=9, cpp_type=9, label=3, number=3, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='arg', full_name='UpdateTarget.arg', index=3, name='arg', full_name='dragon.UpdateTarget.arg', index=3,
number=4, type=11, cpp_type=10, label=3, number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -511,76 +511,76 @@ _UPDATETARGET = _descriptor.Descriptor( ...@@ -511,76 +511,76 @@ _UPDATETARGET = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=760, serialized_start=803,
serialized_end=842, serialized_end=892,
) )
_TENSORFILLER = _descriptor.Descriptor( _TENSORFILLER = _descriptor.Descriptor(
name='TensorFiller', name='TensorFiller',
full_name='TensorFiller', full_name='dragon.TensorFiller',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='tensor', full_name='TensorFiller.tensor', index=0, name='tensor', full_name='dragon.TensorFiller.tensor', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='type', full_name='TensorFiller.type', index=1, name='type', full_name='dragon.TensorFiller.type', index=1,
number=2, type=9, cpp_type=9, label=1, number=2, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("constant").decode('utf-8'), has_default_value=True, default_value=_b("constant").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='value', full_name='TensorFiller.value', index=2, name='value', full_name='dragon.TensorFiller.value', index=2,
number=3, type=2, cpp_type=6, label=1, number=3, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=0, has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='low', full_name='TensorFiller.low', index=3, name='low', full_name='dragon.TensorFiller.low', index=3,
number=4, type=2, cpp_type=6, label=1, number=4, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=0, has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='high', full_name='TensorFiller.high', index=4, name='high', full_name='dragon.TensorFiller.high', index=4,
number=5, type=2, cpp_type=6, label=1, number=5, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=1, has_default_value=True, default_value=1,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='mean', full_name='TensorFiller.mean', index=5, name='mean', full_name='dragon.TensorFiller.mean', index=5,
number=6, type=2, cpp_type=6, label=1, number=6, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=0, has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='std', full_name='TensorFiller.std', index=6, name='std', full_name='dragon.TensorFiller.std', index=6,
number=7, type=2, cpp_type=6, label=1, number=7, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=1, has_default_value=True, default_value=1,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='scale', full_name='TensorFiller.scale', index=7, name='scale', full_name='dragon.TensorFiller.scale', index=7,
number=8, type=2, cpp_type=6, label=1, number=8, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=3, has_default_value=True, default_value=3,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='variance_norm', full_name='TensorFiller.variance_norm', index=8, name='variance_norm', full_name='dragon.TensorFiller.variance_norm', index=8,
number=9, type=14, cpp_type=8, label=1, number=9, type=14, cpp_type=8, label=1,
has_default_value=True, default_value=0, has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -598,69 +598,69 @@ _TENSORFILLER = _descriptor.Descriptor( ...@@ -598,69 +598,69 @@ _TENSORFILLER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=845, serialized_start=895,
serialized_end=1114, serialized_end=1171,
) )
_GRAPHDEF = _descriptor.Descriptor( _GRAPHDEF = _descriptor.Descriptor(
name='GraphDef', name='GraphDef',
full_name='GraphDef', full_name='dragon.GraphDef',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', full_name='GraphDef.name', index=0, name='name', full_name='dragon.GraphDef.name', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='op', full_name='GraphDef.op', index=1, name='op', full_name='dragon.GraphDef.op', index=1,
number=2, type=11, cpp_type=10, label=3, number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='graph_type', full_name='GraphDef.graph_type', index=2, name='graph_type', full_name='dragon.GraphDef.graph_type', index=2,
number=3, type=9, cpp_type=9, label=1, number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='device_option', full_name='GraphDef.device_option', index=3, name='device_option', full_name='dragon.GraphDef.device_option', index=3,
number=5, type=11, cpp_type=10, label=1, number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='arg', full_name='GraphDef.arg', index=4, name='arg', full_name='dragon.GraphDef.arg', index=4,
number=6, type=11, cpp_type=10, label=3, number=6, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='target', full_name='GraphDef.target', index=5, name='target', full_name='dragon.GraphDef.target', index=5,
number=7, type=9, cpp_type=9, label=3, number=7, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='g_target', full_name='GraphDef.g_target', index=6, name='g_target', full_name='dragon.GraphDef.g_target', index=6,
number=8, type=11, cpp_type=10, label=3, number=8, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='u_target', full_name='GraphDef.u_target', index=7, name='u_target', full_name='dragon.GraphDef.u_target', index=7,
number=9, type=11, cpp_type=10, label=3, number=9, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -677,8 +677,8 @@ _GRAPHDEF = _descriptor.Descriptor( ...@@ -677,8 +677,8 @@ _GRAPHDEF = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=1117, serialized_start=1174,
serialized_end=1333, serialized_end=1425,
) )
_TENSORPROTO.fields_by_name['data_type'].enum_type = _TENSORPROTO_DATATYPE _TENSORPROTO.fields_by_name['data_type'].enum_type = _TENSORPROTO_DATATYPE
...@@ -709,63 +709,63 @@ DESCRIPTOR.enum_types_by_name['DeviceType'] = _DEVICETYPE ...@@ -709,63 +709,63 @@ DESCRIPTOR.enum_types_by_name['DeviceType'] = _DEVICETYPE
TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), dict( TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), dict(
DESCRIPTOR = _TENSORPROTO, DESCRIPTOR = _TENSORPROTO,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:TensorProto) # @@protoc_insertion_point(class_scope:dragon.TensorProto)
)) ))
_sym_db.RegisterMessage(TensorProto) _sym_db.RegisterMessage(TensorProto)
TensorProtos = _reflection.GeneratedProtocolMessageType('TensorProtos', (_message.Message,), dict( TensorProtos = _reflection.GeneratedProtocolMessageType('TensorProtos', (_message.Message,), dict(
DESCRIPTOR = _TENSORPROTOS, DESCRIPTOR = _TENSORPROTOS,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:TensorProtos) # @@protoc_insertion_point(class_scope:dragon.TensorProtos)
)) ))
_sym_db.RegisterMessage(TensorProtos) _sym_db.RegisterMessage(TensorProtos)
Argument = _reflection.GeneratedProtocolMessageType('Argument', (_message.Message,), dict( Argument = _reflection.GeneratedProtocolMessageType('Argument', (_message.Message,), dict(
DESCRIPTOR = _ARGUMENT, DESCRIPTOR = _ARGUMENT,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:Argument) # @@protoc_insertion_point(class_scope:dragon.Argument)
)) ))
_sym_db.RegisterMessage(Argument) _sym_db.RegisterMessage(Argument)
DeviceOption = _reflection.GeneratedProtocolMessageType('DeviceOption', (_message.Message,), dict( DeviceOption = _reflection.GeneratedProtocolMessageType('DeviceOption', (_message.Message,), dict(
DESCRIPTOR = _DEVICEOPTION, DESCRIPTOR = _DEVICEOPTION,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:DeviceOption) # @@protoc_insertion_point(class_scope:dragon.DeviceOption)
)) ))
_sym_db.RegisterMessage(DeviceOption) _sym_db.RegisterMessage(DeviceOption)
OperatorDef = _reflection.GeneratedProtocolMessageType('OperatorDef', (_message.Message,), dict( OperatorDef = _reflection.GeneratedProtocolMessageType('OperatorDef', (_message.Message,), dict(
DESCRIPTOR = _OPERATORDEF, DESCRIPTOR = _OPERATORDEF,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:OperatorDef) # @@protoc_insertion_point(class_scope:dragon.OperatorDef)
)) ))
_sym_db.RegisterMessage(OperatorDef) _sym_db.RegisterMessage(OperatorDef)
GradientTarget = _reflection.GeneratedProtocolMessageType('GradientTarget', (_message.Message,), dict( GradientTarget = _reflection.GeneratedProtocolMessageType('GradientTarget', (_message.Message,), dict(
DESCRIPTOR = _GRADIENTTARGET, DESCRIPTOR = _GRADIENTTARGET,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:GradientTarget) # @@protoc_insertion_point(class_scope:dragon.GradientTarget)
)) ))
_sym_db.RegisterMessage(GradientTarget) _sym_db.RegisterMessage(GradientTarget)
UpdateTarget = _reflection.GeneratedProtocolMessageType('UpdateTarget', (_message.Message,), dict( UpdateTarget = _reflection.GeneratedProtocolMessageType('UpdateTarget', (_message.Message,), dict(
DESCRIPTOR = _UPDATETARGET, DESCRIPTOR = _UPDATETARGET,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:UpdateTarget) # @@protoc_insertion_point(class_scope:dragon.UpdateTarget)
)) ))
_sym_db.RegisterMessage(UpdateTarget) _sym_db.RegisterMessage(UpdateTarget)
TensorFiller = _reflection.GeneratedProtocolMessageType('TensorFiller', (_message.Message,), dict( TensorFiller = _reflection.GeneratedProtocolMessageType('TensorFiller', (_message.Message,), dict(
DESCRIPTOR = _TENSORFILLER, DESCRIPTOR = _TENSORFILLER,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:TensorFiller) # @@protoc_insertion_point(class_scope:dragon.TensorFiller)
)) ))
_sym_db.RegisterMessage(TensorFiller) _sym_db.RegisterMessage(TensorFiller)
GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), dict( GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), dict(
DESCRIPTOR = _GRAPHDEF, DESCRIPTOR = _GRAPHDEF,
__module__ = 'dragon_pb2' __module__ = 'dragon_pb2'
# @@protoc_insertion_point(class_scope:GraphDef) # @@protoc_insertion_point(class_scope:dragon.GraphDef)
)) ))
_sym_db.RegisterMessage(GraphDef) _sym_db.RegisterMessage(GraphDef)
......
...@@ -14,7 +14,7 @@ from __future__ import division ...@@ -14,7 +14,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
version = '0.2.2' version = '0.2.2'
full_version = '0.2.2.9' full_version = '0.2.2.10'
release = False release = False
if not release: if not release:
......
...@@ -115,8 +115,8 @@ class Module(object): ...@@ -115,8 +115,8 @@ class Module(object):
def _load_state_dict_key_mismatch(self, full_name, name, is_missing): def _load_state_dict_key_mismatch(self, full_name, name, is_missing):
pass pass
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True, verbose=True):
logger.info('Load the state dict from numpy arrays.') if verbose: logger.info('Load the state dict.')
def submodule_key_mismatch(full_name, is_missing): def submodule_key_mismatch(full_name, is_missing):
module = self module = self
names = full_name.split(".") names = full_name.split(".")
...@@ -131,9 +131,6 @@ class Module(object): ...@@ -131,9 +131,6 @@ class Module(object):
own_state = self.state_dict() own_state = self.state_dict()
for name, param in state_dict.items(): for name, param in state_dict.items():
if name in own_state: if name in own_state:
if not isinstance(param, np.ndarray):
raise ValueError('PyTorch@Dragon can only load params '
'that saved as numpy array.')
state_shape = own_state[name].shape state_shape = own_state[name].shape
param_shape = param.shape param_shape = param.shape
if state_shape != param_shape: if state_shape != param_shape:
...@@ -145,8 +142,15 @@ class Module(object): ...@@ -145,8 +142,15 @@ class Module(object):
raise ValueError('DType of state({}) is {}, \n' raise ValueError('DType of state({}) is {}, \n'
'While load from a PyArray of {}.'.format(name, 'While load from a PyArray of {}.'.format(name,
own_state[name].dtype, str(param.dtype))) own_state[name].dtype, str(param.dtype)))
dg.workspace.FeedTensor(own_state[name].name, param) if isinstance(param, Tensor):
logger.info('* Tensor({}) loaded, Size: ({})'.format(name, own_state[name].copy_(param)
elif isinstance(param, np.ndarray):
dg.tensor_utils.SetPyArray(own_state[name], param)
else:
raise ValueError('Excepted the type of source state is either '
'torch.Tensor or numpy.ndarray, got {}.'.format(type(param)))
if verbose:
logger.info('* Tensor({}) loaded, Size: ({})'.format(name,
', '.join([str(d) for d in param_shape]))) ', '.join([str(d) for d in param_shape])))
if strict: if strict:
missing = set(own_state.keys()) - set(state_dict.keys()) missing = set(own_state.keys()) - set(state_dict.keys())
......
...@@ -18,7 +18,7 @@ from dragon.vm.torch.module import Module ...@@ -18,7 +18,7 @@ from dragon.vm.torch.module import Module
from dragon.vm.torch.tensor import Parameter from dragon.vm.torch.tensor import Parameter
from .modules.conv import Conv2d, ConvTranspose2d from .modules.conv import Conv2d, ConvTranspose2d
from .modules.pooling import MaxPool2d, AvgPool2d from .modules.pooling import MaxPool2d, AvgPool2d
from .modules.activation import ReLU, Sigmoid, Softmax from .modules.activation import ReLU, LeakyReLU, Sigmoid, Softmax
from .modules.linear import Linear from .modules.linear import Linear
from .modules.loss import CrossEntropyLoss from .modules.loss import CrossEntropyLoss
from .modules.container import Container, Sequential, ModuleList from .modules.container import Container, Sequential, ModuleList
......
...@@ -35,6 +35,26 @@ class ReLU(Module): ...@@ -35,6 +35,26 @@ class ReLU(Module):
return self.run(inputs, outputs) return self.run(inputs, outputs)
class LeakyReLU(Module):
def __init__(self, negative_slope=0.01, inplace=False):
super(LeakyReLU, self).__init__()
self._negative_slope = negative_slope
self._inplace = inplace
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Relu',
'n_inputs': 1, 'n_outputs': 1,
'arguments': {'slope': self._negative_slope}
}
def forward(self, x):
inputs = [x]; self.unify_devices(inputs)
outputs = [x if self._inplace else self.register_output(x.dtype)]
return self.run(inputs, outputs)
class Sigmoid(Module): class Sigmoid(Module):
def __init__(self, inplace=False): def __init__(self, inplace=False):
super(Sigmoid, self).__init__() super(Sigmoid, self).__init__()
......
...@@ -19,7 +19,9 @@ from .arithmetic import ( ...@@ -19,7 +19,9 @@ from .arithmetic import (
) )
from .ndarray import ( from .ndarray import (
sum, mean, argmin, argmax, max, topk, cat, gather squeeze, unsqueeze,
sum, mean, argmin, argmax, max, topk,
cat, gather,
) )
from .vision import ( from .vision import (
......
...@@ -13,6 +13,8 @@ from __future__ import absolute_import ...@@ -13,6 +13,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.vm.torch.constants import CTX_TO_DEVICE_OPTION
from dragon.core.tensor_utils import FromTensor
from dragon.vm.torch.tensor import Tensor, Size from dragon.vm.torch.tensor import Tensor, Size
from dragon.vm.torch.execute_engine import RunOperator from dragon.vm.torch.execute_engine import RunOperator
...@@ -20,9 +22,11 @@ from dragon.vm.torch.ops.factory import get_module ...@@ -20,9 +22,11 @@ from dragon.vm.torch.ops.factory import get_module
from dragon.vm.torch.autograd.grad_mode import no_grad from dragon.vm.torch.autograd.grad_mode import no_grad
from dragon.vm.torch.ops.primitive import MakeContext from dragon.vm.torch.ops.primitive import MakeContext
from dragon.vm.torch.ops.arithmetic import _fundamental, _rfundamental from dragon.vm.torch.ops.arithmetic import _fundamental, _rfundamental
from dragon.vm.torch.ops.control_flow import _copy from dragon.vm.torch.ops.ndarray import (
from dragon.vm.torch.ops.ndarray import \ reshape, squeeze, unsqueeze,
(reshape, _permute, _repeat, _fill, _reduce, _arg_reduce, _crop) _permute, _repeat, _crop,
_fill, _reduce, _arg_reduce,
)
from dragon.vm.torch.ops.modules.dtype import AsType from dragon.vm.torch.ops.modules.dtype import AsType
...@@ -33,13 +37,15 @@ from dragon.vm.torch.ops.modules.dtype import AsType ...@@ -33,13 +37,15 @@ from dragon.vm.torch.ops.modules.dtype import AsType
############################################## ##############################################
def copy_(self, src): def copy_(self, src, non_blocking=False):
"""Copy the elements from ``src`` into this tensor and return ``self``. """Copy the elements from ``src`` into this tensor and return ``self``.
Parameters Parameters
---------- ----------
src : vm.torch.Tensor src : vm.torch.Tensor
The source tensor. The source tensor.
non_blocking : boolean
Whether to copy asynchronously between CPU and GPU.
Returns Returns
------- -------
...@@ -47,7 +53,10 @@ def copy_(self, src): ...@@ -47,7 +53,10 @@ def copy_(self, src):
The ``self`` tensor. The ``self`` tensor.
""" """
return _copy(self, src) FromTensor(
src, CTX_TO_DEVICE_OPTION[tuple(src._ctx)],
self.name, CTX_TO_DEVICE_OPTION[tuple(self._ctx)])
return self
Tensor.copy_ = copy_ Tensor.copy_ = copy_
...@@ -308,6 +317,75 @@ Tensor.__rtruediv__ = rdiv ...@@ -308,6 +317,75 @@ Tensor.__rtruediv__ = rdiv
############################################## ##############################################
def _squeeze(self, dim=None):
"""Returns a tensor with all the dimensions of input of size 1 removed.
Parameters
----------
dim : int
The optional dim to remove.
Returns
-------
vm.torch.Tensor
The new tensor.
"""
return squeeze(self, dim=dim)
def _squeeze_(self, dim=None):
"""Inplace of ``Tensor.squeeze()``
Parameters
----------
dim : int
The optional dim to remove.
Returns
-------
vm.torch.Tensor
The self.
"""
return squeeze(self, dim=dim, out=self)
def _unsqueeze(self, dim):
"""Returns a tensor with a dimension of size 1 inserted at the specified position.
Parameters
----------
dim : int
The dim to insert.
Returns
-------
vm.torch.Tensor
The new tensor.
"""
return unsqueeze(self, dim=dim)
def _unsqueeze_(self, dim=None):
"""Inplace of ``Tensor.unsqueeze()``
Parameters
----------
dim : int
The optional dim to remove.
Returns
-------
vm.torch.Tensor
The self.
"""
return unsqueeze(self, dim=dim, out=self)
def view(self, *args): def view(self, *args):
if self._static_shape: if self._static_shape:
raise RuntimeError('Can not view a leaf variable, it owns the static sizes.') raise RuntimeError('Can not view a leaf variable, it owns the static sizes.')
...@@ -353,6 +431,10 @@ def min(self, dim=None, keepdim=False): ...@@ -353,6 +431,10 @@ def min(self, dim=None, keepdim=False):
return _arg_reduce(self, 'MIN', dim, keepdim) return _arg_reduce(self, 'MIN', dim, keepdim)
Tensor.squeeze = _squeeze
Tensor.squeeze_ = _squeeze_
Tensor.unsqueeze = _unsqueeze
Tensor.unsqueeze_ = _unsqueeze_
Tensor.view = view Tensor.view = view
Tensor.view_as = view_as Tensor.view_as = view_as
Tensor.permute = permute Tensor.permute = permute
...@@ -412,6 +494,8 @@ Tensor.double = lambda self: _type_to(self, dtype='float64', inplace=False) ...@@ -412,6 +494,8 @@ Tensor.double = lambda self: _type_to(self, dtype='float64', inplace=False)
Tensor.double_ = lambda self: _type_to(self, dtype='float64', inplace=True) Tensor.double_ = lambda self: _type_to(self, dtype='float64', inplace=True)
Tensor.byte = lambda self: _type_to(self, dtype='uint8', inplace=False) Tensor.byte = lambda self: _type_to(self, dtype='uint8', inplace=False)
Tensor.byte_ = lambda self: _type_to(self, dtype='uint8', inplace=True) Tensor.byte_ = lambda self: _type_to(self, dtype='uint8', inplace=True)
Tensor.char = lambda self: _type_to(self, dtype='int8', inplace=False)
Tensor.char_ = lambda self: _type_to(self, dtype='int8', inplace=True)
Tensor.int = lambda self: _type_to(self, dtype='int32', inplace=False) Tensor.int = lambda self: _type_to(self, dtype='int32', inplace=False)
Tensor.int_ = lambda self: _type_to(self, dtype='int32', inplace=True) Tensor.int_ = lambda self: _type_to(self, dtype='int32', inplace=True)
Tensor.long = lambda self: _type_to(self, dtype='int64', inplace=False) Tensor.long = lambda self: _type_to(self, dtype='int64', inplace=False)
......
...@@ -10,13 +10,4 @@ ...@@ -10,13 +10,4 @@
# ------------------------------------------------------------ # ------------------------------------------------------------
from dragon.vm.torch.ops.primitive import MakeContext from dragon.vm.torch.ops.primitive import MakeContext
from dragon.vm.torch.ops.factory import get_module from dragon.vm.torch.ops.factory import get_module
from dragon.vm.torch.ops.modules.control_flow import Copy \ No newline at end of file
def _copy(dst, src):
if id(dst) == id(src): return dst
ctx = MakeContext(inputs=[dst])
key = 'torch/ops/copy/{}:{}'.format(ctx[0].lower(), ctx[1])
module = get_module(Copy, key, ctx)
return module.forward(dst, src)
\ No newline at end of file
...@@ -14,6 +14,7 @@ from __future__ import division ...@@ -14,6 +14,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.vm.torch.ops.modules.base import BaseModule from dragon.vm.torch.ops.modules.base import BaseModule
from dragon.vm.torch.tensor import ReferneceTensor
class Fill(BaseModule): class Fill(BaseModule):
...@@ -69,13 +70,61 @@ class Reshape(BaseModule): ...@@ -69,13 +70,61 @@ class Reshape(BaseModule):
def forward(self, x, shape): def forward(self, x, shape):
inputs = [x]; self.unify_devices(inputs) inputs = [x]; self.unify_devices(inputs)
outputs = [self.register_output(x.dtype)] outputs = [ReferneceTensor(x)]
if shape is not None: if shape is not None:
for ix, d in enumerate(shape): for ix, d in enumerate(shape):
self.set_argument_i(self.shape[ix], d) self.set_argument_i(self.shape[ix], d)
return self.run(inputs, outputs) return self.run(inputs, outputs)
class Squeeze(BaseModule):
def __init__(self, key, ctx, **kwargs):
super(Squeeze, self).__init__(key, ctx, **kwargs)
self.dim = kwargs.get('dim', None)
self.register_arguments()
self.register_op()
def register_arguments(self):
"""No Arguments for squeeze op."""
pass
def register_op(self):
self.op_meta = {
'op_type': 'Squeeze',
'n_inputs': 1, 'n_outputs': 1,
'arguments': {'axis': self.dim}
}
def forward(self, x, out=None):
inputs = [x]; self.unify_devices(inputs)
outputs = [out] if out else [ReferneceTensor(x)]
return self.run(inputs, outputs)
class UnSqueeze(BaseModule):
def __init__(self, key, ctx, **kwargs):
super(UnSqueeze, self).__init__(key, ctx, **kwargs)
self.dim = kwargs.get('dim', None)
self.register_arguments()
self.register_op()
def register_arguments(self):
"""No Arguments for squeeze op."""
pass
def register_op(self):
self.op_meta = {
'op_type': 'ExpandDims',
'n_inputs': 1, 'n_outputs': 1,
'arguments': {'axis': self.dim}
}
def forward(self, x, out=None):
inputs = [x]; self.unify_devices(inputs)
outputs = [out] if out else [ReferneceTensor(x)]
return self.run(inputs, outputs)
class Permute(BaseModule): class Permute(BaseModule):
def __init__(self, key, ctx, **kwargs): def __init__(self, key, ctx, **kwargs):
super(Permute, self).__init__(key, ctx, **kwargs) super(Permute, self).__init__(key, ctx, **kwargs)
......
...@@ -15,7 +15,8 @@ from __future__ import print_function ...@@ -15,7 +15,8 @@ from __future__ import print_function
from dragon.vm.torch.ops.primitive import MakeContext, CanonicalAxis from dragon.vm.torch.ops.primitive import MakeContext, CanonicalAxis
from dragon.vm.torch.ops.factory import get_module from dragon.vm.torch.ops.factory import get_module
from dragon.vm.torch.ops.modules.shape import Reshape, Fill, Permute, Repeat from dragon.vm.torch.ops.modules.shape import \
Reshape, Squeeze, UnSqueeze, Fill, Permute, Repeat
from dragon.vm.torch.ops.modules.reduce import Reduce, ArgReduce from dragon.vm.torch.ops.modules.reduce import Reduce, ArgReduce
from dragon.vm.torch.ops.modules.crop import Crop from dragon.vm.torch.ops.modules.crop import Crop
from dragon.vm.torch.ops.modules.axis import Concat, Gather from dragon.vm.torch.ops.modules.axis import Concat, Gather
...@@ -29,6 +30,22 @@ def reshape(input, shape, shape_like=None): ...@@ -29,6 +30,22 @@ def reshape(input, shape, shape_like=None):
return module.forward(input, shape) return module.forward(input, shape)
def squeeze(input, dim=None, out=None):
ctx = MakeContext(inputs=[input])
key = 'torch/ops/squeeze/{}:{}/dim:{}'.format(
ctx[0].lower(), ctx[1], dim if dim else 'None')
module = get_module(Squeeze, key, ctx, dim=dim)
return module.forward(input, out=out)
def unsqueeze(input, dim, out=None):
ctx = MakeContext(inputs=[input])
key = 'torch/ops/unsqueeze/{}:{}/dim:{}'.format(
ctx[0].lower(), ctx[1], dim if dim else 'None')
module = get_module(UnSqueeze, key, ctx, dim=dim)
return module.forward(input, out=out)
def _permute(input, perms=None): def _permute(input, perms=None):
ctx = MakeContext(inputs=[input]); len_perms = len(perms) if perms else 0 ctx = MakeContext(inputs=[input]); len_perms = len(perms) if perms else 0
key = 'torch/ops/permute/{}:{}/n_dims:#{}'.format(ctx[0].lower(), ctx[1], len_perms) key = 'torch/ops/permute/{}:{}/n_dims:#{}'.format(ctx[0].lower(), ctx[1], len_perms)
......
...@@ -51,7 +51,8 @@ def _with_file_like(f, mode, body): ...@@ -51,7 +51,8 @@ def _with_file_like(f, mode, body):
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)): (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
new_fd = True new_fd = True
dir = os.path.dirname(f) dir = os.path.dirname(f)
if not os.path.exists(dir): os.makedirs(dir) # Bug fix: empty directory, i.e., under the work directory
if dir != '' and not os.path.exists(dir): os.makedirs(dir)
f = open(f, mode) f = open(f, mode)
try: try:
return body(f) return body(f)
......
...@@ -13,6 +13,8 @@ from __future__ import absolute_import ...@@ -13,6 +13,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys
import copy
import numpy as np import numpy as np
import dragon as dg import dragon as dg
import dragon.core.tensor_utils as tensor_utils import dragon.core.tensor_utils as tensor_utils
...@@ -23,9 +25,11 @@ from dragon.vm.torch.constants import CTX_TO_DEVICE_OPTION ...@@ -23,9 +25,11 @@ from dragon.vm.torch.constants import CTX_TO_DEVICE_OPTION
from .c_apis import * from .c_apis import *
__all__ = ['Tensor', 'Parameter', __all__ = [
'Tensor', 'Parameter',
'FloatTensor', 'DoubleTensor', 'FloatTensor', 'DoubleTensor',
'IntTensor', 'LongTensor', 'ByteTensor', 'IntTensor', 'LongTensor',
'ByteTensor', 'CharTensor',
] ]
...@@ -48,6 +52,9 @@ class Tensor(object): ...@@ -48,6 +52,9 @@ class Tensor(object):
self._requires_grad = kwargs.get('requires_grad', False) self._requires_grad = kwargs.get('requires_grad', False)
self._dg_tensor = kwargs.get('dg_tensor', None) self._dg_tensor = kwargs.get('dg_tensor', None)
self._own_storage = kwargs.get('own_storage', True) self._own_storage = kwargs.get('own_storage', True)
# Hold it to lock shared objects(i.e., tensor with same storage)
self._ref_objects = []
# Owned by the leaf variables(i.e. Can not be Reshaped) # Owned by the leaf variables(i.e. Can not be Reshaped)
self._static_shape = None self._static_shape = None
# Owned by the grad required variables # Owned by the grad required variables
...@@ -541,6 +548,71 @@ class Tensor(object): ...@@ -541,6 +548,71 @@ class Tensor(object):
# # # #
############################################## ##############################################
def squeeze(self, dim=None):
"""Returns a tensor with all the dimensions of input of size 1 removed.
Parameters
----------
dim : int
The optional dim to remove.
Returns
-------
vm.torch.Tensor
The new tensor.
"""
raise NotImplementedError('Refer torch.ops.builtin._squeeze')
def squeeze_(self, dim=None):
"""Inplace of ``Tensor.squeeze()``
Parameters
----------
dim : int
The optional dim to remove.
Returns
-------
vm.torch.Tensor
The self.
"""
raise NotImplementedError('Refer torch.ops.builtin._squeeze_')
def unsqueeze(self, dim):
"""Returns a tensor with a dimension of size 1 inserted at the specified position.
Parameters
----------
dim : int
The dim to insert.
Returns
-------
vm.torch.Tensor
The new tensor.
"""
raise NotImplementedError('Refer torch.ops.builtin._unsqueeze')
def unsqueeze_(self, dim):
"""Inplace of ``Tensor.unsqueeze()``
Parameters
----------
dim : int
The dim to insert.
Returns
-------
vm.torch.Tensor
The self.
"""
raise NotImplementedError('Refer torch.ops.builtin._unsqueeze_')
def view(self, *args): def view(self, *args):
"""Return a new tensor with the same data but a different size. """Return a new tensor with the same data but a different size.
...@@ -605,13 +677,15 @@ class Tensor(object): ...@@ -605,13 +677,15 @@ class Tensor(object):
""" """
raise NotImplementedError('Refer torch.ops.builtin.repeat') raise NotImplementedError('Refer torch.ops.builtin.repeat')
def copy_(self, src): def copy_(self, src, non_blocking=False):
"""Copy the elements from ``src`` into this tensor and return ``self``. """Copy the elements from ``src`` into this tensor and return ``self``.
Parameters Parameters
---------- ----------
src : vm.torch.Tensor src : vm.torch.Tensor
The source tensor. The source tensor.
non_blocking : boolean
Whether to copy asynchronously between CPU and GPU.
Returns Returns
------- -------
...@@ -1034,6 +1108,28 @@ class Tensor(object): ...@@ -1034,6 +1108,28 @@ class Tensor(object):
""" """
raise NotImplementedError('Refer torch.ops.builtin.byte_') raise NotImplementedError('Refer torch.ops.builtin.byte_')
def char(self):
"""Return a ``int8`` tensor with elements of ``self``.
Returns
-------
vm.torch.Tensor
The byte tensor.
"""
raise NotImplementedError('Refer torch.ops.builtin.char')
def char_(self):
"""Inplace of ``Tensor.char()``.
Returns
-------
vm.torch.Tensor
The byte tensor.
"""
raise NotImplementedError('Refer torch.ops.builtin.char_')
############################################## ##############################################
# # # #
# AUTO-GRAD # # AUTO-GRAD #
...@@ -1126,6 +1222,11 @@ def ByteTensor(*args, **kwargs): ...@@ -1126,6 +1222,11 @@ def ByteTensor(*args, **kwargs):
return Tensor(*args, **kwargs) return Tensor(*args, **kwargs)
def CharTensor(*args, **kwargs):
kwargs['dtype'] = 'int8'
return Tensor(*args, **kwargs)
_DTYPE_TO_TENSOR = { _DTYPE_TO_TENSOR = {
'float16': HalfTensor, 'float16': HalfTensor,
'float32': FloatTensor, 'float32': FloatTensor,
...@@ -1133,6 +1234,7 @@ _DTYPE_TO_TENSOR = { ...@@ -1133,6 +1234,7 @@ _DTYPE_TO_TENSOR = {
'int32': IntTensor, 'int32': IntTensor,
'int64': LongTensor, 'int64': LongTensor,
'uint8': ByteTensor, 'uint8': ByteTensor,
'int8': CharTensor,
} }
...@@ -1158,6 +1260,23 @@ def RuntimeTensor(name, dtype='float32', ctx=None): ...@@ -1158,6 +1260,23 @@ def RuntimeTensor(name, dtype='float32', ctx=None):
return constructor(dg_tensor=name, ctx=ctx) return constructor(dg_tensor=name, ctx=ctx)
def ReferneceTensor(src):
"""Create a reference from source tensor.
Commonly used to hold the same storage but takes different sizes,
i.e., view, squeeze, and unsqueeze.
"""
constructor = _DTYPE_TO_TENSOR[src._dtype]
ref = constructor(dg_tensor=src.name, ctx=src._ctx)
name = '{}/id:{}'.format(
src.name.replace('[TPool]', '[Ref]'), id(ref))
dg.workspace.CreateTensor(name)
ref._dg_tensor, ref._own_storage = name, False
ref._ref_objects.append(src)
return ref
############################################## ##############################################
# # # #
# Tensor-Extension # # Tensor-Extension #
......
...@@ -23,7 +23,7 @@ def from_numpy(data): ...@@ -23,7 +23,7 @@ def from_numpy(data):
Parameters Parameters
---------- ----------
data : numpy.ndarray data : ndarray
The nd-array with various data type. The nd-array with various data type.
Return Return
...@@ -113,4 +113,5 @@ __NUMPY_TYPE_TO_TORCH = { ...@@ -113,4 +113,5 @@ __NUMPY_TYPE_TO_TORCH = {
'int32': 'IntTensor', 'int32': 'IntTensor',
'int64': 'LongTensor', 'int64': 'LongTensor',
'uint8': 'ByteTensor', 'uint8': 'ByteTensor',
'int8': 'CharTensor',
} }
\ No newline at end of file
...@@ -97,7 +97,7 @@ class DataReader(Process): ...@@ -97,7 +97,7 @@ class DataReader(Process):
self._db.close() self._db.close()
self._db.open(self._source) self._db.open(self._source)
self._cur_idx = target_idx self._cur_idx = target_idx
self._db.set(str(self._cur_idx).zfill(self._db_zfill)) self._db.set(str(self._cur_idx).zfill(self._zfill))
def reset(self): def reset(self):
"""Reset the cursor and environment. """Reset the cursor and environment.
...@@ -112,12 +112,12 @@ class DataReader(Process): ...@@ -112,12 +112,12 @@ class DataReader(Process):
self._cur_chunk_idx = 0 self._cur_chunk_idx = 0
self._start_idx = int(self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx]) self._start_idx = int(self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx])
self._start_idx = int(self._start_idx * self._chunk_size) self._start_idx = int(self._start_idx * self._chunk_size)
if self._start_idx >= self._db_size: self.next_chunk() if self._start_idx >= self._num_entries: self.next_chunk()
self._end_idx = self._start_idx + self._chunk_size self._end_idx = self._start_idx + self._chunk_size
self._end_idx = min(self._db_size, self._end_idx) self._end_idx = min(self._num_entries, self._end_idx)
else: else:
self._start_idx = 0 self._start_idx = 0
self._end_idx = self._db_size self._end_idx = self._num_entries
self.redirect(self._start_idx) self.redirect(self._start_idx)
...@@ -145,10 +145,10 @@ class DataReader(Process): ...@@ -145,10 +145,10 @@ class DataReader(Process):
else: else:
self._start_idx = self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx] self._start_idx = self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx]
self._start_idx = self._start_idx * self._chunk_size self._start_idx = self._start_idx * self._chunk_size
if self._start_idx >= self._db_size: self.next_chunk() if self._start_idx >= self._num_entries: self.next_chunk()
else: else:
self._end_idx = self._start_idx + self._chunk_size self._end_idx = self._start_idx + self._chunk_size
self._end_idx = min(self._db_size, self._end_idx) self._end_idx = min(self._num_entries, self._end_idx)
self.redirect(self._start_idx) self.redirect(self._start_idx)
def run(self): def run(self):
...@@ -165,14 +165,14 @@ class DataReader(Process): ...@@ -165,14 +165,14 @@ class DataReader(Process):
# init db # init db
self._db = LMDB() self._db = LMDB()
self._db.open(self._source) self._db.open(self._source)
self._db_size = int(self._db.get('size')) self._zfill = self._db.zfill()
self._db_zfill = int(self._db.get('zfill')) self._num_entries = self._db.num_entries()
self._epoch_size = int(self._db_size / self._num_parts + 1) self._epoch_size = int(self._num_entries / self._num_parts + 1)
if self._use_shuffle: if self._use_shuffle:
if self._chunk_size == 1: if self._chunk_size == 1:
# each chunk has at most 1 record [For Fully Shuffle] # each chunk has at most 1 record [For Fully Shuffle]
self._num_shuffle_parts = int(self._db_size / self._chunk_size / self._num_parts) + 1 self._num_shuffle_parts = int(self._num_entries / self._chunk_size / self._num_parts) + 1
else: else:
if self._use_shuffle and self._chunk_size == -1: if self._use_shuffle and self._chunk_size == -1:
# search a optimal chunk size by chunks [For Chunk Shuffle] # search a optimal chunk size by chunks [For Chunk Shuffle]
...@@ -182,12 +182,12 @@ class DataReader(Process): ...@@ -182,12 +182,12 @@ class DataReader(Process):
self._chunk_size = min_chunk_size self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 / self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20))) (self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._db_size / self._num_shuffle_parts / self._num_parts + 1) self._chunk_size = int(self._num_entries / self._num_shuffle_parts / self._num_parts + 1)
else: else:
# each chunk has at most K records [For Multiple Nodes] # each chunk has at most K records [For Multiple Nodes]
# note that if ``shuffle`` and ``multiple_nodes`` are all ``False``, # note that if ``shuffle`` and ``multiple_nodes`` are all ``False``,
# ``chunk_size`` and ``num_shuffle_parts`` are meaningless # ``chunk_size`` and ``num_shuffle_parts`` are meaningless
self._chunk_size = int(self._db_size / self._num_parts) + 1 self._chunk_size = int(self._num_entries / self._num_parts) + 1
self._num_shuffle_parts = 1 self._num_shuffle_parts = 1
self._perm = np.arange(self._num_shuffle_parts) self._perm = np.arange(self._num_shuffle_parts)
......
...@@ -42,7 +42,7 @@ find_modules() ...@@ -42,7 +42,7 @@ find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.2.9', version='0.2.2.10',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/seetaresearch/Dragon', url='https://github.com/seetaresearch/Dragon',
author='Ting Pan', author='Ting Pan',
......
...@@ -114,7 +114,7 @@ inline void GenerateGridAnchors( ...@@ -114,7 +114,7 @@ inline void GenerateGridAnchors(
/******************** Proposal ********************/ /******************** Proposal ********************/
template <typename T, class Context> template <typename T, class Context>
inline void GenerateProposals( void GenerateProposals(
const int A, const int A,
const int feat_h, const int feat_h,
const int feat_w, const int feat_w,
...@@ -129,7 +129,7 @@ inline void GenerateProposals( ...@@ -129,7 +129,7 @@ inline void GenerateProposals(
T* proposals); T* proposals);
template <typename T, class Context> template <typename T, class Context>
inline void GenerateProposals_v2( void GenerateProposals_v2(
const int total_anchors, const int total_anchors,
const float im_h, const float im_h,
const float im_w, const float im_w,
......
...@@ -34,7 +34,7 @@ void ProposalOp<Context>::RunWithType() { ...@@ -34,7 +34,7 @@ void ProposalOp<Context>::RunWithType() {
rcnn::GenerateProposals<T, Context>( rcnn::GenerateProposals<T, Context>(
A, feat_height, feat_width, strides[0], A, feat_height, feat_width, strides[0],
im_height, im_width, min_box_h, min_box_w, im_height, im_width, min_box_h, min_box_w,
Input(0).template data<T, Context>() + num_proposals, Input(0).template data<T, Context>(),
Input(1).template data<T, Context>(), Input(1).template data<T, Context>(),
anchors_.template mutable_data<T, Context>(), anchors_.template mutable_data<T, Context>(),
proposals_.template mutable_data<T, Context>()); proposals_.template mutable_data<T, Context>());
...@@ -59,9 +59,9 @@ void ProposalOp<Context>::RunWithType() { ...@@ -59,9 +59,9 @@ void ProposalOp<Context>::RunWithType() {
CHECK_EQ(strides.size(), scales.size()) CHECK_EQ(strides.size(), scales.size())
<< "\nGiven " << strides.size() << " strides and " << "\nGiven " << strides.size() << " strides and "
<< scales.size() << " scales"; << scales.size() << " scales";
// cls_probs: [1, 2, total_proposals] // cls_probs: [1, total_proposals]
// bbox_deltas: [1, 4, total_proposals] // bbox_deltas: [1, 4, total_proposals]
TIndex total_proposals = Input(-3).dim(2), acc_proposals = 0; TIndex total_proposals = Input(-3).dim(1), acc_proposals = 0;
const TIndex pre_nms_topn = std::min(total_proposals, pre_nms_top_n);; const TIndex pre_nms_topn = std::min(total_proposals, pre_nms_top_n);;
proposals_.Reshape({ total_proposals, 5 }); proposals_.Reshape({ total_proposals, 5 });
auto* proposals = proposals_.template mutable_data<T, CPUContext>(); auto* proposals = proposals_.template mutable_data<T, CPUContext>();
...@@ -93,7 +93,7 @@ void ProposalOp<Context>::RunWithType() { ...@@ -93,7 +93,7 @@ void ProposalOp<Context>::RunWithType() {
rcnn::GenerateProposals_v2<T, Context>(total_proposals, rcnn::GenerateProposals_v2<T, Context>(total_proposals,
im_height, im_width, min_box_h, min_box_w, im_height, im_width, min_box_h, min_box_w,
Input(-3).template data<T, Context>() + total_proposals, Input(-3).template data<T, Context>(),
Input(-2).template data<T, Context>(), Input(-2).template data<T, Context>(),
proposals_.template mutable_data<T, Context>()); proposals_.template mutable_data<T, Context>());
...@@ -113,7 +113,7 @@ void ProposalOp<Context>::RunWithType() { ...@@ -113,7 +113,7 @@ void ProposalOp<Context>::RunWithType() {
} }
total_rois += num_rois; total_rois += num_rois;
Ydata += (num_rois * 5); Ydata += (num_rois * 5);
im_info += 3; im_info += Input(-1).dim(1);
} }
Output(0)->Reshape(vector<TIndex>({ total_rois, 5 })); Output(0)->Reshape(vector<TIndex>({ total_rois, 5 }));
...@@ -148,9 +148,9 @@ void ProposalOp<Context>::RunWithType() { ...@@ -148,9 +148,9 @@ void ProposalOp<Context>::RunWithType() {
template <class Context> template <class Context>
void ProposalOp<Context>::RunOnDevice() { void ProposalOp<Context>::RunOnDevice() {
num_images = Input(0).dim(0); num_images = Input(0).dim(0);
CHECK_EQ(Input(-1).count(), num_images * 3) CHECK_EQ(Input(-1).dim(0), num_images)
<< "\nExcepted " << num_images * 3 << " groups image info, " << "\nExcepted " << num_images << " groups image info, "
<< "but got " << Input(-1).count() / 3 << "."; << "but got " << Input(-1).dim(0) << ".";
roi_indices_.Reshape({ post_nms_top_n }); roi_indices_.Reshape({ post_nms_top_n });
Output(0)->Reshape({ num_images * post_nms_top_n, 5 }); Output(0)->Reshape({ num_images * post_nms_top_n, 5 });
......
...@@ -231,17 +231,24 @@ GraphDef Graph::Share(const GraphDef& optimized_graph) { ...@@ -231,17 +231,24 @@ GraphDef Graph::Share(const GraphDef& optimized_graph) {
GraphDef g; g.CopyFrom(optimized_graph); GraphDef g; g.CopyFrom(optimized_graph);
// actually we need a white list
Set<string> whitelist;
for (auto& target : optimized_graph.target())
whitelist.insert(target);
// rename to create in-place // rename to create in-place
for (int i = 0; i < optimized_graph.op_size(); i++) { for (int i = 0; i < optimized_graph.op_size(); i++) {
const OperatorDef& op = optimized_graph.op(i); const OperatorDef& op = optimized_graph.op(i);
for (int j = 0; j < op.input_size(); j++) { for (int j = 0; j < op.input_size(); j++) {
if (renamed_.count(op.input(j)) && if (whitelist.count(op.input(j)) == 0 &&
renamed_.count(op.input(j)) &&
ws()->SetProxy(op.input(j), renamed_[op.input(j)])) ws()->SetProxy(op.input(j), renamed_[op.input(j)]))
*g.mutable_op(i)->mutable_input(j) *g.mutable_op(i)->mutable_input(j)
= renamed_[op.input(j)]; = renamed_[op.input(j)];
} }
for (int j = 0; j < op.output_size(); j++) { for (int j = 0; j < op.output_size(); j++) {
if (renamed_.count(op.output(j)) && if (whitelist.count(op.output(j)) == 0 &&
renamed_.count(op.output(j)) &&
ws()->SetProxy(op.output(j), renamed_[op.output(j)])) ws()->SetProxy(op.output(j), renamed_[op.output(j)]))
*g.mutable_op(i)->mutable_output(j) *g.mutable_op(i)->mutable_output(j)
= renamed_[op.output(j)]; = renamed_[op.output(j)];
......
...@@ -17,7 +17,7 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() { ...@@ -17,7 +17,7 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() {
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
Output(0)->template Copy<Context, Context>(losses); Output(0)->template CopyFrom<Context>(losses);
return; return;
} }
......
...@@ -19,7 +19,7 @@ void SigmoidFocalLossOp<Context>::RunWithType() { ...@@ -19,7 +19,7 @@ void SigmoidFocalLossOp<Context>::RunWithType() {
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
Output(0)->template Copy<Context, Context>(losses); Output(0)->template CopyFrom<Context>(losses);
return; return;
} }
......
...@@ -24,7 +24,7 @@ void SoftmaxFocalLossOp<Context>::RunWithType() { ...@@ -24,7 +24,7 @@ void SoftmaxFocalLossOp<Context>::RunWithType() {
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
Output(0)->template Copy<Context, Context>(losses); Output(0)->template CopyFrom<Context>(losses);
return; return;
} }
......
...@@ -59,7 +59,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() { ...@@ -59,7 +59,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() {
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
Output(0)->template Copy<Context, Context>(losses); Output(0)->template CopyFrom<Context>(losses);
return; return;
} }
...@@ -167,7 +167,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() { ...@@ -167,7 +167,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
auto* dXdataF32 = Output(0)->template data<float, Context>(); auto* dXdataF32 = Output(0)->template data<float, Context>();
auto* dXdataF16 = prob->template mutable_data<float16, Context>(); auto* dXdataF16 = prob->template mutable_data<float16, Context>();
kernel::TypeA2B<float, float16, Context>(Output(0)->count(), dXdataF32, dXdataF16); kernel::TypeA2B<float, float16, Context>(Output(0)->count(), dXdataF32, dXdataF16);
Output(0)->template Copy<Context, Context>(*prob); Output(0)->template CopyFrom<Context>(*prob);
} }
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); } else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
......
...@@ -68,7 +68,7 @@ template <class Context> ...@@ -68,7 +68,7 @@ template <class Context>
void StopGradientOp<Context>::RunOnDevice() { void StopGradientOp<Context>::RunOnDevice() {
if (Output(0)->name() != Input(0).name()) { if (Output(0)->name() != Input(0).name()) {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
Output(0)->template Copy<Context, Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0));
} }
} }
......
...@@ -14,7 +14,7 @@ void MPIBroadcastOp<Context>::RunWithType() { ...@@ -14,7 +14,7 @@ void MPIBroadcastOp<Context>::RunWithType() {
auto* Xdata = Input(0).template mutable_data<T, CPUContext>(); auto* Xdata = Input(0).template mutable_data<T, CPUContext>();
#endif #endif
MPI_Bcast(Xdata, Input(0).count(), mpi_dtype(), comm_root, comm); MPI_Bcast(Xdata, Input(0).count(), mpi_dtype(), comm_root, comm);
Output(0)->template Copy<Context, Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0));
} else { } else {
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
......
...@@ -8,7 +8,7 @@ namespace dragon { ...@@ -8,7 +8,7 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void MPIGatherOp<Context>::RunWithType() { void MPIGatherOp<Context>::RunWithType() {
if (comm_rank == comm_root) { if (comm_rank == comm_root) {
Output(comm_rank)->template Copy<Context, Context>(Input(0)); Output(comm_rank)->template CopyFrom<Context>(Input(0));
for (int i = 0; i < comm_size; i++) { for (int i = 0; i < comm_size; i++) {
if (i == comm_root) continue; if (i == comm_root) continue;
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
...@@ -76,7 +76,7 @@ OPERATOR_SCHEMA(MPIGather).NumInputs(1).NumOutputs(1, INT_MAX); ...@@ -76,7 +76,7 @@ OPERATOR_SCHEMA(MPIGather).NumInputs(1).NumOutputs(1, INT_MAX);
template <class Context> template <typename T> template <class Context> template <typename T>
void MPIGatherGradientOp<Context>::RunWithType() { void MPIGatherGradientOp<Context>::RunWithType() {
if (comm_rank == comm_root) { if (comm_rank == comm_root) {
Output(0)->template Copy<Context, Context>(Input(this->comm_rank + 1)); Output(0)->template CopyFrom<Context>(Input(this->comm_rank + 1));
for (int i = 0; i < comm_size; i++) { for (int i = 0; i < comm_size; i++) {
if (i == comm_root) continue; if (i == comm_root) continue;
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
......
...@@ -125,7 +125,7 @@ void CropOp<Context>::RunOnDevice() { ...@@ -125,7 +125,7 @@ void CropOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
Output(0)->template Copy<Context, Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0));
// squeeze dimensions // squeeze dimensions
vector<TIndex> squeeze_shape; vector<TIndex> squeeze_shape;
for (int i = 0; i < keep_dims.size(); i++) for (int i = 0; i < keep_dims.size(); i++)
...@@ -229,7 +229,7 @@ void CropGradientOp<Context>::RunOnDevice() { ...@@ -229,7 +229,7 @@ void CropGradientOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(-1)); Output(0)->ReshapeLike(Input(-1));
Output(0)->template Copy<Context, Context>(Input(-1)); Output(0)->template CopyFrom<Context>(Input(-1));
return; return;
} }
......
#include "core/workspace.h" #include "core/workspace.h"
#include "operators/ndarray/expand_dims_op.h" #include "operators/ndarray/dimension_op.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void ExpandDimsOp<Context>::RunOnDevice() { void ExpandDimsOp<Context>::RunOnDevice() {
TIndex _axis_ = axis >= 0 ? axis :
axis + (TIndex)Input(0).ndim() + 1;
vector<TIndex> dims = Input(0).dims(); vector<TIndex> dims = Input(0).dims();
if (axis == -1 || axis >= (int)dims.size()) dims.push_back(1); if (_axis_ < 0 ||
else dims.insert(dims.begin() + axis, 1); _axis_ >= (TIndex)dims.size())
// save Xshape dims.push_back(1);
Tensor* sv = ws()->CreateTensor( else dims.insert(dims.begin() + _axis_, 1);
"/mnt/" + anchor() + "/expand_dims/x_shape");
sv->Reshape({ (TIndex)Input(0).ndim() });
auto* Sdata = sv->template mutable_data<TIndex, CPUContext>();
for (int i = 0; i < Input(0).ndim(); i++) Sdata[i] = Input(0).dim(i);
Output(0)->Reshape(dims); Output(0)->Reshape(dims);
if (Output(0)->name() != Input(0).name()) Output(0)->SetMeta(Input(0).meta());
Output(0)->template Copy<Context, Context>(Input(0)); Output(0)->Share(Input(0).memory());
} }
DEPLOY_CPU(ExpandDims); DEPLOY_CPU(ExpandDims);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ExpandDims); DEPLOY_CUDA(ExpandDims);
#endif #endif
OPERATOR_SCHEMA(ExpandDims) OPERATOR_SCHEMA(ExpandDims).NumInputs(1).NumOutputs(1);
.NumInputs(1).NumOutputs(1)
.Inplace({ { 0, 0 } });
template <class Context>
void ExpandDimsGradientOp<Context>::RunOnDevice() {
Tensor* sv = ws()->GetTensor(
"/mnt/" + anchor() + "/expand_dims/x_shape");
auto* Sdata = sv->template mutable_data<TIndex, CPUContext>();
vector<TIndex> x_shape(sv->count());
for (int i = 0; i < sv->count(); i++) x_shape[i] = Sdata[i];
Output(0)->Reshape(x_shape);
if (Output(0)->name() != Input(-1).name())
Output(0)->template Copy<Context, Context>(Input(-1));
}
DEPLOY_CPU(ExpandDimsGradient); DEPLOY_CPU(ExpandDimsGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ExpandDimsGradient); DEPLOY_CUDA(ExpandDimsGradient);
#endif #endif
OPERATOR_SCHEMA(ExpandDimsGradient) OPERATOR_SCHEMA(ExpandDimsGradient)
.NumInputs(1).NumOutputs(1) .NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
.Inplace({ { 0, 0 } });
class GetExpandDimsGradient final : public GradientMakerBase { class GetExpandDimsGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetExpandDimsGradient); GRADIENT_MAKER_CTOR(GetExpandDimsGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
vector<string> {GO(0)}, vector<string> {I(0), GO(0)},
vector<string> {GI(0)}); vector<string> {GI(0)});
} }
}; };
......
#include "core/workspace.h" #include "core/workspace.h"
#include "operators/ndarray/flatten_op.h" #include "operators/ndarray/dimension_op.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void FlattenOp<Context>::SqueezeRun() { void FlattenOp<Context>::RunOnDevice() {
vector<TIndex> output_dims; vector<TIndex> output_dims;
for (int i = 0; i < axis; i++) if (keep_axes != INT_MAX) {
output_dims.push_back(Input(0).dim(i)); CHECK_LE(keep_axes, (int)Input(0).ndim())
if (num_axes < 1) { << "\nThe total number of axes is " + Input(0).ndim()
output_dims.push_back(Input(0).count(axis)); << ", can not keep " + keep_axes << " .";
int i = 0;
for (; i < keep_axes - 1; i++)
output_dims.push_back(Input(0).dim(i));
if (Input(0).count(i) != 1)
output_dims.push_back(Input(0).count(i));
} else { } else {
TIndex count = Input(0).count(axis, axis + num_axes); for (int i = 0; i < axis; i++)
output_dims.push_back(count);
for (int i = axis + num_axes; i < Input(0).ndim(); i++)
output_dims.push_back(Input(0).dim(i)); output_dims.push_back(Input(0).dim(i));
if (num_axes < 1) {
output_dims.push_back(Input(0).count(axis));
} else {
TIndex count = Input(0).count(axis, axis + num_axes);
output_dims.push_back(count);
for (int i = axis + num_axes; i < Input(0).ndim(); i++)
output_dims.push_back(Input(0).dim(i));
}
} }
Output(0)->Reshape(output_dims); Output(0)->Reshape(output_dims);
if (Output(0)->name() != Input(0).name()) Output(0)->SetMeta(Input(0).meta());
Output(0)->template Copy<Context, Context>(Input(0)); Output(0)->Share(Input(0).memory());
}
template <class Context>
void FlattenOp<Context>::KeepRun() {
CHECK_LE(keep_axes, (int)Input(0).ndim())
<< "\nThe total number of axes is " + Input(0).ndim()
<< ", can not keep " + keep_axes << " .";
vector<TIndex> output_dims;
int i = 0;
for (; i < keep_axes - 1; i++)
output_dims.push_back(Input(0).dim(i));
if (Input(0).count(i) != 1)
output_dims.push_back(Input(0).count(i));
if (Output(0)->name() != Input(0).name())
Output(0)->template Copy<Context, Context>(Input(0));
}
template <class Context>
void FlattenOp<Context>::RunOnDevice() {
// save Xshape
Tensor* sv = ws()->CreateTensor(
"/mnt/" + anchor() + "/flatten/x_shape");
sv->Reshape({ (TIndex)Input(0).ndim() });
auto* Sdata = sv->template mutable_data<TIndex, CPUContext>();
for (int i = 0; i < Input(0).ndim(); i++)
Sdata[i] = Input(0).dim(i);
if (keep_axes != INT_MAX) KeepRun();
else SqueezeRun();
} }
DEPLOY_CPU(Flatten); DEPLOY_CPU(Flatten);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(Flatten); DEPLOY_CUDA(Flatten);
#endif #endif
OPERATOR_SCHEMA(Flatten) OPERATOR_SCHEMA(Flatten).NumInputs(1).NumOutputs(1);
.NumInputs(1).NumOutputs(1)
.Inplace({ { 0, 0 } });
template <class Context>
void FlattenGradientOp<Context>::RunOnDevice() {
Tensor* sv = ws()->GetTensor(
"/mnt/" + anchor() + "/flatten/x_shape");
auto* Sdata = sv->template mutable_data<TIndex, CPUContext>();
vector<TIndex> x_shape(sv->count());
for (int i = 0; i < sv->count(); i++) x_shape[i] = Sdata[i];
Output(0)->Reshape(x_shape);
if (Output(0)->name() != Input(-1).name())
Output(0)->template Copy<Context, Context>(Input(-1));
}
DEPLOY_CPU(FlattenGradient); DEPLOY_CPU(FlattenGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(FlattenGradient); DEPLOY_CUDA(FlattenGradient);
#endif #endif
OPERATOR_SCHEMA(FlattenGradient) OPERATOR_SCHEMA(FlattenGradient)
.NumInputs(1).NumOutputs(1) .NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
.Inplace({ { 0, 0 } });
class GetFlattenGradient final : public GradientMakerBase { class GetFlattenGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetFlattenGradient); GRADIENT_MAKER_CTOR(GetFlattenGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
vector<string> {GO(0)}, vector<string> {I(0), GO(0)},
vector<string> {GI(0)}); vector<string> {GI(0)});
} }
}; };
......
...@@ -61,7 +61,7 @@ void PadOp<Context>::RunOnDevice() { ...@@ -61,7 +61,7 @@ void PadOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
Output(0)->template Copy<Context, Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0));
return; return;
} }
...@@ -175,7 +175,7 @@ void PadGradientOp<Context>::RunOnDevice() { ...@@ -175,7 +175,7 @@ void PadGradientOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(-1)); Output(0)->ReshapeLike(Input(-1));
Output(0)->template Copy<Context, Context>(Input(-1)); Output(0)->template CopyFrom<Context>(Input(-1));
return; return;
} }
......
...@@ -39,7 +39,7 @@ void RandomPickOp<Context>::RunOnDevice() { ...@@ -39,7 +39,7 @@ void RandomPickOp<Context>::RunOnDevice() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
Output(1)->ReshapeLike(*pick_indices); Output(1)->ReshapeLike(*pick_indices);
Output(1)->template Copy<Context, Context>(*pick_indices); Output(1)->template CopyFrom<Context>(*pick_indices);
} }
} }
......
#include "core/workspace.h" #include "core/workspace.h"
#include "operators/ndarray/reshape_op.h" #include "operators/ndarray/dimension_op.h"
namespace dragon { namespace dragon {
...@@ -67,50 +67,31 @@ void ReshapeOp<Context>::RunOnDevice() { ...@@ -67,50 +67,31 @@ void ReshapeOp<Context>::RunOnDevice() {
<< "\nCan not change the total size." << "\nCan not change the total size."
<< Input(0).DimString() << Input(0).DimString()
<< " -> " << DimString(new_shape); << " -> " << DimString(new_shape);
// save Xshape Output(0)->Reshape(new_shape);
Tensor* sv = ws()->CreateTensor( Output(0)->SetMeta(Input(0).meta());
"/mnt/" + anchor() + "/reshape/x_shape"); Output(0)->Share(Input(0).memory());
sv->Reshape({ (TIndex)Input(0).ndim() });
auto* Sdata = sv->template mutable_data<TIndex, CPUContext>();
for (int i = 0; i < Input(0).ndim(); i++) Sdata[i] = Input(0).dim(i);
Output(0)->Reshape(new_shape);
if (Output(0)->name() != Input(0).name())
Output(0)->template Copy<Context, Context>(Input(0));
} }
DEPLOY_CPU(Reshape); DEPLOY_CPU(Reshape);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(Reshape); DEPLOY_CUDA(Reshape);
#endif #endif
OPERATOR_SCHEMA(Reshape) OPERATOR_SCHEMA(Reshape).NumInputs(1).NumOutputs(1);
.NumInputs(1).NumOutputs(1)
.Inplace({ { 0, 0 } });
template <class Context>
void ReshapeGradientOp<Context>::RunOnDevice() {
Tensor* sv = ws()->GetTensor(
"/mnt/" + anchor() + "/reshape/x_shape");
auto* Sdata = sv->template mutable_data<TIndex, CPUContext>();
vector<TIndex> x_shape(sv->count());
for (int i = 0; i < sv->count(); i++) x_shape[i] = Sdata[i];
Output(0)->Reshape(x_shape);
if (Output(0)->name() != Input(-1).name())
Output(0)->template Copy<Context, Context>(Input(-1));
}
DEPLOY_CPU(ReshapeGradient); DEPLOY_CPU(ReshapeGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ReshapeGradient); DEPLOY_CUDA(ReshapeGradient);
#endif #endif
OPERATOR_SCHEMA(ReshapeGradient).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } }); OPERATOR_SCHEMA(ReshapeGradient)
.NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
class GetReshapeGradient final : public GradientMakerBase { class GetReshapeGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetReshapeGradient); GRADIENT_MAKER_CTOR(GetReshapeGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
vector<string> {GO(0)}, vector<string> {I(0), GO(0)},
vector<string> {GI(0)}); vector<string> {GI(0)});
} }
}; };
......
#include "core/workspace.h"
#include "operators/ndarray/dimension_op.h"
namespace dragon {
template <class Context>
void SqueezeOp<Context>::RunOnDevice() {
TIndex _axis_ = axis >= 0 ? axis :
axis + (TIndex)Input(0).ndim();
vector<TIndex> dims;
for (int i = 0; i < Input(0).ndim(); i++)
if ((Input(0).dim(i) != 1) ||
(_axis_ != INT_MAX &&
Input(0).dim(i) == 1 &&
i != _axis_))
dims.push_back(Input(0).dim(i));
Output(0)->Reshape(dims);
Output(0)->SetMeta(Input(0).meta());
Output(0)->Share(Input(0).memory());
}
DEPLOY_CPU(Squeeze);
#ifdef WITH_CUDA
DEPLOY_CUDA(Squeeze);
#endif
OPERATOR_SCHEMA(Squeeze).NumInputs(1).NumOutputs(1);
DEPLOY_CPU(SqueezeGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(SqueezeGradient);
#endif
OPERATOR_SCHEMA(SqueezeGradient)
.NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
class GetSqueezeGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetSqueezeGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(Squeeze, GetSqueezeGradient);
} // namespace dragon
\ No newline at end of file
...@@ -35,7 +35,7 @@ void TileOp<Context>::RunOnDevice() { ...@@ -35,7 +35,7 @@ void TileOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
Output(0)->template Copy<Context, Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0));
return; return;
} }
...@@ -96,7 +96,7 @@ void TileGradientOp<Context>::RunOnDevice() { ...@@ -96,7 +96,7 @@ void TileGradientOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(-1)); Output(0)->ReshapeLike(Input(-1));
Output(0)->template Copy<Context, Context>(Input(-1)); Output(0)->template CopyFrom<Context>(Input(-1));
return; return;
} }
......
...@@ -17,11 +17,11 @@ template <class Context> template <typename T> ...@@ -17,11 +17,11 @@ template <class Context> template <typename T>
void LRNOp<Context>::SplitRunWithType() { void LRNOp<Context>::SplitRunWithType() {
sqr_in = ws()->CreateTensor("/mnt/" + anchor() + "/sqr/in"); sqr_in = ws()->CreateTensor("/mnt/" + anchor() + "/sqr/in");
sqr_in->ReshapeLike(Input(0)); sqr_in->ReshapeLike(Input(0));
sqr_in->template Copy<Context, Context>(Input(0)); sqr_in->template CopyFrom<Context>(Input(0));
prod_in = ws()->CreateTensor("/mnt/" + anchor() + "/prod/in"); prod_in = ws()->CreateTensor("/mnt/" + anchor() + "/prod/in");
prod_in->ReshapeLike(Input(0)); prod_in->ReshapeLike(Input(0));
prod_in->template Copy<Context, Context>(Input(0)); prod_in->template CopyFrom<Context>(Input(0));
} }
template <class Context> template <typename T> template <class Context> template <typename T>
......
syntax = "proto2"; syntax = "proto2";
package dragon;
message BlobShape { message BlobShape {
repeated int64 dim = 1 [packed = true]; repeated int64 dim = 1 [packed = true];
} }
......
syntax = "proto2"; syntax = "proto2";
package dragon;
message TensorProto { message TensorProto {
repeated int32 dims = 1; repeated int32 dims = 1;
enum DataType { enum DataType {
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!