Commit b37d4e5e by Ting PAN

Fix the bug on dropout cuda kernel

Summary:
We forgot to handle the inplace case that generated
the random elements on the output(i.e. the input).

Besides, this commit also fixes the omitted `RunOnDevice` for cudnn activations,
which will rightly dispatches the implementation.
1 parent c1b8f912
......@@ -63,16 +63,6 @@ debughtml:
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/python."
deployhtml:
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/python
rm -rf $(BUILDDIR)/doctrees
rm -rf $(BUILDDIR)/python/_sources
rm -rf $(BUILDDIR)/python/.buildinfo
rm -rf $(BUILDDIR)/python/objects.inv
rm -rf $(BUILDDIR)/python/py-modindex.html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/python."
dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
@echo
......
......@@ -5,7 +5,7 @@
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
......@@ -47,22 +47,23 @@ author = 'Ting Pan\\\\tingpan@seetatech.com'
# HTML
html_theme = 'seeta'
html_title = ""
html_short_title = ""
html_title = ''
html_short_title = ''
html_logo = '../_static/images/dragon.png'
html_favicon = '../_static/favicon.ico'
html_scaled_image_link = False
html_copy_source = False
html_show_sourcelink = False
html_show_sphinx = False
html_show_copyright = False
html_scaled_image_link = False
html_theme_options = {
'navbar_links': {
'Install': path_to('../../install', 1),
'API': [
('C++', path_to('../cpp', 1)),
('C++', path_to('../cc', 1)),
('Python', path_to('', 1))
],
'Github': 'https://github.com/seetaresearch/Dragon',
'Github': 'https://github.com/seetaresearch/dragon',
},
'navbar_logo_link': path_to('../..', 1),
'sidebar_title': 'Python v0.3.0',
......@@ -70,7 +71,7 @@ html_theme_options = {
'breadcrumb_links': [
('Dragon', path_to('../..', 1)),
('API', path_to('../../versions', 1)),
('Python', path_to('../../api/python', 1)),
('Python', path_to('', 1)),
],
}
html_sidebars = {
......
......@@ -83,10 +83,6 @@ Tensor* OperatorBase::Buffer(const string& name) {
return ws()->CreateTensor(unique_name(name));
}
bool OperatorBase::IsInputOutputAlias(int i, int j) {
return ((void*)&Input(i) == (void*)Output(j));
}
string OperatorBase::TypeString(const Tensor& tensor, const Set<string>& types)
const {
std::stringstream ss;
......
......@@ -74,9 +74,6 @@ class DRAGON_API OperatorBase {
return (int)outputs_.size();
}
/*! \brief Whether the output is an alias of input */
bool IsInputOutputAlias(int i, int j);
/*! \brief Return the value of the specified argument */
template <typename T>
T Arg(const string& name, const T& default_value);
......@@ -256,24 +253,23 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
name(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {} \
virtual ~name() {}
#define USE_OPERATOR_BASE_FUNCTIONS \
using OperatorBase::SwitchToPhase; \
using OperatorBase::Input; \
using OperatorBase::Output; \
using OperatorBase::Buffer; \
using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \
using OperatorBase::IsInputOutputAlias; \
using OperatorBase::DebugString; \
using OperatorBase::TypeString; \
using OperatorBase::name; \
using OperatorBase::type; \
using OperatorBase::phase; \
using OperatorBase::dtype; \
using OperatorBase::data_format; \
using OperatorBase::handle; \
using OperatorBase::unique_name; \
using OperatorBase::def; \
#define USE_OPERATOR_BASE_FUNCTIONS \
using OperatorBase::SwitchToPhase; \
using OperatorBase::Input; \
using OperatorBase::Output; \
using OperatorBase::Buffer; \
using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \
using OperatorBase::DebugString; \
using OperatorBase::TypeString; \
using OperatorBase::name; \
using OperatorBase::type; \
using OperatorBase::phase; \
using OperatorBase::dtype; \
using OperatorBase::data_format; \
using OperatorBase::handle; \
using OperatorBase::unique_name; \
using OperatorBase::def; \
using OperatorBase::ws
#define USE_OPERATOR_FUNCTIONS \
......
......@@ -57,7 +57,6 @@ void _DropPath<float16>(
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -90,7 +90,6 @@ void DropPath<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -82,6 +82,7 @@ void _Dropout<float16>(
const T* x, \
uint8_t* mask, \
T* y, \
uint32_t* scratch, \
CPUContext* ctx) { \
_Dropout(count, cast::to<T>(prob), cast::to<T>(scale), x, mask, y, ctx); \
}
......@@ -89,7 +90,6 @@ void _Dropout<float16>(
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -40,28 +40,31 @@ __global__ void _ApplyMask<half>(
template <typename T>
__global__ void _Dropout(
const int nthreads,
const T prob,
const uint32_t threshold,
const T scale,
const T* x,
const uint32_t* r,
uint8_t* mask,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = x[i] * (T)(mask[i] = (y[i] > prob)) * scale;
y[i] = x[i] * (T)(mask[i] = (r[i] > threshold)) * scale;
}
}
template <>
__global__ void _Dropout<half>(
const int nthreads,
const half prob,
const uint32_t threshold,
const half scale,
const half* x,
const uint32_t* r,
uint8_t* mask,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const uint8_t m = mask[i] = __hgt(y[i], prob);
y[i] = __hmul(__hmul(x[i], scale), __float2half((float)m));
y[i] = __hmul(
__hmul(x[i], scale),
__float2half((float)(mask[i] = (r[i] > threshold))));
#endif
}
}
......@@ -94,13 +97,15 @@ void Dropout<float16, CUDAContext>(
const float16* x,
uint8_t* mask,
float16* y,
uint32_t* scratch,
CUDAContext* ctx) {
math::RandomUniform(count, 0.f, 1.f, y, ctx);
math::RandomUniform(count, 0.f, 1.f, scratch, ctx);
_Dropout<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
cast::to<half>(prob),
static_cast<uint32_t>(UINT_MAX * prob),
cast::to<half>(scale),
reinterpret_cast<const half*>(x),
scratch,
mask,
reinterpret_cast<half*>(y));
}
......@@ -125,15 +130,16 @@ void Dropout<float16, CUDAContext>(
const T* x, \
uint8_t* mask, \
T* y, \
uint32_t* scratch, \
CUDAContext* ctx) { \
math::RandomUniform(count, 0.f, 1.f, y, ctx); \
math::RandomUniform(count, 0.f, 1.f, scratch, ctx); \
auto threshold = static_cast<uint32_t>(UINT_MAX * prob); \
_Dropout<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, cast::to<T>(prob), cast::to<T>(scale), x, mask, y); \
count, threshold, cast::to<T>(scale), x, scratch, mask, y); \
}
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -30,6 +30,7 @@ void DropoutOp<Context>::DoRunWithType() {
X.template data<T, Context>(),
Buffer("mask")->template mutable_data<uint8_t, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ws()->template data<uint32_t, Context>({X.count()})[0],
ctx());
} else {
LOG(FATAL) << "Unknown Phase: " << phase();
......
......@@ -79,6 +79,8 @@ class CuDNNDropoutOp final : public DropoutOp<Context> {
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......@@ -106,6 +108,8 @@ class CuDNNDropoutGradientOp final : public DropoutGradientOp<Context> {
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......
......@@ -13,7 +13,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CHECK(this->use_scale_) << "\nCuDNN only supports the scaled dropout.";
CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
CuDNNSetTensorDesc<T>(&input_desc_, {X.count(), 1, 1, 1});
if (phase() == "TEST") {
Y->ReshapeLike(X)->CopyFrom(X, ctx());
......@@ -36,7 +36,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
states_size,
rng_seed_));
} else {
X_states->Reshape({states_size});
X_states->Reshape({(int64_t)states_size});
CUDNN_CHECK(cudnnSetDropoutDescriptor(
dropout_desc_,
ctx()->cudnn_handle(),
......@@ -47,10 +47,10 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
}
}
// Allocate the reserve buffer
// Allocate for the reserve space
size_t reserve_size;
CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size));
Buffer("mask")->Reshape({reserve_size});
auto* X_mask = Buffer("mask")->Reshape({(int64_t)reserve_size});
CUDNN_CHECK(cudnnDropoutForward(
ctx()->cudnn_handle(),
......@@ -59,7 +59,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
X.template data<T, Context>(),
input_desc_,
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
Buffer("mask")->template mutable_data<uint8_t, Context>(),
X_mask->template mutable_data<uint8_t, Context>(),
reserve_size));
} else {
LOG(FATAL) << "Unknown Phase: " << phase();
......@@ -67,10 +67,15 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
}
template <class Context>
void CuDNNDropoutOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, dY.dims());
CuDNNSetTensorDesc<T>(&input_desc_, {dY.count(), 1, 1, 1});
if (phase() == "TEST") {
NOT_IMPLEMENTED;
......@@ -98,10 +103,11 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
}
}
// Allocate the reserve buffer
// Check the reserve space
size_t reserve_size;
CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size));
Buffer("mask")->Reshape({reserve_size});
auto* X_mask = Buffer("mask");
CHECK_EQ(X_mask->size(), reserve_size);
CUDNN_CHECK(cudnnDropoutBackward(
ctx()->cudnn_handle(),
......@@ -110,13 +116,18 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
dY.template data<T, Context>(),
input_desc_,
dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
Buffer("mask")->template mutable_data<uint8_t, Context>(),
X_mask->template mutable_data<uint8_t, Context>(),
reserve_size));
} else {
LOG(FATAL) << "Unknown Phase: " << phase();
}
}
template <class Context>
void CuDNNDropoutGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Dropout);
DEPLOY_CUDNN(DropoutGradient);
......
......@@ -69,6 +69,8 @@ class CuDNNEluOp final : public EluOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......@@ -94,6 +96,8 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......
......@@ -24,6 +24,11 @@ void CuDNNEluOp<Context>::DoRunWithType() {
}
template <class Context>
void CuDNNEluOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void CuDNNEluGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -44,6 +49,11 @@ void CuDNNEluGradientOp<Context>::DoRunWithType() {
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
}
template <class Context>
void CuDNNEluGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Elu);
DEPLOY_CUDNN(EluGradient);
......
......@@ -26,11 +26,11 @@ class ReluOp : public Operator<Context> {
max_value_(OpArg<float>("max_value", 0.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
void RunOnDevice() override;
protected:
float alpha_, max_value_;
};
......@@ -44,11 +44,11 @@ class ReluGradientOp : public Operator<Context> {
max_value_(OpArg<float>("max_value", 0.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
void RunOnDevice() override;
protected:
float alpha_, max_value_;
};
......
......@@ -60,6 +60,8 @@ class CuDNNSigmoidOp final : public SigmoidOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......@@ -85,6 +87,8 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......
......@@ -34,6 +34,11 @@ void CuDNNSigmoidOp<Context>::DoRunWithType() {
}
template <class Context>
void CuDNNSigmoidOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void CuDNNSigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -70,6 +75,11 @@ void CuDNNSigmoidGradientOp<Context>::DoRunWithType() {
#endif
}
template <class Context>
void CuDNNSigmoidGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Sigmoid);
DEPLOY_CUDNN(SigmoidGradient);
......
......@@ -56,6 +56,8 @@ class CuDNNSoftmaxOp final : public SoftmaxOp<Context> {
CuDNNDestroyTensorDesc(&input_desc_);
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......@@ -76,6 +78,8 @@ class CuDNNSoftmaxGradientOp final : public SoftmaxGradientOp<Context> {
CuDNNDestroyTensorDesc(&input_desc_);
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......
......@@ -26,6 +26,11 @@ void CuDNNSoftmaxOp<Context>::DoRunWithType() {
}
template <class Context>
void CuDNNSoftmaxOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -48,6 +53,11 @@ void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() {
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
}
template <class Context>
void CuDNNSoftmaxGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Softmax);
DEPLOY_CUDNN(SoftmaxGradient);
......
......@@ -60,6 +60,8 @@ class CuDNNTanhOp final : public TanhOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......@@ -85,6 +87,8 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
......
......@@ -34,6 +34,11 @@ void CuDNNTanhOp<Context>::DoRunWithType() {
}
template <class Context>
void CuDNNTanhOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void CuDNNTanhGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -70,6 +75,11 @@ void CuDNNTanhGradientOp<Context>::DoRunWithType() {
#endif
}
template <class Context>
void CuDNNTanhGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Tanh);
DEPLOY_CUDNN(TanhGradient);
......
......@@ -38,6 +38,7 @@ void Dropout(
const T* x,
uint8_t* mask,
T* y,
uint32_t* scratch,
Context* ctx);
/* activation.drop_block */
......
......@@ -18,4 +18,3 @@ from dragon.vm.tensorflow.core.framework.framework_lib import *
from dragon.vm.tensorflow.core.ops import losses
from dragon.vm.tensorflow.core.ops import nn
from dragon.vm.tensorflow.core.ops.standard_ops import *
from dragon.vm.tensorflow.core.training import training as train
......@@ -222,7 +222,7 @@ class Layer(module.Module):
The optional variable name.
shape : Sequence[int], optional
The variable shape.
dtype : dragon.vm.tensorflow.dtypes.DType, optional
dtype : str, optional
The optional data type.
initializer : Union[callable, str], optional
The optional initializer.
......
......@@ -121,7 +121,6 @@ class Optimizer(updater.Updater):
# Increase the iterations.
self._iterations += 1
return self
def _create_hypers(self):
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy
from dragon.core.framework import workspace
from dragon.core.ops import framework_ops
from dragon.vm.tensorflow.core.framework import ops
class _DecayBase(object):
def __init__(self):
self.kwargs_str = ''
def set(self, tensor, value, dtype=None):
workspace.feed_tensor(tensor, value, dtype=dtype, enforce_cpu=True)
def get(self, tensor):
return workspace.fetch_tensor(tensor)
class _PiecewiseConstant(_DecayBase):
def __init__(self):
super(_PiecewiseConstant, self).__init__()
def setup(self, *args, **kwargs):
arguments = eval(self.kwargs_str)
self.boundaries = arguments['boundaries']
self.values = arguments['values']
def run(self, inputs, outputs):
gs = self.get(inputs[0])
for at in range(len(self.boundaries) - 1, -1, -1):
if gs >= self.boundaries[at]:
self.set(outputs[0], self.values[at + 1], dtype='float32')
return
class _ExponentialDecay(_DecayBase):
def __init__(self):
super(_ExponentialDecay, self).__init__()
def setup(self, *args, **kwargs):
arguments = eval(self.kwargs_str)
self.learning_rate = arguments['learning_rate']
self.decay_steps = arguments['decay_steps']
self.decay_rate = arguments['decay_rate']
self.staircase = arguments['staircase']
def run(self, inputs, outputs):
gs = self.get(inputs[0])
f = gs // self.decay_steps if self.staircase \
else float(gs) / self.decay_steps
new_lr = self.learning_rate * (self.decay_rate ** f)
self.set(outputs[0], new_lr, dtype='float32')
class _NaturalExpDecay(_DecayBase):
def __init__(self):
super(_NaturalExpDecay, self).__init__()
def setup(self, *args, **kwargs):
arguments = eval(self.kwargs_str)
self.learning_rate = arguments['learning_rate']
self.decay_steps = arguments['decay_steps']
self.decay_rate = arguments['decay_rate']
self.staircase = arguments['staircase']
def run(self, inputs, outputs):
gs = self.get(inputs[0])
f = gs // self.decay_steps if self.staircase \
else float(gs) / self.decay_steps
new_lr = self.learning_rate * math.exp(-self.decay_rate * f)
self.set(outputs[0], new_lr, dtype='float32')
class _CosineDecay(_DecayBase):
def __init__(self):
super(_CosineDecay, self).__init__()
def setup(self, *args, **kwargs):
arguments = eval(self.kwargs_str)
self.learning_rate = arguments['learning_rate']
self.decay_steps = arguments['decay_steps']
self.alpha = arguments['alpha']
def run(self, inputs, outputs):
gs = self.get(inputs[0])
global_step = min(gs, self.decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * global_step / self.decay_steps))
decayed = (1. - self.alpha) * cosine_decay + self.alpha
new_lr = self.learning_rate * decayed
self.set(outputs[0], new_lr, dtype='float32')
class _CosineDecayRestarts(_DecayBase):
def __init__(self):
super(_CosineDecayRestarts, self).__init__()
def setup(self, *args, **kwargs):
arguments = eval(self.kwargs_str)
self.learning_rate = arguments['learning_rate']
self.last_steps = 0
self.decay_steps = arguments['first_decay_steps']
self.t_mul, self.m_mul = arguments['t_mul'], arguments['m_mul']
self.alpha = arguments['alpha']
def run(self, inputs, outputs):
gs = self.get(inputs[0])
global_step = gs - self.last_steps
cosine_decay = 0.5 * (1. + math.cos(
math.pi * global_step / self.decay_steps))
decayed = (1. - self.alpha) * cosine_decay + self.alpha
new_lr = self.learning_rate * decayed
# Restarts
if global_step == self.decay_steps:
self.last_steps = gs + 1
self.decay_steps *= self.t_mul
self.learning_rate *= self.m_mul
self.set(outputs[0], new_lr, dtype='float32')
def piecewise_constant(
x,
boundaries,
values,
name=None,
):
if len(values) != len(boundaries) + 1:
raise ValueError('Excepted {} values, got {}.'.format(
len(boundaries) + 1, len(values)))
lr = framework_ops.python_plugin(
inputs=[ops.convert_to_tensor(x)],
module_name=__name__,
class_name='_PiecewiseConstant',
kwargs_str=str({
'boundaries': boundaries,
'values': values,
}),
name=name,
)
lr.set_value(numpy.array(values[0], dtype='float32'))
return lr
def exponential_decay(
learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False,
name=None,
):
lr = framework_ops.python_plugin(
inputs=[ops.convert_to_tensor(global_step)],
module_name=__name__,
class_name='_ExponentialDecay',
kwargs_str=str({
'learning_rate': learning_rate,
'decay_steps': decay_steps,
'decay_rate': decay_rate,
'staircase': staircase,
}),
name=name,
)
lr.set_value(numpy.array(learning_rate, dtype='float32'))
return lr
def natural_exp_decay(
learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False,
name=None,
):
lr = framework_ops.python_plugin(
inputs=[ops.convert_to_tensor(global_step)],
module_name=__name__,
class_name='_NaturalExpDecay',
kwargs_str=str({
'learning_rate': learning_rate,
'decay_steps': decay_steps,
'decay_rate': decay_rate,
'staircase': staircase,
}),
name=name,
)
lr.set_value(numpy.array(learning_rate, dtype='float32'))
return lr
def cosine_decay(
learning_rate,
global_step,
decay_steps,
alpha=0.0,
name=None,
):
lr = framework_ops.python_plugin(
inputs=[ops.convert_to_tensor(global_step)],
module_name=__name__,
class_name='_CosineDecay',
kwargs_str=str({
'learning_rate': learning_rate,
'decay_steps': decay_steps,
'alpha': alpha,
}),
name=name,
)
lr.set_value(numpy.array(learning_rate, dtype='float32'))
return lr
def cosine_decay_restarts(
learning_rate,
global_step,
first_decay_steps,
t_mul=2.0,
m_mul=1.0,
alpha=0.0,
name=None,
):
lr = framework_ops.python_plugin(
inputs=[ops.convert_to_tensor(global_step)],
module_name=__name__,
class_name='_CosineDecayRestarts',
kwargs_str=str({
'learning_rate': learning_rate,
'first_decay_steps': first_decay_steps,
't_mul': t_mul,
'm_mul': m_mul,
'alpha': alpha
}),
name=name,
)
lr.set_value(numpy.array(learning_rate, dtype='float32'))
return lr
# Alias
piecewise_constant_decay = piecewise_constant
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon import updaters
from dragon.core.framework import types
from dragon.core.framework import workspace
from dragon.vm.tensorflow.core.framework import ops
from dragon.vm.tensorflow.core.ops import variables
from dragon.vm.tensorflow.core.ops.gradients_impl import gradients
class Optimizer(object):
def __init__(self, use_locking, name):
if not name:
raise ValueError('Must specify the optimizer name.')
self._use_locking = use_locking
self._name = name
# Store the losses from gradients.
self._targets = []
# Store the external global step.
self._global_step = None
# Store the internal param updater.
self.updater = None
def apply_gradients(self, grads_and_vars, global_step=None):
self._global_step = global_step
grads_and_vars = list(grads_and_vars)
# Firstly, we should extract the potential decays.
l2_decays = []
for grad, var in grads_and_vars:
if hasattr(var, '__regularizer__'):
if var .__regularizer__ and \
var.__regularizer__.l2 > 0:
l2_decays.append(var.__regularizer__.l2)
# Find the base decay factor.
self.updater.l2_decay = \
base_l2_decay = min(l2_decays) \
if len(l2_decays) > 0 else -1.
# Add to targets.
targets = set()
for grad, var in grads_and_vars:
decay_multiplier = 0.
if hasattr(var, '__regularizer__'):
if var.__regularizer__ and \
var.__regularizer__.l2 > 0:
decay_multiplier = \
var.__regularizer__.l2 / base_l2_decay
self.updater.append((var, grad), decay_mult=decay_multiplier)
if var._grad_info is not None:
targets.update(var._grad_info.cost)
self._targets.extend(list(targets))
return self
@classmethod
def compute_gradients(cls, loss, var_list=None):
if var_list is None:
var_list = variables.trainable_variables() + \
ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)
grads = gradients(loss, var_list)
grads_and_vars = list(zip(grads, var_list))
return grads_and_vars
def get_name(self):
return self._name
def minimize(self, loss, global_step=None, var_list=None):
self._global_step = global_step
grads_and_vars = self.compute_gradients(loss, var_list)
return self.apply_gradients(grads_and_vars, global_step)
def _inc_global_step(self):
"""Increase the internal global step."""
if self._global_step is not None:
gs = int(self._global_step)
if types.is_tensor(self._global_step):
workspace.feed_tensor(
self._global_step, gs + 1,
enforce_cpu=True,
)
else:
self._global_step += 1
def _set_updater(self, cls, learning_rate, *args, **kwargs):
"""Set the updater and learning rate."""
if types.is_tensor(learning_rate):
base_lr = float(learning_rate)
self.updater = cls(base_lr, *args, **kwargs)
slot_lr = self.updater._slot + '/base_lr'
workspace.set_tensor_alias(learning_rate, slot_lr)
if types.is_symbolic_tensor(learning_rate):
self._targets.append(learning_rate)
else:
self.updater = cls(learning_rate, *args, **kwargs)
class GradientDescentOptimizer(Optimizer):
def __init__(
self,
learning_rate,
use_locking=False,
name='GradientDescent',
):
super(GradientDescentOptimizer, self).__init__(use_locking, name)
self._set_updater(updaters.SGD, learning_rate, momentum=0.,)
class MomentumOptimizer(Optimizer):
def __init__(
self,
learning_rate,
momentum,
use_locking=False,
name='Momentum',
use_nesterov=False,
):
super(MomentumOptimizer, self).__init__(use_locking, name)
cls = updaters.Nesterov if use_nesterov else updaters.SGD
self._set_updater(cls, learning_rate, momentum=momentum)
class AdamOptimizer(Optimizer):
def __init__(
self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
use_locking=False,
name='Adam',
):
super(AdamOptimizer, self).__init__(use_locking, name)
self._set_updater(
updaters.Adam,
learning_rate,
beta1=beta1,
beta2=beta2,
eps=epsilon,
)
class RMSPropOptimizer(Optimizer):
def __init__(
self,
learning_rate,
decay=0.9,
momentum=0.0,
epsilon=1e-10,
use_locking=False,
centered=False,
name='RMSProp',
):
super(RMSPropOptimizer, self).__init__(use_locking, name)
self._set_updater(
updaters.RMSProp,
learning_rate,
momentum=momentum,
decay=decay,
eps=epsilon,
)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# flake8: noqa
from dragon.vm.tensorflow.core.training.learning_rate_decay import cosine_decay
from dragon.vm.tensorflow.core.training.learning_rate_decay import cosine_decay_restarts
from dragon.vm.tensorflow.core.training.learning_rate_decay import exponential_decay
from dragon.vm.tensorflow.core.training.learning_rate_decay import natural_exp_decay
from dragon.vm.tensorflow.core.training.learning_rate_decay import piecewise_constant
from dragon.vm.tensorflow.core.training.learning_rate_decay import piecewise_constant_decay
from dragon.vm.tensorflow.core.training.optimizer import GradientDescentOptimizer
from dragon.vm.tensorflow.core.training.optimizer import MomentumOptimizer
from dragon.vm.tensorflow.core.training.optimizer import RMSPropOptimizer
from dragon.vm.tensorflow.core.training.optimizer import AdamOptimizer
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!