Commit 5c8da7f9 by Ting PAN

add nesterov updater

1 parent 4e937b6c
......@@ -8,13 +8,13 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.8.0)
# ---------------- User Config ----------------
# set optional libraries
option(WITH_CUDA "Set to ON use CUDA" ON)
option(WITH_CUDNN "Set to ON use CUDNN" OFF)
option(WITH_BLAS "Set to ON to use BLAS" OFF)
option(WITH_SSE "Set to ON to use SSE 4.1" ON)
option(WITH_MPI "Set to ON to use MPI" OFF)
option(WITH_MPI_CUDA "Set to ON to use MPI_CUDA_AWARE" OFF)
option(WITH_CUDA_FP16 "Set to ON to use FP16" ON)
option(WITH_CUDA "Set ON to use CUDA" ON)
option(WITH_CUDNN "Set ON to use CUDNN" OFF)
option(WITH_BLAS "Set ON to use BLAS" OFF)
option(WITH_SSE "Set ON to use SSE 4.1" ON)
option(WITH_MPI "Set ON to use MPI" OFF)
option(WITH_MPI_CUDA "Set ON to use MPI_CUDA_AWARE" OFF)
option(WITH_CUDA_FP16 "Set ON to use FP16" ON)
# set your 3rdparty
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
......
......@@ -24,9 +24,10 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
void ComputeRunWithFloat() override;
protected:
unique_ptr<Tensor> m, v, tmp;
float lr, beta1, beta2, eps, coeff;
int t;
unique_ptr<Tensor> m, v;
Tensor temp;
};
} // namespace dragon
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#include "operators/update/update_op_base.h"
namespace dragon {
template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
momentum(param("momentum")) {}
void ComputeRunWithFloat() override;
protected:
float lr, momentum;
unique_ptr<Tensor> history;
Tensor temp;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
\ No newline at end of file
......@@ -24,7 +24,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
protected:
float lr, decay, eps;
unique_ptr<Tensor> history;
Tensor buffer;
Tensor temp;
};
} // namespace dragon
......
......@@ -382,13 +382,24 @@ void AdamUpdate(Tensor* x,
const float eps,
const float lr);
/******************** update.nesterov_update ********************/
template <typename T, class Context>
void NesterovUpdate(const int count,
T* x,
T* h,
Tensor* t,
const float momentum,
const float lr,
Context* ctx);
/******************** update.rmsprop_update ********************/
template <typename T, class Context>
void RMSPropUpdate(const int count,
T* x,
T* h,
Tensor* t_buffer,
Tensor* t,
const float decay,
const float eps,
const float lr);
......
......@@ -92,7 +92,7 @@ def Pool2D(inputs, kernel_size, stride, pad=0, mode='MAX_POOLING', **kwargs):
:param kernel_size: a tuple or a int of the kernel size
:param stride: a tuple or a int of the stride size
:param pad: a tuple or a int of the zero-padding size
:param way: a string of 'MAX_POOLING' or 'AVG_POOLING'
:param mode: a string of 'MAX_POOLING' or 'AVG_POOLING'
:return: a 3D or 4D Tensor of the pooled output
"""
......
......@@ -63,6 +63,16 @@ class SGDUpdater(Updater):
self.echo()
class NesterovUpdater(Updater):
def __init__(self, base_lr=0.01, momentum=0.9, **kwargs):
super(NesterovUpdater, self).__init__(**kwargs)
self._hyper_params = dict({'base_lr': base_lr,
'momentum': momentum},
**self._hyper_params)
self._type = 'NesterovUpdate'
self.echo()
class RMSPropUpdater(Updater):
def __init__(self, base_lr=0.01, decay=0.9, eps=1e-8, **kwargs):
super(RMSPropUpdater, self).__init__(**kwargs)
......
......@@ -4,7 +4,7 @@
# Written by Ting Pan
# --------------------------------------------------------
from .solver import SGDSolver, RMSPropSolver, AdamSolver
from .solver import SGDSolver, NesterovSolver, RMSPropSolver, AdamSolver
from .net import Net, PartialNet
from .common import set_mode_cpu, set_mode_gpu, set_device, set_random_seed, \
root_solver, set_root_solver
......
......@@ -241,6 +241,32 @@ class SGDSolver(Solver):
if self._param.HasField(param):
self._update_param[param] = getattr(self._param, param)
class NesterovSolver(Solver):
def __init__(self, prototxt):
super(NesterovSolver, self).__init__(prototxt=prototxt)
self._updater = updaters.NesterovUpdater(**self._update_param)
# generates update targets
for layer, blobs in self._net.params.iteritems(): self._lr_blobs.extend(blobs)
for idx, blob in enumerate(self._lr_blobs):
if self._net._lr_mults[idx] > 0:
if blob.diff is None: continue
self._updater.append((blob.data, blob.diff),
self._net._lr_mults[idx], self._net._decay_mults[idx])
self.train = self._net.function
self.tests = [test_net.function for test_net in self._test_nets]
self.update = function(updater=self._updater)
def CheckUpdateParam(self):
super(NesterovSolver, self).CheckUpdateParam()
params = ['base_lr', 'momentum']
for param in params:
if self._param.HasField(param):
self._update_param[param] = getattr(self._param, param)
class RMSPropSolver(Solver):
def __init__(self, prototxt):
super(RMSPropSolver, self).__init__(prototxt=prototxt)
......@@ -264,6 +290,7 @@ class RMSPropSolver(Solver):
self._update_param['decay'] = self._param.rms_decay
self._update_param['eps'] = self._param.delta
class AdamSolver(Solver):
def __init__(self, prototxt):
super(AdamSolver, self).__init__(prototxt=prototxt)
......
......@@ -92,7 +92,7 @@ def conv2d(input, filter, strides, pads=(0, 0, 0, 0),
if data_format == 'NCHW':
output = ops.Conv2D([input, filter],
num_output=filter.shape[0],
kernel=filter.shape[2:],
kernel_size=filter.shape[2:],
stride=strides[2:],
pad=pads[2:])
return output
......@@ -127,10 +127,10 @@ def avg_pool(value, ksize, strides, pads=(0, 0, 0, 0),
if data_format == 'NCHW':
if pads is None: pads = 0
return ops.Pool2D(value,
kernel=ksize[2:],
kernel_size=ksize[2:],
stride=strides[2:],
pad=pads,
way='AVE')
mode='AVG_POOLING')
else: raise NotImplementedError()
......@@ -162,10 +162,10 @@ def max_pool(value, ksize, strides, pads=(0, 0, 0, 0),
if data_format == 'NCHW':
if pads is None: pads = 0
return ops.Pool2D(value,
kernel=ksize[2:],
kernel_size=ksize[2:],
stride=strides[2:],
pad=pads,
way='MAX')
mode='MAX_POOLING')
else: raise NotImplementedError()
......
......@@ -8,13 +8,18 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat() {
if (!m.get()) {
m.reset(new Tensor()); m->ReshapeLike(input(0));
v.reset(new Tensor()); v->ReshapeLike(input(0));
tmp.reset(new Tensor()); tmp->ReshapeLike(input(0));
}
t++;
coeff = sqrt(1. - pow(beta2, t)) / (1. - pow(beta1, t));
lr = param("base_lr") * coeff * this->lr_mult;
kernel::AdamUpdate<float, Context>(&input(0), m.get(), v.get(), tmp.get(),
beta1, beta2, eps, lr);
kernel::AdamUpdate<float, Context>(&input(0),
m.get(),
v.get(),
&temp,
beta1,
beta2,
eps,
lr);
}
DEPLOY_CPU(AdamUpdate);
......
#include "operators/update/nesterov_update_op.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context>
void NesterovUpdateOp<Context>::ComputeRunWithFloat() {
if (!history.get()) {
history.reset(new Tensor());
history->ReshapeLike(input(0));
}
lr = param("base_lr") * this->lr_mult;
auto* dXdata = input(0).template mutable_data<float, Context>();
auto* Hdata = history->template mutable_data<float, Context>();
kernel::NesterovUpdate<float, Context>(input(0).count(),
dXdata,
Hdata,
&temp,
momentum,
lr,
&ctx());
}
DEPLOY_CPU(NesterovUpdate);
#ifdef WITH_CUDA
DEPLOY_CUDA(NesterovUpdate);
#endif
OPERATOR_SCHEMA(NesterovUpdate).NumInputs(1).NumOutputs(1);
NO_GRADIENT(NesterovUpdate);
} // namespace dragon
\ No newline at end of file
......@@ -15,8 +15,13 @@ void RMSPropUpdateOp<Context>::ComputeRunWithFloat() {
lr = param("base_lr") * this->lr_mult;
auto* dXdata = input(0).template mutable_data<float, Context>();
auto* Hdata = history->template mutable_data<float, Context>();
kernel::RMSPropUpdate<float, Context>(input(0).count(), dXdata, Hdata,
&buffer, decay, eps, lr);
kernel::RMSPropUpdate<float, Context>(input(0).count(),
dXdata,
Hdata,
&temp,
decay,
eps,
lr);
}
DEPLOY_CPU(RMSPropUpdate);
......
......@@ -895,7 +895,36 @@ template <> void AdamUpdate<float, CPUContext>(Tensor* x,
const float beta2,
const float eps,
const float lr) {
NOT_IMPLEMENTED;
TIndex count = x->count();
t->Reshape(vector<TIndex>(1, count));
auto* Xdata = x->mutable_data<float, CPUContext>();
auto* Mdata = m->mutable_data<float, CPUContext>();
auto* Vdata = v->mutable_data<float, CPUContext>();
auto* Tdata = t->mutable_data<float, CPUContext>();
math::Axpby<float, CPUContext>(count, 1.0 - beta1, Xdata, beta1, Mdata);
math::Mul<float, CPUContext>(count, Xdata, Xdata, Tdata);
math::Axpby<float, CPUContext>(count, 1.0 - beta2, Tdata, beta2, Vdata);
math::Sqrt<float, CPUContext>(count, Vdata, Tdata);
math::AddScalar<float, CPUContext>(count, eps, Tdata);
math::Div<float, CPUContext>(count, Mdata, Tdata, Tdata);
math::Scale<float, CPUContext>(count, lr, Tdata, Xdata);
}
/******************** update.nesterov_update ********************/
template <> void NesterovUpdate<float, CPUContext>(const int count,
float* x,
float* h,
Tensor* t,
const float momentum,
const float lr,
CPUContext* ctx) {
t->Reshape(vector<TIndex>(1, count));
float* Tdata = t->mutable_data<float, CPUContext>();
ctx->Copy<float, CPUContext, CPUContext>(count, Tdata, h);
math::Axpby<float, CPUContext>(count, lr, x, momentum, h);
math::Axpby<float, CPUContext>(count, 1.0 + momentum, h, -momentum, Tdata);
ctx->Copy<float, CPUContext, CPUContext>(count, x, Tdata);
}
/******************** update.rmsprop_update ********************/
......@@ -903,18 +932,18 @@ template <> void AdamUpdate<float, CPUContext>(Tensor* x,
template <> void RMSPropUpdate<float, CPUContext>(const int count,
float* x,
float* h,
Tensor* t_buffer,
Tensor* t,
const float decay,
const float eps,
const float lr) {
t_buffer->Reshape(vector<TIndex>(1, count));
float* buffer = t_buffer->mutable_data<float, CPUContext>();
math::Square<float, CPUContext>(count, x, buffer);
math::Axpby<float, CPUContext>(count, 1.0 - decay, buffer, decay, h);
math::Sqrt<float, CPUContext>(count, h, buffer);
math::AddScalar<float, CPUContext>(count, eps, buffer);
math::Div<float, CPUContext>(count, x, buffer, buffer);
math::Axpby<float, CPUContext>(count, lr, buffer, 0.0, x);
t->Reshape(vector<TIndex>(1, count));
float* Tdata = t->mutable_data<float, CPUContext>();
math::Square<float, CPUContext>(count, x, Tdata);
math::Axpby<float, CPUContext>(count, 1.0 - decay, Tdata, decay, h);
math::Sqrt<float, CPUContext>(count, h, Tdata);
math::AddScalar<float, CPUContext>(count, eps, Tdata);
math::Div<float, CPUContext>(count, x, Tdata, Tdata);
math::Axpby<float, CPUContext>(count, lr, Tdata, 0.0, x);
}
/******************** utils.compare ********************/
......
......@@ -1647,7 +1647,7 @@ template <> void AdamUpdate<float, CUDAContext>(Tensor* x,
const float beta2,
const float eps,
const float lr) {
const int count = x->count();
TIndex count = x->count();
auto* Xdata = x->mutable_data<float, CUDAContext>();
auto* Mdata = m->mutable_data<float, CUDAContext>();
auto* Vdata = v->mutable_data<float, CUDAContext>();
......@@ -1662,6 +1662,35 @@ template <> void AdamUpdate<float, CUDAContext>(Tensor* x,
CUDA_POST_KERNEL_CHECK;
}
/******************** update.nesterov_update ********************/
template <typename T>
__global__ void _NesterovUpdate(const int n,
T* g,
T* h,
const T momentum,
const T lr) {
CUDA_KERNEL_LOOP(i, n) {
T hi = h[i];
T hi_new = h[i] = momentum * hi + lr * g[i];
g[i] = (1 + momentum) * hi_new - momentum * hi;
}
}
template <> void NesterovUpdate<float, CUDAContext>(const int count,
float* x,
float* h,
Tensor* t,
const float momentum,
const float lr,
CUDAContext* ctx) {
_NesterovUpdate<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
x,
h,
momentum,
lr);
CUDA_POST_KERNEL_CHECK;
}
/******************** update.rmsprop_update ********************/
template <typename T>
......@@ -1681,7 +1710,7 @@ __global__ void _RMSPropUpdate(const int n,
template <> void RMSPropUpdate<float, CUDAContext>(const int count,
float* x,
float* h,
Tensor* t_buffer,
Tensor* t,
const float decay,
const float eps,
const float lr) {
......
......@@ -104,8 +104,30 @@
8. Deploy
- Install Dragon
```Shell
cd Dragon
python setup.py install
```
``Hint``: If you do not have permission, try as follows:
```Shell
cd Dragon
python setup.py install --user
```
- Install protobuf
```Shell
pip install protobuf
```
- Install lmdb
```Shell
python Dragon/setup.py install
pip install lmdb
```
----
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!