Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit 5c8da7f9
authored
Jul 31, 2017
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add nesterov updater
1 parent
4e937b6c
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
240 additions
and
35 deletions
Dragon/CMakeLists.txt
Dragon/include/operators/update/adam_update_op.h
Dragon/include/operators/update/nesterov_update_op.h
Dragon/include/operators/update/rmsprop_update_op.h
Dragon/include/utils/op_kernel.h
Dragon/python/operators/vision.py
Dragon/python/updaters.py
Dragon/python/vm/caffe/__init__.py
Dragon/python/vm/caffe/proto/caffe_pb2.py
Dragon/python/vm/caffe/solver.py
Dragon/python/vm/tensorflow/ops/nn_ops.py
Dragon/src/operators/update/adam_update_op.cc
Dragon/src/operators/update/nesterov_update_op.cc
Dragon/src/operators/update/rmsprop_update_op.cc
Dragon/src/utils/op_kernel.cc
Dragon/src/utils/op_kernel.cu
README.md
Dragon/CMakeLists.txt
View file @
5c8da7f
...
@@ -8,13 +8,13 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.8.0)
...
@@ -8,13 +8,13 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.8.0)
# ---------------- User Config ----------------
# ---------------- User Config ----------------
# set optional libraries
# set optional libraries
option
(
WITH_CUDA
"Set
to ON
use CUDA"
ON
)
option
(
WITH_CUDA
"Set
ON to
use CUDA"
ON
)
option
(
WITH_CUDNN
"Set
to ON
use CUDNN"
OFF
)
option
(
WITH_CUDNN
"Set
ON to
use CUDNN"
OFF
)
option
(
WITH_BLAS
"Set
to
ON to use BLAS"
OFF
)
option
(
WITH_BLAS
"Set ON to use BLAS"
OFF
)
option
(
WITH_SSE
"Set
to
ON to use SSE 4.1"
ON
)
option
(
WITH_SSE
"Set ON to use SSE 4.1"
ON
)
option
(
WITH_MPI
"Set
to
ON to use MPI"
OFF
)
option
(
WITH_MPI
"Set ON to use MPI"
OFF
)
option
(
WITH_MPI_CUDA
"Set
to
ON to use MPI_CUDA_AWARE"
OFF
)
option
(
WITH_MPI_CUDA
"Set ON to use MPI_CUDA_AWARE"
OFF
)
option
(
WITH_CUDA_FP16
"Set
to
ON to use FP16"
ON
)
option
(
WITH_CUDA_FP16
"Set ON to use FP16"
ON
)
# set your 3rdparty
# set your 3rdparty
set
(
3RDPARTY_DIR
${
PROJECT_SOURCE_DIR
}
/../3rdparty
)
set
(
3RDPARTY_DIR
${
PROJECT_SOURCE_DIR
}
/../3rdparty
)
...
...
Dragon/include/operators/update/adam_update_op.h
View file @
5c8da7f
...
@@ -24,9 +24,10 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
...
@@ -24,9 +24,10 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat
()
override
;
protected
:
protected
:
unique_ptr
<
Tensor
>
m
,
v
,
tmp
;
float
lr
,
beta1
,
beta2
,
eps
,
coeff
;
float
lr
,
beta1
,
beta2
,
eps
,
coeff
;
int
t
;
int
t
;
unique_ptr
<
Tensor
>
m
,
v
;
Tensor
temp
;
};
};
}
// namespace dragon
}
// namespace dragon
...
...
Dragon/include/operators/update/nesterov_update_op.h
0 → 100644
View file @
5c8da7f
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#include "operators/update/update_op_base.h"
namespace
dragon
{
template
<
class
Context
>
class
NesterovUpdateOp
final
:
public
UpdateOpBase
<
Context
>
{
public
:
NesterovUpdateOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
UpdateOpBase
<
Context
>
(
op_def
,
ws
),
momentum
(
param
(
"momentum"
))
{}
void
ComputeRunWithFloat
()
override
;
protected
:
float
lr
,
momentum
;
unique_ptr
<
Tensor
>
history
;
Tensor
temp
;
};
}
// namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
\ No newline at end of file
Dragon/include/operators/update/rmsprop_update_op.h
View file @
5c8da7f
...
@@ -24,7 +24,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
...
@@ -24,7 +24,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
protected
:
protected
:
float
lr
,
decay
,
eps
;
float
lr
,
decay
,
eps
;
unique_ptr
<
Tensor
>
history
;
unique_ptr
<
Tensor
>
history
;
Tensor
buffer
;
Tensor
temp
;
};
};
}
// namespace dragon
}
// namespace dragon
...
...
Dragon/include/utils/op_kernel.h
View file @
5c8da7f
...
@@ -382,13 +382,24 @@ void AdamUpdate(Tensor* x,
...
@@ -382,13 +382,24 @@ void AdamUpdate(Tensor* x,
const
float
eps
,
const
float
eps
,
const
float
lr
);
const
float
lr
);
/******************** update.nesterov_update ********************/
template
<
typename
T
,
class
Context
>
void
NesterovUpdate
(
const
int
count
,
T
*
x
,
T
*
h
,
Tensor
*
t
,
const
float
momentum
,
const
float
lr
,
Context
*
ctx
);
/******************** update.rmsprop_update ********************/
/******************** update.rmsprop_update ********************/
template
<
typename
T
,
class
Context
>
template
<
typename
T
,
class
Context
>
void
RMSPropUpdate
(
const
int
count
,
void
RMSPropUpdate
(
const
int
count
,
T
*
x
,
T
*
x
,
T
*
h
,
T
*
h
,
Tensor
*
t
_buffer
,
Tensor
*
t
,
const
float
decay
,
const
float
decay
,
const
float
eps
,
const
float
eps
,
const
float
lr
);
const
float
lr
);
...
...
Dragon/python/operators/vision.py
View file @
5c8da7f
...
@@ -92,7 +92,7 @@ def Pool2D(inputs, kernel_size, stride, pad=0, mode='MAX_POOLING', **kwargs):
...
@@ -92,7 +92,7 @@ def Pool2D(inputs, kernel_size, stride, pad=0, mode='MAX_POOLING', **kwargs):
:param kernel_size: a tuple or a int of the kernel size
:param kernel_size: a tuple or a int of the kernel size
:param stride: a tuple or a int of the stride size
:param stride: a tuple or a int of the stride size
:param pad: a tuple or a int of the zero-padding size
:param pad: a tuple or a int of the zero-padding size
:param
way:
a string of 'MAX_POOLING' or 'AVG_POOLING'
:param
mode:
a string of 'MAX_POOLING' or 'AVG_POOLING'
:return: a 3D or 4D Tensor of the pooled output
:return: a 3D or 4D Tensor of the pooled output
"""
"""
...
...
Dragon/python/updaters.py
View file @
5c8da7f
...
@@ -63,6 +63,16 @@ class SGDUpdater(Updater):
...
@@ -63,6 +63,16 @@ class SGDUpdater(Updater):
self
.
echo
()
self
.
echo
()
class
NesterovUpdater
(
Updater
):
def
__init__
(
self
,
base_lr
=
0.01
,
momentum
=
0.9
,
**
kwargs
):
super
(
NesterovUpdater
,
self
)
.
__init__
(
**
kwargs
)
self
.
_hyper_params
=
dict
({
'base_lr'
:
base_lr
,
'momentum'
:
momentum
},
**
self
.
_hyper_params
)
self
.
_type
=
'NesterovUpdate'
self
.
echo
()
class
RMSPropUpdater
(
Updater
):
class
RMSPropUpdater
(
Updater
):
def
__init__
(
self
,
base_lr
=
0.01
,
decay
=
0.9
,
eps
=
1e-8
,
**
kwargs
):
def
__init__
(
self
,
base_lr
=
0.01
,
decay
=
0.9
,
eps
=
1e-8
,
**
kwargs
):
super
(
RMSPropUpdater
,
self
)
.
__init__
(
**
kwargs
)
super
(
RMSPropUpdater
,
self
)
.
__init__
(
**
kwargs
)
...
...
Dragon/python/vm/caffe/__init__.py
View file @
5c8da7f
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
# Written by Ting Pan
# Written by Ting Pan
# --------------------------------------------------------
# --------------------------------------------------------
from
.solver
import
SGDSolver
,
RMSPropSolver
,
AdamSolver
from
.solver
import
SGDSolver
,
NesterovSolver
,
RMSPropSolver
,
AdamSolver
from
.net
import
Net
,
PartialNet
from
.net
import
Net
,
PartialNet
from
.common
import
set_mode_cpu
,
set_mode_gpu
,
set_device
,
set_random_seed
,
\
from
.common
import
set_mode_cpu
,
set_mode_gpu
,
set_device
,
set_random_seed
,
\
root_solver
,
set_root_solver
root_solver
,
set_root_solver
...
...
Dragon/python/vm/caffe/proto/caffe_pb2.py
View file @
5c8da7f
This diff is collapsed.
Click to expand it.
Dragon/python/vm/caffe/solver.py
View file @
5c8da7f
...
@@ -241,6 +241,32 @@ class SGDSolver(Solver):
...
@@ -241,6 +241,32 @@ class SGDSolver(Solver):
if
self
.
_param
.
HasField
(
param
):
if
self
.
_param
.
HasField
(
param
):
self
.
_update_param
[
param
]
=
getattr
(
self
.
_param
,
param
)
self
.
_update_param
[
param
]
=
getattr
(
self
.
_param
,
param
)
class
NesterovSolver
(
Solver
):
def
__init__
(
self
,
prototxt
):
super
(
NesterovSolver
,
self
)
.
__init__
(
prototxt
=
prototxt
)
self
.
_updater
=
updaters
.
NesterovUpdater
(
**
self
.
_update_param
)
# generates update targets
for
layer
,
blobs
in
self
.
_net
.
params
.
iteritems
():
self
.
_lr_blobs
.
extend
(
blobs
)
for
idx
,
blob
in
enumerate
(
self
.
_lr_blobs
):
if
self
.
_net
.
_lr_mults
[
idx
]
>
0
:
if
blob
.
diff
is
None
:
continue
self
.
_updater
.
append
((
blob
.
data
,
blob
.
diff
),
self
.
_net
.
_lr_mults
[
idx
],
self
.
_net
.
_decay_mults
[
idx
])
self
.
train
=
self
.
_net
.
function
self
.
tests
=
[
test_net
.
function
for
test_net
in
self
.
_test_nets
]
self
.
update
=
function
(
updater
=
self
.
_updater
)
def
CheckUpdateParam
(
self
):
super
(
NesterovSolver
,
self
)
.
CheckUpdateParam
()
params
=
[
'base_lr'
,
'momentum'
]
for
param
in
params
:
if
self
.
_param
.
HasField
(
param
):
self
.
_update_param
[
param
]
=
getattr
(
self
.
_param
,
param
)
class
RMSPropSolver
(
Solver
):
class
RMSPropSolver
(
Solver
):
def
__init__
(
self
,
prototxt
):
def
__init__
(
self
,
prototxt
):
super
(
RMSPropSolver
,
self
)
.
__init__
(
prototxt
=
prototxt
)
super
(
RMSPropSolver
,
self
)
.
__init__
(
prototxt
=
prototxt
)
...
@@ -264,6 +290,7 @@ class RMSPropSolver(Solver):
...
@@ -264,6 +290,7 @@ class RMSPropSolver(Solver):
self
.
_update_param
[
'decay'
]
=
self
.
_param
.
rms_decay
self
.
_update_param
[
'decay'
]
=
self
.
_param
.
rms_decay
self
.
_update_param
[
'eps'
]
=
self
.
_param
.
delta
self
.
_update_param
[
'eps'
]
=
self
.
_param
.
delta
class
AdamSolver
(
Solver
):
class
AdamSolver
(
Solver
):
def
__init__
(
self
,
prototxt
):
def
__init__
(
self
,
prototxt
):
super
(
AdamSolver
,
self
)
.
__init__
(
prototxt
=
prototxt
)
super
(
AdamSolver
,
self
)
.
__init__
(
prototxt
=
prototxt
)
...
...
Dragon/python/vm/tensorflow/ops/nn_ops.py
View file @
5c8da7f
...
@@ -92,7 +92,7 @@ def conv2d(input, filter, strides, pads=(0, 0, 0, 0),
...
@@ -92,7 +92,7 @@ def conv2d(input, filter, strides, pads=(0, 0, 0, 0),
if
data_format
==
'NCHW'
:
if
data_format
==
'NCHW'
:
output
=
ops
.
Conv2D
([
input
,
filter
],
output
=
ops
.
Conv2D
([
input
,
filter
],
num_output
=
filter
.
shape
[
0
],
num_output
=
filter
.
shape
[
0
],
kernel
=
filter
.
shape
[
2
:],
kernel
_size
=
filter
.
shape
[
2
:],
stride
=
strides
[
2
:],
stride
=
strides
[
2
:],
pad
=
pads
[
2
:])
pad
=
pads
[
2
:])
return
output
return
output
...
@@ -127,10 +127,10 @@ def avg_pool(value, ksize, strides, pads=(0, 0, 0, 0),
...
@@ -127,10 +127,10 @@ def avg_pool(value, ksize, strides, pads=(0, 0, 0, 0),
if
data_format
==
'NCHW'
:
if
data_format
==
'NCHW'
:
if
pads
is
None
:
pads
=
0
if
pads
is
None
:
pads
=
0
return
ops
.
Pool2D
(
value
,
return
ops
.
Pool2D
(
value
,
kernel
=
ksize
[
2
:],
kernel
_size
=
ksize
[
2
:],
stride
=
strides
[
2
:],
stride
=
strides
[
2
:],
pad
=
pads
,
pad
=
pads
,
way
=
'AVE
'
)
mode
=
'AVG_POOLING
'
)
else
:
raise
NotImplementedError
()
else
:
raise
NotImplementedError
()
...
@@ -162,10 +162,10 @@ def max_pool(value, ksize, strides, pads=(0, 0, 0, 0),
...
@@ -162,10 +162,10 @@ def max_pool(value, ksize, strides, pads=(0, 0, 0, 0),
if
data_format
==
'NCHW'
:
if
data_format
==
'NCHW'
:
if
pads
is
None
:
pads
=
0
if
pads
is
None
:
pads
=
0
return
ops
.
Pool2D
(
value
,
return
ops
.
Pool2D
(
value
,
kernel
=
ksize
[
2
:],
kernel
_size
=
ksize
[
2
:],
stride
=
strides
[
2
:],
stride
=
strides
[
2
:],
pad
=
pads
,
pad
=
pads
,
way
=
'MAX
'
)
mode
=
'MAX_POOLING
'
)
else
:
raise
NotImplementedError
()
else
:
raise
NotImplementedError
()
...
...
Dragon/src/operators/update/adam_update_op.cc
View file @
5c8da7f
...
@@ -8,13 +8,18 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat() {
...
@@ -8,13 +8,18 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat() {
if
(
!
m
.
get
())
{
if
(
!
m
.
get
())
{
m
.
reset
(
new
Tensor
());
m
->
ReshapeLike
(
input
(
0
));
m
.
reset
(
new
Tensor
());
m
->
ReshapeLike
(
input
(
0
));
v
.
reset
(
new
Tensor
());
v
->
ReshapeLike
(
input
(
0
));
v
.
reset
(
new
Tensor
());
v
->
ReshapeLike
(
input
(
0
));
tmp
.
reset
(
new
Tensor
());
tmp
->
ReshapeLike
(
input
(
0
));
}
}
t
++
;
t
++
;
coeff
=
sqrt
(
1.
-
pow
(
beta2
,
t
))
/
(
1.
-
pow
(
beta1
,
t
));
coeff
=
sqrt
(
1.
-
pow
(
beta2
,
t
))
/
(
1.
-
pow
(
beta1
,
t
));
lr
=
param
(
"base_lr"
)
*
coeff
*
this
->
lr_mult
;
lr
=
param
(
"base_lr"
)
*
coeff
*
this
->
lr_mult
;
kernel
::
AdamUpdate
<
float
,
Context
>
(
&
input
(
0
),
m
.
get
(),
v
.
get
(),
tmp
.
get
(),
kernel
::
AdamUpdate
<
float
,
Context
>
(
&
input
(
0
),
beta1
,
beta2
,
eps
,
lr
);
m
.
get
(),
v
.
get
(),
&
temp
,
beta1
,
beta2
,
eps
,
lr
);
}
}
DEPLOY_CPU
(
AdamUpdate
);
DEPLOY_CPU
(
AdamUpdate
);
...
...
Dragon/src/operators/update/nesterov_update_op.cc
0 → 100644
View file @
5c8da7f
#include "operators/update/nesterov_update_op.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace
dragon
{
template
<
class
Context
>
void
NesterovUpdateOp
<
Context
>::
ComputeRunWithFloat
()
{
if
(
!
history
.
get
())
{
history
.
reset
(
new
Tensor
());
history
->
ReshapeLike
(
input
(
0
));
}
lr
=
param
(
"base_lr"
)
*
this
->
lr_mult
;
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
float
,
Context
>
();
auto
*
Hdata
=
history
->
template
mutable_data
<
float
,
Context
>
();
kernel
::
NesterovUpdate
<
float
,
Context
>
(
input
(
0
).
count
(),
dXdata
,
Hdata
,
&
temp
,
momentum
,
lr
,
&
ctx
());
}
DEPLOY_CPU
(
NesterovUpdate
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
NesterovUpdate
);
#endif
OPERATOR_SCHEMA
(
NesterovUpdate
).
NumInputs
(
1
).
NumOutputs
(
1
);
NO_GRADIENT
(
NesterovUpdate
);
}
//
namespace
dragon
\ No newline at end of file
Dragon/src/operators/update/rmsprop_update_op.cc
View file @
5c8da7f
...
@@ -15,8 +15,13 @@ void RMSPropUpdateOp<Context>::ComputeRunWithFloat() {
...
@@ -15,8 +15,13 @@ void RMSPropUpdateOp<Context>::ComputeRunWithFloat() {
lr
=
param
(
"base_lr"
)
*
this
->
lr_mult
;
lr
=
param
(
"base_lr"
)
*
this
->
lr_mult
;
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
float
,
Context
>
();
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
float
,
Context
>
();
auto
*
Hdata
=
history
->
template
mutable_data
<
float
,
Context
>
();
auto
*
Hdata
=
history
->
template
mutable_data
<
float
,
Context
>
();
kernel
::
RMSPropUpdate
<
float
,
Context
>
(
input
(
0
).
count
(),
dXdata
,
Hdata
,
kernel
::
RMSPropUpdate
<
float
,
Context
>
(
input
(
0
).
count
(),
&
buffer
,
decay
,
eps
,
lr
);
dXdata
,
Hdata
,
&
temp
,
decay
,
eps
,
lr
);
}
}
DEPLOY_CPU
(
RMSPropUpdate
);
DEPLOY_CPU
(
RMSPropUpdate
);
...
...
Dragon/src/utils/op_kernel.cc
View file @
5c8da7f
...
@@ -895,7 +895,36 @@ template <> void AdamUpdate<float, CPUContext>(Tensor* x,
...
@@ -895,7 +895,36 @@ template <> void AdamUpdate<float, CPUContext>(Tensor* x,
const
float
beta2
,
const
float
beta2
,
const
float
eps
,
const
float
eps
,
const
float
lr
)
{
const
float
lr
)
{
NOT_IMPLEMENTED
;
TIndex
count
=
x
->
count
();
t
->
Reshape
(
vector
<
TIndex
>
(
1
,
count
));
auto
*
Xdata
=
x
->
mutable_data
<
float
,
CPUContext
>
();
auto
*
Mdata
=
m
->
mutable_data
<
float
,
CPUContext
>
();
auto
*
Vdata
=
v
->
mutable_data
<
float
,
CPUContext
>
();
auto
*
Tdata
=
t
->
mutable_data
<
float
,
CPUContext
>
();
math
::
Axpby
<
float
,
CPUContext
>
(
count
,
1.0
-
beta1
,
Xdata
,
beta1
,
Mdata
);
math
::
Mul
<
float
,
CPUContext
>
(
count
,
Xdata
,
Xdata
,
Tdata
);
math
::
Axpby
<
float
,
CPUContext
>
(
count
,
1.0
-
beta2
,
Tdata
,
beta2
,
Vdata
);
math
::
Sqrt
<
float
,
CPUContext
>
(
count
,
Vdata
,
Tdata
);
math
::
AddScalar
<
float
,
CPUContext
>
(
count
,
eps
,
Tdata
);
math
::
Div
<
float
,
CPUContext
>
(
count
,
Mdata
,
Tdata
,
Tdata
);
math
::
Scale
<
float
,
CPUContext
>
(
count
,
lr
,
Tdata
,
Xdata
);
}
/******************** update.nesterov_update ********************/
template
<>
void
NesterovUpdate
<
float
,
CPUContext
>
(
const
int
count
,
float
*
x
,
float
*
h
,
Tensor
*
t
,
const
float
momentum
,
const
float
lr
,
CPUContext
*
ctx
)
{
t
->
Reshape
(
vector
<
TIndex
>
(
1
,
count
));
float
*
Tdata
=
t
->
mutable_data
<
float
,
CPUContext
>
();
ctx
->
Copy
<
float
,
CPUContext
,
CPUContext
>
(
count
,
Tdata
,
h
);
math
::
Axpby
<
float
,
CPUContext
>
(
count
,
lr
,
x
,
momentum
,
h
);
math
::
Axpby
<
float
,
CPUContext
>
(
count
,
1.0
+
momentum
,
h
,
-
momentum
,
Tdata
);
ctx
->
Copy
<
float
,
CPUContext
,
CPUContext
>
(
count
,
x
,
Tdata
);
}
}
/******************** update.rmsprop_update ********************/
/******************** update.rmsprop_update ********************/
...
@@ -903,18 +932,18 @@ template <> void AdamUpdate<float, CPUContext>(Tensor* x,
...
@@ -903,18 +932,18 @@ template <> void AdamUpdate<float, CPUContext>(Tensor* x,
template
<>
void
RMSPropUpdate
<
float
,
CPUContext
>
(
const
int
count
,
template
<>
void
RMSPropUpdate
<
float
,
CPUContext
>
(
const
int
count
,
float
*
x
,
float
*
x
,
float
*
h
,
float
*
h
,
Tensor
*
t
_buffer
,
Tensor
*
t
,
const
float
decay
,
const
float
decay
,
const
float
eps
,
const
float
eps
,
const
float
lr
)
{
const
float
lr
)
{
t
_buffer
->
Reshape
(
vector
<
TIndex
>
(
1
,
count
));
t
->
Reshape
(
vector
<
TIndex
>
(
1
,
count
));
float
*
buffer
=
t_buffer
->
mutable_data
<
float
,
CPUContext
>
();
float
*
Tdata
=
t
->
mutable_data
<
float
,
CPUContext
>
();
math
::
Square
<
float
,
CPUContext
>
(
count
,
x
,
buffer
);
math
::
Square
<
float
,
CPUContext
>
(
count
,
x
,
Tdata
);
math
::
Axpby
<
float
,
CPUContext
>
(
count
,
1.0
-
decay
,
buffer
,
decay
,
h
);
math
::
Axpby
<
float
,
CPUContext
>
(
count
,
1.0
-
decay
,
Tdata
,
decay
,
h
);
math
::
Sqrt
<
float
,
CPUContext
>
(
count
,
h
,
buffer
);
math
::
Sqrt
<
float
,
CPUContext
>
(
count
,
h
,
Tdata
);
math
::
AddScalar
<
float
,
CPUContext
>
(
count
,
eps
,
buffer
);
math
::
AddScalar
<
float
,
CPUContext
>
(
count
,
eps
,
Tdata
);
math
::
Div
<
float
,
CPUContext
>
(
count
,
x
,
buffer
,
buffer
);
math
::
Div
<
float
,
CPUContext
>
(
count
,
x
,
Tdata
,
Tdata
);
math
::
Axpby
<
float
,
CPUContext
>
(
count
,
lr
,
buffer
,
0.0
,
x
);
math
::
Axpby
<
float
,
CPUContext
>
(
count
,
lr
,
Tdata
,
0.0
,
x
);
}
}
/******************** utils.compare ********************/
/******************** utils.compare ********************/
...
...
Dragon/src/utils/op_kernel.cu
View file @
5c8da7f
...
@@ -1647,7 +1647,7 @@ template <> void AdamUpdate<float, CUDAContext>(Tensor* x,
...
@@ -1647,7 +1647,7 @@ template <> void AdamUpdate<float, CUDAContext>(Tensor* x,
const float beta2,
const float beta2,
const float eps,
const float eps,
const float lr) {
const float lr) {
const int
count = x->count();
TIndex
count = x->count();
auto* Xdata = x->mutable_data<float, CUDAContext>();
auto* Xdata = x->mutable_data<float, CUDAContext>();
auto* Mdata = m->mutable_data<float, CUDAContext>();
auto* Mdata = m->mutable_data<float, CUDAContext>();
auto* Vdata = v->mutable_data<float, CUDAContext>();
auto* Vdata = v->mutable_data<float, CUDAContext>();
...
@@ -1662,6 +1662,35 @@ template <> void AdamUpdate<float, CUDAContext>(Tensor* x,
...
@@ -1662,6 +1662,35 @@ template <> void AdamUpdate<float, CUDAContext>(Tensor* x,
CUDA_POST_KERNEL_CHECK;
CUDA_POST_KERNEL_CHECK;
}
}
/******************** update.nesterov_update ********************/
template <typename T>
__global__ void _NesterovUpdate(const int n,
T* g,
T* h,
const T momentum,
const T lr) {
CUDA_KERNEL_LOOP(i, n) {
T hi = h[i];
T hi_new = h[i] = momentum * hi + lr * g[i];
g[i] = (1 + momentum) * hi_new - momentum * hi;
}
}
template <> void NesterovUpdate<float, CUDAContext>(const int count,
float* x,
float* h,
Tensor* t,
const float momentum,
const float lr,
CUDAContext* ctx) {
_NesterovUpdate<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
x,
h,
momentum,
lr);
CUDA_POST_KERNEL_CHECK;
}
/******************** update.rmsprop_update ********************/
/******************** update.rmsprop_update ********************/
template <typename T>
template <typename T>
...
@@ -1681,7 +1710,7 @@ __global__ void _RMSPropUpdate(const int n,
...
@@ -1681,7 +1710,7 @@ __global__ void _RMSPropUpdate(const int n,
template <> void RMSPropUpdate<float, CUDAContext>(const int count,
template <> void RMSPropUpdate<float, CUDAContext>(const int count,
float* x,
float* x,
float* h,
float* h,
Tensor* t
_buffer
,
Tensor* t,
const float decay,
const float decay,
const float eps,
const float eps,
const float lr) {
const float lr) {
...
...
README.md
View file @
5c8da7f
...
@@ -104,8 +104,30 @@
...
@@ -104,8 +104,30 @@
8.
Deploy
8.
Deploy
-
Install Dragon
```Shell
cd Dragon
python setup.py install
```
``Hint``: If you do not have permission, try as follows:
```Shell
cd Dragon
python setup.py install --user
```
-
Install protobuf
```Shell
pip install protobuf
```
-
Install lmdb
```Shell
```Shell
python Dragon/setup.py install
pip install lmdb
```
```
----
----
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment