Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit 0936a502
authored
Apr 11, 2019
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix the wrong grads of LSTMCell
1 parent
40e94d24
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
504 additions
and
171 deletions
Dragon/python/dragon/core/tensor.py
Dragon/python/dragon/core/workspace.py
Dragon/python/dragon/operators/contrib/rcnn/ops.py
Dragon/python/dragon/operators/misc.py
Dragon/python/dragon/operators/rnn/rnn_param.py
Dragon/python/dragon/operators/rnn/rnn_wrapper.py
Dragon/python/dragon/vm/torch/autograd/variable.py
Dragon/python/dragon/vm/torch/nn/__init__.py
Dragon/python/dragon/vm/torch/nn/init.py
Dragon/python/dragon/vm/torch/nn/modules/affine.py
Dragon/python/dragon/vm/torch/nn/modules/batchnorm.py
Dragon/python/dragon/vm/torch/nn/modules/container.py
Dragon/python/dragon/vm/torch/nn/modules/conv.py
Dragon/python/dragon/vm/torch/nn/modules/depthwise_conv.py
Dragon/python/dragon/vm/torch/nn/modules/dropblock.py
Dragon/python/dragon/vm/torch/nn/modules/groupnorm.py
Dragon/python/dragon/vm/torch/nn/modules/loss.py
Dragon/python/dragon/vm/torch/nn/modules/pooling.py
Dragon/python/dragon/vm/torch/nn/modules/rnn.py
Dragon/python/dragon/vm/torch/onnx/utils.py
Dragon/python/dragon/vm/torch/ops/builtin.py
Dragon/src/operators/recurrent/lstm_cell_op.cc
Dragon/src/operators/vision/cudnn_conv2d_transpose_op.cc
Dragon/python/dragon/core/tensor.py
View file @
0936a50
...
...
@@ -771,16 +771,6 @@ class Tensor(object):
def
__hash__
(
self
):
return
id
(
self
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""Print the expressions.
Returns
-------
None
"""
return
self
.
debug_expressions
()
###############################################
# #
# Theano API #
...
...
@@ -1030,7 +1020,7 @@ class Tensor(object):
"""
expressions
=
dict
()
# 1
.
Collect inputs
# 1
)
Collect inputs
if
not
isinstance
(
inputs
,
list
):
inputs
=
[
inputs
]
for
input
in
inputs
:
for
op_idx
,
expr
in
input
.
expressions
.
items
():
...
...
@@ -1044,7 +1034,7 @@ class Tensor(object):
if
not
op_idx
in
expressions
:
expressions
[
op_idx
]
=
expr
# 2
.
Generate outputs
# 2
)
Generate outputs
outputs
=
[]
if
existing_outputs
is
None
:
name_scope
=
_scope
.
get_default_name_scope
()
...
...
@@ -1061,7 +1051,7 @@ class Tensor(object):
num_outputs
=
len
(
outputs
)
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
# 3
.
Construct OperatorDef
# 3
)
Construct OperatorDef
inputs_name
=
[
input
.
name
for
input
in
inputs
]
outputs_name
=
[
output
.
name
for
output
in
outputs
]
op_idx
,
op_name
=
_helper
.
OperatorHelper
.
get_index_and_name
()
...
...
@@ -1073,7 +1063,7 @@ class Tensor(object):
expressions
[
op_idx
]
=
op_def
# 4
.
Add outputs
# 4
)
Add outputs
for
idx
,
output
in
enumerate
(
outputs
):
# Deliver expressions
output
.
expressions
=
expressions
...
...
@@ -1085,11 +1075,11 @@ class Tensor(object):
for
input
in
extra_inputs
:
output
.
extra_targets
.
add
(
input
.
name
)
# 5
.
Refine the shape and data type
# 5
)
Refine the shape and data type
outputs
=
_helper
.
OperatorHelper
.
apply
(
op_type
,
arguments
=
kwargs
,
inputs
=
inputs
,
outputs
=
outputs
)
# 6
.
Returns
# 6
)
Returns
if
num_outputs
>
1
:
return
outputs
elif
num_outputs
==
1
:
return
outputs
[
0
]
else
:
return
None
...
...
Dragon/python/dragon/core/workspace.py
View file @
0936a50
...
...
@@ -112,7 +112,7 @@ class Workspace(_C.Workspace):
This class is a fusion of *Workspace*, *Pool* and *tf.Graph*.
We find that they work in a similar way while named different.
We find that they work in a similar way while named different
ly
.
"""
def
__init__
(
self
,
name
=
''
):
...
...
@@ -514,8 +514,8 @@ def RunGraph(
Returns
-------
None, NDArray or list of NDA
rray
The outputs
, format as NDA
rray.
sequence of numpy.nda
rray
The outputs
which are copied to numpy a
rray.
See Also
--------
...
...
@@ -551,11 +551,11 @@ def Backward(
input_grads
=
None
,
ignored_grads
=
None
,
):
"""Compute the gradients of given input
flow
s.
"""Compute the gradients of given input
operator
s.
Parameters
----------
input_flow
: sequence of OperatorDef
forward_ops
: sequence of OperatorDef
The referring ops to generate gradients.
targets : sequence or str
The solving targets.
...
...
@@ -576,10 +576,13 @@ def Backward(
options
[
'log_meta_graph'
])
else
False
get_default_workspace
()
.
Backward
(
forward_ops
,
targets
,
input_grads
if
input_grads
else
[],
ignored_grads
if
ignored_grads
else
[],
options
[
'share_grads'
],
required_logging
)
forward_ops
,
targets
,
input_grads
if
input_grads
else
[],
ignored_grads
if
ignored_grads
else
[],
options
[
'share_grads'
],
required_logging
,
)
def
LogMetaGraph
(
graph_def
):
...
...
@@ -666,6 +669,7 @@ def Snapshot(
"""
file_path
=
prefix
+
filename
+
suffix
if
_mpi
.
Is_Init
():
if
not
_mpi
.
AllowSnapshot
():
return
file_path
=
file_path
+
'.rank.{}'
.
format
(
_mpi
.
Rank
())
...
...
Dragon/python/dragon/operators/contrib/rcnn/ops.py
View file @
0936a50
...
...
@@ -17,11 +17,21 @@ from dragon.operators import *
@OpSchema.Inputs
(
3
,
INT_MAX
)
def
Proposal
(
inputs
,
strides
,
ratios
,
scales
,
pre_nms_top_n
=
6000
,
post_nms_top_n
=
300
,
nms_thresh
=
0.7
,
min_size
=
16
,
min_level
=
2
,
max_level
=
5
,
canonical_scale
=
224
,
canonical_level
=
4
,
**
kwargs
):
def
Proposal
(
inputs
,
strides
,
ratios
,
scales
,
pre_nms_top_n
=
6000
,
post_nms_top_n
=
300
,
nms_thresh
=
0.7
,
min_size
=
16
,
min_level
=
2
,
max_level
=
5
,
canonical_scale
=
224
,
canonical_level
=
4
,
**
kwargs
):
"""Generate Regional Proposals, introduced by `[Ren et.al, 2015] <https://arxiv.org/abs/1506.01497>`_.
Multi-Level proposals was introduced by `[Lin et.al, 2017] <https://arxiv.org/abs/1612.03144>`_.
...
...
Dragon/python/dragon/operators/misc.py
View file @
0936a50
...
...
@@ -130,7 +130,7 @@ def Accuracy(inputs, top_k=1, axis=1, ignore_labels=(), **kwargs):
**Type Constraints**:
* logits (*float
16*, *float
32*)
* logits (*float32*)
* labels (*float32*, *int64*)
...
...
Dragon/python/dragon/operators/rnn/rnn_param.py
View file @
0936a50
...
...
@@ -18,9 +18,17 @@ from .. import *
@OpSchema.Inputs
(
2
)
def
RNNParamSet
(
inputs
,
layer_id
,
param_id
,
param_type
,
rnn_mode
,
input_size
,
hidden_size
,
num_layers
=
1
,
num_directions
=
1
,
**
kwargs
):
inputs
,
layer_id
,
param_id
,
param_type
,
rnn_mode
,
input_size
,
hidden_size
,
num_layers
=
1
,
num_directions
=
1
,
**
kwargs
):
arguments
=
ParseArgs
(
locals
())
arguments
[
'inputs'
]
=
inputs
[
1
]
arguments
[
'existing_outputs'
]
=
inputs
[
0
]
...
...
Dragon/python/dragon/operators/rnn/rnn_wrapper.py
View file @
0936a50
...
...
@@ -14,20 +14,28 @@ from __future__ import division
from
__future__
import
print_function
import
numpy
import
dragon
import
warnings
from
dragon.core.tensor
import
Tensor
from
dragon.core.tensor_utils
import
FromShape
from
dragon.operators.rnn.rnn_param
import
RNNParamSet
from
dragon
import
config
as
_cfg
from
dragon.core
import
workspace
as
_workspace
from
dragon.core.tensor
import
Tensor
as
_Tensor
from
dragon.core
import
tensor_utils
as
_tensor_utils
from
.rnn_param
import
RNNParamSet
class
RNNBase
(
object
):
"""A simple class wrapping general RNN ops."""
def
__init__
(
self
,
mode
,
input_size
,
hidden_size
,
num_layers
=
1
,
bidirectional
=
False
,
dropout
=
0
,
name
=
None
):
mode
,
input_size
,
hidden_size
,
num_layers
=
1
,
bidirectional
=
False
,
dropout
=
0
,
name
=
None
,
):
eligible_rnn_modes
=
(
'rnn_tanh'
,
'rnn_relu'
,
'lstm'
,
'gru'
)
if
mode
.
lower
()
not
in
eligible_rnn_modes
:
raise
ValueError
(
'Unknown rnn mode: {}.'
...
...
@@ -50,7 +58,7 @@ class RNNBase(object):
if
self
.
mode
==
'lstm'
:
gate_size
=
4
*
self
.
hidden_size
elif
self
.
mode
==
'gru'
:
gate_size
=
3
*
self
.
hidden_size
else
:
gate_size
=
self
.
hidden_size
# 1
.
Plan weights
# 1
)
Plan weights
self
.
_matrix_shape
,
self
.
_bias_shape
=
[],
[]
for
layer
in
range
(
self
.
num_layers
):
for
direction
in
range
(
self
.
num_directions
):
...
...
@@ -64,16 +72,18 @@ class RNNBase(object):
# Bw (0 ~ 3), Br (4 ~ 7)
self
.
_bias_shape
.
extend
([
b_ih_shape
,
b_hh_shape
])
# 2
.
Compute total number of parameters
# 2
)
Compute total number of parameters
self
.
_weights_count
=
0
for
shape
in
self
.
_matrix_shape
+
self
.
_bias_shape
:
self
.
_weights_count
+=
numpy
.
prod
(
shape
)
# 3. Register the packed weights
self
.
weights
=
FromShape
(
shape
=
[
self
.
_weights_count
],
name
=
self
.
name
+
'/weights'
if
self
.
name
else
None
)
# 3) Register the packed weights
self
.
weights
=
_tensor_utils
.
FromShape
(
shape
=
[
self
.
_weights_count
],
name
=
self
.
name
+
'/weights'
if
self
.
name
else
None
,
)
# 4
.
Create the initialization grids
# 4
)
Create the initialization grids
if
self
.
mode
==
'lstm'
:
num_params_per_layer
=
8
elif
self
.
mode
==
'gru'
:
num_params_per_layer
=
6
else
:
num_params_per_layer
=
2
...
...
@@ -121,8 +131,14 @@ class RNNBase(object):
# #
##############################################
def
set_param
(
self
,
layer
=
0
,
direction
=
0
,
param_id
=
0
,
type
=
'matrix'
,
initializer
=
None
):
def
set_param
(
self
,
layer
=
0
,
direction
=
0
,
param_id
=
0
,
type
=
'matrix'
,
initializer
=
None
,
):
if
type
==
'matrix'
:
self
.
_matrix_init_grids
[
layer
][
direction
][
param_id
]
=
initializer
elif
type
==
'bias'
:
...
...
@@ -130,20 +146,35 @@ class RNNBase(object):
else
:
raise
ValueError
(
'Unknown param type: '
+
type
)
def
_set_param
(
self
,
layer_id
,
param_id
,
param_type
,
param
):
def
_set_param
(
self
,
layer_id
,
param_id
,
param_type
,
param
,
):
if
isinstance
(
param
,
numpy
.
ndarray
):
param_temp
=
dragon
.
Tensor
.
Ref
(
'/tmp/rnn_param'
)
param_temp
=
_
Tensor
.
Ref
(
'/tmp/rnn_param'
)
param_temp
.
set_value
(
param
)
param
=
param_temp
else
:
raise
ValueError
(
'Excepted a numpy array.'
)
self
.
weights
.
expressions
=
dict
()
# Clear cached expressions
outputs
=
RNNParamSet
([
self
.
weights
,
param
],
layer_id
,
param_id
,
param_type
,
rnn_mode
=
self
.
mode
,
input_size
=
self
.
input_size
,
hidden_size
=
self
.
hidden_size
,
num_layers
=
self
.
num_layers
,
num_directions
=
self
.
num_directions
)
for
k
,
v
in
outputs
.
expressions
.
items
():
dragon
.
workspace
.
RunOperator
(
v
)
outputs
=
RNNParamSet
(
inputs
=
[
self
.
weights
,
param
],
layer_id
=
layer_id
,
param_id
=
param_id
,
param_type
=
param_type
,
rnn_mode
=
self
.
mode
,
input_size
=
self
.
input_size
,
hidden_size
=
self
.
hidden_size
,
num_layers
=
self
.
num_layers
,
num_directions
=
self
.
num_directions
,
)
for
k
,
v
in
outputs
.
expressions
.
items
():
_workspace
.
RunOperator
(
v
)
def
_reset_params
(
self
):
numpy
.
random
.
seed
(
dragon
.
confi
g
.
GetRandomSeed
())
numpy
.
random
.
seed
(
_cf
g
.
GetRandomSeed
())
if
self
.
mode
==
'lstm'
:
num_gates
=
4
elif
self
.
mode
==
'gru'
:
num_gates
=
3
else
:
num_gates
=
1
...
...
@@ -162,15 +193,29 @@ class RNNBase(object):
matrix_shape
=
self
.
_matrix_shape
[
packed_id
][:]
bias_shape
=
self
.
_bias_shape
[
packed_id
][:]
matrix_shape
[
0
]
=
bias_shape
[
0
]
=
int
(
matrix_shape
[
0
]
/
num_gates
)
self
.
_set_param
(
layer_id
=
pseudo_layer_id
,
param_id
=
param_id
,
param_type
=
'matrix'
,
param
=
matrix_init
(
matrix_shape
))
self
.
_set_param
(
layer_id
=
pseudo_layer_id
,
param_id
=
param_id
,
param_type
=
'bias'
,
param
=
bias_init
(
bias_shape
))
self
.
_set_param
(
layer_id
=
pseudo_layer_id
,
param_id
=
param_id
,
param_type
=
'matrix'
,
param
=
matrix_init
(
matrix_shape
),
)
self
.
_set_param
(
layer_id
=
pseudo_layer_id
,
param_id
=
param_id
,
param_type
=
'bias'
,
param
=
bias_init
(
bias_shape
),
)
self
.
weights
.
expressions
=
weights_states
self
.
_init_params
=
True
def
create
(
self
,
x
,
hx
=
None
,
cx
=
None
,
required_hidden
=
True
,
required_cell
=
False
):
def
create
(
self
,
x
,
hx
=
None
,
cx
=
None
,
required_hidden
=
True
,
required_cell
=
False
,
):
"""Return outputs of this rnn.
Parameters
...
...
@@ -187,9 +232,9 @@ class RNNBase(object):
Return ``y``, ``hidden``, ``cell`` if ``True``.
"""
if
hx
and
not
isinstance
(
hx
,
Tensor
):
if
hx
and
not
isinstance
(
hx
,
_
Tensor
):
raise
TypeError
(
'Excepted hx as a Tensor, got {}.'
.
format
(
type
(
hx
)))
if
cx
and
not
isinstance
(
cx
,
Tensor
):
if
cx
and
not
isinstance
(
cx
,
_
Tensor
):
raise
TypeError
(
'Excepted cx as a Tensor, got {}.'
.
format
(
type
(
cx
)))
if
not
self
.
_init_params
:
self
.
_reset_params
()
...
...
@@ -211,7 +256,8 @@ class RNNBase(object):
elif
required_hidden
:
num_outputs
=
2
else
:
num_outputs
=
1
return
Tensor
.
CreateOperator
(
num_outputs
=
num_outputs
,
**
arguments
)
return
_Tensor
.
CreateOperator
(
num_outputs
=
num_outputs
,
**
arguments
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
create
(
*
args
,
**
kwargs
)
\ No newline at end of file
Dragon/python/dragon/vm/torch/autograd/variable.py
View file @
0936a50
...
...
@@ -60,10 +60,10 @@ def backward(self, gradient=None):
_tensor_utils
.
FromArray
(
gradient
.
numpy
(
True
),
self
.
name
+
'_grad'
)
input_grads
.
append
(
self
.
name
+
'_grad'
)
# 3
.
Dispatch the backward ops
# 3
)
Dispatch the backward ops
_backward_impl
(
forward_ops
,
targets
,
input_grads
,
ignored_grads
)
# 4
.
Release resources
# 4
)
Release resources
# We should release both the operator handles and tensors
for
forward_op
in
forward_ops
:
_get_operator_pool
()
.
put
(
forward_op
.
name
)
...
...
Dragon/python/dragon/vm/torch/nn/__init__.py
View file @
0936a50
...
...
@@ -29,7 +29,7 @@ from .modules.activation import (
)
from
.modules.loss
import
(
BCEWithLogitsLoss
,
BCEWithLogitsLoss
,
SCEWithLogitsLoss
,
NLLLoss
,
CrossEntropyLoss
,
L1Loss
,
MSELoss
,
SmoothL1Loss
,
SigmoidFocalLoss
,
SoftmaxFocalLoss
,
...
...
Dragon/python/dragon/vm/torch/nn/init.py
View file @
0936a50
...
...
@@ -24,7 +24,15 @@ from dragon.vm.torch.autograd.grad_mode import no_grad
def
calculate_gain
(
nonlinearity
,
param
=
None
):
linear_fns
=
[
'linear'
,
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
]
linear_fns
=
[
'linear'
,
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
,
]
if
nonlinearity
in
linear_fns
or
nonlinearity
==
'sigmoid'
:
return
1
elif
nonlinearity
==
'tanh'
:
...
...
@@ -34,7 +42,9 @@ def calculate_gain(nonlinearity, param=None):
elif
nonlinearity
==
'leaky_relu'
:
if
param
is
None
:
negative_slope
=
0.01
elif
not
isinstance
(
param
,
bool
)
and
isinstance
(
param
,
int
)
or
isinstance
(
param
,
float
):
elif
not
isinstance
(
param
,
bool
)
and
\
isinstance
(
param
,
int
)
or
\
isinstance
(
param
,
float
):
# True/False are instances of int, hence check above
negative_slope
=
param
else
:
...
...
Dragon/python/dragon/vm/torch/nn/modules/affine.py
View file @
0936a50
...
...
@@ -18,8 +18,14 @@ from dragon.vm.torch.ops.builtin import zeros, ones
class
Affine
(
Module
):
def
__init__
(
self
,
num_features
,
bias
=
True
,
fix_weight
=
False
,
fix_bias
=
False
,
inplace
=
False
):
def
__init__
(
self
,
num_features
,
bias
=
True
,
fix_weight
=
False
,
fix_bias
=
False
,
inplace
=
False
,
):
super
(
Affine
,
self
)
.
__init__
()
self
.
num_features
=
num_features
self
.
inplace
=
inplace
...
...
Dragon/python/dragon/vm/torch/nn/modules/batchnorm.py
View file @
0936a50
...
...
@@ -20,8 +20,14 @@ from dragon.vm.torch.module import RunOperator
class
_BatchNorm
(
Module
):
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
):
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
,
):
super
(
_BatchNorm
,
self
)
.
__init__
()
self
.
num_features
=
num_features
self
.
eps
=
eps
...
...
Dragon/python/dragon/vm/torch/nn/modules/container.py
View file @
0936a50
...
...
@@ -27,10 +27,8 @@ from dragon.vm.torch.nn import Module
class
Container
(
Module
):
def
__init__
(
self
,
**
kwargs
):
super
(
Container
,
self
)
.
__init__
()
# DeprecationWarning is ignored by default <sigh>
warnings
.
warn
(
"nn.Container is deprecated. All of it's functionality "
"is now implemented in nn.Module. Subclass that instead."
)
for
key
,
value
in
kwargs
.
items
():
...
...
@@ -38,7 +36,7 @@ class Container(Module):
class
Sequential
(
Module
):
r
"""A sequential container.
"""A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
...
...
@@ -59,8 +57,8 @@ class Sequential(Module):
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
"""
def
__init__
(
self
,
*
args
):
super
(
Sequential
,
self
)
.
__init__
()
if
len
(
args
)
==
1
and
isinstance
(
args
[
0
],
OrderedDict
):
...
...
@@ -71,7 +69,7 @@ class Sequential(Module):
self
.
add_module
(
str
(
idx
),
module
)
def
_get_item_by_idx
(
self
,
iterator
,
idx
):
"""Get the idx-th item of the iterator"""
"""Get the idx-th item of the iterator
.
"""
size
=
len
(
self
)
idx
=
operator
.
index
(
idx
)
if
not
-
size
<=
idx
<
size
:
...
...
Dragon/python/dragon/vm/torch/nn/modules/conv.py
View file @
0936a50
...
...
@@ -21,8 +21,19 @@ from dragon.vm.torch.nn.modules.utils import _pair
class
_ConvNd
(
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
transposed
,
output_padding
,
groups
,
bias
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
transposed
,
output_padding
,
groups
,
bias
,
):
super
(
_ConvNd
,
self
)
.
__init__
()
if
in_channels
%
groups
!=
0
:
raise
ValueError
(
'in_channels must be divisible by groups'
)
...
...
@@ -94,36 +105,61 @@ class _ConvNd(Module):
class
Conv2d
(
_ConvNd
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
bias
=
True
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
):
kernel_size
=
_pair
(
kernel_size
)
stride
=
_pair
(
stride
)
padding
=
_pair
(
padding
)
dilation
=
_pair
(
dilation
)
super
(
Conv2d
,
self
)
.
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
False
,
_pair
(
0
),
groups
,
bias
)
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
False
,
_pair
(
0
),
groups
,
bias
,
)
def
forward
(
self
,
input
):
inputs
=
[
input
,
self
.
weight
]
+
([
self
.
bias
]
if
self
.
bias
else
[])
inputs
=
[
input
,
self
.
weight
]
+
\
([
self
.
bias
]
if
self
.
bias
else
[])
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
ConvTranspose2d
(
_ConvNd
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
output_padding
=
0
,
groups
=
1
,
bias
=
True
,
dilation
=
1
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
output_padding
=
0
,
groups
=
1
,
bias
=
True
,
dilation
=
1
,
):
kernel_size
=
_pair
(
kernel_size
)
stride
=
_pair
(
stride
)
padding
=
_pair
(
padding
)
dilation
=
_pair
(
dilation
)
super
(
ConvTranspose2d
,
self
)
.
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
True
,
_pair
(
0
),
groups
,
bias
)
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
True
,
_pair
(
0
),
groups
,
bias
,
)
def
forward
(
self
,
input
):
inputs
=
[
input
,
self
.
weight
]
+
([
self
.
bias
]
if
self
.
bias
else
[])
inputs
=
[
input
,
self
.
weight
]
+
\
([
self
.
bias
]
if
self
.
bias
else
[])
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
\ No newline at end of file
Dragon/python/dragon/vm/torch/nn/modules/depthwise_conv.py
View file @
0936a50
...
...
@@ -21,8 +21,16 @@ from dragon.vm.torch.nn.modules.utils import _pair
class
_DepthwiseConvNd
(
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
output_padding
,
bias
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
output_padding
,
bias
,
):
super
(
_DepthwiseConvNd
,
self
)
.
__init__
()
if
in_channels
!=
out_channels
:
raise
ValueError
(
'in/out channels must be same'
)
...
...
@@ -75,17 +83,26 @@ class _DepthwiseConvNd(Module):
class
DepthwiseConv2d
(
_DepthwiseConvNd
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
bias
=
True
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
bias
=
True
,
):
kernel_size
=
_pair
(
kernel_size
)
stride
=
_pair
(
stride
)
padding
=
_pair
(
padding
)
super
(
DepthwiseConv2d
,
self
)
.
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
_pair
(
0
),
bias
)
stride
,
padding
,
_pair
(
0
),
bias
,
)
def
forward
(
self
,
input
):
inputs
=
[
input
,
self
.
weight
]
+
([
self
.
bias
]
if
self
.
bias
else
[])
inputs
=
[
input
,
self
.
weight
]
+
\
([
self
.
bias
]
if
self
.
bias
else
[])
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
\ No newline at end of file
Dragon/python/dragon/vm/torch/nn/modules/dropblock.py
View file @
0936a50
...
...
@@ -17,8 +17,14 @@ from dragon.vm.torch.nn import Module
class
DropBlock2d
(
Module
):
def
__init__
(
self
,
block_size
=
7
,
kp
=
0.9
,
alpha
=
1.
,
decrement
=
0.
,
inplace
=
False
):
def
__init__
(
self
,
block_size
=
7
,
kp
=
0.9
,
alpha
=
1.
,
decrement
=
0.
,
inplace
=
False
,
):
super
(
DropBlock2d
,
self
)
.
__init__
()
self
.
kp
=
kp
self
.
block_size
=
block_size
...
...
Dragon/python/dragon/vm/torch/nn/modules/groupnorm.py
View file @
0936a50
...
...
@@ -19,8 +19,13 @@ from dragon.vm.torch.ops.builtin import zeros, ones
class
_GroupNorm
(
Module
):
def
__init__
(
self
,
num_features
,
group
=
32
,
eps
=
1e-5
,
affine
=
True
):
def
__init__
(
self
,
num_features
,
group
=
32
,
eps
=
1e-5
,
affine
=
True
,
):
super
(
_GroupNorm
,
self
)
.
__init__
()
self
.
num_features
=
num_features
self
.
group
=
group
...
...
Dragon/python/dragon/vm/torch/nn/modules/loss.py
View file @
0936a50
...
...
@@ -22,7 +22,12 @@ from dragon.vm.torch.nn.functional import _Reduction
class
_Loss
(
Module
):
def
__init__
(
self
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
def
__init__
(
self
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
_Loss
,
self
)
.
__init__
()
if
size_average
is
not
None
or
reduce
is
not
None
:
self
.
reduction
=
_Reduction
.
legacy_get_string
(
size_average
,
reduce
)
...
...
@@ -31,17 +36,31 @@ class _Loss(Module):
class
_WeightedLoss
(
_Loss
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
_WeightedLoss
,
self
)
.
__init__
(
size_average
,
reduce
,
reduction
)
self
.
weight
=
weight
if
weight
is
not
None
:
raise
NotImplementedError
(
'WeightedLoss has been not implemented yet.'
)
raise
NotImplementedError
(
'WeightedLoss has been not implemented yet.'
)
class
NLLLoss
(
_WeightedLoss
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
ignore_index
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
super
(
NLLLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
ignore_index
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
NLLLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
self
.
ignore_index
=
ignore_index
self
.
normalization
=
{
'elementwise_mean'
:
'VALID'
,
...
...
@@ -60,17 +79,26 @@ class NLLLoss(_WeightedLoss):
}
def
forward
(
self
,
input
,
target
):
inputs
=
[
input
,
target
];
self
.
unify_devices
(
inputs
)
inputs
=
[
input
,
target
]
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
BCEWithLogitsLoss
(
_WeightedLoss
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
pos_weight
=
None
):
super
(
BCEWithLogitsLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
pos_weight
=
None
,
):
super
(
BCEWithLogitsLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
if
pos_weight
is
not
None
:
raise
NotImplementedError
(
'Positive weight has been not implemented yet.'
)
raise
NotImplementedError
(
'Positive weight has been not implemented yet.'
)
self
.
normalization
=
{
'elementwise_mean'
:
'VALID'
,
'sum'
:
'None'
,
...
...
@@ -86,15 +114,59 @@ class BCEWithLogitsLoss(_WeightedLoss):
}
def
forward
(
self
,
input
,
target
):
inputs
=
[
input
,
target
];
self
.
unify_devices
(
inputs
)
inputs
=
[
input
,
target
]
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
SCEWithLogitsLoss
(
_WeightedLoss
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
pos_weight
=
None
,
):
super
(
SCEWithLogitsLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
if
pos_weight
is
not
None
:
raise
NotImplementedError
(
'Positive weight has been not implemented yet.'
)
self
.
normalization
=
{
'elementwise_mean'
:
'VALID'
,
'sum'
:
'None'
,
'none'
:
'UNIT'
}[
self
.
reduction
]
self
.
register_op
()
def
register_op
(
self
):
self
.
op_meta
=
{
'op_type'
:
'SoftmaxCrossEntropy'
,
'arguments'
:
{
'axis'
:
1
,
'normalization'
:
self
.
normalization
,
},
}
def
forward
(
self
,
input
,
target
):
inputs
=
[
input
,
target
]
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
CrossEntropyLoss
(
_WeightedLoss
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
ignore_index
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
super
(
CrossEntropyLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
ignore_index
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
CrossEntropyLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
self
.
ignore_index
=
ignore_index
self
.
normalization
=
{
'elementwise_mean'
:
'VALID'
,
...
...
@@ -113,13 +185,19 @@ class CrossEntropyLoss(_WeightedLoss):
}
def
forward
(
self
,
input
,
target
):
inputs
=
[
input
,
target
];
self
.
unify_devices
(
inputs
)
inputs
=
[
input
,
target
]
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
L1Loss
(
_Loss
):
def
__init__
(
self
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
def
__init__
(
self
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
L1Loss
,
self
)
.
__init__
(
size_average
,
reduce
,
reduction
)
self
.
normalization
=
{
'elementwise_mean'
:
'BATCH_SIZE'
,
...
...
@@ -135,13 +213,19 @@ class L1Loss(_Loss):
}
def
forward
(
self
,
input
,
target
):
inputs
=
[
input
,
target
];
self
.
unify_devices
(
inputs
)
inputs
=
[
input
,
target
]
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
MSELoss
(
_Loss
):
def
__init__
(
self
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
def
__init__
(
self
,
size_average
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
MSELoss
,
self
)
.
__init__
(
size_average
,
reduce
,
reduction
)
self
.
normalization
=
{
'elementwise_mean'
:
'BATCH_SIZE'
,
...
...
@@ -158,14 +242,20 @@ class MSELoss(_Loss):
}
def
forward
(
self
,
input
,
target
):
inputs
=
[
input
,
target
];
self
.
unify_devices
(
inputs
)
inputs
=
[
input
,
target
]
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
SmoothL1Loss
(
_Loss
):
def
__init__
(
self
,
size_average
=
None
,
beta
=
1.0
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
def
__init__
(
self
,
size_average
=
None
,
beta
=
1.0
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
SmoothL1Loss
,
self
)
.
__init__
(
size_average
,
reduce
,
reduction
)
self
.
normalization
=
{
'elementwise_mean'
:
'BATCH_SIZE'
,
...
...
@@ -192,10 +282,19 @@ class SmoothL1Loss(_Loss):
class
SigmoidFocalLoss
(
_WeightedLoss
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
alpha
=
0.25
,
gamma
=
2.0
,
neg_id
=
0
,
ignore_index
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
super
(
SigmoidFocalLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
alpha
=
0.25
,
gamma
=
2.0
,
neg_id
=
0
,
ignore_index
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
SigmoidFocalLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
self
.
alpha
,
self
.
gamma
,
self
.
neg_id
=
alpha
,
gamma
,
neg_id
self
.
ignore_index
=
ignore_index
self
.
normalization
=
{
...
...
@@ -218,16 +317,26 @@ class SigmoidFocalLoss(_WeightedLoss):
}
def
forward
(
self
,
input
,
target
):
inputs
=
[
input
,
target
];
self
.
unify_devices
(
inputs
)
inputs
=
[
input
,
target
]
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
SoftmaxFocalLoss
(
_WeightedLoss
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
alpha
=
0.25
,
gamma
=
2.0
,
neg_id
=
0
,
ignore_index
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
):
super
(
SoftmaxFocalLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
def
__init__
(
self
,
weight
=
None
,
size_average
=
None
,
alpha
=
0.25
,
gamma
=
2.0
,
neg_id
=
0
,
ignore_index
=
None
,
reduce
=
None
,
reduction
=
'elementwise_mean'
,
):
super
(
SoftmaxFocalLoss
,
self
)
.
__init__
(
weight
,
size_average
,
reduce
,
reduction
)
self
.
alpha
,
self
.
gamma
,
self
.
neg_id
=
alpha
,
gamma
,
neg_id
self
.
ignore_index
=
ignore_index
self
.
normalization
=
{
...
...
@@ -250,6 +359,7 @@ class SoftmaxFocalLoss(_WeightedLoss):
}
def
forward
(
self
,
input
,
target
):
inputs
=
[
input
,
target
];
self
.
unify_devices
(
inputs
)
inputs
=
[
input
,
target
]
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
\ No newline at end of file
Dragon/python/dragon/vm/torch/nn/modules/pooling.py
View file @
0936a50
...
...
@@ -18,8 +18,15 @@ from dragon.vm.torch.nn.modules.utils import _pair
class
_PoolNd
(
Module
):
def
__init__
(
self
,
kernel_size
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
return_indices
=
False
,
ceil_mode
=
False
):
def
__init__
(
self
,
kernel_size
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
return_indices
=
False
,
ceil_mode
=
False
,
):
super
(
_PoolNd
,
self
)
.
__init__
()
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
or
kernel_size
...
...
Dragon/python/dragon/vm/torch/nn/modules/rnn.py
View file @
0936a50
...
...
@@ -31,9 +31,17 @@ from dragon.vm.torch.ops.builtin import zeros as Zeros, xw_plus_b
class
RNNBase
(
Module
):
def
__init__
(
self
,
mode
,
input_size
,
hidden_size
,
num_layers
=
1
,
bias
=
True
,
batch_first
=
False
,
dropout
=
0
,
bidirectional
=
False
):
def
__init__
(
self
,
mode
,
input_size
,
hidden_size
,
num_layers
=
1
,
bias
=
True
,
batch_first
=
False
,
dropout
=
0
,
bidirectional
=
False
,
):
super
(
RNNBase
,
self
)
.
__init__
()
self
.
mode
=
mode
self
.
input_size
=
input_size
...
...
@@ -256,14 +264,23 @@ class RNN(RNNBase):
Examples
--------
>>> import dragon.vm.torch as torch
>>> rnn = RNN(32, 64, num_layers=1, bidirectional=True)
>>> x = torch.ones(8, 32, 256)
>>> outputs, hidden = rnn(x)
"""
def
__init__
(
self
,
input_size
,
hidden_size
,
nonlinearity
=
'relu'
,
num_layers
=
1
,
bias
=
True
,
batch_first
=
False
,
dropout
=
0
,
bidirectional
=
False
):
def
__init__
(
self
,
input_size
,
hidden_size
,
nonlinearity
=
'relu'
,
num_layers
=
1
,
bias
=
True
,
batch_first
=
False
,
dropout
=
0
,
bidirectional
=
False
,
):
"""Construct a RNN module.
Parameters
...
...
@@ -303,14 +320,22 @@ class LSTM(RNNBase):
Examples
--------
>>> import dragon.vm.torch as torch
>>> rnn = LSTM(32, 64, num_layers=2, bidirectional=True, dropout=0.5)
>>> x = torch.ones(8, 32, 256)
>>> outputs, hidden = rnn(x)
"""
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
bias
=
True
,
batch_first
=
False
,
dropout
=
0
,
bidirectional
=
False
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
bias
=
True
,
batch_first
=
False
,
dropout
=
0
,
bidirectional
=
False
,
):
"""Construct a LSTM module.
Parameters
...
...
@@ -347,14 +372,22 @@ class GRU(RNNBase):
Examples
--------
>>> import dragon.vm.torch as torch
>>> rnn = GRU(32, 64, num_layers=2, bidirectional=False)
>>> x = torch.ones(8, 32, 256)
>>> outputs, hidden = rnn(x)
"""
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
bias
=
True
,
batch_first
=
False
,
dropout
=
0
,
bidirectional
=
False
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
bias
=
True
,
batch_first
=
False
,
dropout
=
0
,
bidirectional
=
False
,
):
"""Construct a GRU module.
Parameters
...
...
@@ -413,15 +446,12 @@ class RNNCellBase(Module):
for
weight
in
self
.
parameters
():
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
from
.activation
import
Tanh
,
Sigmoid
class
LSTMCell
(
RNNCellBase
):
def
__init__
(
self
,
input_size
,
hidden_size
,
bias
=
True
):
super
(
LSTMCell
,
self
)
.
__init__
(
input_size
,
hidden_size
,
bias
,
num_chunks
=
4
)
self
.
register_op
()
self
.
tanh
=
Tanh
()
self
.
sigmoid
=
Sigmoid
()
def
register_op
(
self
):
self
.
op_meta
=
{
'op_type'
:
'LSTMCell'
,
'arguments'
:
{}}
...
...
Dragon/python/dragon/vm/torch/onnx/utils.py
View file @
0936a50
...
...
@@ -27,9 +27,13 @@ from dragon.vm.onnx import export_from_graph_def
def
export
(
model
,
args
,
f
,
verbose
=
False
,
input_names
=
None
,
output_names
=
None
,
opset_version
=
None
,
model
,
args
,
f
,
verbose
=
False
,
input_names
=
None
,
output_names
=
None
,
opset_version
=
None
,
):
"""Export a model into ONNX format.
...
...
Dragon/python/dragon/vm/torch/ops/builtin.py
View file @
0936a50
...
...
@@ -1128,8 +1128,14 @@ def _allreduce(grads):
return
module
.
forward
(
grads
)
def
_update
(
param
,
grad
,
op_type
,
slot
,
lr_mult
=
1.0
,
decay_mult
=
1.0
):
def
_update
(
param
,
grad
,
op_type
,
slot
,
lr_mult
=
1.0
,
decay_mult
=
1.0
,
):
dev
=
MakeDevice
(
inputs
=
[
param
])
key
=
'{}/{}/{}/{}'
.
format
(
op_type
,
dev
,
slot
,
param
.
name
)
module
=
get_module
(
Update
,
key
,
dev
,
op_type
=
op_type
,
...
...
@@ -1169,26 +1175,55 @@ def bilinear_resize(input, dsize, fx=-1.0, fy=-1.0):
return
_resize_2d
(
input
,
'BilinearResize'
,
dsize
,
fx
,
fy
)
def
roi_pool
(
feature
,
rois
,
pooled_h
,
pooled_w
,
spatial_scale
):
def
roi_pool
(
feature
,
rois
,
pooled_h
,
pooled_w
,
spatial_scale
,
):
dev
=
MakeDevice
(
inputs
=
[
feature
])
key
=
'RoIPool/{}/pool_h:{}/pool_w:{}/spatial_scale:{}'
.
format
(
dev
,
pooled_h
,
pooled_w
,
spatial_scale
)
key
=
'RoIPool/{}'
\
'/pool_h:{}'
\
'/pool_w:{}'
\
'/spatial_scale:{}'
\
.
format
(
dev
,
pooled_h
,
pooled_w
,
spatial_scale
)
module
=
get_module
(
RoIPool
,
key
,
dev
,
pooled_h
=
pooled_h
,
pooled_w
=
pooled_w
,
spatial_scale
=
spatial_scale
)
pooled_h
=
pooled_h
,
pooled_w
=
pooled_w
,
spatial_scale
=
spatial_scale
,
)
return
module
.
forward
(
feature
,
rois
)
def
roi_align
(
feature
,
rois
,
pooled_h
,
pooled_w
,
spatial_scale
,
sampling_ratio
=
2
):
def
roi_align
(
feature
,
rois
,
pooled_h
,
pooled_w
,
spatial_scale
,
sampling_ratio
=
2
,
):
dev
=
MakeDevice
(
inputs
=
[
feature
])
key
=
'RoIAlign/{}/pool_h:{}/pool_w:{}/'
\
'spatial_scale:{}/sampling_ratio:{}'
.
format
(
dev
,
pooled_h
,
pooled_w
,
spatial_scale
,
sampling_ratio
)
key
=
'RoIAlign/{}'
\
'/pool_h:{}'
\
'/pool_w:{}'
\
'/spatial_scale:{}'
\
'/sampling_ratio:{}'
\
.
format
(
dev
,
pooled_h
,
pooled_w
,
spatial_scale
,
sampling_ratio
)
module
=
get_module
(
RoIAlign
,
key
,
dev
,
pooled_h
=
pooled_h
,
pooled_w
=
pooled_w
,
spatial_scale
=
spatial_scale
,
sampling_ratio
=
sampling_ratio
)
pooled_h
=
pooled_h
,
pooled_w
=
pooled_w
,
spatial_scale
=
spatial_scale
,
sampling_ratio
=
sampling_ratio
,
)
return
module
.
forward
(
feature
,
rois
)
\ No newline at end of file
Dragon/src/operators/recurrent/lstm_cell_op.cc
View file @
0936a50
...
...
@@ -37,7 +37,7 @@ void LSTMCellGradientOp<Context>::RunWithType() {
auto
*
Xdata
=
Input
(
0
).
template
data
<
T
,
Context
>
();
auto
*
HXdata
=
Input
(
1
).
template
data
<
T
,
Context
>
();
auto
*
Cdata
=
Input
(
2
).
template
data
<
T
,
Context
>
();
auto
*
dHdata
=
Input
(
-
2
).
template
data
<
T
,
Context
>
();
auto
*
dHdata
=
Input
(
3
).
template
data
<
T
,
Context
>
();
auto
*
dCdata
=
Input
(
4
).
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
Output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
dHXdata
=
Output
(
1
)
->
template
mutable_data
<
T
,
Context
>
();
...
...
@@ -81,7 +81,7 @@ class GetLSTMCellGradient final : public GradientMakerBase {
GRADIENT_MAKER_CTOR
(
GetLSTMCellGradient
);
vector
<
OperatorDef
>
MakeDefs
()
override
{
return
SingleDef
(
def
.
type
()
+
"Gradient"
,
""
,
vector
<
string
>
({
I
(
0
),
I
(
1
),
O
(
0
),
GO
(
0
),
GO
(
1
)
}),
vector
<
string
>
({
I
(
0
),
I
(
1
),
O
(
1
),
GO
(
0
),
GO
(
1
)
}),
vector
<
string
>
({
GI
(
0
),
GI
(
1
)
}));
}
// Fill zero for dCNext
...
...
Dragon/src/operators/vision/cudnn_conv2d_transpose_op.cc
View file @
0936a50
...
...
@@ -163,7 +163,7 @@ void CuDNNConvTranspose2dOp<Context>::RunOnDevice() {
#if CUDNN_VERSION_MAX(6, 0, 0)
for
(
int
i
=
0
;
i
<
dilation
.
size
();
i
++
)
if
(
dilation
[
i
]
!=
1
)
return
Conv
2dTranspose
Op
<
Context
>::
RunOnDevice
();
return
Conv
Transpose2d
Op
<
Context
>::
RunOnDevice
();
#endif
ConvTranspose2dOp
<
Context
>::
Reshape
();
...
...
@@ -355,7 +355,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::RunOnDevice() {
#if CUDNN_VERSION_MAX(6, 0, 0)
for
(
int
i
=
0
;
i
<
dilation
.
size
();
i
++
)
if
(
dilation
[
i
]
!=
1
)
return
Conv
2dTranspose
GradientOp
<
Context
>::
RunOnDevice
();
return
Conv
Transpose2d
GradientOp
<
Context
>::
RunOnDevice
();
#endif
ConvTranspose2dGradientOp
<
Context
>::
GradientReshape
();
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment