Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit fe161546
authored
Jun 11, 2018
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Preliminary FP16 Training Support
1 parent
4f4ac2ef
Show whitespace changes
Inline
Side-by-side
Showing
27 changed files
with
434 additions
and
32 deletions
Dragon/include/core/common.h
Dragon/include/operators/loss/sparse_softmax_cross_entropy_op.h
Dragon/include/operators/update/adam_update_op.h
Dragon/include/operators/update/nesterov_update_op.h
Dragon/include/operators/update/rmsprop_update_op.h
Dragon/include/operators/update/sgd_update_op.h
Dragon/include/operators/update/update_op_base.h
Dragon/include/operators/vision/conv_op.h
Dragon/include/operators/vision/conv_transpose_op.h
Dragon/include/utils/cuda_device.h
Dragon/python/dragon/__init__.py
Dragon/python/dragon/version.py
Dragon/python/setup.py
Dragon/src/operators/loss/sparse_softmax_cross_entropy_op.cc
Dragon/src/operators/misc/accuracy_op.cc
Dragon/src/operators/misc/astype_op.cc
Dragon/src/operators/update/adam_update_op.cc
Dragon/src/operators/update/nesterov_update_op.cc
Dragon/src/operators/update/rmsprop_update_op.cc
Dragon/src/operators/update/sgd_update_op.cc
Dragon/src/operators/update/update_op_base.cc
Dragon/src/operators/vision/cudnn_conv2d_op.cc
Dragon/src/operators/vision/cudnn_conv2d_transpose_op.cc
Dragon/src/utils/math_functions_fp16.cu
Dragon/src/utils/op_kernel.cc
Dragon/src/utils/op_kernel.cu
Dragon/src/utils/op_kernel_fp16.cu
Dragon/include/core/common.h
View file @
fe16154
...
@@ -49,6 +49,8 @@ using Map = std::unordered_map<Key, Value>;
...
@@ -49,6 +49,8 @@ using Map = std::unordered_map<Key, Value>;
template
<
typename
Value
>
template
<
typename
Value
>
using
Set
=
std
::
unordered_set
<
Value
>
;
using
Set
=
std
::
unordered_set
<
Value
>
;
#define DRAGON_VERSION 2204
#define CONCATENATE_IMPL(s1, s2) s1##s2
#define CONCATENATE_IMPL(s1, s2) s1##s2
#define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2)
#define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2)
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__)
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__)
...
...
Dragon/include/operators/loss/sparse_softmax_cross_entropy_op.h
View file @
fe16154
...
@@ -33,6 +33,8 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
...
@@ -33,6 +33,8 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
USE_OPERATOR_FUNCTIONS
(
Context
);
USE_OPERATOR_FUNCTIONS
(
Context
);
void
SoftmaxRun
();
void
SoftmaxRun
();
void
SoftmaxRunFP16
();
void
RunOnDevice
()
override
;
void
RunOnDevice
()
override
;
template
<
typename
Tx
,
typename
Ty
>
void
RunWithType
();
template
<
typename
Tx
,
typename
Ty
>
void
RunWithType
();
...
...
Dragon/include/operators/update/adam_update_op.h
View file @
fe16154
...
@@ -25,6 +25,7 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
...
@@ -25,6 +25,7 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
USE_UPDATER_FUNCTIONS
(
Context
);
USE_UPDATER_FUNCTIONS
(
Context
);
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat16
()
override
;
protected
:
protected
:
int
t
;
float
lr
,
beta1
,
beta2
,
eps
;
int
t
;
float
lr
,
beta1
,
beta2
,
eps
;
...
...
Dragon/include/operators/update/nesterov_update_op.h
View file @
fe16154
...
@@ -25,6 +25,7 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
...
@@ -25,6 +25,7 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
USE_UPDATER_FUNCTIONS
(
Context
);
USE_UPDATER_FUNCTIONS
(
Context
);
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat16
()
override
;
protected
:
protected
:
float
lr
,
momentum
;
float
lr
,
momentum
;
...
...
Dragon/include/operators/update/rmsprop_update_op.h
View file @
fe16154
...
@@ -25,6 +25,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
...
@@ -25,6 +25,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
USE_UPDATER_FUNCTIONS
(
Context
);
USE_UPDATER_FUNCTIONS
(
Context
);
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat16
()
override
;
protected
:
protected
:
float
lr
,
decay
,
eps
;
float
lr
,
decay
,
eps
;
...
...
Dragon/include/operators/update/sgd_update_op.h
View file @
fe16154
...
@@ -26,6 +26,7 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
...
@@ -26,6 +26,7 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
USE_UPDATER_FUNCTIONS
(
Context
);
USE_UPDATER_FUNCTIONS
(
Context
);
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat
()
override
;
void
ComputeRunWithFloat16
()
override
;
protected
:
protected
:
float
old_lr
,
lr
,
momentum
,
correction
;
float
old_lr
,
lr
,
momentum
,
correction
;
...
...
Dragon/include/operators/update/update_op_base.h
View file @
fe16154
...
@@ -35,6 +35,7 @@ class UpdateOpBase : public Operator<Context> {
...
@@ -35,6 +35,7 @@ class UpdateOpBase : public Operator<Context> {
void
RunOnDevice
()
override
;
void
RunOnDevice
()
override
;
template
<
typename
T
>
void
PreprocessRunWithType
();
template
<
typename
T
>
void
PreprocessRunWithType
();
virtual
void
ComputeRunWithFloat
()
=
0
;
virtual
void
ComputeRunWithFloat
()
=
0
;
virtual
void
ComputeRunWithFloat16
()
{
LOG
(
FATAL
)
<<
"This Updater does not support FP16."
;
}
template
<
typename
T
>
void
UpdateRunWithType
();
template
<
typename
T
>
void
UpdateRunWithType
();
protected
:
protected
:
...
...
Dragon/include/operators/vision/conv_op.h
View file @
fe16154
...
@@ -103,6 +103,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
...
@@ -103,6 +103,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
protected
:
protected
:
cudnnHandle_t
*
handle
;
cudnnHandle_t
*
handle
;
cudaStream_t
*
stream
;
cudaStream_t
*
stream
;
cudnnDataType_t
compute_type
;
cudnnTensorFormat_t
format
;
cudnnTensorFormat_t
format
;
cudnnConvolutionFwdAlgo_t
fwd_algo
;
cudnnConvolutionFwdAlgo_t
fwd_algo
;
cudnnTensorDescriptor_t
input_desc
,
output_desc
,
bias_desc
;
cudnnTensorDescriptor_t
input_desc
,
output_desc
,
bias_desc
;
...
@@ -164,6 +165,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
...
@@ -164,6 +165,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
protected
:
protected
:
cudnnHandle_t
*
handle
;
cudnnHandle_t
*
handle
;
cudaStream_t
*
stream
;
cudaStream_t
*
stream
;
cudnnDataType_t
compute_type
;
cudnnTensorFormat_t
format
;
cudnnTensorFormat_t
format
;
cudnnConvolutionBwdFilterAlgo_t
bwd_filter_algo
;
cudnnConvolutionBwdFilterAlgo_t
bwd_filter_algo
;
cudnnConvolutionBwdDataAlgo_t
bwd_data_algo
;
cudnnConvolutionBwdDataAlgo_t
bwd_data_algo
;
...
...
Dragon/include/operators/vision/conv_transpose_op.h
View file @
fe16154
...
@@ -106,6 +106,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
...
@@ -106,6 +106,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
protected
:
protected
:
cudnnHandle_t
*
handle
;
cudnnHandle_t
*
handle
;
cudaStream_t
*
stream
;
cudaStream_t
*
stream
;
cudnnDataType_t
compute_type
;
cudnnTensorFormat_t
format
;
cudnnTensorFormat_t
format
;
cudnnConvolutionBwdDataAlgo_t
fwd_algo
;
cudnnConvolutionBwdDataAlgo_t
fwd_algo
;
cudnnTensorDescriptor_t
input_desc
,
output_desc
,
bias_desc
;
cudnnTensorDescriptor_t
input_desc
,
output_desc
,
bias_desc
;
...
@@ -167,6 +168,7 @@ public:
...
@@ -167,6 +168,7 @@ public:
protected
:
protected
:
cudnnHandle_t
*
handle
;
cudnnHandle_t
*
handle
;
cudaStream_t
*
stream
;
cudaStream_t
*
stream
;
cudnnDataType_t
compute_type
;
cudnnTensorFormat_t
format
;
cudnnTensorFormat_t
format
;
cudnnConvolutionBwdFilterAlgo_t
bwd_filter_algo
;
cudnnConvolutionBwdFilterAlgo_t
bwd_filter_algo
;
cudnnConvolutionFwdAlgo_t
bwd_data_algo
;
cudnnConvolutionFwdAlgo_t
bwd_data_algo
;
...
...
Dragon/include/utils/cuda_device.h
View file @
fe16154
...
@@ -116,6 +116,12 @@ inline const cudaDeviceProp& GetDeviceProperty(const int device_id) {
...
@@ -116,6 +116,12 @@ inline const cudaDeviceProp& GetDeviceProperty(const int device_id) {
return
props
.
props
[
device_id
];
return
props
.
props
[
device_id
];
}
}
inline
bool
CUDA_TRUE_FP16_AVAILABLE
()
{
int
device
=
CUDA_CURRENT_DEVICE
();
auto
&
prop
=
GetDeviceProperty
(
device
);
return
prop
.
major
>=
6
;
}
inline
bool
TENSOR_CORE_AVAILABLE
()
{
inline
bool
TENSOR_CORE_AVAILABLE
()
{
#if CUDA_VERSION < 9000
#if CUDA_VERSION < 9000
return
false
;
return
false
;
...
...
Dragon/python/dragon/__init__.py
View file @
fe16154
...
@@ -34,3 +34,7 @@ from dragon.core.scope import TensorScope as name_scope
...
@@ -34,3 +34,7 @@ from dragon.core.scope import TensorScope as name_scope
from
dragon.core.scope
import
PhaseScope
as
phase_scope
from
dragon.core.scope
import
PhaseScope
as
phase_scope
from
dragon.core.scope
import
DeviceScope
as
device_scope
from
dragon.core.scope
import
DeviceScope
as
device_scope
# version
from
dragon.version
import
version
__version__
=
version
Dragon/python/dragon/version.py
0 → 100644
View file @
fe16154
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
version
=
'0.2.2'
full_version
=
'0.2.2.4'
release
=
False
if
not
release
:
version
=
full_version
\ No newline at end of file
Dragon/python/setup.py
View file @
fe16154
...
@@ -42,7 +42,7 @@ find_modules()
...
@@ -42,7 +42,7 @@ find_modules()
setup
(
name
=
'dragon'
,
setup
(
name
=
'dragon'
,
version
=
'0.2.2.
3
'
,
version
=
'0.2.2.
4
'
,
description
=
'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework'
,
description
=
'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework'
,
url
=
'https://github.com/seetaresearch/Dragon'
,
url
=
'https://github.com/seetaresearch/Dragon'
,
author
=
'Ting Pan'
,
author
=
'Ting Pan'
,
...
...
Dragon/src/operators/loss/sparse_softmax_cross_entropy_op.cc
View file @
fe16154
...
@@ -20,6 +20,24 @@ void SparseSoftmaxCrossEntropyOp<Context>::SoftmaxRun() {
...
@@ -20,6 +20,24 @@ void SparseSoftmaxCrossEntropyOp<Context>::SoftmaxRun() {
softmax_op
->
Run
();
softmax_op
->
Run
();
}
}
template
<
class
Context
>
void
SparseSoftmaxCrossEntropyOp
<
Context
>::
SoftmaxRunFP16
()
{
Tensor
*
XF32
=
ws
()
->
CreateTensor
(
"/mnt/"
+
anchor
()
+
"/softmax/xf32"
);
XF32
->
ReshapeLike
(
Input
(
0
));
auto
*
XdataF16
=
Input
(
0
).
template
data
<
float16
,
Context
>
();
auto
*
XdataF32
=
XF32
->
template
mutable_data
<
float
,
Context
>
();
kernel
::
TypeA2B
<
float16
,
float
,
Context
>
(
Input
(
0
).
count
(),
XdataF16
,
XdataF32
);
OperatorDef
softmax_def
=
MakeOperatorDef
(
"Softmax"
,
""
,
vector
<
string
>
({
XF32
->
name
()
}),
vector
<
string
>
({
"/mnt/"
+
anchor
()
+
"/softmax/prob"
}));
softmax_def
.
add_arg
()
->
CopyFrom
(
this
->
arg
(
"axis"
));
if
(
op_def
().
has_device_option
())
softmax_def
.
mutable_device_option
()
->
CopyFrom
(
op_def
().
device_option
());
if
(
!
softmax_op
)
softmax_op
.
reset
(
CreateOperator
(
softmax_def
,
ws
()));
else
softmax_op
->
MutableOp
(
softmax_def
);
softmax_op
->
Run
();
}
template
<
class
Context
>
template
<
typename
Tx
,
typename
Ty
>
template
<
class
Context
>
template
<
typename
Tx
,
typename
Ty
>
void
SparseSoftmaxCrossEntropyOp
<
Context
>::
RunWithType
()
{
void
SparseSoftmaxCrossEntropyOp
<
Context
>::
RunWithType
()
{
auto
*
prob_data
=
prob
->
template
data
<
Tx
,
Context
>
();
auto
*
prob_data
=
prob
->
template
data
<
Tx
,
Context
>
();
...
@@ -59,11 +77,11 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() {
...
@@ -59,11 +77,11 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() {
<<
"
\n
Number of predictions must match the number of labels."
;
<<
"
\n
Number of predictions must match the number of labels."
;
valid
.
Reshape
(
vector
<
TIndex
>
(
1
,
outer_dim
*
inner_dim
));
valid
.
Reshape
(
vector
<
TIndex
>
(
1
,
outer_dim
*
inner_dim
));
losses
.
Reshape
(
vector
<
TIndex
>
(
1
,
outer_dim
*
inner_dim
));
losses
.
Reshape
(
vector
<
TIndex
>
(
1
,
outer_dim
*
inner_dim
));
ws
()
->
CreateTensor
(
"/mnt/"
+
anchor
()
+
"/softmax/prob"
);
prob
=
ws
()
->
CreateTensor
(
"/mnt/"
+
anchor
()
+
"/softmax/prob"
);
SoftmaxRun
();
prob
=
ws
()
->
GetTensor
(
"/mnt/"
+
anchor
()
+
"/softmax/prob"
);
if
(
XIsType
(
Input
(
0
),
float
))
{
if
(
XIsType
(
Input
(
0
),
float
)
||
XIsType
(
Input
(
0
),
float16
))
{
if
(
XIsType
(
Input
(
0
),
float16
))
SoftmaxRunFP16
();
else
SoftmaxRun
();
if
(
XIsType
(
Input
(
1
),
float
))
RunWithType
<
float
,
float
>
();
if
(
XIsType
(
Input
(
1
),
float
))
RunWithType
<
float
,
float
>
();
else
if
(
XIsType
(
Input
(
1
),
int64_t
))
RunWithType
<
float
,
int64_t
>
();
else
if
(
XIsType
(
Input
(
1
),
int64_t
))
RunWithType
<
float
,
int64_t
>
();
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
1
),
{
"float32"
,
"int64"
});
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
1
),
{
"float32"
,
"int64"
});
...
@@ -118,11 +136,17 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
...
@@ -118,11 +136,17 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
Output
(
0
)
->
ReshapeLike
(
Input
(
0
));
Output
(
0
)
->
ReshapeLike
(
Input
(
0
));
valid
.
Reshape
(
vector
<
TIndex
>
(
1
,
outer_dim
*
inner_dim
));
valid
.
Reshape
(
vector
<
TIndex
>
(
1
,
outer_dim
*
inner_dim
));
if
(
XIsType
(
Input
(
0
),
float
))
{
if
(
XIsType
(
Input
(
0
),
float
)
||
XIsType
(
Input
(
0
),
float16
)
)
{
if
(
XIsType
(
Input
(
1
),
float
))
RunWithType
<
float
,
float
>
();
if
(
XIsType
(
Input
(
1
),
float
))
RunWithType
<
float
,
float
>
();
else
if
(
XIsType
(
Input
(
1
),
int64_t
))
RunWithType
<
float
,
int64_t
>
();
else
if
(
XIsType
(
Input
(
1
),
int64_t
))
RunWithType
<
float
,
int64_t
>
();
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
1
),
{
"float32"
,
"int64"
});
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
1
),
{
"float32"
,
"int64"
});
}
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
0
),
{
"float32"
});
if
(
XIsType
(
Input
(
0
),
float16
))
{
auto
*
dXdataF32
=
Output
(
0
)
->
template
data
<
float
,
Context
>
();
auto
*
dXdataF16
=
prob
->
template
mutable_data
<
float16
,
Context
>
();
kernel
::
TypeA2B
<
float
,
float16
,
Context
>
(
Output
(
0
)
->
count
(),
dXdataF32
,
dXdataF16
);
Output
(
0
)
->
template
Copy
<
Context
,
Context
>
(
*
prob
);
}
}
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
0
),
{
"float32"
,
"float16"
});
}
}
DEPLOY_CPU
(
SparseSoftmaxCrossEntropyGradient
);
DEPLOY_CPU
(
SparseSoftmaxCrossEntropyGradient
);
...
...
Dragon/src/operators/misc/accuracy_op.cc
View file @
fe16154
#include <algorithm>
#include <algorithm>
#include "operators/misc/accuracy_op.h"
#include "operators/misc/accuracy_op.h"
#include "core/workspace.h"
#include "utils/op_kernel.h"
#include "utils/math_functions.h"
#include "utils/math_functions.h"
namespace
dragon
{
namespace
dragon
{
template
<
class
Context
>
template
<
typename
Tx
,
typename
Ty
>
template
<
class
Context
>
template
<
typename
Tx
,
typename
Ty
>
...
@@ -13,7 +16,17 @@ void AccuracyOp<Context>::RunWithType() {
...
@@ -13,7 +16,17 @@ void AccuracyOp<Context>::RunWithType() {
Map
<
int
,
TIndex
>
num_per_class
;
Map
<
int
,
TIndex
>
num_per_class
;
TIndex
acc
=
0
,
count
=
0
;
TIndex
acc
=
0
,
count
=
0
;
auto
*
Xdata
=
Input
(
0
).
template
data
<
Tx
,
CPUContext
>
();
const
Tx
*
Xdata
;
if
(
XIsType
(
Input
(
0
),
float16
))
{
Tensor
*
XF32
=
ws
()
->
CreateTensor
(
"/mnt/"
+
anchor
()
+
"/accuracy/xf32"
);
XF32
->
ReshapeLike
(
Input
(
0
));
auto
*
XdataF16
=
Input
(
0
).
template
data
<
float16
,
CPUContext
>
();
auto
*
XdataF32
=
XF32
->
template
mutable_data
<
float
,
CPUContext
>
();
kernel
::
TypeA2B
<
float16
,
float
,
CPUContext
>
(
Input
(
0
).
count
(),
XdataF16
,
XdataF32
);
Xdata
=
XdataF32
;
}
else
Xdata
=
Input
(
0
).
template
data
<
Tx
,
CPUContext
>
();
auto
*
labels
=
Input
(
1
).
template
data
<
Ty
,
CPUContext
>
();
auto
*
labels
=
Input
(
1
).
template
data
<
Ty
,
CPUContext
>
();
auto
*
ignores
=
ignore_labels
.
count
()
>
0
?
auto
*
ignores
=
ignore_labels
.
count
()
>
0
?
ignore_labels
.
data
<
int
,
CPUContext
>
()
:
nullptr
;
ignore_labels
.
data
<
int
,
CPUContext
>
()
:
nullptr
;
...
@@ -60,11 +73,11 @@ void AccuracyOp<Context>::RunOnDevice() {
...
@@ -60,11 +73,11 @@ void AccuracyOp<Context>::RunOnDevice() {
Output
(
0
)
->
Reshape
(
vector
<
TIndex
>
(
1
,
1
));
Output
(
0
)
->
Reshape
(
vector
<
TIndex
>
(
1
,
1
));
if
(
OutputSize
()
>
1
)
Output
(
1
)
->
Reshape
(
vector
<
TIndex
>
(
1
,
num_classes
));
if
(
OutputSize
()
>
1
)
Output
(
1
)
->
Reshape
(
vector
<
TIndex
>
(
1
,
num_classes
));
if
(
XIsType
(
Input
(
0
),
float
))
{
if
(
XIsType
(
Input
(
0
),
float
)
||
XIsType
(
Input
(
0
),
float16
)
)
{
if
(
XIsType
(
Input
(
1
),
float
))
RunWithType
<
float
,
float
>
();
if
(
XIsType
(
Input
(
1
),
float
))
RunWithType
<
float
,
float
>
();
else
if
(
XIsType
(
Input
(
1
),
int64_t
))
RunWithType
<
float
,
int64_t
>
();
else
if
(
XIsType
(
Input
(
1
),
int64_t
))
RunWithType
<
float
,
int64_t
>
();
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
1
),
{
"float32"
,
"int64"
});
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
1
),
{
"float32"
,
"int64"
});
}
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
0
),
{
"float32"
});
}
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
0
),
{
"float32"
,
"float16"
});
}
}
DEPLOY_CPU
(
Accuracy
);
DEPLOY_CPU
(
Accuracy
);
...
...
Dragon/src/operators/misc/astype_op.cc
View file @
fe16154
...
@@ -22,6 +22,7 @@ namespace dragon {
...
@@ -22,6 +22,7 @@ namespace dragon {
auto* Ydata = buffer->template mutable_data<type_b, Context>(); \
auto* Ydata = buffer->template mutable_data<type_b, Context>(); \
kernel::TypeA2B<type_a, type_b, Context>(Output(0)->count(), Xdata, Ydata); \
kernel::TypeA2B<type_a, type_b, Context>(Output(0)->count(), Xdata, Ydata); \
Output(0)->template Copy<Context, Context>(*buffer); \
Output(0)->template Copy<Context, Context>(*buffer); \
ws()->ReleaseBuffer(buffer); \
} \
} \
return; \
return; \
}
}
...
...
Dragon/src/operators/update/adam_update_op.cc
View file @
fe16154
...
@@ -22,6 +22,24 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat() {
...
@@ -22,6 +22,24 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat() {
lr
,
beta1
,
beta2
,
eps
,
dXdata
,
Mdata
,
Vdata
);
lr
,
beta1
,
beta2
,
eps
,
dXdata
,
Mdata
,
Vdata
);
}
}
template
<
class
Context
>
void
AdamUpdateOp
<
Context
>::
ComputeRunWithFloat16
()
{
Tensor
*
m
=
ws
()
->
CreateTensor
(
"/mnt/"
+
Slot
()
+
"/adam/m"
);
Tensor
*
v
=
ws
()
->
CreateTensor
(
"/mnt/"
+
Slot
()
+
"/adam/v"
);
m
->
ReshapeLike
(
Input
(
0
));
v
->
ReshapeLike
(
Input
(
0
));
t
++
;
beta1
=
Param
(
"beta1"
),
beta2
=
Param
(
"beta2"
),
eps
=
Param
(
"eps"
);
float
coeff
=
sqrt
(
1.
-
pow
(
beta2
,
t
))
/
(
1.
-
pow
(
beta1
,
t
));
lr
=
Param
(
"base_lr"
)
*
coeff
*
this
->
lr_mult
;
auto
*
dXdata
=
Input
(
0
).
template
mutable_data
<
float16
,
Context
>
();
auto
*
Mdata
=
m
->
mutable_data
<
float16
,
Context
>
();
auto
*
Vdata
=
v
->
mutable_data
<
float16
,
Context
>
();
kernel
::
AdamUpdate
<
float16
,
Context
>
(
Input
(
0
).
count
(),
lr
,
beta1
,
beta2
,
eps
,
dXdata
,
Mdata
,
Vdata
);
}
DEPLOY_CPU
(
AdamUpdate
);
DEPLOY_CPU
(
AdamUpdate
);
#ifdef WITH_CUDA
#ifdef WITH_CUDA
DEPLOY_CUDA
(
AdamUpdate
);
DEPLOY_CUDA
(
AdamUpdate
);
...
...
Dragon/src/operators/update/nesterov_update_op.cc
View file @
fe16154
...
@@ -17,6 +17,18 @@ void NesterovUpdateOp<Context>::ComputeRunWithFloat() {
...
@@ -17,6 +17,18 @@ void NesterovUpdateOp<Context>::ComputeRunWithFloat() {
lr
,
momentum
,
dXdata
,
Hdata
);
lr
,
momentum
,
dXdata
,
Hdata
);
}
}
template
<
class
Context
>
void
NesterovUpdateOp
<
Context
>::
ComputeRunWithFloat16
()
{
Tensor
*
h
=
ws
()
->
CreateTensor
(
"/mnt/"
+
Slot
()
+
"/nesterov/h"
);
h
->
ReshapeLike
(
Input
(
0
));
lr
=
Param
(
"base_lr"
)
*
this
->
lr_mult
,
momentum
=
Param
(
"momentum"
);
auto
*
dXdata
=
Input
(
0
).
template
mutable_data
<
float16
,
Context
>
();
auto
*
Hdata
=
h
->
template
mutable_data
<
float16
,
Context
>
();
kernel
::
NesterovUpdate
<
float16
,
Context
>
(
Input
(
0
).
count
(),
lr
,
momentum
,
dXdata
,
Hdata
);
}
DEPLOY_CPU
(
NesterovUpdate
);
DEPLOY_CPU
(
NesterovUpdate
);
#ifdef WITH_CUDA
#ifdef WITH_CUDA
DEPLOY_CUDA
(
NesterovUpdate
);
DEPLOY_CUDA
(
NesterovUpdate
);
...
...
Dragon/src/operators/update/rmsprop_update_op.cc
View file @
fe16154
...
@@ -17,6 +17,19 @@ void RMSPropUpdateOp<Context>::ComputeRunWithFloat() {
...
@@ -17,6 +17,19 @@ void RMSPropUpdateOp<Context>::ComputeRunWithFloat() {
lr
,
decay
,
eps
,
dXdata
,
Hdata
);
lr
,
decay
,
eps
,
dXdata
,
Hdata
);
}
}
template
<
class
Context
>
void
RMSPropUpdateOp
<
Context
>::
ComputeRunWithFloat16
()
{
Tensor
*
h
=
ws
()
->
CreateTensor
(
"/mnt/"
+
Slot
()
+
"/rmsprop/h"
);
h
->
ReshapeLike
(
Input
(
0
));
lr
=
Param
(
"base_lr"
)
*
this
->
lr_mult
;
decay
=
Param
(
"decay"
),
eps
=
Param
(
"eps"
);
auto
*
dXdata
=
Input
(
0
).
template
mutable_data
<
float16
,
Context
>
();
auto
*
Hdata
=
h
->
template
mutable_data
<
float16
,
Context
>
();
kernel
::
RMSPropUpdate
<
float16
,
Context
>
(
Input
(
0
).
count
(),
lr
,
decay
,
eps
,
dXdata
,
Hdata
);
}
DEPLOY_CPU
(
RMSPropUpdate
);
DEPLOY_CPU
(
RMSPropUpdate
);
#ifdef WITH_CUDA
#ifdef WITH_CUDA
DEPLOY_CUDA
(
RMSPropUpdate
);
DEPLOY_CUDA
(
RMSPropUpdate
);
...
...
Dragon/src/operators/update/sgd_update_op.cc
View file @
fe16154
...
@@ -19,6 +19,19 @@ void SGDUpdateOp<Context>::ComputeRunWithFloat() {
...
@@ -19,6 +19,19 @@ void SGDUpdateOp<Context>::ComputeRunWithFloat() {
lr
,
momentum
*
correction
,
dXdata
,
Hdata
);
lr
,
momentum
*
correction
,
dXdata
,
Hdata
);
}
}
template
<
class
Context
>
void
SGDUpdateOp
<
Context
>::
ComputeRunWithFloat16
()
{
Tensor
*
h
=
ws
()
->
CreateTensor
(
"/mnt/"
+
Slot
()
+
"/sgd/h"
);
h
->
ReshapeLike
(
Input
(
0
));
lr
=
Param
(
"base_lr"
)
*
this
->
lr_mult
,
momentum
=
Param
(
"momentum"
);
if
(
old_lr
>
0
)
{
correction
=
lr
/
old_lr
;
}
old_lr
=
lr
;
auto
*
dXdata
=
Input
(
0
).
template
mutable_data
<
float16
,
Context
>
();
auto
*
Hdata
=
h
->
template
mutable_data
<
float16
,
Context
>
();
kernel
::
SGDUpdate
<
float16
,
Context
>
(
Input
(
0
).
count
(),
lr
,
momentum
*
correction
,
dXdata
,
Hdata
);
}
DEPLOY_CPU
(
SGDUpdate
);
DEPLOY_CPU
(
SGDUpdate
);
#ifdef WITH_CUDA
#ifdef WITH_CUDA
DEPLOY_CUDA
(
SGDUpdate
);
DEPLOY_CUDA
(
SGDUpdate
);
...
...
Dragon/src/operators/update/update_op_base.cc
View file @
fe16154
#include "operators/update/update_op_base.h"
#include "operators/update/update_op_base.h"
#include "core/workspace.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/math_functions.h"
#include "utils/cast.h"
namespace
dragon
{
namespace
dragon
{
...
@@ -27,10 +28,10 @@ void UpdateOpBase<Context>::PreprocessRunWithType() {
...
@@ -27,10 +28,10 @@ void UpdateOpBase<Context>::PreprocessRunWithType() {
clip_thresh
=
Param
(
"clip_gradient"
);
clip_thresh
=
Param
(
"clip_gradient"
);
if
(
clip_thresh
>
0
)
{
if
(
clip_thresh
>
0
)
{
auto
*
dXdata
=
Input
(
0
).
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
Input
(
0
).
template
mutable_data
<
T
,
Context
>
();
T
sumsq_grad
=
math
::
Dot
<
T
,
Context
>
(
Input
(
0
).
count
(),
dXdata
,
dXdata
);
float
sumsq_grad
=
math
::
Dot
<
T
,
Context
>
(
Input
(
0
).
count
(),
dXdata
,
dXdata
);
const
T
l2norm
=
sqrt
(
sumsq_grad
);
const
float
l2norm
=
sqrt
(
sumsq_grad
);
if
(
l2norm
>
clip_thresh
)
{
if
(
l2norm
>
clip_thresh
)
{
T
factor
=
clip_thresh
/
l2norm
;
float
factor
=
clip_thresh
/
l2norm
;
math
::
Scal
<
T
,
Context
>
(
Input
(
0
).
count
(),
factor
,
dXdata
);
math
::
Scal
<
T
,
Context
>
(
Input
(
0
).
count
(),
factor
,
dXdata
);
}
}
}
}
...
@@ -48,7 +49,8 @@ void UpdateOpBase<Context>::UpdateRunWithType() {
...
@@ -48,7 +49,8 @@ void UpdateOpBase<Context>::UpdateRunWithType() {
auto
*
dXdata
=
Input
(
0
).
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
Input
(
0
).
template
mutable_data
<
T
,
Context
>
();
auto
*
Xdata
=
Output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
Xdata
=
Output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
math
::
Axpy
<
T
,
Context
>
(
Output
(
0
)
->
count
(),
-
1.0
,
dXdata
,
Xdata
);
math
::
Axpy
<
T
,
Context
>
(
Output
(
0
)
->
count
(),
-
1.0
,
dXdata
,
Xdata
);
if
(
zero_grad
)
math
::
Set
<
T
,
Context
>
(
Input
(
0
).
count
(),
0
,
dXdata
);
T
zeroT
=
dragon_cast
<
T
,
float
>
(
0.
f
);
if
(
zero_grad
)
math
::
Set
<
T
,
Context
>
(
Input
(
0
).
count
(),
zeroT
,
dXdata
);
}
}
template
<
class
Context
>
template
<
class
Context
>
...
@@ -62,7 +64,11 @@ void UpdateOpBase<Context>::RunOnDevice() {
...
@@ -62,7 +64,11 @@ void UpdateOpBase<Context>::RunOnDevice() {
PreprocessRunWithType
<
float
>
();
PreprocessRunWithType
<
float
>
();
ComputeRunWithFloat
();
ComputeRunWithFloat
();
UpdateRunWithType
<
float
>
();
UpdateRunWithType
<
float
>
();
}
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
0
),
{
"float32"
});
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
PreprocessRunWithType
<
float16
>
();
ComputeRunWithFloat16
();
UpdateRunWithType
<
float16
>
();
}
else
LOG
(
FATAL
)
<<
DTypeHelper
(
Input
(
0
),
{
"float32"
,
"float16"
});
}
}
template
class
UpdateOpBase
<
CPUContext
>
;
template
class
UpdateOpBase
<
CPUContext
>
;
...
...
Dragon/src/operators/vision/cudnn_conv2d_op.cc
View file @
fe16154
...
@@ -135,12 +135,14 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
...
@@ -135,12 +135,14 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
#ifdef WITH_CUDA_FP16
#ifdef WITH_CUDA_FP16
#if CUDNN_VERSION_MIN(6, 0, 0)
#if CUDNN_VERSION_MIN(6, 0, 0)
compute_type
=
CUDA_TRUE_FP16_AVAILABLE
()
?
CUDNN_DATA_HALF
:
CUDNN_DATA_FLOAT
;
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
stride
[
0
],
this
->
stride
[
1
],
this
->
stride
[
0
],
this
->
stride
[
1
],
this
->
dilation
[
0
],
this
->
dilation
[
1
],
this
->
dilation
[
0
],
this
->
dilation
[
1
],
CUDNN_CROSS_CORRELATION
,
CUDNN_CROSS_CORRELATION
,
CUDNN_DATA_FLOAT
));
compute_type
));
#else
#else
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
pad
[
0
],
this
->
pad
[
1
],
...
@@ -317,12 +319,17 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
...
@@ -317,12 +319,17 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
#ifdef WITH_CUDA_FP16
#ifdef WITH_CUDA_FP16
#if CUDNN_VERSION_MIN(6, 0, 0)
#if CUDNN_VERSION_MIN(6, 0, 0)
// may encounter CUDNN_STATUS_BAD_PARAM if using CUDNN_DATA_HALF
// keep it before cuDNN fix this bug
// compute_type = CUDA_TRUE_FP16_AVAILABLE() ?
// CUDNN_DATA_HALF : CUDNN_DATA_FLOAT;
compute_type
=
CUDNN_DATA_FLOAT
;
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
stride
[
0
],
this
->
stride
[
1
],
this
->
stride
[
0
],
this
->
stride
[
1
],
this
->
dilation
[
0
],
this
->
dilation
[
1
],
this
->
dilation
[
0
],
this
->
dilation
[
1
],
CUDNN_CROSS_CORRELATION
,
CUDNN_CROSS_CORRELATION
,
CUDNN_DATA_FLOAT
));
compute_type
));
#else
#else
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
pad
[
0
],
this
->
pad
[
1
],
...
...
Dragon/src/operators/vision/cudnn_conv2d_transpose_op.cc
View file @
fe16154
...
@@ -137,12 +137,14 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
...
@@ -137,12 +137,14 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
#ifdef WITH_CUDA_FP16
#ifdef WITH_CUDA_FP16
#if CUDNN_VERSION_MIN(6, 0, 0)
#if CUDNN_VERSION_MIN(6, 0, 0)
compute_type
=
CUDA_TRUE_FP16_AVAILABLE
()
?
CUDNN_DATA_HALF
:
CUDNN_DATA_FLOAT
;
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
stride
[
0
],
this
->
stride
[
1
],
this
->
stride
[
0
],
this
->
stride
[
1
],
this
->
dilation
[
0
],
this
->
dilation
[
1
],
this
->
dilation
[
0
],
this
->
dilation
[
1
],
CUDNN_CROSS_CORRELATION
,
CUDNN_CROSS_CORRELATION
,
CUDNN_DATA_FLOAT
));
compute_type
));
#else
#else
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
pad
[
0
],
this
->
pad
[
1
],
...
@@ -321,12 +323,14 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
...
@@ -321,12 +323,14 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
}
else
if
(
XIsType
(
Input
(
0
),
float16
))
{
#ifdef WITH_CUDA_FP16
#ifdef WITH_CUDA_FP16
#if CUDNN_VERSION_MIN(6, 0, 0)
#if CUDNN_VERSION_MIN(6, 0, 0)
compute_type
=
CUDA_TRUE_FP16_AVAILABLE
()
?
CUDNN_DATA_HALF
:
CUDNN_DATA_FLOAT
;
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
stride
[
0
],
this
->
stride
[
1
],
this
->
stride
[
0
],
this
->
stride
[
1
],
this
->
dilation
[
0
],
this
->
dilation
[
1
],
this
->
dilation
[
0
],
this
->
dilation
[
1
],
CUDNN_CROSS_CORRELATION
,
CUDNN_CROSS_CORRELATION
,
CUDNN_DATA_FLOAT
));
compute_type
));
#else
#else
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
conv_desc
,
this
->
pad
[
0
],
this
->
pad
[
1
],
this
->
pad
[
0
],
this
->
pad
[
1
],
...
...
Dragon/src/utils/math_functions_fp16.cu
View file @
fe16154
...
@@ -20,9 +20,7 @@ __global__ void _SetHalf(const int n, const T alpha, T* x) {
...
@@ -20,9 +20,7 @@ __global__ void _SetHalf(const int n, const T alpha, T* x) {
}
}
}
}
template <> void Set<float16, CUDAContext>(const int n,
template <> void Set<float16, CUDAContext>(const int n, const float16 alpha, float16* x) {
const float16 alpha,
float16* x) {
#ifdef WITH_CUDA_FP16
#ifdef WITH_CUDA_FP16
if (n % 2 == 0)
if (n % 2 == 0)
_SetHalf<half2> << <GET_BLOCKS(n / 2), CUDA_NUM_THREADS >> >(n / 2,
_SetHalf<half2> << <GET_BLOCKS(n / 2), CUDA_NUM_THREADS >> >(n / 2,
...
@@ -36,18 +34,29 @@ template <> void Set<float16, CUDAContext>(const int n,
...
@@ -36,18 +34,29 @@ template <> void Set<float16, CUDAContext>(const int n,
#endif
#endif
}
}
template <> void RandomUniform<float16, CUDAContext>(const int n,
#ifdef WITH_CUDA_FP16
const float low,
__global__ void _TypeFloat2Half(const int n, const float* a, half* b) {
const float high,
CUDA_KERNEL_LOOP(idx, n) {
float16* x) {
b[idx] = __float2half(a[idx]);
NOT_IMPLEMENTED;
}
}
}
#endif
template <> void RandomNormal<float16, CUDAContext>(const int n,
template <> void RandomNormal<float16, CUDAContext>(const int n,
const float mu,
const float mu,
const float sigma,
const float sigma,
float16* x) {
float16* x) {
NOT_IMPLEMENTED;
#ifdef WITH_CUDA_FP16
float* xf32 = (float*)CUDAContext::New(n * sizeof(float));
CURAND_CHECK(curandGenerateNormal(curand_generator(), xf32, n, mu, sigma));
_TypeFloat2Half << <GET_BLOCKS(n), CUDA_NUM_THREADS >> >(n,
xf32,
reinterpret_cast<half*>(x));
CUDA_POST_KERNEL_CHECK;
CUDAContext::Delete(xf32);
#else
CUDA_FP16_NOT_COMPILED;
#endif
}
}
/******************** Level-1 ********************/
/******************** Level-1 ********************/
...
@@ -380,10 +389,9 @@ template <> void Scale<float16, CUDAContext>(const int n,
...
@@ -380,10 +389,9 @@ template <> void Scale<float16, CUDAContext>(const int n,
template <> float Dot<float16, CUDAContext>(int n, const float16* a, const float16* b) {
template <> float Dot<float16, CUDAContext>(int n, const float16* a, const float16* b) {
#ifdef WITH_CUDA_FP16
#ifdef WITH_CUDA_FP16
float16 result;
float16 result;
CUBLAS_CHECK(cublasDotEx(cublas_handle(),
CUBLAS_CHECK(cublasDotEx(cublas_handle(), n,
n,
a, CUDA_R_16F, 1,
&a, CUDA_R_16F, 1,
b, CUDA_R_16F, 1,
&b, CUDA_R_16F, 1,
&result, CUDA_R_16F,
&result, CUDA_R_16F,
CUDA_R_32F));
CUDA_R_32F));
return dragon_cast<float, float16>(result);
return dragon_cast<float, float16>(result);
...
@@ -491,6 +499,26 @@ template <> void Axpby<float16, CUDAContext>(const int n,
...
@@ -491,6 +499,26 @@ template <> void Axpby<float16, CUDAContext>(const int n,
Axpy<float16, CUDAContext>(n, alpha, x, y);
Axpy<float16, CUDAContext>(n, alpha, x, y);
}
}
template <> void RandomUniform<float16, CUDAContext>(const int n,
const float low,
const float high,
float16* x) {
#ifdef WITH_CUDA_FP16
float* xf32 = (float*)CUDAContext::New(n * sizeof(float));
CURAND_CHECK(curandGenerateUniform(curand_generator(), xf32, n));
_TypeFloat2Half << <GET_BLOCKS(n), CUDA_NUM_THREADS >> >(n,
xf32,
reinterpret_cast<half*>(x));
CUDA_POST_KERNEL_CHECK;
float range = high - low;
if (range != float(1)) Scal<float16, CUDAContext>(n, range, x);
if (low != float(0)) AddScalar<float16, CUDAContext>(n, low, x);
CUDAContext::Delete(xf32);
#else
CUDA_FP16_NOT_COMPILED;
#endif
}
/******************** Level-3 ********************/
/******************** Level-3 ********************/
template <> void Gemm<float16, CUDAContext>(const CBLAS_TRANSPOSE transA,
template <> void Gemm<float16, CUDAContext>(const CBLAS_TRANSPOSE transA,
...
...
Dragon/src/utils/op_kernel.cc
View file @
fe16154
...
@@ -1838,6 +1838,15 @@ template <> void AdamUpdate<float, CPUContext>(const int count,
...
@@ -1838,6 +1838,15 @@ template <> void AdamUpdate<float, CPUContext>(const int count,
_AdamUpdate
<
float
>
(
count
,
lr
,
beta1
,
beta2
,
eps
,
g
,
m
,
v
);
_AdamUpdate
<
float
>
(
count
,
lr
,
beta1
,
beta2
,
eps
,
g
,
m
,
v
);
}
}
template
<>
void
AdamUpdate
<
float16
,
CPUContext
>
(
const
int
count
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
float16
*
g
,
float16
*
m
,
float16
*
v
)
{
LOG
(
FATAL
)
<<
"float16 is unsupported for CPUContext."
;
}
/******************** update.nesterov_update ********************/
/******************** update.nesterov_update ********************/
template
<
typename
T
>
template
<
typename
T
>
...
@@ -1862,6 +1871,13 @@ template <> void NesterovUpdate<float, CPUContext>(const int count,
...
@@ -1862,6 +1871,13 @@ template <> void NesterovUpdate<float, CPUContext>(const int count,
_NesterovUpdate
<
float
>
(
count
,
lr
,
momentum
,
g
,
h
);
_NesterovUpdate
<
float
>
(
count
,
lr
,
momentum
,
g
,
h
);
}
}
template
<>
void
NesterovUpdate
<
float16
,
CPUContext
>
(
const
int
count
,
const
float
lr
,
const
float
momentum
,
float16
*
g
,
float16
*
h
)
{
LOG
(
FATAL
)
<<
"float16 is unsupported for CPUContext."
;
}
/******************** update.rmsprop_update ********************/
/******************** update.rmsprop_update ********************/
template
<
typename
T
>
template
<
typename
T
>
...
@@ -1888,6 +1904,14 @@ template <> void RMSPropUpdate<float, CPUContext>(const int count,
...
@@ -1888,6 +1904,14 @@ template <> void RMSPropUpdate<float, CPUContext>(const int count,
_RMSPropUpdate
<
float
>
(
count
,
lr
,
decay
,
eps
,
g
,
h
);
_RMSPropUpdate
<
float
>
(
count
,
lr
,
decay
,
eps
,
g
,
h
);
}
}
template
<>
void
RMSPropUpdate
<
float16
,
CPUContext
>
(
const
int
count
,
const
float
lr
,
const
float
decay
,
const
float
eps
,
float16
*
g
,
float16
*
h
)
{
LOG
(
FATAL
)
<<
"float16 is unsupported for CPUContext."
;
}
/******************** update.sgd_update ********************/
/******************** update.sgd_update ********************/
template
<
typename
T
>
template
<
typename
T
>
...
@@ -1911,6 +1935,13 @@ template <> void SGDUpdate<float, CPUContext>(const int count,
...
@@ -1911,6 +1935,13 @@ template <> void SGDUpdate<float, CPUContext>(const int count,
_SGDUpdate
<
float
>
(
count
,
lr
,
momentum
,
g
,
h
);
_SGDUpdate
<
float
>
(
count
,
lr
,
momentum
,
g
,
h
);
}
}
template
<>
void
SGDUpdate
<
float16
,
CPUContext
>
(
const
int
count
,
const
float
lr
,
const
float
momentum
,
float16
*
g
,
float16
*
h
)
{
LOG
(
FATAL
)
<<
"float16 is unsupported for CPUContext."
;
}
/******************** vision.bilinear_resize ********************/
/******************** vision.bilinear_resize ********************/
template
<
typename
T
>
template
<
typename
T
>
...
...
Dragon/src/utils/op_kernel.cu
View file @
fe16154
Dragon/src/utils/op_kernel_fp16.cu
View file @
fe16154
...
@@ -482,6 +482,193 @@ template <> void TransposeGrad<float16, CUDAContext>(const int count,
...
@@ -482,6 +482,193 @@ template <> void TransposeGrad<float16, CUDAContext>(const int count,
#endif
#endif
}
}
/******************** update.adam_update ********************/
#ifdef WITH_CUDA_FP16
__global__ void _AdamUpdateHalf(const int count,
const half lr,
const half beta1,
const half beta2,
const half eps,
half* g, half* m, half* v) {
CUDA_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530
half gi = g[i];
half kOne = __float2half(1.f);
half mi = m[i] = __hadd(__hmul(m[i], beta1), __hmul(gi, __hsub(kOne, beta1)));
half vi = v[i] = __hadd(__hmul(v[i], beta2), __hmul(gi, __hmul(gi, __hsub(kOne, beta2))));
g[i] = __hdiv(__hmul(lr, mi), __hadd(hsqrt(vi), eps));
#endif
}
}
#endif
template <> void AdamUpdate<float16, CUDAContext>(const int count,
const float lr,
const float beta1,
const float beta2,
const float eps,
float16* g, float16* m, float16* v) {
#ifdef WITH_CUDA_FP16
_AdamUpdateHalf << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
dragon_cast<half, float>(lr),
dragon_cast<half, float>(beta1),
dragon_cast<half, float>(beta2),
dragon_cast<half, float>(eps),
reinterpret_cast<half*>(g),
reinterpret_cast<half*>(m),
reinterpret_cast<half*>(v));
CUDA_POST_KERNEL_CHECK;
#else
CUDA_FP16_NOT_COMPILED;
#endif
}
/******************** update.nesterov_update ********************/
#ifdef WITH_CUDA_FP16
__global__ void _NesterovUpdateHalf(const int count,
const half lr,
const half momentum,
half* g, half* h) {
CUDA_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530
half hi = h[i];
half hi_new = h[i] = __hadd(__hmul(momentum, hi), __hmul(lr, g[i]));
half kOne = __float2half(1.f);
g[i] = __hsub(__hmul(__hadd(kOne, momentum), hi_new), __hmul(momentum, hi));
#endif
}
}
__global__ void _NesterovUpdateHalf2(const int count,
const half2 lr,
const half2 momentum,
half2* g, half2* h) {
CUDA_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530
half2 hi = h[i];
half2 hi_new = h[i] = __hadd2(__hmul2(momentum, hi), __hmul2(lr, g[i]));
half2 kOne = __float2half2_rn(1.f);
g[i] = __hsub2(__hmul2(__hadd2(kOne, momentum), hi_new), __hmul2(momentum, hi));
#endif
}
}
#endif
template <> void NesterovUpdate<float16, CUDAContext>(const int count,
const float lr,
const float momentum,
float16* g, float16* h) {
#ifdef WITH_CUDA_FP16
if (count % 2 == 0) {
_NesterovUpdateHalf2 << <GET_BLOCKS(count / 2), CUDA_NUM_THREADS >> >(count / 2,
dragon_cast<half2, float>(lr),
dragon_cast<half2, float>(momentum),
reinterpret_cast<half2*>(g),
reinterpret_cast<half2*>(h));
} else {
_NesterovUpdateHalf << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
dragon_cast<half, float>(lr),
dragon_cast<half, float>(momentum),
reinterpret_cast<half*>(g),
reinterpret_cast<half*>(h));
}
CUDA_POST_KERNEL_CHECK;
#else
CUDA_FP16_NOT_COMPILED;
#endif
}
/******************** update.rmsprop_update ********************/
#ifdef WITH_CUDA_FP16
__global__ void _RMSPropUpdateHalf(const int count,
const half lr,
const half decay,
const half eps,
half* g, half* h) {
CUDA_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530
half gi = g[i];
half kOne = __float2half(1.f);
half hi = h[i] = __hadd(__hmul(decay, h[i]), __hmul(__hmul(__hsub(kOne, decay), gi), gi));
g[i] = __hdiv(__hmul(lr, g[i]), __hadd(hsqrt(hi), eps));
#endif
}
}
#endif
template <> void RMSPropUpdate<float16, CUDAContext>(const int count,
const float lr,
const float decay,
const float eps,
float16* g, float16* h) {
#ifdef WITH_CUDA_FP16
_RMSPropUpdateHalf << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
dragon_cast<half, float>(lr),
dragon_cast<half, float>(decay),
dragon_cast<half, float>(eps),
reinterpret_cast<half*>(g),
reinterpret_cast<half*>(h));
CUDA_POST_KERNEL_CHECK;
#else
CUDA_FP16_NOT_COMPILED;
#endif
}
/******************** update.sgd_update ********************/
#ifdef WITH_CUDA_FP16
__global__ void _SGDUpdateHalf(const int count,
const half lr,
const half momentum,
half* g, half* h) {
CUDA_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530
half hi = h[i];
g[i] = h[i] = __hadd(__hmul(momentum, hi), __hmul(lr, g[i]));
#endif
}
}
__global__ void _SGDUpdateHalf2(const int count,
const half2 lr,
const half2 momentum,
half2* g, half2* h) {
CUDA_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530
half2 hi = h[i];
g[i] = h[i] = __hadd2(__hmul2(momentum, hi), __hmul2(lr, g[i]));
#endif
}
}
#endif
template <> void SGDUpdate<float16, CUDAContext>(const int count,
const float lr,
const float momentum,
float16* g, float16* h) {
#ifdef WITH_CUDA_FP16
if (count % 2 == 0) {
_SGDUpdateHalf2 << <GET_BLOCKS(count / 2), CUDA_NUM_THREADS >> >(count / 2,
dragon_cast<half2, float>(lr),
dragon_cast<half2, float>(momentum),
reinterpret_cast<half2*>(g),
reinterpret_cast<half2*>(h));
} else {
_SGDUpdateHalf << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
dragon_cast<half, float>(lr),
dragon_cast<half, float>(momentum),
reinterpret_cast<half*>(g),
reinterpret_cast<half*>(h));
}
CUDA_POST_KERNEL_CHECK;
#else
CUDA_FP16_NOT_COMPILED;
#endif
}
} // namespace kernel
} // namespace kernel
} // namespace dragon
} // namespace dragon
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment