Commit 3dfb6ea5 by Ting PAN

Enhance tensor memory mapping with an optional offset

Summary:
This commit adds a feature to map memory and load data with offset.
Tensor takes mapped memory is readonly and will drop it on next mutation.
1 parent 654febe3
...@@ -47,15 +47,15 @@ void UnifiedMemory::ToCUDA(size_t size) { ...@@ -47,15 +47,15 @@ void UnifiedMemory::ToCUDA(size_t size) {
} }
} }
const void* UnifiedMemory::cpu_data(size_t size) { const void* UnifiedMemory::cpu_data(size_t size, size_t offset) {
ToCPU(size); ToCPU(size);
return (const void*)cpu_ptr_; return (const void*)((uint8_t*)cpu_ptr_ + offset);
} }
const void* UnifiedMemory::cuda_data(size_t size) { const void* UnifiedMemory::cuda_data(size_t size, size_t offset) {
SwitchToCUDADevice(CUDAContext::current_device()); SwitchToCUDADevice(CUDAContext::current_device());
ToCUDA(size); ToCUDA(size);
return (const void*)cuda_ptr_; return (const void*)((uint8_t*)cuda_ptr_ + offset);
} }
void* UnifiedMemory::mutable_cpu_data(size_t size) { void* UnifiedMemory::mutable_cpu_data(size_t size) {
......
...@@ -89,10 +89,10 @@ class DRAGON_API UnifiedMemory { ...@@ -89,10 +89,10 @@ class DRAGON_API UnifiedMemory {
Map<string, string> info() const; Map<string, string> info() const;
/*! \brief Return the const cpu data */ /*! \brief Return the const cpu data */
const void* cpu_data(size_t size = 0); const void* cpu_data(size_t size = 0, size_t offset = 0);
/*! \brief Return the const cuda data */ /*! \brief Return the const cuda data */
const void* cuda_data(size_t size = 0); const void* cuda_data(size_t size = 0, size_t offset = 0);
/*! \brief Return the mutable cpu data */ /*! \brief Return the mutable cpu data */
void* mutable_cpu_data(size_t size = 0); void* mutable_cpu_data(size_t size = 0);
......
...@@ -52,8 +52,8 @@ Tensor& OperatorBase::Input(int i) { ...@@ -52,8 +52,8 @@ Tensor& OperatorBase::Input(int i) {
Tensor* OperatorBase::Output(int i) { Tensor* OperatorBase::Output(int i) {
CHECK_LT(i, (int)outputs_.size()); CHECK_LT(i, (int)outputs_.size());
CHECK_GE(i, -(int)outputs_.size()); CHECK_GE(i, -(int)outputs_.size());
if (i >= 0) return outputs_[i]; if (i >= 0) return outputs_[i]->MapFrom(nullptr);
return outputs_[i + outputs_.size()]; return outputs_[i + outputs_.size()]->MapFrom(nullptr);
} }
Tensor* OperatorBase::Output(int i, const vec32_t& inputs) { Tensor* OperatorBase::Output(int i, const vec32_t& inputs) {
...@@ -114,14 +114,15 @@ OperatorBase* OperatorBase::DeriveFrom(const OperatorDef& def) { ...@@ -114,14 +114,15 @@ OperatorBase* OperatorBase::DeriveFrom(const OperatorDef& def) {
template <class Context> template <class Context>
void Operator<Context>::Prepare() { void Operator<Context>::Prepare() {
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); ++i) {
if (Input(i).version() >= 0) { auto& X = *inputs_[i];
if (X.version() >= 0) {
const auto& name = def().input(i); const auto& name = def().input(i);
auto ver_pos = name.find("/ver:"); auto ver_pos = name.find("/ver:");
auto version = std::atoi(name.substr(ver_pos + 5).c_str()); auto version = std::atoi(name.substr(ver_pos + 5).c_str());
if (version == Input(i).version()) continue; if (version == X.version()) continue;
LOG(DEBUG) << "Excepted version of Tensor(" + Input(i).name() + ") " LOG(DEBUG) << "Excepted version of Tensor(" + X.name() + ") "
<< "is " << version << ", got " << Input(i).version() << "is " << version << ", got " << X.version()
<< ". Recompute."; << ". Recompute.";
Tensor* flag = workspace()->GetTensor("flagged/recomp"); Tensor* flag = workspace()->GetTensor("flagged/recomp");
flag->mutable_data<bool, CPUContext>()[0] = true; flag->mutable_data<bool, CPUContext>()[0] = true;
...@@ -136,12 +137,13 @@ void Operator<Context>::Prepare() { ...@@ -136,12 +137,13 @@ void Operator<Context>::Prepare() {
template <class Context> template <class Context>
void Operator<Context>::Release() { void Operator<Context>::Release() {
for (int i = 0; i < OutputSize(); i++) { for (int i = 0; i < OutputSize(); ++i) {
if (Output(i)->version() >= 0) { auto* Y = outputs_[i];
if (Y->version() >= 0) {
const auto& name = def().output(i); const auto& name = def().output(i);
auto ver_pos = name.find("/ver:"); auto ver_pos = name.find("/ver:");
auto version = std::atoi(name.substr(ver_pos + 5).c_str()); auto version = std::atoi(name.substr(ver_pos + 5).c_str());
Output(i)->set_version(version); Y->set_version(version);
} }
} }
} }
......
...@@ -137,18 +137,21 @@ class DRAGON_API Tensor { ...@@ -137,18 +137,21 @@ class DRAGON_API Tensor {
} }
/*! \brief Map memory from a tensor */ /*! \brief Map memory from a tensor */
Tensor* MapFrom(Tensor* other) { Tensor* MapFrom(Tensor* other, size_t offset = 0) {
if (other == nullptr) { if (other == nullptr) {
if (mapped_memory_ != nullptr) {
mapped_memory_ = nullptr; mapped_memory_ = nullptr;
capacity_ = (memory_ != nullptr ? memory_->size() : 0); capacity_ = (memory_ != nullptr ? memory_->size() : 0);
offset_ = 0;
}
} else { } else {
auto* new_memory = other->memory(); auto* new_memory = other->memory();
if (new_memory != nullptr) { if (new_memory != nullptr) {
CHECK_LE(size_, new_memory->size()) CHECK_LE(size_, new_memory->size())
<< "\nMap from a memory with smaller capacity."; << "\nMap from a memory with smaller capacity.";
memory_.reset();
mapped_memory_ = new_memory; mapped_memory_ = new_memory;
capacity_ = new_memory->size(); capacity_ = new_memory->size();
offset_ = offset;
} }
} }
return this; return this;
...@@ -158,10 +161,10 @@ class DRAGON_API Tensor { ...@@ -158,10 +161,10 @@ class DRAGON_API Tensor {
void Reset() { void Reset() {
dims_.clear(); dims_.clear();
strides_.clear(); strides_.clear();
memory_.reset();
meta_ = TypeMeta(); meta_ = TypeMeta();
size_ = capacity_ = 0; memory_.reset();
mapped_memory_ = nullptr; mapped_memory_ = nullptr;
size_ = capacity_ = offset_ = 0;
if (ExternalDeleter != nullptr) { if (ExternalDeleter != nullptr) {
ExternalDeleter(); ExternalDeleter();
ExternalDeleter = nullptr; ExternalDeleter = nullptr;
...@@ -295,11 +298,8 @@ class DRAGON_API Tensor { ...@@ -295,11 +298,8 @@ class DRAGON_API Tensor {
/*! \brief Return the memory */ /*! \brief Return the memory */
UnifiedMemory* memory(bool required = false, bool owned = false) { UnifiedMemory* memory(bool required = false, bool owned = false) {
if (capacity_ < size_ * meta_.itemsize()) { if (capacity_ < size_ * meta_.itemsize()) {
if (mapped_memory_ != nullptr) {
mapped_memory_ = nullptr; mapped_memory_ = nullptr;
} else {
memory_.reset(); memory_.reset();
}
capacity_ = 0; capacity_ = 0;
} }
auto* ptr = (owned || !mapped_memory_ ? memory_.get() : mapped_memory_); auto* ptr = (owned || !mapped_memory_ ? memory_.get() : mapped_memory_);
...@@ -317,9 +317,9 @@ class DRAGON_API Tensor { ...@@ -317,9 +317,9 @@ class DRAGON_API Tensor {
const void* raw_data() { const void* raw_data() {
const auto context_type = TypeMeta::Id<Context>(); const auto context_type = TypeMeta::Id<Context>();
if (context_type == TypeMeta::Id<CPUContext>()) { if (context_type == TypeMeta::Id<CPUContext>()) {
return memory(true)->cpu_data(nbytes()); return memory(true)->cpu_data(nbytes(), offset_);
} else if (context_type == TypeMeta::Id<CUDAContext>()) { } else if (context_type == TypeMeta::Id<CUDAContext>()) {
return memory(true)->cuda_data(nbytes()); return memory(true)->cuda_data(nbytes(), offset_);
} else { } else {
LOG(FATAL) << "Unsupported context type."; LOG(FATAL) << "Unsupported context type.";
return nullptr; return nullptr;
...@@ -393,7 +393,9 @@ class DRAGON_API Tensor { ...@@ -393,7 +393,9 @@ class DRAGON_API Tensor {
if (memory != memory_.get()) { if (memory != memory_.get()) {
memory_.reset(memory); memory_.reset(memory);
} }
mapped_memory_ = nullptr;
capacity_ = memory->size(); capacity_ = memory->size();
offset_ = 0;
} }
} }
...@@ -410,6 +412,9 @@ class DRAGON_API Tensor { ...@@ -410,6 +412,9 @@ class DRAGON_API Tensor {
/*! \brief The byte length of memory */ /*! \brief The byte length of memory */
size_t capacity_ = 0; size_t capacity_ = 0;
/*! \brief The byte offset of memory */
size_t offset_ = 0;
/*! \brief The tensor version */ /*! \brief The tensor version */
int version_ = -1; int version_ = -1;
......
...@@ -40,6 +40,10 @@ void SplitOp<Context>::DoRunWithType() { ...@@ -40,6 +40,10 @@ void SplitOp<Context>::DoRunWithType() {
auto* Y = Output(i); auto* Y = Output(i);
if (Y->has_name()) { if (Y->has_name()) {
Y_dims[axis] = size_splits[i]; Y_dims[axis] = size_splits[i];
if (!copy_chunks_ && axis == 0) {
Y->Reshape(Y_dims)->set_meta(X.meta())->MapFrom(
&X, sizeof(T) * input_offset);
} else {
math::CopyMatrix( math::CopyMatrix(
X.count(0, axis), X.count(0, axis),
size_splits[i] * X.count(axis + 1), size_splits[i] * X.count(axis + 1),
...@@ -49,6 +53,7 @@ void SplitOp<Context>::DoRunWithType() { ...@@ -49,6 +53,7 @@ void SplitOp<Context>::DoRunWithType() {
Y->Reshape(Y_dims)->template mutable_data<T, Context>(), Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
}
input_offset += size_splits[i] * X.count(axis + 1); input_offset += size_splits[i] * X.count(axis + 1);
} }
} }
......
...@@ -20,13 +20,18 @@ namespace dragon { ...@@ -20,13 +20,18 @@ namespace dragon {
template <class Context> template <class Context>
class SplitOp final : public Operator<Context> { class SplitOp final : public Operator<Context> {
public: public:
SIMPLE_CTOR_DTOR(SplitOp); SplitOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
copy_chunks_(OP_SINGLE_ARG(int64_t, "copy", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
int64_t copy_chunks_;
}; };
template <class Context> template <class Context>
......
...@@ -565,6 +565,7 @@ def sort_args(**kwargs): ...@@ -565,6 +565,7 @@ def sort_args(**kwargs):
def split_args(**kwargs): def split_args(**kwargs):
return { return {
'axis': kwargs.get('axis', 0), 'axis': kwargs.get('axis', 0),
'copy': kwargs.get('copy', True),
'size_splits': kwargs.get('size_splits', None), 'size_splits': kwargs.get('size_splits', None),
'slice_points': kwargs.get('slice_points', None), 'slice_points': kwargs.get('slice_points', None),
} }
......
...@@ -1528,6 +1528,7 @@ def split( ...@@ -1528,6 +1528,7 @@ def split(
num_or_size_splits, num_or_size_splits,
axis=0, axis=0,
slice_points=None, slice_points=None,
copy=True,
**kwargs **kwargs
): ):
r"""Split input into chunks along the given axis. r"""Split input into chunks along the given axis.
...@@ -1568,6 +1569,8 @@ def split( ...@@ -1568,6 +1569,8 @@ def split(
The axis to split. The axis to split.
slice_points : Sequence[int], optional slice_points : Sequence[int], optional
The optional slice points. The optional slice points.
copy : bool, optional, default=True
Copy or create the views of input.
Returns Returns
------- -------
...@@ -1588,10 +1591,11 @@ def split( ...@@ -1588,10 +1591,11 @@ def split(
% len(slice_points)) % len(slice_points))
if context.executing_eagerly(): if context.executing_eagerly():
return OpLib.execute( return OpLib.execute(
'Split', inputs, outputs=[None] * num_splits, 'Split', inputs, outputs=[None] * num_splits, axis=axis,
axis=axis, size_splits=size_splits, slice_points=slice_points) size_splits=size_splits, slice_points=slice_points, copy=copy)
return OpLib.add('Split', inputs, num_outputs=num_splits, axis=axis, return OpLib.add('Split', inputs, num_outputs=num_splits, axis=axis,
size_splits=size_splits, slice_points=slice_points, **kwargs) size_splits=size_splits, slice_points=slice_points,
copy=copy, **kwargs)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
......
...@@ -339,8 +339,9 @@ inline void CollapseTransposeAxes( ...@@ -339,8 +339,9 @@ inline void CollapseTransposeAxes(
new_dims.begin(), new_dims.end(), [](int x) { return x == -1; }); new_dims.begin(), new_dims.end(), [](int x) { return x == -1; });
new_dims.erase(erase_iter, new_dims.end()); new_dims.erase(erase_iter, new_dims.end());
for (int i = 0; i < new_axes.size(); ++i) { for (int i = 0; i < new_axes.size(); ++i) {
const auto axis = new_axes[i];
for (auto collapse_axis : collapse_axes) { for (auto collapse_axis : collapse_axes) {
if (new_axes[i] > collapse_axis) new_axes[i]--; if (axis > collapse_axis) new_axes[i]--;
} }
} }
} }
......
...@@ -250,7 +250,7 @@ def channel_normalize(input, mean, std, dim=-1, dtype='float32', dims=None): ...@@ -250,7 +250,7 @@ def channel_normalize(input, mean, std, dim=-1, dtype='float32', dims=None):
ndim=len(dims) if dims is not None else 0, perm=dims) ndim=len(dims) if dims is not None else 0, perm=dims)
def chunk(tensor, chunks, dim=0): def chunk(tensor, chunks, dim=0, copy=True):
"""Split input into a specific number of chunks. """Split input into a specific number of chunks.
Examples: Examples:
...@@ -277,6 +277,8 @@ def chunk(tensor, chunks, dim=0): ...@@ -277,6 +277,8 @@ def chunk(tensor, chunks, dim=0):
The number of chunks to split. The number of chunks to split.
dim : int, optional, default=0 dim : int, optional, default=0
The dimension to split. The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns Returns
------- -------
...@@ -286,7 +288,7 @@ def chunk(tensor, chunks, dim=0): ...@@ -286,7 +288,7 @@ def chunk(tensor, chunks, dim=0):
""" """
return FunctionLib.apply( return FunctionLib.apply(
'Split', tensor.device, [tensor], outputs=[None] * chunks, 'Split', tensor.device, [tensor], outputs=[None] * chunks,
axis=dim, size_split=None) axis=dim, size_split=None, copy=copy)
def cumsum(input, dim, out=None): def cumsum(input, dim, out=None):
...@@ -1117,7 +1119,7 @@ def sort(input, dim=-1, descending=False, out=None): ...@@ -1117,7 +1119,7 @@ def sort(input, dim=-1, descending=False, out=None):
outputs=out if out else [None, None], axis=dim, descending=descending) outputs=out if out else [None, None], axis=dim, descending=descending)
def split(tensor, split_size_or_sections, dim=0): def split(tensor, split_size_or_sections, dim=0, copy=True):
"""Split input into chunks along the given dimension. """Split input into chunks along the given dimension.
Either size of every chunk or each chunk will be accepted: Either size of every chunk or each chunk will be accepted:
...@@ -1146,6 +1148,8 @@ def split(tensor, split_size_or_sections, dim=0): ...@@ -1146,6 +1148,8 @@ def split(tensor, split_size_or_sections, dim=0):
The number or size of chunks. The number or size of chunks.
dim : int, optional, default=0 dim : int, optional, default=0
The dimension to split. The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns Returns
------- -------
...@@ -1167,7 +1171,7 @@ def split(tensor, split_size_or_sections, dim=0): ...@@ -1167,7 +1171,7 @@ def split(tensor, split_size_or_sections, dim=0):
size_splits[-1] = size - (split_size_or_sections * (num_splits - 1)) size_splits[-1] = size - (split_size_or_sections * (num_splits - 1))
return FunctionLib.apply( return FunctionLib.apply(
'Split', tensor.device, [tensor], outputs=[None] * num_splits, 'Split', tensor.device, [tensor], outputs=[None] * num_splits,
axis=dim, size_splits=size_splits) axis=dim, size_splits=size_splits, copy=copy)
def squeeze(input, dim=None, out=None): def squeeze(input, dim=None, out=None):
......
...@@ -573,7 +573,7 @@ def char_(self): ...@@ -573,7 +573,7 @@ def char_(self):
return array_ops.cast(self, 'int8', self) return array_ops.cast(self, 'int8', self)
def chunk(self, chunks, dim=0): def chunk(self, chunks, dim=0, copy=True):
"""Split self into several parts along the given dim. """Split self into several parts along the given dim.
Parameters Parameters
...@@ -582,6 +582,8 @@ def chunk(self, chunks, dim=0): ...@@ -582,6 +582,8 @@ def chunk(self, chunks, dim=0):
The number of chunks to split. The number of chunks to split.
dim : int, optional, default=0 dim : int, optional, default=0
The dimension to split. The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns Returns
------- -------
...@@ -589,7 +591,7 @@ def chunk(self, chunks, dim=0): ...@@ -589,7 +591,7 @@ def chunk(self, chunks, dim=0):
The output chunks. The output chunks.
""" """
return array_ops.chunk(self, chunks, dim) return array_ops.chunk(self, chunks, dim, copy)
def clamp(self, min=None, max=None): def clamp(self, min=None, max=None):
...@@ -2420,7 +2422,7 @@ def sort(self, dim=-1, descending=False): ...@@ -2420,7 +2422,7 @@ def sort(self, dim=-1, descending=False):
return array_ops.sort(self, dim, descending) return array_ops.sort(self, dim, descending)
def split(self, split_size_or_sections, dim=0): def split(self, split_size_or_sections, dim=0, copy=True):
"""Return the split chunks along the given dimension. """Return the split chunks along the given dimension.
Parameters Parameters
...@@ -2434,13 +2436,15 @@ def split(self, split_size_or_sections, dim=0): ...@@ -2434,13 +2436,15 @@ def split(self, split_size_or_sections, dim=0):
------- -------
Sequence[dragon.vm.torch.Tensor] Sequence[dragon.vm.torch.Tensor]
The output tensors. The output tensors.
copy : bool, optional, default=True
Copy or create the views of input.
See Also See Also
-------- --------
`torch.split(...)`_ `torch.split(...)`_
""" """
return array_ops.split(self, split_size_or_sections, dim) return array_ops.split(self, split_size_or_sections, dim, copy)
def sqrt(self): def sqrt(self):
......
...@@ -713,7 +713,7 @@ class Tensor(object): ...@@ -713,7 +713,7 @@ class Tensor(object):
""" """
def chunk(self, chunks, dim=0): def chunk(self, chunks, dim=0, copy=True):
"""Split self into several parts along the given dim. """Split self into several parts along the given dim.
Parameters Parameters
...@@ -722,6 +722,8 @@ class Tensor(object): ...@@ -722,6 +722,8 @@ class Tensor(object):
The number of chunks to split. The number of chunks to split.
dim : int, optional, default=0 dim : int, optional, default=0
The dimension to split. The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns Returns
------- -------
...@@ -2566,7 +2568,7 @@ class Tensor(object): ...@@ -2566,7 +2568,7 @@ class Tensor(object):
""" """
def split(self, split_size_or_sections, dim=0): def split(self, split_size_or_sections, dim=0, copy=True):
"""Return the split chunks along the given dimension. """Return the split chunks along the given dimension.
Parameters Parameters
...@@ -2575,6 +2577,8 @@ class Tensor(object): ...@@ -2575,6 +2577,8 @@ class Tensor(object):
The number or size of chunks. The number or size of chunks.
dim : int, optional, default=0 dim : int, optional, default=0
The dimension to split. The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns Returns
------- -------
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!