Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit 4e937b6c
authored
Jul 31, 2017
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add custom concat op for densenet
1 parent
8651e1b5
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
179 additions
and
158 deletions
Dragon/include/core/operator_gradient.h
Dragon/include/core/tensor.h
Dragon/include/operators/common/concat_op.h
Dragon/include/operators/update/update_op_base.h
Dragon/include/operators/vision/dense_concat_op.h
Dragon/python/operators/vision.py
Dragon/python/ops.py
Dragon/python/vm/caffe/io/image_fetcher.py
Dragon/python/vm/caffe/layers/__init__.py
Dragon/python/vm/caffe/layers/common.py
Dragon/src/operators/common/concat_op.cc
Dragon/src/operators/update/update_op_base.cc
Dragon/src/operators/vision/dense_concat_op.cc
Dragon/include/core/operator_gradient.h
View file @
4e937b6
...
@@ -74,7 +74,7 @@ class GradientMakerBase {
...
@@ -74,7 +74,7 @@ class GradientMakerBase {
const
vector
<
string
>&
g_outputs_
;
const
vector
<
string
>&
g_outputs_
;
};
};
// implemented in operator.c
pp
// implemented in operator.c
c
Gradient
MakeGradientForOp
(
const
OperatorDef
&
op_def
,
const
vector
<
string
>&
g_outputs
);
Gradient
MakeGradientForOp
(
const
OperatorDef
&
op_def
,
const
vector
<
string
>&
g_outputs
);
# define GRADIENT_MAKER_CTOR(name) \
# define GRADIENT_MAKER_CTOR(name) \
...
@@ -99,7 +99,7 @@ DECLARE_REGISTRY(NoGradientRegistry,
...
@@ -99,7 +99,7 @@ DECLARE_REGISTRY(NoGradientRegistry,
const
OperatorDef
&
,
const
OperatorDef
&
,
const
vector
<
string
>&
);
const
vector
<
string
>&
);
// define in the operator.c
pp
// define in the operator.c
c
#define REGISTER_GRADIENT(name, ...) \
#define REGISTER_GRADIENT(name, ...) \
REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
...
...
Dragon/include/core/tensor.h
View file @
4e937b6
...
@@ -189,6 +189,10 @@ class Tensor {
...
@@ -189,6 +189,10 @@ class Tensor {
memory_
.
reset
();
memory_
.
reset
();
}
}
void
Release
()
{
memory_
.
reset
();
}
private
:
private
:
vector
<
TIndex
>
dims_
;
vector
<
TIndex
>
dims_
;
TIndex
size_
=
0
,
capacity_
=
0
;
TIndex
size_
=
0
,
capacity_
=
0
;
...
...
Dragon/include/operators/common/concat_op.h
View file @
4e937b6
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
namespace
dragon
{
namespace
dragon
{
template
<
class
Context
>
template
<
class
Context
>
class
ConcatOp
final
:
public
Operator
<
Context
>
{
class
ConcatOp
:
public
Operator
<
Context
>
{
public
:
public
:
ConcatOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
ConcatOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
Operator
<
Context
>
(
op_def
,
ws
),
:
Operator
<
Context
>
(
op_def
,
ws
),
...
...
Dragon/include/operators/update/update_op_base.h
View file @
4e937b6
...
@@ -26,9 +26,7 @@ class UpdateOpBase : public Operator<Context> {
...
@@ -26,9 +26,7 @@ class UpdateOpBase : public Operator<Context> {
float
param
(
const
string
&
name
)
const
;
float
param
(
const
string
&
name
)
const
;
void
InitMPI
();
void
InitMPI
();
void
ShareBeforeRun
()
override
;
void
RunOnDevice
()
override
;
void
RunOnDevice
()
override
;
void
ClearAfterRun
()
override
;
template
<
typename
T
>
void
ReduceRunWithType
();
template
<
typename
T
>
void
ReduceRunWithType
();
template
<
typename
T
>
void
PreprocessRunWithType
();
template
<
typename
T
>
void
PreprocessRunWithType
();
virtual
void
ComputeRunWithFloat
()
=
0
;
virtual
void
ComputeRunWithFloat
()
=
0
;
...
...
Dragon/include/operators/vision/dense_concat_op.h
0 → 100644
View file @
4e937b6
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
#define DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
#include "operators/common/concat_op.h"
namespace
dragon
{
template
<
class
Context
>
class
DenseConcatOp
final
:
public
ConcatOp
<
Context
>
{
public
:
DenseConcatOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
ConcatOp
<
Context
>
(
op_def
,
ws
)
{
}
void
RunOnDevice
()
override
;
};
template
<
class
Context
>
class
DenseConcatGradientOp
:
public
ConcatGradientOp
<
Context
>
{
public
:
DenseConcatGradientOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
ConcatGradientOp
<
Context
>
(
op_def
,
ws
)
{}
void
ShareBeforeRun
()
override
;
void
RunOnDevice
()
override
;
void
ClearAfterRun
()
override
;
template
<
typename
T
>
void
RunWithType
();
};
}
// namespace dragon
#endif // DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
\ No newline at end of file
Dragon/python/operators/vision.py
View file @
4e937b6
...
@@ -194,9 +194,20 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
...
@@ -194,9 +194,20 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
return
output
return
output
def
DenseConcat
(
inputs
,
axis
=
1
,
**
kwargs
):
if
not
isinstance
(
inputs
,
list
)
or
len
(
inputs
)
!=
2
:
raise
RuntimeError
(
'DenseConcat Operator accepts 2 Tensors as inputs'
)
args
=
locals
();
kwargs
=
args
[
'kwargs'
]
del
args
[
'kwargs'
];
kwargs
=
dict
(
args
,
**
kwargs
)
kwargs
[
'num_input'
]
=
len
(
inputs
)
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'DenseConcat'
,
**
kwargs
)
if
all
(
input
.
shape
is
not
None
for
input
in
inputs
):
if
all
(
input
.
shape
[
axis
]
is
not
None
for
input
in
inputs
):
output
.
shape
=
inputs
[
0
]
.
shape
[:]
for
i
in
xrange
(
1
,
len
(
inputs
)):
output
.
shape
[
axis
]
+=
inputs
[
i
]
.
shape
[
axis
]
return
output
\ No newline at end of file
Dragon/python/ops.py
View file @
4e937b6
...
@@ -39,6 +39,7 @@ ROIAlign = vision.ROIAlign
...
@@ -39,6 +39,7 @@ ROIAlign = vision.ROIAlign
LRN
=
vision
.
LRN
LRN
=
vision
.
LRN
NNResize
=
vision
.
NNResize
NNResize
=
vision
.
NNResize
BiasAdd
=
vision
.
BiasAdd
BiasAdd
=
vision
.
BiasAdd
DenseConcat
=
vision
.
DenseConcat
# recurrent
# recurrent
LSTMUnit
=
recurrent
.
LSTMUnit
LSTMUnit
=
recurrent
.
LSTMUnit
...
...
Dragon/python/vm/caffe/io/image_fetcher.py
deleted
100644 → 0
View file @
8651e1b
# --------------------------------------------------------
# Caffe for Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
import
os
import
numpy
as
np
import
numpy.random
as
npr
from
multiprocessing
import
Process
import
dragon.config
as
config
import
dragon.core.mpi
as
mpi
from
__init__
import
GetProperty
class
Datum
(
object
):
def
__init__
(
self
):
self
.
_file
=
''
self
.
_label
=
None
class
ImageReader
(
Process
):
def
__init__
(
self
,
**
kwargs
):
super
(
ImageReader
,
self
)
.
__init__
()
self
.
_shuffle
=
GetProperty
(
kwargs
,
'shuffle'
,
False
)
self
.
_force_gray
=
GetProperty
(
kwargs
,
'force_gray'
,
False
)
self
.
_source
=
GetProperty
(
kwargs
,
'source'
,
''
)
self
.
_mean_value
=
GetProperty
(
kwargs
,
'mean_value'
,
[])
self
.
_scale
=
GetProperty
(
kwargs
,
'scale'
,
1.0
)
self
.
_mirror
=
GetProperty
(
kwargs
,
'mirror'
,
False
)
self
.
_phase
=
GetProperty
(
kwargs
,
'phase'
,
'TRAIN'
)
self
.
_crop_size
=
GetProperty
(
kwargs
,
'crop_size'
,
0
)
self
.
_step_val
=
GetProperty
(
kwargs
,
'step'
,
1
)
self
.
_random_seed
=
config
.
GetRandomSeed
()
self
.
_indices
=
[]
self
.
_cur_idx
=
0
if
mpi
.
is_init
():
idx
,
group
=
mpi
.
allow_parallel
()
if
idx
!=
-
1
:
# valid data parallel
rank
=
mpi
.
rank
()
self
.
_random_seed
+=
rank
# for shuffle
for
i
,
node
in
enumerate
(
group
):
if
rank
==
node
:
self
.
_cur_idx
=
i
if
not
kwargs
.
has_key
(
'step'
):
self
.
_step_val
=
len
(
group
)
self
.
_Q
=
None
self
.
ParseImageSet
()
self
.
daemon
=
True
def
cleanup
():
print
'Terminating DataReader......'
self
.
terminate
()
self
.
join
()
import
atexit
atexit
.
register
(
cleanup
)
def
ParseImageSet
(
self
):
if
not
os
.
path
.
exists
(
self
.
_source
):
raise
RuntimeError
(
'DataReader found the source does not exist'
)
with
open
(
self
.
_source
)
as
f
:
for
line
in
f
:
content
=
line
.
split
()
item
=
Datum
()
item
.
_file
=
content
[
0
]
if
len
(
content
)
>
1
:
item
.
_label
=
tuple
(
content
[
idx
]
for
idx
in
xrange
(
1
,
len
(
content
)))
else
:
item
.
_label
=
None
self
.
_indices
.
append
(
item
)
def
load_image
(
self
,
index
):
import
PIL.Image
as
Image
filepath
=
os
.
path
.
join
(
self
.
_source
,
self
.
_indices
[
index
]
.
_file
)
assert
os
.
path
.
exists
(
filepath
)
im
=
Image
.
open
(
filepath
)
im
=
np
.
array
(
im
,
dtype
=
np
.
float32
)
if
len
(
im
.
shape
)
<
3
:
im
=
im
[:,
:,
np
.
newaxis
]
if
self
.
_force_gray
:
im
=
im
[:,
:,
-
1
,
np
.
newaxis
]
else
:
if
im
.
shape
[
2
]
==
1
:
# copy to 3 channels
im
=
np
.
concatenate
([
im
,
im
,
im
],
axis
=
2
)
else
:
im
=
im
[:,
:,
::
-
1
]
# RGB -> BGR
# handle crop
if
self
.
_crop_size
>
0
:
assert
im
.
shape
[
0
]
>=
self
.
_crop_size
assert
im
.
shape
[
1
]
>=
self
.
_crop_size
if
self
.
_phase
==
0
:
h_off
=
npr
.
randint
(
im
.
shape
[
0
]
-
self
.
_crop_size
+
1
)
w_off
=
npr
.
randint
(
im
.
shape
[
1
]
-
self
.
_crop_size
+
1
)
else
:
h_off
=
(
im
.
shape
[
0
]
-
self
.
_crop_size
)
/
2
w_off
=
(
im
.
shape
[
1
]
-
self
.
_crop_size
)
/
2
im
=
im
[
h_off
:
h_off
+
self
.
_crop_size
,
w_off
:
w_off
+
self
.
_crop_size
,
:]
# handle mirror
if
self
.
_mirror
:
if
npr
.
randint
(
0
,
2
)
>
0
:
im
=
im
[:,
::
-
1
,
:]
# handle mean value
if
len
(
self
.
_mean_value
)
>
0
:
im
=
im
-
self
.
_mean_value
# handle scale
if
self
.
_scale
!=
1.0
:
im
=
im
*
self
.
_scale
return
im
def
load_image_label
(
self
,
index
):
im
=
self
.
load_image
(
index
)
label
=
self
.
_indices
[
index
]
.
_label
if
label
is
not
None
:
return
(
im
,
label
)
else
:
return
[
im
]
def
run
(
self
):
npr
.
seed
(
self
.
_random_seed
)
while
True
:
self
.
_Q
.
put
(
self
.
load_image_label
(
self
.
_cur_idx
))
if
self
.
_shuffle
:
self
.
_cur_idx
=
npr
.
randint
(
0
,
len
(
self
.
_indices
))
else
:
self
.
_cur_idx
=
(
self
.
_cur_idx
+
self
.
_step_val
)
%
len
(
self
.
_indices
)
Dragon/python/vm/caffe/layers/__init__.py
View file @
4e937b6
...
@@ -21,4 +21,4 @@ from common import InnerProductLayer, AccuracyLayer, BatchNormLayer, \
...
@@ -21,4 +21,4 @@ from common import InnerProductLayer, AccuracyLayer, BatchNormLayer, \
ReshapeLayer
,
EltwiseLayer
,
ScaleLayer
,
\
ReshapeLayer
,
EltwiseLayer
,
ScaleLayer
,
\
SoftmaxLayer
,
PermuteLayer
,
FlattenLayer
,
ConcatLayer
,
\
SoftmaxLayer
,
PermuteLayer
,
FlattenLayer
,
ConcatLayer
,
\
NormalizeLayer
,
InstanceNormLayer
,
TileLayer
,
\
NormalizeLayer
,
InstanceNormLayer
,
TileLayer
,
\
ExpandDimsLayer
,
ProposalLayer
ExpandDimsLayer
,
ProposalLayer
,
DenseConcatLayer
\ No newline at end of file
\ No newline at end of file
Dragon/python/vm/caffe/layers/common.py
View file @
4e937b6
...
@@ -89,6 +89,17 @@ class ConcatLayer(Layer):
...
@@ -89,6 +89,17 @@ class ConcatLayer(Layer):
return
ops
.
Concat
(
bottom
,
**
self
.
_param
)
return
ops
.
Concat
(
bottom
,
**
self
.
_param
)
class
DenseConcatLayer
(
Layer
):
def
__init__
(
self
,
LayerParameter
):
super
(
DenseConcatLayer
,
self
)
.
__init__
(
LayerParameter
)
param
=
LayerParameter
.
concat_param
self
.
_param
=
{
'axis'
:
param
.
axis
}
def
Setup
(
self
,
bottom
):
super
(
DenseConcatLayer
,
self
)
.
Setup
(
bottom
)
return
ops
.
DenseConcat
(
bottom
,
**
self
.
_param
)
class
CropLayer
(
Layer
):
class
CropLayer
(
Layer
):
def
__init__
(
self
,
LayerParameter
):
def
__init__
(
self
,
LayerParameter
):
super
(
CropLayer
,
self
)
.
__init__
(
LayerParameter
)
super
(
CropLayer
,
self
)
.
__init__
(
LayerParameter
)
...
...
Dragon/src/operators/common/concat_op.cc
View file @
4e937b6
...
@@ -12,11 +12,14 @@ void ConcatOp<Context>::RunWithType() {
...
@@ -12,11 +12,14 @@ void ConcatOp<Context>::RunWithType() {
TIndex
count
=
input
(
i
).
count
();
TIndex
count
=
input
(
i
).
count
();
x_concat_dim
=
input
(
i
).
dim
(
axis
);
x_concat_dim
=
input
(
i
).
dim
(
axis
);
kernel
::
Concat
<
T
,
Context
>
(
count
,
kernel
::
Concat
<
T
,
Context
>
(
count
,
outer_dim
,
inner_dim
,
outer_dim
,
x_concat_dim
,
y_concat_dim
,
inner_dim
,
concat_offset
,
x_concat_dim
,
Xdata
,
Ydata
,
y_concat_dim
,
&
ctx
());
concat_offset
,
Xdata
,
Ydata
,
&
ctx
());
concat_offset
+=
x_concat_dim
;
concat_offset
+=
x_concat_dim
;
}
}
}
}
...
@@ -24,7 +27,7 @@ void ConcatOp<Context>::RunWithType() {
...
@@ -24,7 +27,7 @@ void ConcatOp<Context>::RunWithType() {
template
<
class
Context
>
template
<
class
Context
>
void
ConcatOp
<
Context
>::
RunOnDevice
(){
void
ConcatOp
<
Context
>::
RunOnDevice
(){
concat_dims
=
input
(
0
).
dims
();
concat_dims
=
input
(
0
).
dims
();
for
(
int
i
=
1
;
i
<
nin
;
i
++
){
for
(
int
i
=
1
;
i
<
nin
;
i
++
)
{
CHECK_EQ
(
concat_dims
.
size
(),
input
(
i
).
ndim
())
CHECK_EQ
(
concat_dims
.
size
(),
input
(
i
).
ndim
())
<<
"
\n
all inputs must have the same ndim."
;
<<
"
\n
all inputs must have the same ndim."
;
for
(
int
j
=
0
;
j
<
concat_dims
.
size
();
j
++
){
for
(
int
j
=
0
;
j
<
concat_dims
.
size
();
j
++
){
...
@@ -59,17 +62,20 @@ OPERATOR_SCHEMA(Concat).NumInputs(1, INT_MAX).NumOutputs(1);
...
@@ -59,17 +62,20 @@ OPERATOR_SCHEMA(Concat).NumInputs(1, INT_MAX).NumOutputs(1);
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
ConcatGradientOp
<
Context
>::
RunWithType
()
{
void
ConcatGradientOp
<
Context
>::
RunWithType
()
{
auto
*
dYdata
=
input
(
-
1
).
template
data
<
T
,
Context
>
();
auto
*
dYdata
=
input
(
-
1
).
template
data
<
T
,
Context
>
();
for
(
int
i
=
0
;
i
<
nin
;
i
++
){
for
(
int
i
=
0
;
i
<
nin
;
i
++
)
{
x_concat_dim
=
input
(
i
).
dim
(
axis
);
x_concat_dim
=
input
(
i
).
dim
(
axis
);
if
(
output
(
i
)
->
name
()
!=
"ignore"
)
{
if
(
output
(
i
)
->
name
()
!=
"ignore"
)
{
auto
*
dXdata
=
output
(
i
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
output
(
i
)
->
template
mutable_data
<
T
,
Context
>
();
TIndex
count
=
output
(
i
)
->
count
();
TIndex
count
=
output
(
i
)
->
count
();
kernel
::
ConcatGrad
<
T
,
Context
>
(
count
,
kernel
::
ConcatGrad
<
T
,
Context
>
(
count
,
outer_dim
,
inner_dim
,
outer_dim
,
x_concat_dim
,
y_concat_dim
,
inner_dim
,
concat_offset
,
x_concat_dim
,
dYdata
,
dXdata
,
y_concat_dim
,
&
ctx
());
concat_offset
,
dYdata
,
dXdata
,
&
ctx
());
}
}
concat_offset
+=
x_concat_dim
;
concat_offset
+=
x_concat_dim
;
}
}
...
@@ -132,5 +138,4 @@ public:
...
@@ -132,5 +138,4 @@ public:
};
};
REGISTER_GRADIENT
(
Concat
,
GetConcatGradient
);
REGISTER_GRADIENT
(
Concat
,
GetConcatGradient
);
}
// namespace dragon
}
//
namespace
dragon
\ No newline at end of file
Dragon/src/operators/update/update_op_base.cc
View file @
4e937b6
...
@@ -17,7 +17,7 @@ void UpdateOpBase<Context>::InitMPI() {
...
@@ -17,7 +17,7 @@ void UpdateOpBase<Context>::InitMPI() {
this
->
args
().
count
(
"root"
))
{
this
->
args
().
count
(
"root"
))
{
#ifdef WITH_MPI
#ifdef WITH_MPI
comm
=
(
MPI_Comm
)
OperatorBase
::
GetSingleArg
<
int64_t
>
(
"comm"
,
0
);
comm
=
(
MPI_Comm
)
OperatorBase
::
GetSingleArg
<
int64_t
>
(
"comm"
,
0
);
group
=
(
MPI_Group
)
OperatorBase
::
GetSingleArg
<
int64_t
>
(
"group"
,
0
);
group
=
(
MPI_Group
)
OperatorBase
::
GetSingleArg
<
int64_t
>
(
"group"
,
0
);
int
world_root
=
OperatorBase
::
GetSingleArg
<
int
>
(
"root"
,
0
);
int
world_root
=
OperatorBase
::
GetSingleArg
<
int
>
(
"root"
,
0
);
if
(
comm
==
MPI_COMM_NULL
)
return
;
if
(
comm
==
MPI_COMM_NULL
)
return
;
allow_parallel
=
true
;
allow_parallel
=
true
;
...
@@ -46,6 +46,7 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
...
@@ -46,6 +46,7 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
segment_ends
[
0
]
=
segment_sizes
[
0
];
segment_ends
[
0
]
=
segment_sizes
[
0
];
for
(
int
i
=
1
;
i
<
segment_ends
.
size
();
i
++
)
for
(
int
i
=
1
;
i
<
segment_ends
.
size
();
i
++
)
segment_ends
[
i
]
=
segment_sizes
[
i
]
+
segment_ends
[
i
-
1
];
segment_ends
[
i
]
=
segment_sizes
[
i
]
+
segment_ends
[
i
-
1
];
buffer
=
ws
()
->
GetBuffer
();
buffer
->
Reshape
(
vector
<
TIndex
>
(
1
,
segment_sizes
[
0
]));
buffer
->
Reshape
(
vector
<
TIndex
>
(
1
,
segment_sizes
[
0
]));
#ifdef WITH_CUDA_AWARE
#ifdef WITH_CUDA_AWARE
auto
*
Bdata
=
buffer
->
mutable_data
<
T
,
Context
>
();
auto
*
Bdata
=
buffer
->
mutable_data
<
T
,
Context
>
();
...
@@ -80,6 +81,7 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
...
@@ -80,6 +81,7 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
1.0
,
Bdata
,
segment_update
);
1.0
,
Bdata
,
segment_update
);
#endif // WITH_CUDA_AWARE
#endif // WITH_CUDA_AWARE
}
}
ws
()
->
ReleaseBuffer
(
buffer
);
// allgather
// allgather
for
(
int
i
=
0
;
i
<
comm_size
-
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
comm_size
-
1
;
i
++
)
{
...
@@ -205,14 +207,4 @@ template class UpdateOpBase<CPUContext>;
...
@@ -205,14 +207,4 @@ template class UpdateOpBase<CPUContext>;
template
class
UpdateOpBase
<
CUDAContext
>
;
template
class
UpdateOpBase
<
CUDAContext
>
;
#endif
#endif
template
<
class
Context
>
void
UpdateOpBase
<
Context
>::
ShareBeforeRun
()
{
buffer
=
ws
()
->
GetBuffer
();
}
template
<
class
Context
>
void
UpdateOpBase
<
Context
>::
ClearAfterRun
()
{
ws
()
->
ReleaseBuffer
(
buffer
);
}
}
//
namespace
dragon
}
//
namespace
dragon
\ No newline at end of file
Dragon/src/operators/vision/dense_concat_op.cc
0 → 100644
View file @
4e937b6
#include "operators/vision/dense_concat_op.h"
#include "core/workspace.h"
#include "utils/op_kernel.h"
namespace
dragon
{
template
<
class
Context
>
void
DenseConcatOp
<
Context
>::
RunOnDevice
()
{
ConcatOp
<
Context
>::
RunOnDevice
();
input
(
0
).
Release
();
// keep shape, just release mem
}
DEPLOY_CPU
(
DenseConcat
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
DenseConcat
);
#endif
OPERATOR_SCHEMA
(
DenseConcat
).
NumInputs
(
2
).
NumOutputs
(
1
);
template
<
class
Context
>
template
<
typename
T
>
void
DenseConcatGradientOp
<
Context
>::
RunWithType
()
{
// restore X1 from Y
auto
*
Ydata
=
input
(
-
2
).
template
data
<
T
,
Context
>
();
auto
*
Xdata
=
input
(
0
).
template
mutable_data
<
T
,
Context
>
();
this
->
x_concat_dim
=
input
(
0
).
dim
(
this
->
axis
);
TIndex
count
=
input
(
0
).
count
();
this
->
concat_dims
=
input
(
-
1
).
dims
();
this
->
y_concat_dim
=
this
->
concat_dims
[
this
->
axis
];
this
->
outer_dim
=
input
(
-
1
).
count
(
0
,
this
->
axis
);
this
->
inner_dim
=
input
(
-
1
).
count
(
this
->
axis
+
1
);
kernel
::
ConcatGrad
<
T
,
Context
>
(
count
,
this
->
outer_dim
,
this
->
inner_dim
,
this
->
x_concat_dim
,
this
->
y_concat_dim
,
0
,
Ydata
,
Xdata
,
&
ctx
());
}
template
<
class
Context
>
void
DenseConcatGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
input
(
0
).
template
IsType
<
float
>
())
RunWithType
<
float
>
();
else
if
(
input
(
0
).
template
IsType
<
float16
>
())
RunWithType
<
float16
>
();
else
LOG
(
FATAL
)
<<
"unsupported input types."
;
ConcatGradientOp
<
Context
>::
RunOnDevice
();
}
template
<
class
Context
>
void
DenseConcatGradientOp
<
Context
>::
ShareBeforeRun
()
{
Tensor
*
dX
=
ws
()
->
GetBuffer
();
if
(
dX
!=
nullptr
)
output
(
0
)
->
Replace
(
*
dX
);
}
template
<
class
Context
>
void
DenseConcatGradientOp
<
Context
>::
ClearAfterRun
()
{
Tensor
*
dY
=
&
input
(
-
1
);
Tensor
*
Y
=
&
input
(
-
2
);
ws
()
->
ReleaseBuffer
(
dY
);
ws
()
->
ReleaseBuffer
(
Y
,
true
);
}
DEPLOY_CPU
(
DenseConcatGradient
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
DenseConcatGradient
);
#endif
OPERATOR_SCHEMA
(
DenseConcatGradient
).
NumInputs
(
4
).
NumOutputs
(
2
);
class
GetDenseConcatGradient
:
public
GradientMakerBase
{
public
:
GRADIENT_MAKER_CTOR
(
GetDenseConcatGradient
);
vector
<
OperatorDef
>
MakeDefs
()
override
{
return
SingleDef
(
def
.
type
()
+
"Gradient"
,
""
,
vector
<
string
>
{
I
(
0
),
I
(
1
),
O
(
0
),
GO
(
0
)},
vector
<
string
>
{
GI
(
0
),
GI
(
1
)});
}
};
REGISTER_GRADIENT
(
DenseConcat
,
GetDenseConcatGradient
);
}
//
namespace
dragon
\ No newline at end of file
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment