Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit a739c49b
authored
Apr 16, 2019
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Optimize the CuDNNDepthwiseConv2d
1 parent
0936a502
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
374 additions
and
224 deletions
Dragon/include/operators/vision/conv_op.h
Dragon/include/operators/vision/depthwise_conv_op.h
Dragon/include/utils/op_kernel.h
Dragon/python/dragon/core/helper.py
Dragon/python/dragon/core/tensor.py
Dragon/python/dragon/core/workspace.py
Dragon/python/dragon/operators/activation.py
Dragon/python/dragon/operators/arithmetic.py
Dragon/python/dragon/operators/array.py
Dragon/python/dragon/operators/data.py
Dragon/python/dragon/operators/loss.py
Dragon/python/dragon/operators/norm.py
Dragon/python/dragon/operators/recurrent.py
Dragon/python/dragon/operators/vision.py
Dragon/python/dragon/utils/vision/data_batch.py
Dragon/python/dragon/utils/vision/data_transformer.py
Dragon/python/dragon/vm/tensorflow/ops/standard_ops.py
Dragon/src/kernels/recurrent/lstm_cell_op_kernel.cc
Dragon/src/kernels/recurrent/lstm_cell_op_kernel.cu
Dragon/src/operators/array/flatten_op.cc
Dragon/src/operators/recurrent/lstm_cell_op.cc
Dragon/src/operators/vision/cudnn_depthwise_conv2d_op.cc
Dragon/include/operators/vision/conv_op.h
View file @
a739c49
...
...
@@ -107,7 +107,7 @@ class CuDNNConv2dOp final : public Conv2dOp<Context> {
};
template
<
class
Context
>
class
CuDNNConv2dGradientOp
final
:
public
Conv2dGradientOp
<
Context
>
{
class
CuDNNConv2dGradientOp
:
public
Conv2dGradientOp
<
Context
>
{
public
:
CuDNNConv2dGradientOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Conv2dGradientOp
<
Context
>
(
def
,
ws
),
...
...
Dragon/include/operators/vision/depthwise_conv_op.h
View file @
a739c49
...
...
@@ -13,7 +13,7 @@
#ifndef DRAGON_OPERATORS_VISION_DEPTHWISE_CONV_OP_H_
#define DRAGON_OPERATORS_VISION_DEPTHWISE_CONV_OP_H_
#include "operators/vision/conv_op
_base
.h"
#include "operators/vision/conv_op.h"
namespace
dragon
{
...
...
@@ -62,10 +62,10 @@ template <class Context>
class
CuDNNDepthwiseConv2dOp
final
:
public
DepthwiseConv2dOp
<
Context
>
{
public
:
CuDNNDepthwiseConv2dOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
DepthwiseConv2dOp
<
Context
>
(
def
,
ws
)
{
CuDNNDepthwiseConv2dOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
DepthwiseConv2dOp
<
Context
>
(
def
,
ws
)
{
CUDNN_CHECK
(
cudnnCreateTensorDescriptor
(
&
bias_desc
));
CUDNN_CHECK
(
cudnnCreateTensorDescriptor
(
&
output_desc
));
}
...
...
@@ -86,12 +86,12 @@ class CuDNNDepthwiseConv2dOp final
template
<
class
Context
>
class
CuDNNDepthwiseConv2dGradientOp
final
:
public
Depthwise
Conv2dGradientOp
<
Context
>
{
:
public
CuDNN
Conv2dGradientOp
<
Context
>
{
public
:
CuDNNDepthwiseConv2dGradientOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Depthwise
Conv2dGradientOp
<
Context
>
(
def
,
ws
)
{
CuDNNDepthwiseConv2dGradientOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
CuDNN
Conv2dGradientOp
<
Context
>
(
def
,
ws
)
{
CUDNN_CHECK
(
cudnnCreateTensorDescriptor
(
&
bias_desc
));
CUDNN_CHECK
(
cudnnCreateTensorDescriptor
(
&
input_desc
));
}
...
...
Dragon/include/utils/op_kernel.h
View file @
a739c49
...
...
@@ -941,22 +941,20 @@ void GroupNormBackward(
template
<
typename
T
,
class
Context
>
void
LSTMCell
(
const
int
count
,
const
int
N
,
const
int
C
,
const
T
*
cx
,
T
*
xact
,
T
*
actx
,
T
*
c
,
T
*
h
,
Context
*
ctx
);
template
<
typename
T
,
class
Context
>
void
LSTMCellGrad
(
const
int
count
,
const
int
N
,
const
int
C
,
const
T
*
cx
,
const
T
*
xact
,
const
T
*
actx
,
const
T
*
c
,
const
T
*
dc
,
const
T
*
dh
,
...
...
Dragon/python/dragon/core/helper.py
View file @
a739c49
...
...
@@ -637,6 +637,35 @@ class OperatorHelper(object):
return
outputs
@classmethod
def
_apply_Flatten
(
cls
,
arguments
,
inputs
,
outputs
):
outputs
[
0
]
.
dtype
=
inputs
[
0
]
.
dtype
keep_axes
=
arguments
[
'keep_axes'
]
axis
,
num_axes
=
arguments
[
'axis'
],
arguments
[
'num_axes'
]
try
:
fake_shape
=
inputs
[
0
]
.
shape
[:]
fake_shape
=
[
1
if
dim
is
None
else
dim
for
dim
in
fake_shape
]
if
keep_axes
is
not
None
:
keep_axes
=
min
(
keep_axes
,
len
(
inputs
.
shape
))
total_count
=
numpy
.
prod
(
fake_shape
)
outputs
[
0
]
.
shape
=
[]
for
i
in
range
(
keep_axes
-
1
):
outputs
[
0
]
.
shape
.
append
(
inputs
[
0
]
.
shape
[
i
])
total_count
*=
fake_shape
[
i
]
if
total_count
!=
1
:
outputs
[
0
]
.
shape
.
append
(
total_count
)
else
:
if
num_axes
==
-
1
:
num_axes
=
len
(
inputs
[
0
]
.
shape
)
-
axis
num_axes
=
max
(
num_axes
,
1
)
num_flatten
=
numpy
.
prod
(
fake_shape
[
axis
:
axis
+
num_axes
])
outputs
[
0
]
.
shape
=
\
inputs
[
0
]
.
shape
[:
axis
]
+
[
num_flatten
]
\
+
inputs
[
0
]
.
shape
[
axis
+
num_axes
:]
except
:
pass
return
outputs
@classmethod
def
_apply_Reshape
(
cls
,
arguments
,
inputs
,
outputs
):
outputs
[
0
]
.
dtype
=
inputs
[
0
]
.
dtype
shape
=
arguments
[
'dims'
]
...
...
Dragon/python/dragon/core/tensor.py
View file @
a739c49
...
...
@@ -507,17 +507,21 @@ class Tensor(object):
def
_from_constants
(
self
,
value
):
if
not
isinstance
(
value
,
numpy
.
ndarray
):
try
:
value
=
numpy
.
array
(
value
,
dtype
=
self
.
dtype
if
self
.
dtype
else
'float32'
)
value
=
numpy
.
array
(
value
,
dtype
=
self
.
dtype
if
self
.
dtype
else
'float32'
)
except
:
raise
TypeError
(
'Can not convert the value to Tensor or numpy array.'
)
ref_tensor
=
Tensor
.
Ref
(
name
=
_workspace
.
GetDummyName
(
'Constant'
,
domain
=
'Tensor'
,
zero_based
=
False
),
shape
=
list
(
value
.
shape
),
dtype
=
str
(
value
.
dtype
))
ref_tensor
.
set_value
(
value
)
return
ref_tensor
return
Tensor
.
Ref
(
name
=
_workspace
.
GetDummyName
(
basename
=
'Constant'
,
domain
=
'Tensor'
,
zero_based
=
False
,
),
shape
=
list
(
value
.
shape
),
dtype
=
str
(
value
.
dtype
),
)
.
set_value
(
value
)
def
__add__
(
self
,
other
):
"""Calculate x + y.
...
...
Dragon/python/dragon/core/workspace.py
View file @
a739c49
...
...
@@ -770,8 +770,9 @@ def _stringify_tensor(obj):
class
_DefaultWorkspaceStack
(
_tls
.
Stack
):
"""A thread-local stack of objects for
providing an implicit default workspace.
"""
providing an implicit default workspace.
"""
def
__init__
(
self
):
super
(
_DefaultWorkspaceStack
,
self
)
.
__init__
()
self
.
_global_default_workspace
=
None
...
...
Dragon/python/dragon/operators/activation.py
View file @
a739c49
...
...
@@ -165,7 +165,7 @@ def Tanh(inputs, **kwargs):
@OpSchema.Inputs
(
1
)
@ArgumentHelper.Desc
(
'prob'
,
as_target
=
Fals
e
)
@ArgumentHelper.Desc
(
'prob'
,
as_target
=
Tru
e
)
def
Dropout
(
inputs
,
prob
=
0.5
,
scale
=
True
,
**
kwargs
):
"""Randomly set a unit into zero. `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
...
...
Dragon/python/dragon/operators/arithmetic.py
View file @
a739c49
...
...
@@ -516,4 +516,4 @@ def MovingAverage(inputs, decay, **kwargs):
The outputs, i.e., the *y*.
"""
return
Accumulate
(
inputs
,
1
-
decay
,
decay
,
**
kwargs
)
\ No newline at end of file
return
Accumulate
(
inputs
,
1.
-
decay
,
decay
,
**
kwargs
)
\ No newline at end of file
Dragon/python/dragon/operators/array.py
View file @
a739c49
...
...
@@ -38,9 +38,11 @@ def Gather(inputs, indices, axis=0, **kwargs):
"""
arguments
=
ParseArgs
(
locals
())
if
not
isinstance
(
indices
,
Tensor
):
indices
=
Tensor
.
Ref
(
''
,
dtype
=
'int64'
)
\
.
_from_constants
(
indices
)
arguments
[
'inputs'
],
arguments
[
'indices'
]
=
\
[
arguments
[
'inputs'
],
Tensor
.
Convert
(
indices
,
dtype
=
'int64'
)],
None
[
arguments
[
'inputs'
],
indices
],
None
return
Tensor
.
CreateOperator
(
'Gather'
,
**
arguments
)
...
...
@@ -48,9 +50,13 @@ def Gather(inputs, indices, axis=0, **kwargs):
@ArgumentHelper.RepeatedDesc
(
'starts'
)
@ArgumentHelper.RepeatedDesc
(
'sizes'
)
def
Crop
(
inputs
,
starts
=
None
,
sizes
=
None
,
start_axis
=
None
,
offsets
=
None
,
shape_like
=
None
,
**
kwargs
inputs
,
starts
=
None
,
sizes
=
None
,
start_axis
=
None
,
offsets
=
None
,
shape_like
=
None
,
**
kwargs
):
"""Crop the input according to the given starts and sizes.
...
...
@@ -274,7 +280,14 @@ def Mean(inputs, axes=None, keep_dims=False, **kwargs):
@OpSchema.Inputs
(
1
)
def
_ArgReduce
(
inputs
,
axis
=
None
,
operation
=
'ARGMAX'
,
top_k
=
1
,
keep_dims
=
False
,
**
kwargs
):
def
_ArgReduce
(
inputs
,
axis
=
None
,
operation
=
'ARGMAX'
,
top_k
=
1
,
keep_dims
=
False
,
**
kwargs
):
arguments
=
ParseArgs
(
locals
())
arguments
[
'axis'
]
=
arguments
[
'axis'
]
if
arguments
else
INT_MAX
return
Tensor
.
CreateOperator
(
'ArgReduce'
,
num_outputs
=
2
,
**
arguments
)
...
...
@@ -577,33 +590,7 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
>>> [24]
"""
arguments
=
ParseArgs
(
locals
())
output
=
Tensor
.
CreateOperator
(
op_type
=
'Flatten'
,
**
arguments
)
if
inputs
.
shape
is
not
None
:
fake_shape
=
inputs
.
shape
[:]
fake_shape
=
[
1
if
dim
is
None
else
dim
for
dim
in
fake_shape
]
if
keep_axes
is
not
None
:
if
keep_axes
>
len
(
inputs
.
shape
):
raise
ValueError
(
'The total number of axes is {}, can not keep {}.'
.
format
(
len
(
inputs
.
shape
),
keep_axes
))
total_count
=
np
.
prod
(
fake_shape
)
output
.
shape
=
[]
for
i
in
range
(
keep_axes
-
1
):
output
.
shape
.
append
(
inputs
.
shape
[
i
])
total_count
*=
fake_shape
[
i
]
if
total_count
!=
1
:
output
.
shape
.
append
(
total_count
)
else
:
if
num_axes
==
-
1
:
num_axes
=
len
(
inputs
.
shape
)
-
axis
elif
num_axes
==
0
:
raise
ValueError
(
'num_axes must > 0 or be -1.'
)
num_flatten
=
np
.
prod
(
fake_shape
[
axis
:
axis
+
num_axes
])
output
.
shape
=
inputs
.
shape
[:
axis
]
+
[
num_flatten
]
+
inputs
.
shape
[
axis
+
num_axes
:]
return
output
return
Tensor
.
CreateOperator
(
op_type
=
'Flatten'
,
**
ParseArgs
(
locals
()))
@OpSchema.Inputs
(
1
)
...
...
@@ -676,20 +663,7 @@ def Squeeze(inputs, axis=None, **kwargs):
>>> print(Squeeze(a, axis=0).shape)
"""
arguments
=
ParseArgs
(
locals
())
output
=
Tensor
.
CreateOperator
(
op_type
=
'Squeeze'
,
**
arguments
)
if
inputs
.
shape
is
not
None
:
output_shape
=
[]
if
axis
:
axis
+=
(
0
if
axis
>=
0
else
len
(
inputs
.
shape
))
for
idx
,
dim
in
enumerate
(
inputs
.
shape
[:]):
if
dim
!=
1
or
\
(
axis
and
dim
==
1
and
idx
!=
axis
):
output_shape
.
append
(
dim
)
output
.
shape
=
output_shape
return
output
return
Tensor
.
CreateOperator
(
op_type
=
'Squeeze'
,
**
ParseArgs
(
locals
()))
@OpSchema.Inputs
(
1
)
...
...
Dragon/python/dragon/operators/data.py
View file @
a739c49
...
...
@@ -84,8 +84,13 @@ def LMDBData(**kwargs):
@OpSchema.Inputs
(
1
)
def
ImageData
(
inputs
,
mean_values
=
None
,
std_values
=
None
,
dtype
=
'float32'
,
data_format
=
'NCHW'
,
**
kwargs
):
inputs
,
mean_values
=
None
,
std_values
=
None
,
dtype
=
'float32'
,
data_format
=
'NCHW'
,
**
kwargs
):
"""Process the images from 4D raw data.
Note that we assume the data format of raw data is **NHWC**.
...
...
Dragon/python/dragon/operators/loss.py
View file @
a739c49
...
...
@@ -19,8 +19,12 @@ from .activation import Softmax
@OpSchema.Inputs
(
2
)
def
NLLLoss
(
inputs
,
axis
=
1
,
normalization
=
'VALID'
,
ignore_labels
=
(),
**
kwargs
):
inputs
,
axis
=
1
,
normalization
=
'VALID'
,
ignore_labels
=
(),
**
kwargs
):
"""Compute the negative likelihood loss with sparse labels.
**Type Constraints**:
...
...
@@ -36,7 +40,7 @@ def NLLLoss(
axis : int, optional
The axis to apply softmax, can be negative.
normalization : {'UNIT', 'FULL', 'VALID', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
ignore_labels : sequence of int, optional, default=()
The label id to ignore.
...
...
@@ -55,8 +59,12 @@ def NLLLoss(
@OpSchema.Inputs
(
2
)
def
SparseSoftmaxCrossEntropy
(
inputs
,
axis
=
1
,
normalization
=
'VALID'
,
ignore_labels
=
(),
**
kwargs
):
inputs
,
axis
=
1
,
normalization
=
'VALID'
,
ignore_labels
=
(),
**
kwargs
):
"""Compute the softmax cross entropy with sparse labels.
**Type Constraints**:
...
...
@@ -72,7 +80,7 @@ def SparseSoftmaxCrossEntropy(
axis : int, optional
The axis to apply softmax, can be negative.
normalization : {'UNIT', 'FULL', 'VALID', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
ignore_labels : sequence of int, optional, default=()
The label id to ignore.
...
...
@@ -100,7 +108,7 @@ def SigmoidCrossEntropy(inputs, normalization='VALID', **kwargs):
inputs : sequence of Tensor
The inputs, represent [logits, targets].
normalization : {'UNIT', 'FULL', 'VALID', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
Returns
-------
...
...
@@ -128,7 +136,7 @@ def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs):
axis : int, optional
The axis to apply softmax, can be negative.
normalization : {'UNIT', 'FULL', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
Returns
-------
...
...
@@ -158,7 +166,7 @@ def SmoothL1Loss(inputs, beta=1.0, normalization='BATCH_SIZE', **kwargs):
beta : float, optional
The transition point from L1 to L2 loss
normalization : {'FULL', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
Returns
-------
...
...
@@ -182,7 +190,7 @@ def L1Loss(inputs, scale=1., normalization='BATCH_SIZE', **kwargs):
scale : float, optional
The scale factor applying on the reduced loss.
normalization : {'FULL', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
Returns
-------
...
...
@@ -206,7 +214,7 @@ def L2Loss(inputs, scale=1., normalization='BATCH_SIZE', **kwargs):
scale : float, optional
The scale factor applying on the reduced loss.
normalization : {'FULL', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
Returns
-------
...
...
@@ -219,8 +227,14 @@ def L2Loss(inputs, scale=1., normalization='BATCH_SIZE', **kwargs):
@OpSchema.Inputs
(
2
)
def
SigmoidFocalLoss
(
inputs
,
axis
=
1
,
normalization
=
'VALID'
,
alpha
=
0.25
,
gamma
=
2.0
,
neg_id
=
0
,
**
kwargs
):
inputs
,
axis
=
1
,
normalization
=
'VALID'
,
alpha
=
0.25
,
gamma
=
2.0
,
neg_id
=
0
,
**
kwargs
):
"""Compute the sigmoid focal loss with sparse labels. `[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`_.
**Type Constraints**: *float32*
...
...
@@ -232,7 +246,7 @@ def SigmoidFocalLoss(
axis : int, optional
The axis to apply softmax, can be negative.
normalization : {'UNIT', 'FULL', 'VALID', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
alpha : float, optional, default=0.25
The scale factor on the rare class.
gamma : float, optional, default=2.0
...
...
@@ -255,8 +269,15 @@ def SigmoidFocalLoss(
@OpSchema.Inputs
(
2
)
def
SoftmaxFocalLoss
(
inputs
,
axis
=
1
,
normalization
=
'VALID'
,
ignore_labels
=
(),
alpha
=
0.25
,
gamma
=
2.0
,
neg_id
=
0
,
**
kwargs
):
inputs
,
axis
=
1
,
normalization
=
'VALID'
,
ignore_labels
=
(),
alpha
=
0.25
,
gamma
=
2.0
,
neg_id
=
0
,
**
kwargs
):
"""Compute the softmax focal loss with sparse labels. `[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`_.
**Type Constraints**: *float32*
...
...
@@ -268,7 +289,7 @@ def SoftmaxFocalLoss(
axis : int, optional
The axis to apply softmax, can be negative.
normalization : {'UNIT', 'FULL', 'VALID', 'BATCH_SIZE', 'NONE'}, optional
The
normalization method
.
The
method of normalization
.
ignore_labels : sequence of int, optional, default=()
The label id to ignore.
alpha : float, optional, default=0.25
...
...
@@ -293,8 +314,12 @@ def SoftmaxFocalLoss(
@OpSchema.Inputs
(
2
)
def
CTCLoss
(
inputs
,
blank_first
=
True
,
padding_mask
=-
1
,
use_softmax
=
True
,
**
kwargs
):
inputs
,
blank_first
=
True
,
padding_mask
=-
1
,
use_softmax
=
True
,
**
kwargs
):
"""Compute the ctc loss with batched variable length of labels. `[Graves & Gomez, 2006] <http://www.cs.utoronto.ca/~graves/icml_2006.pdf>`_.
The data format of inputs should be *[T, N, C]*.
...
...
@@ -329,5 +354,6 @@ def CTCLoss(
"""
arguments
=
ParseArgs
(
locals
())
if
use_softmax
:
arguments
[
'inputs'
][
0
]
=
Softmax
(
arguments
[
'inputs'
][
0
],
axis
=
2
)
if
use_softmax
:
arguments
[
'inputs'
][
0
]
=
\
Softmax
(
arguments
[
'inputs'
][
0
],
axis
=
2
)
return
Tensor
.
CreateOperator
(
'CTCLoss'
,
**
arguments
)
\ No newline at end of file
Dragon/python/dragon/operators/norm.py
View file @
a739c49
...
...
@@ -18,8 +18,13 @@ from . import *
@OpSchema.Inputs
(
5
)
def
BatchNorm
(
inputs
,
axis
=-
1
,
momentum
=
0.9
,
eps
=
1e-5
,
use_stats
=-
1
,
**
kwargs
):
inputs
,
axis
=-
1
,
momentum
=
0.9
,
eps
=
1e-5
,
use_stats
=-
1
,
**
kwargs
):
"""Batch Normalization. `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
We enforce the number of inputs should be *5*, i.e.,
...
...
Dragon/python/dragon/operators/recurrent.py
View file @
a739c49
...
...
@@ -29,8 +29,16 @@ class RNN(RNNBase):
>>> outputs, hidden = rnn(x)
"""
def
__init__
(
self
,
input_size
,
hidden_size
,
nonlinearity
=
'relu'
,
num_layers
=
1
,
bidirectional
=
False
,
dropout
=
0
,
name
=
None
):
def
__init__
(
self
,
input_size
,
hidden_size
,
nonlinearity
=
'relu'
,
num_layers
=
1
,
bidirectional
=
False
,
dropout
=
0
,
name
=
None
,
):
"""Construct a RNN instance.
Parameters
...
...
@@ -57,8 +65,10 @@ class RNN(RNNBase):
"""
mode
=
'rnn_relu'
if
nonlinearity
==
'relu'
else
'rnn_tanh'
super
(
RNN
,
self
)
.
__init__
(
mode
,
input_size
,
hidden_size
,
num_layers
,
bidirectional
,
dropout
,
name
)
super
(
RNN
,
self
)
.
__init__
(
mode
,
input_size
,
hidden_size
,
num_layers
,
bidirectional
,
dropout
,
name
,
)
class
LSTM
(
RNNBase
):
...
...
@@ -73,8 +83,15 @@ class LSTM(RNNBase):
>>> outputs, hidden = rnn(x)
"""
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
bidirectional
=
False
,
dropout
=
0
,
name
=
None
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
bidirectional
=
False
,
dropout
=
0
,
name
=
None
,
):
"""Construct a LSTM instance.
Parameters
...
...
@@ -98,8 +115,10 @@ class LSTM(RNNBase):
The wrapper of general RNN.
"""
super
(
LSTM
,
self
)
.
__init__
(
'lstm'
,
input_size
,
hidden_size
,
num_layers
,
bidirectional
,
dropout
,
name
)
super
(
LSTM
,
self
)
.
__init__
(
'lstm'
,
input_size
,
hidden_size
,
num_layers
,
bidirectional
,
dropout
,
name
,
)
class
GRU
(
RNNBase
):
...
...
@@ -114,8 +133,15 @@ class GRU(RNNBase):
>>> outputs, hidden = rnn(x)
"""
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
bidirectional
=
False
,
dropout
=
0
,
name
=
None
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
bidirectional
=
False
,
dropout
=
0
,
name
=
None
,
):
"""Construct a GRU instance.
Parameters
...
...
@@ -139,8 +165,10 @@ class GRU(RNNBase):
The wrapper of general RNN.
"""
super
(
GRU
,
self
)
.
__init__
(
'gru'
,
input_size
,
hidden_size
,
num_layers
,
bidirectional
,
dropout
,
name
)
super
(
GRU
,
self
)
.
__init__
(
'gru'
,
input_size
,
hidden_size
,
num_layers
,
bidirectional
,
dropout
,
name
,
)
@OpSchema.Inputs
(
2
)
...
...
@@ -160,4 +188,5 @@ def LSTMCell(inputs, **kwargs):
The outputs, ``h`` and ``c`` respectively.
"""
return
Tensor
.
CreateOperator
(
'LSTMCell'
,
num_outputs
=
2
,
**
ParseArgs
(
locals
()))
\ No newline at end of file
return
Tensor
.
CreateOperator
(
'LSTMCell'
,
num_outputs
=
2
,
**
ParseArgs
(
locals
()))
\ No newline at end of file
Dragon/python/dragon/operators/vision.py
View file @
a739c49
...
...
@@ -33,9 +33,17 @@ def _normalize_pads(value, rank):
@OpSchema.Inputs
(
2
,
3
)
def
Conv2d
(
inputs
,
num_output
,
kernel_shape
,
strides
=
1
,
pads
=
0
,
dilations
=
1
,
group
=
1
,
padding
=
'VALID'
,
data_format
=
'NCHW'
,
**
kwargs
):
inputs
,
num_output
,
kernel_shape
,
strides
=
1
,
pads
=
0
,
dilations
=
1
,
group
=
1
,
padding
=
'VALID'
,
data_format
=
'NCHW'
,
**
kwargs
):
"""2D Convolution.
The spatial output dimension of convolution can be computed as follows:
...
...
@@ -99,8 +107,15 @@ def Conv2d(
@OpSchema.Inputs
(
2
,
3
)
def
DepthwiseConv2d
(
inputs
,
num_output
,
kernel_shape
=
3
,
strides
=
1
,
pads
=
0
,
padding
=
'VALID'
,
data_format
=
'NCHW'
,
**
kwargs
):
inputs
,
num_output
,
kernel_shape
=
3
,
strides
=
1
,
pads
=
0
,
padding
=
'VALID'
,
data_format
=
'NCHW'
,
**
kwargs
):
"""Depthwise 2D Convolution. `[Chollet, 2016] <https://arxiv.org/abs/1610.02357>`_.
Set ``padding`` to *VALID* will use the value of ``pads``.
...
...
@@ -149,10 +164,19 @@ def DepthwiseConv2d(
@ArgumentHelper.RepeatedDesc
(
'output_padding'
)
@ArgumentHelper.RepeatedDesc
(
'output_shape'
)
def
ConvTranspose2d
(
inputs
,
num_output
,
kernel_shape
,
strides
=
1
,
pads
=
0
,
dilations
=
1
,
group
=
1
,
output_padding
=
None
,
output_shape
=
None
,
padding
=
'VALID'
,
data_format
=
'NCHW'
,
**
kwargs
):
inputs
,
num_output
,
kernel_shape
,
strides
=
1
,
pads
=
0
,
dilations
=
1
,
group
=
1
,
output_padding
=
None
,
output_shape
=
None
,
padding
=
'VALID'
,
data_format
=
'NCHW'
,
**
kwargs
):
"""2D Deconvolution.
The spatial output dimension of deconvolution can be computed as follows:
...
...
@@ -224,8 +248,17 @@ def ConvTranspose2d(
@OpSchema.Inputs
(
1
)
def
Pool2d
(
inputs
,
kernel_shape
,
strides
,
pads
=
0
,
padding
=
'VALID'
,
ceil_mode
=
True
,
mode
=
'MAX'
,
data_format
=
'NCHW'
,
global_pooling
=
False
,
**
kwargs
):
inputs
,
kernel_shape
,
strides
,
pads
=
0
,
padding
=
'VALID'
,
ceil_mode
=
True
,
mode
=
'MAX'
,
data_format
=
'NCHW'
,
global_pooling
=
False
,
**
kwargs
):
"""2D Pooling, MAX or AVG.
The spatial output dimension of pooling can be computed as follows:
...
...
@@ -308,7 +341,14 @@ def ROIPool(inputs, pool_h, pool_w, spatial_scale=1.0, **kwargs):
@OpSchema.Inputs
(
2
)
def
ROIAlign
(
inputs
,
pool_h
=
0
,
pool_w
=
0
,
spatial_scale
=
1.0
,
sampling_ratio
=
2
,
**
kwargs
):
def
ROIAlign
(
inputs
,
pool_h
=
0
,
pool_w
=
0
,
spatial_scale
=
1.0
,
sampling_ratio
=
2
,
**
kwargs
):
"""AVG RoIAlign. `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
**Type Constraints**: (*float16*, *float32*)
...
...
@@ -337,8 +377,15 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, sampling_ratio=2, **
@OpSchema.Inputs
(
1
)
def
LRN
(
inputs
,
local_size
=
5
,
alpha
=
0.0001
,
beta
=
0.75
,
k
=
2.0
,
mode
=
'ACROSS_CHANNELS'
,
data_format
=
'NCHW'
,
**
kwargs
):
inputs
,
local_size
=
5
,
alpha
=
0.0001
,
beta
=
0.75
,
k
=
2.0
,
mode
=
'ACROSS_CHANNELS'
,
data_format
=
'NCHW'
,
**
kwargs
):
"""Local Response Normalization. `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_.
**Type Constraints**: (*float16*, *float32*)
...
...
@@ -379,8 +426,14 @@ def LRN(
@OpSchema.Inputs
(
1
)
@ArgumentHelper.RepeatedDesc
(
'dsize'
)
def
NNResize
(
inputs
,
dsize
,
shape_like
=
None
,
fy
=-
1.0
,
fx
=-
1.0
,
data_format
=
'NCHW'
,
**
kwargs
):
inputs
,
dsize
,
shape_like
=
None
,
fy
=-
1.0
,
fx
=-
1.0
,
data_format
=
'NCHW'
,
**
kwargs
):
"""Resize the image with Nearest-Neighbor method.
Set ``dsize`` to None if you want to use ``shape_like`` or ``fy/fx``.
...
...
@@ -430,8 +483,14 @@ def NNResize(
@OpSchema.Inputs
(
1
)
@ArgumentHelper.RepeatedDesc
(
'dsize'
)
def
BilinearResize
(
inputs
,
dsize
,
shape_like
=
None
,
fy
=-
1.0
,
fx
=-
1.0
,
data_format
=
'NCHW'
,
**
kwargs
):
inputs
,
dsize
,
shape_like
=
None
,
fy
=-
1.0
,
fx
=-
1.0
,
data_format
=
'NCHW'
,
**
kwargs
):
"""Resize the image with Bi-linear method.
Set ``dsize`` to None if you want to use ``shape_like`` or ``fy/fx``.
...
...
@@ -508,8 +567,14 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
@OpSchema.Inputs
(
1
)
@ArgumentHelper.Desc
(
'keep_prob'
,
as_target
=
False
)
def
DropBlock2d
(
inputs
,
block_size
=
7
,
keep_prob
=
0.9
,
alpha
=
1.
,
decrement
=
0.
,
data_format
=
'NCHW'
,
**
kwargs
):
inputs
,
block_size
=
7
,
keep_prob
=
0.9
,
alpha
=
1.
,
decrement
=
0.
,
data_format
=
'NCHW'
,
**
kwargs
):
"""Randomly drop the outputs according to the spatial blocks. `[Ghiasi et.al, 2018] <https://arxiv.org/abs/1810.12890>`_.
Set the ``decrement`` to schedule ``keep_prob`` for each iteration.
...
...
Dragon/python/dragon/utils/vision/data_batch.py
View file @
a739c49
...
...
@@ -53,6 +53,8 @@ class DataBatch(object):
The value to fill when padding is valid.
crop_size : int, optional, default=0
The cropping size.
cutout_size : int, optional, default=0
The square size to cutout.
mirror : bool, optional, default=False
Whether to mirror(flip horizontally) images.
color_augmentation : bool, optional, default=False
...
...
Dragon/python/dragon/utils/vision/data_transformer.py
View file @
a739c49
...
...
@@ -47,6 +47,8 @@ class DataTransformer(multiprocessing.Process):
The value to fill when padding is valid.
crop_size : int, optional, default=0
The cropping size.
cutout_size : int, optional, default=0
The square size to cutout.
mirror : bool, optional, default=False
Whether to mirror(flip horizontally) images.
color_augmentation : bool, optional, default=False
...
...
@@ -65,6 +67,7 @@ class DataTransformer(multiprocessing.Process):
self
.
_padding
=
kwargs
.
get
(
'padding'
,
0
)
self
.
_fill_value
=
kwargs
.
get
(
'fill_value'
,
127
)
self
.
_crop_size
=
kwargs
.
get
(
'crop_size'
,
0
)
self
.
_cutout_size
=
kwargs
.
get
(
'cutout_size'
,
0
)
self
.
_mirror
=
kwargs
.
get
(
'mirror'
,
False
)
self
.
_color_aug
=
kwargs
.
get
(
'color_augmentation'
,
False
)
self
.
_min_random_scale
=
kwargs
.
get
(
'min_random_scale'
,
1.0
)
...
...
@@ -127,6 +130,13 @@ class DataTransformer(multiprocessing.Process):
im
=
im
[
h_off
:
h_off
+
self
.
_crop_size
,
w_off
:
w_off
+
self
.
_crop_size
,
:]
# CutOut
if
self
.
_cutout_size
>
0
:
h_off
=
numpy
.
random
.
randint
(
im
.
shape
[
0
])
w_off
=
numpy
.
random
.
randint
(
im
.
shape
[
1
])
im
[
h_off
:
h_off
+
self
.
_cutout_size
,
w_off
:
w_off
+
self
.
_cutout_size
,
:]
=
self
.
_fill_value
# Random mirror
if
self
.
_mirror
:
if
numpy
.
random
.
randint
(
0
,
2
)
>
0
:
...
...
Dragon/python/dragon/vm/tensorflow/ops/standard_ops.py
View file @
a739c49
...
...
@@ -22,5 +22,4 @@ from dragon.vm.tensorflow.ops.random_ops import *
from
dragon.vm.tensorflow.ops.math_ops
import
*
from
dragon.vm.tensorflow.ops.array_ops
import
*
from
dragon.vm.tensorflow.ops.control_flow_ops
import
*
from
dragon.vm.tensorflow.ops.nn_ops
import
*
from
dragon.vm.tensorflow.ops.gradients_impl
import
gradients
\ No newline at end of file
Dragon/src/kernels/recurrent/lstm_cell_op_kernel.cc
View file @
a739c49
...
...
@@ -8,40 +8,39 @@ namespace kernel {
/*! LSTMCell <T = float32, Device = CPU> */
template
<
typename
T
>
T
_
SigmoidUnit
(
T
x
)
{
return
T
(
1
)
/
(
T
(
1
)
+
exp
(
-
x
));
}
T
_
s
(
T
x
)
{
return
T
(
1
)
/
(
T
(
1
)
+
exp
(
-
x
));
}
template
<>
void
LSTMCell
<
float
,
CPUContext
>
(
const
int
count
,
const
int
N
,
const
int
C
,
const
float
*
cx
,
float
*
xact
,
float
*
actx
,
float
*
c
,
float
*
h
,
CPUContext
*
ctx
)
{
float
i
,
f
,
o
,
c_
;
int
f_offset
=
C
,
o_offset
=
2
*
C
,
c_offset
=
3
*
C
,
x_offset
=
4
*
C
;
int
f_offset
=
C
,
o_offset
=
2
*
C
,
c_offset
=
3
*
C
,
x_offset
=
4
*
C
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
idx
=
0
;
idx
<
C
;
++
idx
)
{
xact
[
idx
]
=
i
=
_SigmoidUnit
<
float
>
(
xact
[
idx
]);
xact
[
idx
+
f_offset
]
=
f
=
_SigmoidUnit
<
float
>
(
xact
[
idx
+
f_offset
]);
xact
[
idx
+
o_offset
]
=
o
=
_SigmoidUnit
<
float
>
(
xact
[
idx
+
o_offset
]);
xact
[
idx
+
c_offset
]
=
c_
=
tanh
(
xact
[
idx
+
c_offset
]);
actx
[
idx
]
=
i
=
_s
<
float
>
(
actx
[
idx
]);
actx
[
idx
+
f_offset
]
=
f
=
_s
<
float
>
(
actx
[
idx
+
f_offset
]);
actx
[
idx
+
o_offset
]
=
o
=
_s
<
float
>
(
actx
[
idx
+
o_offset
]);
actx
[
idx
+
c_offset
]
=
c_
=
tanh
(
actx
[
idx
+
c_offset
]);
c_
=
c
[
idx
]
=
f
*
cx
[
idx
]
+
i
*
c_
;
h
[
idx
]
=
o
*
tanh
(
c_
);
}
cx
+=
C
;
xact
+=
x_offset
;
c
+=
C
;
h
+=
C
;
cx
+=
C
;
actx
+=
x_offset
;
c
+=
C
;
h
+=
C
;
}
}
/*! LSTMCellGrad <T = float32, Device = CPU> */
template
<>
void
LSTMCellGrad
<
float
,
CPUContext
>
(
const
int
count
,
const
int
N
,
const
int
C
,
const
float
*
cx
,
const
float
*
xact
,
const
float
*
actx
,
const
float
*
c
,
const
float
*
dc
,
const
float
*
dh
,
...
...
@@ -49,21 +48,19 @@ template <> void LSTMCellGrad<float, CPUContext>(
float
*
dx
,
CPUContext
*
ctx
)
{
float
i
,
f
,
o
,
g
,
tanh_c
,
dcx_sum_term
;
int
f_offset
=
C
,
o_offset
=
2
*
C
,
c_offset
=
3
*
C
,
x_offset
=
4
*
C
;
int
f_offset
=
C
,
o_offset
=
2
*
C
,
c_offset
=
3
*
C
,
x_offset
=
4
*
C
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
idx
=
0
;
idx
<
C
;
++
idx
)
{
i
=
xact
[
idx
];
f
=
xact
[
idx
+
f_offset
];
o
=
xact
[
idx
+
o_offset
];
g
=
xact
[
idx
+
c_offset
];
//
BPTT compute the dc_{t-1} at the time of t
//
dc_{t-1} = dl / d(h_{t}) * d(h_{t}) / d(c_{t}) * d(c_{t}) / d(c_{t-1})
//
+ d(c_{t+1}) / d(c_{t}) * d(c_{t}) / d(c_{t-1})
//
= (dl / d(h_{t}) * d(h_{t}) / d(c_{t}) + d(c_{t+1}) / d(c_{t}))
//
* d(c_{t}) / d(c_{t-1})
i
=
actx
[
idx
];
f
=
actx
[
idx
+
f_offset
];
o
=
actx
[
idx
+
o_offset
];
g
=
actx
[
idx
+
c_offset
];
// BPTT compute the dc_{t-1} at the time of t
// dc_{t-1} = dl / d(h_{t}) * d(h_{t}) / d(c_{t}) * d(c_{t}) / d(c_{t-1})
// + d(c_{t+1}) / d(c_{t}) * d(c_{t}) / d(c_{t-1})
// = (dl / d(h_{t}) * d(h_{t}) / d(c_{t}) + d(c_{t+1}) / d(c_{t}))
// * d(c_{t}) / d(c_{t-1})
tanh_c
=
tanh
(
c
[
idx
]);
dcx_sum_term
=
dh
[
idx
]
*
o
*
(
1
-
tanh_c
*
tanh_c
)
+
dc
[
idx
];
dcx
[
idx
]
=
dcx_sum_term
*
f
;
...
...
@@ -72,7 +69,8 @@ template <> void LSTMCellGrad<float, CPUContext>(
dx
[
idx
+
o_offset
]
=
dh
[
idx
]
*
tanh_c
*
o
*
(
1
-
o
);
dx
[
idx
+
c_offset
]
=
dcx_sum_term
*
i
*
(
1
-
g
*
g
);
}
cx
+=
C
;
xact
+=
x_offset
;
c
+=
C
;
dc
+=
C
;
dh
+=
C
;
cx
+=
C
;
actx
+=
x_offset
;
c
+=
C
;
dc
+=
C
;
dh
+=
C
;
dcx
+=
C
;
dx
+=
x_offset
;
}
}
...
...
Dragon/src/kernels/recurrent/lstm_cell_op_kernel.cu
View file @
a739c49
...
...
@@ -11,94 +11,91 @@ namespace kernel {
template <typename T>
__global__ void _LSTMCellAct(
const int
count
,
const int
nthreads
,
const int c_offset,
const int x_offset,
T*
xact
) {
CUDA_1D_KERNEL_LOOP(i
dx, count
) {
const int offset = i
dx
% x_offset;
xact[idx
] = offset < c_offset ?
(
(T)1 / ((T)1 + exp(-xact[idx
])))
: tanh(
xact[idx
]);
T*
actx
) {
CUDA_1D_KERNEL_LOOP(i
, nthreads
) {
const int offset = i % x_offset;
actx[i
] = offset < c_offset ?
(
T(1) / (T(1) + exp(-actx[i
])))
: tanh(
actx[i
]);
}
}
template <typename T>
__global__ void _LSTMCellGate(
const int
count
,
const int
nthreads
,
const int hidden_size,
const int o_offset,
// 2 * hidden_size
const int c_offset,
// 3 * hidden_size
const int x_offset,
// 4 * hidden_size
const int o_offset,
const int c_offset,
const int x_offset,
const T* cx,
const T*
xact
,
const T*
actx
,
T* c,
T* h) {
CUDA_1D_KERNEL_LOOP(idx,
count
) {
CUDA_1D_KERNEL_LOOP(idx,
nthreads
) {
const int n = idx / hidden_size;
const int offset = idx % hidden_size;
const T*
x = xact
+ n * x_offset;
const T i =
x
[offset];
const T f =
x
[offset + hidden_size];
const T o =
x
[offset + o_offset];
T c_ =
x
[offset + c_offset];
const T*
actx_ = actx
+ n * x_offset;
const T i =
actx_
[offset];
const T f =
actx_
[offset + hidden_size];
const T o =
actx_
[offset + o_offset];
T c_ =
actx_
[offset + c_offset];
c_ = c[idx] = f * cx[idx] + i * c_;
h[idx] = o * tanh(c_);
}
}
template <> void LSTMCell<float, CUDAContext>(
const int count,
const int N,
const int C,
const float* cx,
float*
xact
,
float*
actx
,
float* c,
float* h,
CUDAContext* ctx) {
const int o_offset = 2 * C,
c_offset = 3 * C,
x_offset = 4 * C;
auto o_offset = 2 * C, c_offset = 3 * C,
x_offset = 4 * C, NC = N * C;
_LSTMCellAct<float>
<< < CUDA_BLOCKS(
count
* 4), CUDA_THREADS,
<< < CUDA_BLOCKS(
NC
* 4), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count * 4, c_offset, x_offset, xact);
(NC * 4, c_offset, x_offset, actx);
_LSTMCellGate<float>
<< < CUDA_BLOCKS(
count
), CUDA_THREADS,
<< < CUDA_BLOCKS(
NC
), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(
count, C, o_offset, c_offset, x
_offset,
cx, xact
, c, h);
(
NC, C, o_offset, c
_offset,
x_offset, cx, actx
, c, h);
}
/*! LSTMCellGrad <T = float32, Device = CUDA> */
template <typename T>
__global__ void _LSTMCellGateGrad(
const int
count
,
const int
nthreads
,
const int hidden_size,
const int o_offset,
const int c_offset,
const int x_offset,
const T* cx,
const T*
xact
,
const T*
actx
,
const T* c,
const T* dc,
const T* dh,
T* dcx,
T* dx) {
CUDA_1D_KERNEL_LOOP(idx,
count
) {
CUDA_1D_KERNEL_LOOP(idx,
nthreads
) {
const int n = idx / hidden_size;
const int offset = idx % hidden_size;
const T*
xact_ = xact
+ n * x_offset;
const T*
actx_ = actx
+ n * x_offset;
T* dx_ = dx + n * x_offset;
const T i =
xact
_[offset];
const T f =
xact
_[offset + hidden_size];
const T o =
xact
_[offset + o_offset];
const T g =
xact
_[offset + c_offset];
const T i =
actx
_[offset];
const T f =
actx
_[offset + hidden_size];
const T o =
actx
_[offset + o_offset];
const T g =
actx
_[offset + c_offset];
const T tanh_c = tanh(c[idx]);
const T dcx_sum_term =
dh[idx] * o * (
1
- tanh_c * tanh_c) + dc[idx];
dh[idx] * o * (
T(1)
- tanh_c * tanh_c) + dc[idx];
dcx[idx] = dcx_sum_term * f;
dx_[offset] = dcx_sum_term * g;
dx_[offset + hidden_size] = dcx_sum_term * cx[idx];
...
...
@@ -109,44 +106,44 @@ __global__ void _LSTMCellGateGrad(
template <typename T>
__global__ void _LSTMCellActGrad(
const int
count
,
const int
nthreads
,
const int c_offset,
const int x_offset,
const T*
xact
,
const T*
actx
,
T* dx) {
CUDA_1D_KERNEL_LOOP(idx, count) {
const int offset = idx % x_offset;
const T val = xact[idx];
if (offset < c_offset) dx[idx] = dx[idx] * val * (T(1) - val);
else dx[idx] = dx[idx] * (T(1) - val * val);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const T val = actx[i];
const int offset = i % x_offset;
if (offset < c_offset) {
dx[i] = dx[i] * val * (T(1) - val);
} else {
dx[i] = dx[i] * (T(1) - val * val);
}
}
}
template <> void LSTMCellGrad<float, CUDAContext>(
const int count,
const int N,
const int C,
const float* cx,
const float*
xact
,
const float*
actx
,
const float* c,
const float* dc,
const float* dh,
float* dcx,
float* dx,
CUDAContext* ctx) {
const int o_offset = 2 * C,
c_offset = 3 * C,
x_offset = 4 * C;
auto o_offset = 2 * C, c_offset = 3 * C,
x_offset = 4 * C, NC = N * C;
_LSTMCellGateGrad<float>
<< < CUDA_BLOCKS(
count
), CUDA_THREADS,
<< < CUDA_BLOCKS(
NC
), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count, C, o_offset, c_offset, x_offset,
cx, xact, c, dc, dh, dcx, dx);
(NC, C, o_offset, c_offset, x_offset,
cx, actx, c, dc, dh, dcx, dx);
_LSTMCellActGrad<float>
<< < CUDA_BLOCKS(
count
* 4), CUDA_THREADS,
<< < CUDA_BLOCKS(
NC
* 4), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(
count * 4, c_offset, x_offset, xact
, dx);
(
NC * 4, c_offset, x_offset, actx
, dx);
}
} // namespace kernel
...
...
Dragon/src/operators/array/flatten_op.cc
View file @
a739c49
...
...
@@ -17,8 +17,9 @@ void FlattenOp<Context>::RunOnDevice() {
vector
<
int64_t
>
output_dims
;
if
(
keep_axes
!=
INT_MAX
)
{
CHECK_LE
(
keep_axes
,
Input
(
0
).
ndim
())
<<
"
\n
The total number of axes is "
+
Input
(
0
).
ndim
()
<<
", can not keep "
+
keep_axes
<<
" ."
;
<<
"
\n
The total number of axes is "
<<
Input
(
0
).
ndim
()
<<
", can not keep "
<<
keep_axes
<<
" ."
;
int
i
=
0
;
for
(;
i
<
keep_axes
-
1
;
i
++
)
output_dims
.
push_back
(
Input
(
0
).
dim
(
i
));
...
...
Dragon/src/operators/recurrent/lstm_cell_op.cc
View file @
a739c49
...
...
@@ -12,9 +12,10 @@ void LSTMCellOp<Context>::RunWithType() {
auto
*
Hdata
=
Output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
Cdata
=
Output
(
1
)
->
template
mutable_data
<
T
,
Context
>
();
kernel
::
LSTMCell
(
Input
(
1
).
count
(),
Input
(
1
).
dim
(
0
),
Input
(
1
).
ndim
()
==
2
?
Input
(
1
).
dim
(
1
)
:
Input
(
1
).
dim
(
2
),
HXdata
,
Xdata
,
Cdata
,
Hdata
,
ctx
());
kernel
::
LSTMCell
(
Input
(
1
).
dim
(
0
),
Input
(
1
).
ndim
()
==
2
?
Input
(
1
).
dim
(
1
)
:
Input
(
1
).
dim
(
2
),
HXdata
,
Xdata
,
Cdata
,
Hdata
,
ctx
());
}
template
<
class
Context
>
...
...
@@ -47,10 +48,11 @@ void LSTMCellGradientOp<Context>::RunWithType() {
cast
::
to
<
T
>
(
0.
f
),
dCdata
,
ctx
());
}
kernel
::
LSTMCellGrad
(
Input
(
1
).
count
(),
Input
(
1
).
dim
(
0
),
Input
(
1
).
ndim
()
==
2
?
Input
(
1
).
dim
(
1
)
:
Input
(
1
).
dim
(
2
),
HXdata
,
Xdata
,
Cdata
,
dCdata
,
dHdata
,
dHXdata
,
dXdata
,
ctx
());
kernel
::
LSTMCellGrad
(
Input
(
1
).
dim
(
0
),
Input
(
1
).
ndim
()
==
2
?
Input
(
1
).
dim
(
1
)
:
Input
(
1
).
dim
(
2
),
HXdata
,
Xdata
,
Cdata
,
dCdata
,
dHdata
,
dHXdata
,
dXdata
,
ctx
());
}
template
<
class
Context
>
...
...
Dragon/src/operators/vision/cudnn_depthwise_conv2d_op.cc
View file @
a739c49
...
...
@@ -113,6 +113,11 @@ template <class Context>
void
CuDNNDepthwiseConv2dGradientOp
<
Context
>::
RunOnDevice
()
{
group
=
channels
=
data_format
==
"NCHW"
?
Input
(
0
).
dim
(
1
)
:
Input
(
0
).
dim
(
-
1
);
#if CUDNN_VERSION_MIN(7, 0, 0)
// The group implementation of CuDNN is faster
// Enable if CuDNN >= 7.0
return
CuDNNConv2dGradientOp
<
Context
>::
RunOnDevice
();
#endif
GradientReshape
();
if
(
XIsType
(
Input
(
0
),
float
))
RunWithType
<
float
>
();
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment