Commit 3dfb6ea5 by Ting PAN

Enhance tensor memory mapping with an optional offset

Summary:
This commit adds a feature to map memory and load data with offset.
Tensor takes mapped memory is readonly and will drop it on next mutation.
1 parent 654febe3
......@@ -47,15 +47,15 @@ void UnifiedMemory::ToCUDA(size_t size) {
}
}
const void* UnifiedMemory::cpu_data(size_t size) {
const void* UnifiedMemory::cpu_data(size_t size, size_t offset) {
ToCPU(size);
return (const void*)cpu_ptr_;
return (const void*)((uint8_t*)cpu_ptr_ + offset);
}
const void* UnifiedMemory::cuda_data(size_t size) {
const void* UnifiedMemory::cuda_data(size_t size, size_t offset) {
SwitchToCUDADevice(CUDAContext::current_device());
ToCUDA(size);
return (const void*)cuda_ptr_;
return (const void*)((uint8_t*)cuda_ptr_ + offset);
}
void* UnifiedMemory::mutable_cpu_data(size_t size) {
......
......@@ -89,10 +89,10 @@ class DRAGON_API UnifiedMemory {
Map<string, string> info() const;
/*! \brief Return the const cpu data */
const void* cpu_data(size_t size = 0);
const void* cpu_data(size_t size = 0, size_t offset = 0);
/*! \brief Return the const cuda data */
const void* cuda_data(size_t size = 0);
const void* cuda_data(size_t size = 0, size_t offset = 0);
/*! \brief Return the mutable cpu data */
void* mutable_cpu_data(size_t size = 0);
......
......@@ -52,8 +52,8 @@ Tensor& OperatorBase::Input(int i) {
Tensor* OperatorBase::Output(int i) {
CHECK_LT(i, (int)outputs_.size());
CHECK_GE(i, -(int)outputs_.size());
if (i >= 0) return outputs_[i];
return outputs_[i + outputs_.size()];
if (i >= 0) return outputs_[i]->MapFrom(nullptr);
return outputs_[i + outputs_.size()]->MapFrom(nullptr);
}
Tensor* OperatorBase::Output(int i, const vec32_t& inputs) {
......@@ -114,14 +114,15 @@ OperatorBase* OperatorBase::DeriveFrom(const OperatorDef& def) {
template <class Context>
void Operator<Context>::Prepare() {
for (int i = 0; i < InputSize(); i++) {
if (Input(i).version() >= 0) {
for (int i = 0; i < InputSize(); ++i) {
auto& X = *inputs_[i];
if (X.version() >= 0) {
const auto& name = def().input(i);
auto ver_pos = name.find("/ver:");
auto version = std::atoi(name.substr(ver_pos + 5).c_str());
if (version == Input(i).version()) continue;
LOG(DEBUG) << "Excepted version of Tensor(" + Input(i).name() + ") "
<< "is " << version << ", got " << Input(i).version()
if (version == X.version()) continue;
LOG(DEBUG) << "Excepted version of Tensor(" + X.name() + ") "
<< "is " << version << ", got " << X.version()
<< ". Recompute.";
Tensor* flag = workspace()->GetTensor("flagged/recomp");
flag->mutable_data<bool, CPUContext>()[0] = true;
......@@ -136,12 +137,13 @@ void Operator<Context>::Prepare() {
template <class Context>
void Operator<Context>::Release() {
for (int i = 0; i < OutputSize(); i++) {
if (Output(i)->version() >= 0) {
for (int i = 0; i < OutputSize(); ++i) {
auto* Y = outputs_[i];
if (Y->version() >= 0) {
const auto& name = def().output(i);
auto ver_pos = name.find("/ver:");
auto version = std::atoi(name.substr(ver_pos + 5).c_str());
Output(i)->set_version(version);
Y->set_version(version);
}
}
}
......
......@@ -137,18 +137,21 @@ class DRAGON_API Tensor {
}
/*! \brief Map memory from a tensor */
Tensor* MapFrom(Tensor* other) {
Tensor* MapFrom(Tensor* other, size_t offset = 0) {
if (other == nullptr) {
if (mapped_memory_ != nullptr) {
mapped_memory_ = nullptr;
capacity_ = (memory_ != nullptr ? memory_->size() : 0);
offset_ = 0;
}
} else {
auto* new_memory = other->memory();
if (new_memory != nullptr) {
CHECK_LE(size_, new_memory->size())
<< "\nMap from a memory with smaller capacity.";
memory_.reset();
mapped_memory_ = new_memory;
capacity_ = new_memory->size();
offset_ = offset;
}
}
return this;
......@@ -158,10 +161,10 @@ class DRAGON_API Tensor {
void Reset() {
dims_.clear();
strides_.clear();
memory_.reset();
meta_ = TypeMeta();
size_ = capacity_ = 0;
memory_.reset();
mapped_memory_ = nullptr;
size_ = capacity_ = offset_ = 0;
if (ExternalDeleter != nullptr) {
ExternalDeleter();
ExternalDeleter = nullptr;
......@@ -295,11 +298,8 @@ class DRAGON_API Tensor {
/*! \brief Return the memory */
UnifiedMemory* memory(bool required = false, bool owned = false) {
if (capacity_ < size_ * meta_.itemsize()) {
if (mapped_memory_ != nullptr) {
mapped_memory_ = nullptr;
} else {
memory_.reset();
}
capacity_ = 0;
}
auto* ptr = (owned || !mapped_memory_ ? memory_.get() : mapped_memory_);
......@@ -317,9 +317,9 @@ class DRAGON_API Tensor {
const void* raw_data() {
const auto context_type = TypeMeta::Id<Context>();
if (context_type == TypeMeta::Id<CPUContext>()) {
return memory(true)->cpu_data(nbytes());
return memory(true)->cpu_data(nbytes(), offset_);
} else if (context_type == TypeMeta::Id<CUDAContext>()) {
return memory(true)->cuda_data(nbytes());
return memory(true)->cuda_data(nbytes(), offset_);
} else {
LOG(FATAL) << "Unsupported context type.";
return nullptr;
......@@ -393,7 +393,9 @@ class DRAGON_API Tensor {
if (memory != memory_.get()) {
memory_.reset(memory);
}
mapped_memory_ = nullptr;
capacity_ = memory->size();
offset_ = 0;
}
}
......@@ -410,6 +412,9 @@ class DRAGON_API Tensor {
/*! \brief The byte length of memory */
size_t capacity_ = 0;
/*! \brief The byte offset of memory */
size_t offset_ = 0;
/*! \brief The tensor version */
int version_ = -1;
......
......@@ -40,6 +40,10 @@ void SplitOp<Context>::DoRunWithType() {
auto* Y = Output(i);
if (Y->has_name()) {
Y_dims[axis] = size_splits[i];
if (!copy_chunks_ && axis == 0) {
Y->Reshape(Y_dims)->set_meta(X.meta())->MapFrom(
&X, sizeof(T) * input_offset);
} else {
math::CopyMatrix(
X.count(0, axis),
size_splits[i] * X.count(axis + 1),
......@@ -49,6 +53,7 @@ void SplitOp<Context>::DoRunWithType() {
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
}
}
input_offset += size_splits[i] * X.count(axis + 1);
}
}
......
......@@ -20,13 +20,18 @@ namespace dragon {
template <class Context>
class SplitOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(SplitOp);
SplitOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
copy_chunks_(OP_SINGLE_ARG(int64_t, "copy", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t copy_chunks_;
};
template <class Context>
......
......@@ -565,6 +565,7 @@ def sort_args(**kwargs):
def split_args(**kwargs):
return {
'axis': kwargs.get('axis', 0),
'copy': kwargs.get('copy', True),
'size_splits': kwargs.get('size_splits', None),
'slice_points': kwargs.get('slice_points', None),
}
......
......@@ -1528,6 +1528,7 @@ def split(
num_or_size_splits,
axis=0,
slice_points=None,
copy=True,
**kwargs
):
r"""Split input into chunks along the given axis.
......@@ -1568,6 +1569,8 @@ def split(
The axis to split.
slice_points : Sequence[int], optional
The optional slice points.
copy : bool, optional, default=True
Copy or create the views of input.
Returns
-------
......@@ -1588,10 +1591,11 @@ def split(
% len(slice_points))
if context.executing_eagerly():
return OpLib.execute(
'Split', inputs, outputs=[None] * num_splits,
axis=axis, size_splits=size_splits, slice_points=slice_points)
'Split', inputs, outputs=[None] * num_splits, axis=axis,
size_splits=size_splits, slice_points=slice_points, copy=copy)
return OpLib.add('Split', inputs, num_outputs=num_splits, axis=axis,
size_splits=size_splits, slice_points=slice_points, **kwargs)
size_splits=size_splits, slice_points=slice_points,
copy=copy, **kwargs)
@OpSchema.num_inputs(1)
......
......@@ -339,8 +339,9 @@ inline void CollapseTransposeAxes(
new_dims.begin(), new_dims.end(), [](int x) { return x == -1; });
new_dims.erase(erase_iter, new_dims.end());
for (int i = 0; i < new_axes.size(); ++i) {
const auto axis = new_axes[i];
for (auto collapse_axis : collapse_axes) {
if (new_axes[i] > collapse_axis) new_axes[i]--;
if (axis > collapse_axis) new_axes[i]--;
}
}
}
......
......@@ -250,7 +250,7 @@ def channel_normalize(input, mean, std, dim=-1, dtype='float32', dims=None):
ndim=len(dims) if dims is not None else 0, perm=dims)
def chunk(tensor, chunks, dim=0):
def chunk(tensor, chunks, dim=0, copy=True):
"""Split input into a specific number of chunks.
Examples:
......@@ -277,6 +277,8 @@ def chunk(tensor, chunks, dim=0):
The number of chunks to split.
dim : int, optional, default=0
The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns
-------
......@@ -286,7 +288,7 @@ def chunk(tensor, chunks, dim=0):
"""
return FunctionLib.apply(
'Split', tensor.device, [tensor], outputs=[None] * chunks,
axis=dim, size_split=None)
axis=dim, size_split=None, copy=copy)
def cumsum(input, dim, out=None):
......@@ -1117,7 +1119,7 @@ def sort(input, dim=-1, descending=False, out=None):
outputs=out if out else [None, None], axis=dim, descending=descending)
def split(tensor, split_size_or_sections, dim=0):
def split(tensor, split_size_or_sections, dim=0, copy=True):
"""Split input into chunks along the given dimension.
Either size of every chunk or each chunk will be accepted:
......@@ -1146,6 +1148,8 @@ def split(tensor, split_size_or_sections, dim=0):
The number or size of chunks.
dim : int, optional, default=0
The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns
-------
......@@ -1167,7 +1171,7 @@ def split(tensor, split_size_or_sections, dim=0):
size_splits[-1] = size - (split_size_or_sections * (num_splits - 1))
return FunctionLib.apply(
'Split', tensor.device, [tensor], outputs=[None] * num_splits,
axis=dim, size_splits=size_splits)
axis=dim, size_splits=size_splits, copy=copy)
def squeeze(input, dim=None, out=None):
......
......@@ -573,7 +573,7 @@ def char_(self):
return array_ops.cast(self, 'int8', self)
def chunk(self, chunks, dim=0):
def chunk(self, chunks, dim=0, copy=True):
"""Split self into several parts along the given dim.
Parameters
......@@ -582,6 +582,8 @@ def chunk(self, chunks, dim=0):
The number of chunks to split.
dim : int, optional, default=0
The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns
-------
......@@ -589,7 +591,7 @@ def chunk(self, chunks, dim=0):
The output chunks.
"""
return array_ops.chunk(self, chunks, dim)
return array_ops.chunk(self, chunks, dim, copy)
def clamp(self, min=None, max=None):
......@@ -2420,7 +2422,7 @@ def sort(self, dim=-1, descending=False):
return array_ops.sort(self, dim, descending)
def split(self, split_size_or_sections, dim=0):
def split(self, split_size_or_sections, dim=0, copy=True):
"""Return the split chunks along the given dimension.
Parameters
......@@ -2434,13 +2436,15 @@ def split(self, split_size_or_sections, dim=0):
-------
Sequence[dragon.vm.torch.Tensor]
The output tensors.
copy : bool, optional, default=True
Copy or create the views of input.
See Also
--------
`torch.split(...)`_
"""
return array_ops.split(self, split_size_or_sections, dim)
return array_ops.split(self, split_size_or_sections, dim, copy)
def sqrt(self):
......
......@@ -713,7 +713,7 @@ class Tensor(object):
"""
def chunk(self, chunks, dim=0):
def chunk(self, chunks, dim=0, copy=True):
"""Split self into several parts along the given dim.
Parameters
......@@ -722,6 +722,8 @@ class Tensor(object):
The number of chunks to split.
dim : int, optional, default=0
The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns
-------
......@@ -2566,7 +2568,7 @@ class Tensor(object):
"""
def split(self, split_size_or_sections, dim=0):
def split(self, split_size_or_sections, dim=0, copy=True):
"""Return the split chunks along the given dimension.
Parameters
......@@ -2575,6 +2577,8 @@ class Tensor(object):
The number or size of chunks.
dim : int, optional, default=0
The dimension to split.
copy : bool, optional, default=True
Copy or create the views of input.
Returns
-------
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!