Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit c9db9eee
authored
Dec 14, 2017
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix/Refactor the GroupConvolution on cuDNN
1 parent
6f2751b1
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
24 changed files
with
280 additions
and
67 deletions
Dragon/include/core/context.h
Dragon/include/core/context_cuda.h
Dragon/include/core/mixedmem.h
Dragon/include/core/tensor.h
Dragon/include/operators/norm/l2_norm_op.h
Dragon/include/operators/vision/conv_op.h
Dragon/include/operators/vision/conv_transpose_op.h
Dragon/include/operators/vision/lrn_op.h
Dragon/include/utils/cudnn_device.h
Dragon/python/dragon/core/tensor.py
Dragon/python/dragon/docs/install.rst
Dragon/python/dragon/operators/ndarray.py
Dragon/python/dragon/operators/norm.py
Dragon/python/dragon/operators/vision.py
Dragon/python/dragon/vm/caffe/layers/vision.py
Dragon/src/core/graph.cc
Dragon/src/core/mixedmem.cc
Dragon/src/core/operator.cc
Dragon/src/operators/norm/l2_norm_op.cc
Dragon/src/operators/vision/cudnn_conv2d_op.cc
Dragon/src/operators/vision/cudnn_conv2d_transpose_op.cc
Dragon/src/operators/vision/cudnn_lrn_op.cc
Dragon/src/operators/vision/lrn_op.cc
Dragon/src/utils/cudnn_device.cc
Dragon/include/core/context.h
View file @
c9db9ee
...
@@ -51,6 +51,9 @@ class CPUContext {
...
@@ -51,6 +51,9 @@ class CPUContext {
inline
static
void
Memcpy
(
size_t
nbytes
,
void
*
dst
,
const
void
*
src
)
{
memcpy
(
dst
,
src
,
nbytes
);
}
inline
static
void
Memcpy
(
size_t
nbytes
,
void
*
dst
,
const
void
*
src
)
{
memcpy
(
dst
,
src
,
nbytes
);
}
inline
static
void
Delete
(
void
*
data
)
{
free
(
data
);
}
inline
static
void
Delete
(
void
*
data
)
{
free
(
data
);
}
template
<
class
DstContext
,
class
SrcContext
>
inline
static
void
MemcpyAsync
(
size_t
nbytes
,
void
*
dst
,
const
void
*
src
)
{
NOT_IMPLEMENTED
;
}
template
<
typename
T
,
class
DstContext
,
class
SrcContext
>
template
<
typename
T
,
class
DstContext
,
class
SrcContext
>
inline
static
void
Copy
(
int
n
,
T
*
dst
,
const
T
*
src
)
{
inline
static
void
Copy
(
int
n
,
T
*
dst
,
const
T
*
src
)
{
if
(
dst
==
src
)
return
;
if
(
dst
==
src
)
return
;
...
...
Dragon/include/core/context_cuda.h
View file @
c9db9ee
...
@@ -108,11 +108,11 @@ class CUDAContext {
...
@@ -108,11 +108,11 @@ class CUDAContext {
CUDA_CHECK
(
cudaMemcpy
(
dst
,
src
,
nbytes
,
cudaMemcpyDefault
));
CUDA_CHECK
(
cudaMemcpy
(
dst
,
src
,
nbytes
,
cudaMemcpyDefault
));
}
}
template
<
class
DstContext
,
class
SrcContext
>
inline
static
void
MemcpyAsync
(
size_t
nbytes
,
void
*
dst
,
const
void
*
src
)
{
inline
static
void
MemcpyAsync
(
size_t
nbytes
,
void
*
dst
,
const
void
*
src
)
{
cudaStream_t
stream
;
cudaStream_t
stream
;
CUDA_CHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDA_CHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDA_CHECK
(
cudaMemcpyAsync
(
dst
,
src
,
nbytes
,
cudaMemcpyDefault
,
stream
));
CUDA_CHECK
(
cudaMemcpyAsync
(
dst
,
src
,
nbytes
,
cudaMemcpyDefault
,
stream
));
CUDA_CHECK
(
cudaStreamSynchronize
(
stream
));
CUDA_CHECK
(
cudaStreamDestroy
(
stream
));
CUDA_CHECK
(
cudaStreamDestroy
(
stream
));
}
}
...
...
Dragon/include/core/mixedmem.h
View file @
c9db9ee
Dragon/include/core/tensor.h
View file @
c9db9ee
...
@@ -215,7 +215,7 @@ class Tensor {
...
@@ -215,7 +215,7 @@ class Tensor {
TIndex
size_
=
0
,
capacity_
=
0
;
TIndex
size_
=
0
,
capacity_
=
0
;
TypeMeta
meta_
;
TypeMeta
meta_
;
string
name_
;
string
name_
;
shared_ptr
<
MixedMemory
>
memory_
;
shared_ptr
<
MixedMemory
>
memory_
,
host_memory_
;
MixedMemory
*
ex_memory_
=
nullptr
;
MixedMemory
*
ex_memory_
=
nullptr
;
bool
is_corrupted_
=
false
,
own_mem_
=
true
;
bool
is_corrupted_
=
false
,
own_mem_
=
true
;
};
};
...
...
Dragon/include/operators/norm/l2_norm_op.h
View file @
c9db9ee
...
@@ -18,7 +18,8 @@ class L2NormOp final : public Operator<Context> {
...
@@ -18,7 +18,8 @@ class L2NormOp final : public Operator<Context> {
:
Operator
<
Context
>
(
op_def
,
ws
),
:
Operator
<
Context
>
(
op_def
,
ws
),
axis
(
OperatorBase
::
GetSingleArg
<
int
>
(
"axis"
,
0
)),
axis
(
OperatorBase
::
GetSingleArg
<
int
>
(
"axis"
,
0
)),
num_axes
(
OperatorBase
::
GetSingleArg
<
int
>
(
"num_axes"
,
-
1
)),
num_axes
(
OperatorBase
::
GetSingleArg
<
int
>
(
"num_axes"
,
-
1
)),
eps
(
OperatorBase
::
GetSingleArg
<
float
>
(
"eps"
,
float
(
1e-5
)))
{}
eps
(
OperatorBase
::
GetSingleArg
<
float
>
(
"eps"
,
float
(
1e-5
))),
mode
(
OperatorBase
::
GetSingleArg
<
string
>
(
"mode"
,
"SUM"
))
{}
void
RunOnDevice
()
override
;
void
RunOnDevice
()
override
;
template
<
typename
T
>
void
RunWithType
();
template
<
typename
T
>
void
RunWithType
();
...
@@ -26,6 +27,7 @@ class L2NormOp final : public Operator<Context> {
...
@@ -26,6 +27,7 @@ class L2NormOp final : public Operator<Context> {
protected
:
protected
:
float
eps
;
float
eps
;
TIndex
axis
,
num_axes
,
end_axis
;
TIndex
axis
,
num_axes
,
end_axis
;
string
mode
;
bool
across_inner
;
bool
across_inner
;
Tensor
*
norm
,
*
buffer
,
*
multiplier
;
Tensor
*
norm
,
*
buffer
,
*
multiplier
;
TIndex
outer_dim
,
dim
,
inner_dim
,
spatial_dim
;
TIndex
outer_dim
,
dim
,
inner_dim
,
spatial_dim
;
...
...
Dragon/include/operators/vision/conv_op.h
View file @
c9db9ee
...
@@ -48,10 +48,15 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
...
@@ -48,10 +48,15 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
public
:
public
:
CuDNNConv2dOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
CuDNNConv2dOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Conv2dOp
<
Context
>
(
def
,
ws
)
{
:
Conv2dOp
<
Context
>
(
def
,
ws
)
{
handle
=
new
cudnnHandle_t
[
this
->
group
];
#if CUDNN_VERSION_MIN(7, 0, 0)
stream
=
new
cudaStream_t
[
this
->
group
];
cudnn_group
=
1
;
#else
cudnn_group
=
this
->
group
;
#endif
handle
=
new
cudnnHandle_t
[
cudnn_group
];
stream
=
new
cudaStream_t
[
cudnn_group
];
ctx
().
SwitchToDevice
();
ctx
().
SwitchToDevice
();
for
(
int
g
=
0
;
g
<
this
->
group
;
g
++
)
{
for
(
int
g
=
0
;
g
<
cudnn_
group
;
g
++
)
{
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
[
g
]));
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
[
g
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handle
[
g
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handle
[
g
]));
CUDNN_CHECK
(
cudnnSetStream
(
handle
[
g
],
stream
[
g
]));
CUDNN_CHECK
(
cudnnSetStream
(
handle
[
g
],
stream
[
g
]));
...
@@ -78,7 +83,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
...
@@ -78,7 +83,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnnConvolutionDescriptor_t
conv_desc
;
cudnnConvolutionDescriptor_t
conv_desc
;
cudnnFilterDescriptor_t
filter_desc
;
cudnnFilterDescriptor_t
filter_desc
;
size_t
workspace_fwd_data_size
;
size_t
workspace_fwd_data_size
;
TIndex
bias_offset
;
TIndex
bias_offset
,
cudnn_group
;
};
};
template
<
class
Context
>
template
<
class
Context
>
...
@@ -86,9 +91,14 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
...
@@ -86,9 +91,14 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
public
:
public
:
CuDNNConv2dGradientOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
CuDNNConv2dGradientOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Conv2dGradientOp
<
Context
>
(
def
,
ws
)
{
:
Conv2dGradientOp
<
Context
>
(
def
,
ws
)
{
handle
=
new
cudnnHandle_t
[
this
->
group
*
3
];
#if CUDNN_VERSION_MIN(7, 0, 0)
stream
=
new
cudaStream_t
[
this
->
group
*
3
];
cudnn_group
=
1
;
for
(
int
g
=
0
;
g
<
this
->
group
*
3
;
g
++
)
{
#else
cudnn_group
=
this
->
group
;
#endif
handle
=
new
cudnnHandle_t
[
cudnn_group
*
3
];
stream
=
new
cudaStream_t
[
cudnn_group
*
3
];
for
(
int
g
=
0
;
g
<
cudnn_group
*
3
;
g
++
)
{
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
[
g
]));
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
[
g
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handle
[
g
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handle
[
g
]));
CUDNN_CHECK
(
cudnnSetStream
(
handle
[
g
],
stream
[
g
]));
CUDNN_CHECK
(
cudnnSetStream
(
handle
[
g
],
stream
[
g
]));
...
@@ -116,7 +126,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
...
@@ -116,7 +126,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnnConvolutionDescriptor_t
conv_desc
;
cudnnConvolutionDescriptor_t
conv_desc
;
cudnnFilterDescriptor_t
filter_desc
;
cudnnFilterDescriptor_t
filter_desc
;
size_t
workspace_bwd_filter_size
,
workspace_bwd_data_size
;
size_t
workspace_bwd_filter_size
,
workspace_bwd_data_size
;
int
bias_offset
;
TIndex
bias_offset
,
cudnn_group
;
};
};
#endif // WITH_CUDNN
#endif // WITH_CUDNN
...
...
Dragon/include/operators/vision/conv_transpose_op.h
View file @
c9db9ee
...
@@ -52,8 +52,13 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
...
@@ -52,8 +52,13 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
public
:
public
:
CuDNNConv2dTransposeOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
CuDNNConv2dTransposeOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Conv2dTransposeOp
<
Context
>
(
def
,
ws
)
{
:
Conv2dTransposeOp
<
Context
>
(
def
,
ws
)
{
handle
=
new
cudnnHandle_t
[
this
->
group
];
#if CUDNN_VERSION_MIN(7, 0, 0)
stream
=
new
cudaStream_t
[
this
->
group
];
cudnn_group
=
1
;
#else
cudnn_group
=
this
->
group
;
#endif
handle
=
new
cudnnHandle_t
[
cudnn_group
];
stream
=
new
cudaStream_t
[
cudnn_group
];
for
(
int
g
=
0
;
g
<
this
->
group
;
g
++
)
{
for
(
int
g
=
0
;
g
<
this
->
group
;
g
++
)
{
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
[
g
]));
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
[
g
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handle
[
g
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handle
[
g
]));
...
@@ -80,7 +85,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
...
@@ -80,7 +85,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnnConvolutionDescriptor_t
conv_desc
;
cudnnConvolutionDescriptor_t
conv_desc
;
cudnnFilterDescriptor_t
filter_desc
;
cudnnFilterDescriptor_t
filter_desc
;
size_t
workspace_fwd_data_size
;
size_t
workspace_fwd_data_size
;
int
bias_offset
;
TIndex
bias_offset
,
cudnn_group
;
};
};
template
<
class
Context
>
template
<
class
Context
>
...
@@ -88,9 +93,14 @@ class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context>
...
@@ -88,9 +93,14 @@ class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context>
public
:
public
:
CuDNNConv2dTransposeGradientOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
CuDNNConv2dTransposeGradientOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Conv2dTransposeGradientOp
<
Context
>
(
def
,
ws
)
{
:
Conv2dTransposeGradientOp
<
Context
>
(
def
,
ws
)
{
handle
=
new
cudnnHandle_t
[
this
->
group
*
3
];
#if CUDNN_VERSION_MIN(7, 0, 0)
stream
=
new
cudaStream_t
[
this
->
group
*
3
];
cudnn_group
=
1
;
for
(
int
g
=
0
;
g
<
this
->
group
*
3
;
g
++
)
{
#else
cudnn_group
=
this
->
group
;
#endif
handle
=
new
cudnnHandle_t
[
cudnn_group
*
3
];
stream
=
new
cudaStream_t
[
cudnn_group
*
3
];
for
(
int
g
=
0
;
g
<
cudnn_group
*
3
;
g
++
)
{
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
[
g
]));
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
[
g
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handle
[
g
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handle
[
g
]));
CUDNN_CHECK
(
cudnnSetStream
(
handle
[
g
],
stream
[
g
]));
CUDNN_CHECK
(
cudnnSetStream
(
handle
[
g
],
stream
[
g
]));
...
@@ -117,7 +127,7 @@ public:
...
@@ -117,7 +127,7 @@ public:
cudnnConvolutionDescriptor_t
conv_desc
;
cudnnConvolutionDescriptor_t
conv_desc
;
cudnnFilterDescriptor_t
filter_desc
;
cudnnFilterDescriptor_t
filter_desc
;
size_t
workspace_bwd_filter_size
,
workspace_bwd_data_size
;
size_t
workspace_bwd_filter_size
,
workspace_bwd_data_size
;
int
bias_offset
;
TIndex
bias_offset
,
cudnn_group
;
};
};
#endif // WITH_CUDNN
#endif // WITH_CUDNN
...
...
Dragon/include/operators/vision/lrn_op.h
View file @
c9db9ee
...
@@ -18,11 +18,12 @@ class LRNOp : public Operator<Context> {
...
@@ -18,11 +18,12 @@ class LRNOp : public Operator<Context> {
public
:
public
:
LRNOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
LRNOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
Operator
<
Context
>
(
op_def
,
ws
),
:
Operator
<
Context
>
(
op_def
,
ws
),
mode
((
LRNMode
)
OperatorBase
::
GetSingleArg
<
int
>
(
"mode"
,
ACROSS_CHANNELS
)),
local_size
(
OperatorBase
::
GetSingleArg
<
int
>
(
"local_size"
,
5
)),
local_size
(
OperatorBase
::
GetSingleArg
<
int
>
(
"local_size"
,
5
)),
alpha
(
OperatorBase
::
GetSingleArg
<
float
>
(
"alpha"
,
float
(
0
.
0001
))),
alpha
(
OperatorBase
::
GetSingleArg
<
float
>
(
"alpha"
,
float
(
0
.
0001
))),
beta
(
OperatorBase
::
GetSingleArg
<
float
>
(
"beta"
,
float
(
0
.
75
))),
beta
(
OperatorBase
::
GetSingleArg
<
float
>
(
"beta"
,
float
(
0
.
75
))),
k
(
OperatorBase
::
GetSingleArg
<
float
>
(
"k"
,
float
(
2
.
0
)))
{}
k
(
OperatorBase
::
GetSingleArg
<
float
>
(
"k"
,
float
(
2
.
0
))),
mode
(
OperatorBase
::
GetSingleArg
<
string
>
(
"mode"
,
"ACROSS_CHANNELS"
)),
data_format
(
OperatorBase
::
GetSingleArg
<
string
>
(
"data_format"
,
"NCHW"
))
{}
void
RunOnDevice
()
override
;
void
RunOnDevice
()
override
;
template
<
typename
T
>
void
RunWithType
();
template
<
typename
T
>
void
RunWithType
();
...
@@ -34,9 +35,9 @@ class LRNOp : public Operator<Context> {
...
@@ -34,9 +35,9 @@ class LRNOp : public Operator<Context> {
template
<
typename
T
>
void
ProdRunWithType
();
template
<
typename
T
>
void
ProdRunWithType
();
protected
:
protected
:
LRNMode
mode
;
int
local_size
;
int
local_size
;
float
alpha
,
beta
,
k
;
float
alpha
,
beta
,
k
;
string
mode
,
data_format
;
unique_ptr
<
OperatorBase
>
sqr_op
,
pool_op
,
pow_op
,
prod_op
;
unique_ptr
<
OperatorBase
>
sqr_op
,
pool_op
,
pow_op
,
prod_op
;
Tensor
*
sqr_in
,
*
prod_in
,
*
sqr_out
,
*
pool_out
,
*
pow_out
;
Tensor
*
sqr_in
,
*
prod_in
,
*
sqr_out
,
*
pool_out
,
*
pow_out
;
Tensor
*
scale
;
Tensor
*
scale
;
...
@@ -47,11 +48,12 @@ class LRNGradientOp : public Operator<Context> {
...
@@ -47,11 +48,12 @@ class LRNGradientOp : public Operator<Context> {
public
:
public
:
LRNGradientOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
LRNGradientOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
Operator
<
Context
>
(
op_def
,
ws
),
:
Operator
<
Context
>
(
op_def
,
ws
),
mode
((
LRNMode
)
OperatorBase
::
GetSingleArg
<
int
>
(
"mode"
,
ACROSS_CHANNELS
)),
local_size
(
OperatorBase
::
GetSingleArg
<
int
>
(
"local_size"
,
5
)),
local_size
(
OperatorBase
::
GetSingleArg
<
int
>
(
"local_size"
,
5
)),
alpha
(
OperatorBase
::
GetSingleArg
<
float
>
(
"alpha"
,
float
(
0
.
0001
))),
alpha
(
OperatorBase
::
GetSingleArg
<
float
>
(
"alpha"
,
float
(
0
.
0001
))),
beta
(
OperatorBase
::
GetSingleArg
<
float
>
(
"beta"
,
float
(
0
.
75
))),
beta
(
OperatorBase
::
GetSingleArg
<
float
>
(
"beta"
,
float
(
0
.
75
))),
k
(
OperatorBase
::
GetSingleArg
<
float
>
(
"k"
,
float
(
2
.
0
)))
{}
k
(
OperatorBase
::
GetSingleArg
<
float
>
(
"k"
,
float
(
2
.
0
))),
mode
(
OperatorBase
::
GetSingleArg
<
string
>
(
"mode"
,
"ACROSS_CHANNELS"
)),
data_format
(
OperatorBase
::
GetSingleArg
<
string
>
(
"data_format"
,
"NCHW"
))
{}
void
RunOnDevice
()
override
;
void
RunOnDevice
()
override
;
template
<
typename
T
>
void
RunWithType
();
template
<
typename
T
>
void
RunWithType
();
...
@@ -63,9 +65,9 @@ class LRNGradientOp : public Operator<Context> {
...
@@ -63,9 +65,9 @@ class LRNGradientOp : public Operator<Context> {
template
<
typename
T
>
void
ProdRunWithType
();
template
<
typename
T
>
void
ProdRunWithType
();
protected
:
protected
:
LRNMode
mode
;
int
local_size
;
int
local_size
;
float
alpha
,
beta
,
k
;
float
alpha
,
beta
,
k
;
string
mode
,
data_format
;
unique_ptr
<
OperatorBase
>
sqr_op
,
pool_op
,
pow_op
,
prod_op
;
unique_ptr
<
OperatorBase
>
sqr_op
,
pool_op
,
pow_op
,
prod_op
;
Tensor
*
sqr_in
,
*
prod_in
,
*
sqr_out
,
*
pool_out
,
*
pow_out
;
Tensor
*
sqr_in
,
*
prod_in
,
*
sqr_out
,
*
pool_out
,
*
pow_out
;
Tensor
*
scale
;
Tensor
*
scale
;
...
...
Dragon/include/utils/cudnn_device.h
View file @
c9db9ee
...
@@ -76,6 +76,9 @@ template <typename T>
...
@@ -76,6 +76,9 @@ template <typename T>
void
cudnnSetTensor4dDesc
(
cudnnTensorDescriptor_t
*
desc
,
const
string
&
data_format
,
const
std
::
vector
<
int64_t
>&
dims
);
void
cudnnSetTensor4dDesc
(
cudnnTensorDescriptor_t
*
desc
,
const
string
&
data_format
,
const
std
::
vector
<
int64_t
>&
dims
);
template
<
typename
T
>
template
<
typename
T
>
void
cudnnSetTensor4dDescWithGroup
(
cudnnTensorDescriptor_t
*
desc
,
const
string
&
data_format
,
const
std
::
vector
<
int64_t
>&
dims
,
const
int64_t
group
);
template
<
typename
T
>
void
cudnnSetTensor5dDesc
(
cudnnTensorDescriptor_t
*
desc
,
const
string
&
data_format
,
const
std
::
vector
<
int64_t
>&
dims
);
void
cudnnSetTensor5dDesc
(
cudnnTensorDescriptor_t
*
desc
,
const
string
&
data_format
,
const
std
::
vector
<
int64_t
>&
dims
);
template
<
typename
T
>
template
<
typename
T
>
...
...
Dragon/python/dragon/core/tensor.py
View file @
c9db9ee
...
@@ -156,29 +156,39 @@ class Tensor(object):
...
@@ -156,29 +156,39 @@ class Tensor(object):
"""
"""
return
self
.
Normal
(
mu
=
mean
,
sigma
=
std
)
return
self
.
Normal
(
mu
=
mean
,
sigma
=
std
)
def
Xavier
(
self
):
def
Xavier
(
self
,
scale
=
3.0
):
"""
"""
Register as a variable with xavier initializer.
Register as a variable with xavier initializer.
"""
"""
return
self
.
_no_parameter_filler
(
'xavier'
)
filler
=
pb
.
TensorFiller
()
filler
.
tensor
=
self
.
name
filler
.
type
=
'xavier'
filler
.
scale
=
scale
ws
.
CreateFiller
(
filler
)
return
self
def
MSRA
(
self
):
def
MSRA
(
self
,
scale
=
2.0
):
"""
"""
Register as a variable with msra initializer.
Register as a variable with msra initializer.
"""
"""
return
self
.
_no_parameter_filler
(
'msra'
)
filler
=
pb
.
TensorFiller
()
filler
.
tensor
=
self
.
name
filler
.
type
=
'msra'
filler
.
scale
=
scale
ws
.
CreateFiller
(
filler
)
return
self
def
GlorotUniform
(
self
):
def
GlorotUniform
(
self
,
scale
=
3.0
):
"""
"""
Register as a variable with glorot uniform initializer.
Register as a variable with glorot uniform initializer.
"""
"""
return
self
.
Xavier
()
return
self
.
Xavier
(
scale
)
def
GlorotNormal
(
self
):
def
GlorotNormal
(
self
,
scale
=
2.0
):
"""
"""
Register as a variable with glorot normal initializer.
Register as a variable with glorot normal initializer.
"""
"""
return
self
.
MSRA
()
return
self
.
MSRA
(
scale
)
##############################################
##############################################
# #
# #
...
...
Dragon/python/dragon/docs/install.rst
View file @
c9db9ee
...
@@ -19,10 +19,18 @@ Installation - Linux (Normal, CPU)
...
@@ -19,10 +19,18 @@ Installation - Linux (Normal, CPU)
**Step 1:** Install C++ Dependencies
**Step 1:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell
.. code-block:: shell
sudo apt-get install libpython-dev
sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev
sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev
sudo apt-get install libopenblas-dev
**Step 2:** Install Python Requirements
**Step 2:** Install Python Requirements
...
@@ -83,10 +91,18 @@ Installation - Linux (Normal, GPU)
...
@@ -83,10 +91,18 @@ Installation - Linux (Normal, GPU)
**Step 2:** Install C++ Dependencies
**Step 2:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell
.. code-block:: shell
sudo apt-get install libpython-dev
sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev
sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev
sudo apt-get install libopenblas-dev
**Step 3:** Install Python Requirements
**Step 3:** Install Python Requirements
...
@@ -149,10 +165,18 @@ Installation - Linux (Distributed, CPU)
...
@@ -149,10 +165,18 @@ Installation - Linux (Distributed, CPU)
**Step 2:** Install C++ Dependencies
**Step 2:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell
.. code-block:: shell
sudo apt-get install libpython-dev
sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev
sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev
sudo apt-get install libopenblas-dev
**Step 3:** Install Python Requirements
**Step 3:** Install Python Requirements
...
@@ -229,10 +253,18 @@ Installation - Linux (Distributed, GPU)
...
@@ -229,10 +253,18 @@ Installation - Linux (Distributed, GPU)
**Step 3:** Install C++ Dependencies
**Step 3:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell
.. code-block:: shell
sudo apt-get install libpython-dev
sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev
sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev
sudo apt-get install libopenblas-dev
**Step 4:** Install Python Requirements
**Step 4:** Install Python Requirements
...
@@ -564,6 +596,7 @@ Add ``REPO_ROOT/3rdparty/bin`` to system environment variables
...
@@ -564,6 +596,7 @@ Add ``REPO_ROOT/3rdparty/bin`` to system environment variables
python setup.py install --user
python setup.py install --user
.. _Anaconda: https://www.anaconda.com/download
.. _CUDA: https://developer.nvidia.com/cuda-toolkit
.. _CUDA: https://developer.nvidia.com/cuda-toolkit
.. _CUDNN: https://developer.nvidia.com/cudnn
.. _CUDNN: https://developer.nvidia.com/cudnn
.. _NCCL: https://developer.nvidia.com/nccl
.. _NCCL: https://developer.nvidia.com/nccl
...
...
Dragon/python/dragon/operators/ndarray.py
View file @
c9db9ee
...
@@ -673,6 +673,7 @@ def Reshape(inputs, shape, **kwargs):
...
@@ -673,6 +673,7 @@ def Reshape(inputs, shape, **kwargs):
output
.
shape
=
[
1
]
*
len
(
shape
)
output
.
shape
=
[
1
]
*
len
(
shape
)
for
i
,
s
in
enumerate
(
shape
):
for
i
,
s
in
enumerate
(
shape
):
if
s
==
-
1
:
output
.
shape
[
i
]
=
1
if
s
==
-
1
:
output
.
shape
[
i
]
=
1
elif
s
==
0
:
output
.
shape
[
i
]
=
inputs
.
shape
[
i
]
else
:
output
.
shape
[
i
]
=
s
else
:
output
.
shape
[
i
]
=
s
return
output
return
output
...
...
Dragon/python/dragon/operators/norm.py
View file @
c9db9ee
...
@@ -189,7 +189,7 @@ def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs):
...
@@ -189,7 +189,7 @@ def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs):
return
output
return
output
def
L2Norm
(
inputs
,
axis
=
0
,
num_axes
=-
1
,
eps
=
1e-5
,
**
kwargs
):
def
L2Norm
(
inputs
,
axis
=
0
,
num_axes
=-
1
,
eps
=
1e-5
,
mode
=
'SUM'
,
**
kwargs
):
"""L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
"""L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
Parameters
Parameters
...
@@ -202,6 +202,8 @@ def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, **kwargs):
...
@@ -202,6 +202,8 @@ def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, **kwargs):
The number of axes of stats region. Default is ``-1`` (Till End).
The number of axes of stats region. Default is ``-1`` (Till End).
eps : float
eps : float
The eps.
The eps.
mode : str
The mode on computing normalizer. ``SUM`` or ``MEAN``.
Returns
Returns
-------
-------
...
...
Dragon/python/dragon/operators/vision.py
View file @
c9db9ee
...
@@ -61,6 +61,12 @@ def Conv2d(inputs, num_output, kernel_size,
...
@@ -61,6 +61,12 @@ def Conv2d(inputs, num_output, kernel_size,
"""
"""
CheckInputs
(
inputs
,
2
,
3
)
CheckInputs
(
inputs
,
2
,
3
)
arguments
=
ParseArguments
(
locals
())
arguments
=
ParseArguments
(
locals
())
if
padding
not
in
(
'VALID'
,
'SAME'
):
raise
ValueError
(
'Unsupported padding algorithm: {}'
.
format
(
padding
))
if
data_format
not
in
(
'NCHW'
,
'NHWC'
):
raise
ValueError
(
'Unsupported data format: {}'
.
format
(
data_format
))
if
not
isinstance
(
arguments
[
'kernel_size'
],
list
):
if
not
isinstance
(
arguments
[
'kernel_size'
],
list
):
arguments
[
'kernel_size'
]
=
[
arguments
[
'kernel_size'
]]
arguments
[
'kernel_size'
]
=
[
arguments
[
'kernel_size'
]]
if
not
isinstance
(
arguments
[
'stride'
],
list
):
if
not
isinstance
(
arguments
[
'stride'
],
list
):
...
@@ -154,6 +160,11 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
...
@@ -154,6 +160,11 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
CheckInputs
(
inputs
,
2
,
3
)
CheckInputs
(
inputs
,
2
,
3
)
arguments
=
ParseArguments
(
locals
())
arguments
=
ParseArguments
(
locals
())
if
padding
not
in
(
'VALID'
,
'SAME'
):
raise
ValueError
(
'Unsupported padding algorithm: {}'
.
format
(
padding
))
if
data_format
not
in
(
'NCHW'
,
'NHWC'
):
raise
ValueError
(
'Unsupported data format: {}'
.
format
(
data_format
))
arguments
[
'output_shape'
]
=
None
arguments
[
'output_shape'
]
=
None
if
output_shape
is
not
None
:
if
output_shape
is
not
None
:
if
not
isinstance
(
output_shape
,
list
):
if
not
isinstance
(
output_shape
,
list
):
...
@@ -170,17 +181,43 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
...
@@ -170,17 +181,43 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
if
not
isinstance
(
arguments
[
'kernel_size'
],
list
):
if
not
isinstance
(
arguments
[
'kernel_size'
],
list
):
arguments
[
'kernel_size'
]
=
[
arguments
[
'kernel_size'
]]
arguments
[
'kernel_size'
]
=
[
arguments
[
'kernel_size'
]]
if
not
isinstance
(
arguments
[
'stride'
],
list
):
if
not
isinstance
(
arguments
[
'stride'
],
list
):
arguments
[
'stride'
]
=
[
arguments
[
'stride'
]]
arguments
[
'stride'
]
=
[
arguments
[
'stride'
]]
if
not
isinstance
(
arguments
[
'pad'
],
list
):
if
not
isinstance
(
arguments
[
'pad'
],
list
):
arguments
[
'pad'
]
=
[
arguments
[
'pad'
]]
arguments
[
'pad'
]
=
[
arguments
[
'pad'
]]
if
not
isinstance
(
arguments
[
'dilation'
],
list
):
if
not
isinstance
(
arguments
[
'dilation'
],
list
):
arguments
[
'dilation'
]
=
[
arguments
[
'dilation'
]]
arguments
[
'dilation'
]
=
[
arguments
[
'dilation'
]]
return
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'Conv2dTranspose'
,
**
arguments
)
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'Conv2dTranspose'
,
**
arguments
)
if
inputs
[
0
]
.
shape
is
not
None
:
output
.
shape
=
inputs
[
0
]
.
shape
[:]
channel_axis
=
1
if
data_format
==
'NCHW'
else
-
1
spatial_axis
=
2
if
data_format
==
'NCHW'
else
1
output
.
shape
[
channel_axis
]
=
num_output
for
i
in
xrange
(
2
):
k
=
arguments
[
'kernel_size'
][
i
]
if
i
<
len
(
arguments
[
'kernel_size'
])
\
else
arguments
[
'kernel_size'
][
-
1
]
s
=
arguments
[
'stride'
][
i
]
if
i
<
len
(
arguments
[
'stride'
])
\
else
arguments
[
'stride'
][
-
1
]
p
=
arguments
[
'pad'
][
i
]
if
i
<
len
(
arguments
[
'pad'
])
\
else
arguments
[
'pad'
][
-
1
]
d
=
arguments
[
'dilation'
][
i
]
if
i
<
len
(
arguments
[
'dilation'
])
\
else
arguments
[
'dilation'
][
-
1
]
dk
=
d
*
(
k
-
1
)
+
1
dp
=
2
*
p
input_size
=
output
.
shape
[
i
+
spatial_axis
]
if
padding
!=
'SAME'
:
output
.
shape
[
i
+
spatial_axis
]
=
s
*
(
input_size
-
1
)
+
dk
-
dp
else
:
if
output_shape
is
None
:
raise
ValueError
(
'The output shape must be specified if using SAME padding algorithm.'
)
if
'dynamic_dsize'
in
arguments
:
output
.
shape
=
None
return
output
output
.
shape
[
i
+
spatial_axis
]
=
output_shape
[
i
+
spatial_axis
]
return
output
def
Pool2d
(
inputs
,
kernel_size
,
stride
,
pad
=
0
,
padding
=
'VALID'
,
def
Pool2d
(
inputs
,
kernel_size
,
stride
,
pad
=
0
,
padding
=
'VALID'
,
...
@@ -222,6 +259,14 @@ def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
...
@@ -222,6 +259,14 @@ def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
"""
"""
CheckInputs
(
inputs
,
1
)
CheckInputs
(
inputs
,
1
)
arguments
=
ParseArguments
(
locals
())
arguments
=
ParseArguments
(
locals
())
if
mode
not
in
(
'MAX'
,
'AVG'
):
raise
ValueError
(
'Unsupported lrn mode: {}'
.
format
(
mode
))
if
padding
not
in
(
'VALID'
,
'SAME'
):
raise
ValueError
(
'Unsupported padding algorithm: {}'
.
format
(
padding
))
if
data_format
not
in
(
'NCHW'
,
'NHWC'
):
raise
ValueError
(
'Unsupported data format: {}'
.
format
(
data_format
))
if
not
isinstance
(
arguments
[
'kernel_size'
],
list
):
if
not
isinstance
(
arguments
[
'kernel_size'
],
list
):
arguments
[
'kernel_size'
]
=
[
arguments
[
'kernel_size'
]]
arguments
[
'kernel_size'
]
=
[
arguments
[
'kernel_size'
]]
if
not
isinstance
(
arguments
[
'stride'
],
list
):
if
not
isinstance
(
arguments
[
'stride'
],
list
):
...
@@ -311,7 +356,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
...
@@ -311,7 +356,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
return
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'ROIAlign'
,
**
arguments
)
return
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'ROIAlign'
,
**
arguments
)
def
LRN
(
inputs
,
local_size
=
5
,
alpha
=
0.0001
,
beta
=
0.75
,
k
=
2.0
,
mode
=
'ACROSS_CHANNELS'
,
**
kwargs
):
def
LRN
(
inputs
,
local_size
=
5
,
alpha
=
0.0001
,
beta
=
0.75
,
k
=
2.0
,
mode
=
'ACROSS_CHANNELS'
,
data_format
=
'NCHW'
,
**
kwargs
):
"""Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_.
"""Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_.
Parameters
Parameters
...
@@ -328,17 +374,22 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN
...
@@ -328,17 +374,22 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN
The k of LRN.
The k of LRN.
mode : str
mode : str
The mode, ``ACROSS_CHANNELS`` or ``WITHIN_CHANNEL``.
The mode, ``ACROSS_CHANNELS`` or ``WITHIN_CHANNEL``.
data_format : str
The data format. ``NCHW`` or ``NHWC``.
Returns
Returns
-------
-------
Tensor
Tensor
The
normalized
tensor.
The
output
tensor.
"""
"""
CheckInputs
(
inputs
,
1
)
CheckInputs
(
inputs
,
1
)
arguments
=
ParseArguments
(
locals
())
arguments
=
ParseArguments
(
locals
())
SUPPORT_MODES
=
{
'ACROSS_CHANNELS'
:
0
,
'WITHIN_CHANNEL'
:
1
}
arguments
[
'mode'
]
=
SUPPORT_MODES
[
mode
]
if
mode
not
in
(
'ACROSS_CHANNELS'
,
'WITHIN_CHANNEL'
):
raise
ValueError
(
'Unsupported lrn mode: {}'
.
format
(
mode
))
if
data_format
not
in
(
'NCHW'
,
'NHWC'
):
raise
ValueError
(
'Unsupported data format: {}'
.
format
(
data_format
))
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'LRN'
,
**
arguments
)
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'LRN'
,
**
arguments
)
...
@@ -356,9 +407,9 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
...
@@ -356,9 +407,9 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
Parameters
Parameters
----------
----------
inputs : Tensor
inputs : Tensor
The input ten
os
r.
The input ten
so
r.
dsize : tuple, list, Tensor or None
dsize : tuple, list, Tensor or None
The output size.
The output size
, formats as (h, w)
.
fy : float
fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded).
The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float
fx : float
...
@@ -374,6 +425,10 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
...
@@ -374,6 +425,10 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
"""
"""
CheckInputs
(
inputs
,
1
)
CheckInputs
(
inputs
,
1
)
arguments
=
ParseArguments
(
locals
())
arguments
=
ParseArguments
(
locals
())
if
data_format
not
in
(
'NCHW'
,
'NHWC'
):
raise
ValueError
(
'Unsupported data format: {}'
.
format
(
data_format
))
if
arguments
[
'dsize'
]
is
not
None
:
if
arguments
[
'dsize'
]
is
not
None
:
if
isinstance
(
arguments
[
'dsize'
][
0
],
Tensor
):
if
isinstance
(
arguments
[
'dsize'
][
0
],
Tensor
):
arguments
[
'dynamic_dsize'
]
=
[
arguments
[
'dsize'
][
0
]
.
name
,
arguments
[
'dynamic_dsize'
]
=
[
arguments
[
'dsize'
][
0
]
.
name
,
...
@@ -388,6 +443,20 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
...
@@ -388,6 +443,20 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'NNResize'
,
**
arguments
)
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'NNResize'
,
**
arguments
)
if
inputs
.
shape
is
not
None
:
if
len
(
inputs
.
shape
)
!=
4
:
raise
ValueError
(
'The inputs should be a 4d Tensor.'
)
if
'dynamic_dsize'
not
in
arguments
:
output
.
shape
=
inputs
.
shape
[:]
spatial_axis
=
2
if
data_format
==
'NCHW'
else
1
for
i
in
xrange
(
2
):
output_dim
=
output
.
shape
[
spatial_axis
+
i
]
if
'static_size'
in
arguments
:
output_dim
=
dsize
[
i
]
else
:
output_dim
=
int
(
float
(
output_dim
)
*
([
fy
,
fx
])[
i
])
output
.
shape
[
spatial_axis
+
i
]
=
output_dim
return
output
return
output
...
@@ -399,9 +468,9 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
...
@@ -399,9 +468,9 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
Parameters
Parameters
----------
----------
inputs : Tensor
inputs : Tensor
The input ten
os
r.
The input ten
so
r.
dsize : tuple, list, Tensor or None
dsize : tuple, list, Tensor or None
The
dest output size
.
The
output size, formats as (h, w)
.
fy : float
fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded).
The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float
fx : float
...
@@ -417,6 +486,10 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
...
@@ -417,6 +486,10 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
"""
"""
CheckInputs
(
inputs
,
1
)
CheckInputs
(
inputs
,
1
)
arguments
=
ParseArguments
(
locals
())
arguments
=
ParseArguments
(
locals
())
if
data_format
not
in
(
'NCHW'
,
'NHWC'
):
raise
ValueError
(
'Unsupported data format: {}'
.
format
(
data_format
))
if
arguments
[
'dsize'
]
is
not
None
:
if
arguments
[
'dsize'
]
is
not
None
:
if
isinstance
(
arguments
[
'dsize'
][
0
],
Tensor
):
if
isinstance
(
arguments
[
'dsize'
][
0
],
Tensor
):
arguments
[
'dynamic_dsize'
]
=
[
arguments
[
'dsize'
][
0
]
.
name
,
arguments
[
'dynamic_dsize'
]
=
[
arguments
[
'dsize'
][
0
]
.
name
,
...
@@ -431,6 +504,20 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
...
@@ -431,6 +504,20 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'BilinearResize'
,
**
arguments
)
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'BilinearResize'
,
**
arguments
)
if
inputs
.
shape
is
not
None
:
if
len
(
inputs
.
shape
)
!=
4
:
raise
ValueError
(
'The inputs should be a 4d Tensor.'
)
if
'dynamic_dsize'
not
in
arguments
:
output
.
shape
=
inputs
.
shape
[:]
spatial_axis
=
2
if
data_format
==
'NCHW'
else
1
for
i
in
xrange
(
2
):
output_dim
=
output
.
shape
[
spatial_axis
+
i
]
if
'static_size'
in
arguments
:
output_dim
=
dsize
[
i
]
else
:
output_dim
=
int
(
float
(
output_dim
)
*
([
fy
,
fx
])[
i
])
output
.
shape
[
spatial_axis
+
i
]
=
output_dim
return
output
return
output
...
@@ -453,6 +540,9 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
...
@@ -453,6 +540,9 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
CheckInputs
(
inputs
,
2
)
CheckInputs
(
inputs
,
2
)
arguments
=
ParseArguments
(
locals
())
arguments
=
ParseArguments
(
locals
())
if
data_format
not
in
(
'NCHW'
,
'NHWC'
):
raise
ValueError
(
'Unsupported data format: {}'
.
format
(
data_format
))
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'BiasAdd'
,
**
arguments
)
output
=
Tensor
.
CreateOperator
(
nout
=
1
,
op_type
=
'BiasAdd'
,
**
arguments
)
if
inputs
[
0
]
.
shape
is
not
None
:
if
inputs
[
0
]
.
shape
is
not
None
:
...
...
Dragon/python/dragon/vm/caffe/layers/vision.py
View file @
c9db9ee
...
@@ -229,7 +229,9 @@ class LRNLayer(Layer):
...
@@ -229,7 +229,9 @@ class LRNLayer(Layer):
self
.
_param
=
{
'local_size'
:
param
.
local_size
,
self
.
_param
=
{
'local_size'
:
param
.
local_size
,
'alpha'
:
param
.
alpha
,
'alpha'
:
param
.
alpha
,
'beta'
:
param
.
beta
,
'beta'
:
param
.
beta
,
'mode'
:
{
0
:
'ACROSS_CHANNELS'
,
1
:
'WITHIN_CHANNEL'
}[
param
.
norm_region
]}
'mode'
:
{
0
:
'ACROSS_CHANNELS'
,
1
:
'WITHIN_CHANNEL'
}[
param
.
norm_region
],
'data_format'
:
'NCHW'
}
def
Setup
(
self
,
bottom
):
def
Setup
(
self
,
bottom
):
super
(
LRNLayer
,
self
)
.
Setup
(
bottom
)
super
(
LRNLayer
,
self
)
.
Setup
(
bottom
)
input
=
bottom
[
0
]
if
isinstance
(
bottom
,
list
)
else
bottom
input
=
bottom
[
0
]
if
isinstance
(
bottom
,
list
)
else
bottom
...
...
Dragon/src/core/graph.cc
View file @
c9db9ee
...
@@ -261,7 +261,7 @@ GraphDef Graph::MakeUpdate(const GraphDef& meta_graph) {
...
@@ -261,7 +261,7 @@ GraphDef Graph::MakeUpdate(const GraphDef& meta_graph) {
}
else
if
(
this
->
args_
[
"parallel_mode"
].
s
()
==
"MIXED"
)
{
}
else
if
(
this
->
args_
[
"parallel_mode"
].
s
()
==
"MIXED"
)
{
/*
/*
See: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour
See: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour
Link
s
: http://arxiv.org/abs/1706.02677
Link: http://arxiv.org/abs/1706.02677
*/
*/
NOT_IMPLEMENTED
;
NOT_IMPLEMENTED
;
}
}
...
...
Dragon/src/core/mixedmem.cc
View file @
c9db9ee
Dragon/src/core/operator.cc
View file @
c9db9ee
...
@@ -150,6 +150,8 @@ void Operator<Context>::CleanResource() {
...
@@ -150,6 +150,8 @@ void Operator<Context>::CleanResource() {
if
(
output
(
i
)
->
memory
()
!=
buffer
->
memory
())
buffer
->
Move
(
output
(
i
)
->
memory
());
if
(
output
(
i
)
->
memory
()
!=
buffer
->
memory
())
buffer
->
Move
(
output
(
i
)
->
memory
());
}
}
}
}
// post-process for sharing grads
if
(
allow_share_grads_
)
{
if
(
allow_share_grads_
)
{
// TODO(PhyscalX): we preset input(-1)->output(0) to share
// TODO(PhyscalX): we preset input(-1)->output(0) to share
Tensor
*
dY
=
&
input
(
-
1
);
Tensor
*
dY
=
&
input
(
-
1
);
...
...
Dragon/src/operators/norm/l2_norm_op.cc
View file @
c9db9ee
...
@@ -30,6 +30,7 @@ void L2NormOp<Context>::RunWithType() {
...
@@ -30,6 +30,7 @@ void L2NormOp<Context>::RunWithType() {
if
(
across_inner
)
{
if
(
across_inner
)
{
auto
*
Ndata_
=
norm
->
template
mutable_data
<
float
,
CPUContext
>
();
auto
*
Ndata_
=
norm
->
template
mutable_data
<
float
,
CPUContext
>
();
float
sum_of_sqr
=
math
::
Dot
<
T
,
Context
>
(
buffer
->
count
(),
Xdata
,
Xdata
);
float
sum_of_sqr
=
math
::
Dot
<
T
,
Context
>
(
buffer
->
count
(),
Xdata
,
Xdata
);
if
(
mode
==
"MEAN"
)
sum_of_sqr
=
sum_of_sqr
/
dim
;
Ndata_
[
n
]
=
pow
(
sum_of_sqr
+
eps
,
0.5
);
Ndata_
[
n
]
=
pow
(
sum_of_sqr
+
eps
,
0.5
);
math
::
Scale
<
T
,
Context
>
(
buffer
->
count
(),
1.0
/
Ndata_
[
n
],
Xdata
,
Ydata
);
math
::
Scale
<
T
,
Context
>
(
buffer
->
count
(),
1.0
/
Ndata_
[
n
],
Xdata
,
Ydata
);
}
else
{
}
else
{
...
@@ -37,7 +38,7 @@ void L2NormOp<Context>::RunWithType() {
...
@@ -37,7 +38,7 @@ void L2NormOp<Context>::RunWithType() {
math
::
Square
<
T
,
Context
>
(
buffer
->
count
(),
Xdata
,
Bdata
);
math
::
Square
<
T
,
Context
>
(
buffer
->
count
(),
Xdata
,
Bdata
);
// compute T1 = \sum_{i} x_{i,j}^{2}
// compute T1 = \sum_{i} x_{i,j}^{2}
math
::
Gemv
<
T
,
Context
>
(
CblasTrans
,
dim
,
inner_dim
,
math
::
Gemv
<
T
,
Context
>
(
CblasTrans
,
dim
,
inner_dim
,
1.0
,
mode
==
"MEAN"
?
1.0
/
dim
:
1.0
,
Bdata
,
DMuldata
,
Bdata
,
DMuldata
,
1.0
,
1.0
,
Ndata
);
Ndata
);
...
...
Dragon/src/operators/vision/cudnn_conv2d_op.cc
View file @
c9db9ee
This diff is collapsed.
Click to expand it.
Dragon/src/operators/vision/cudnn_conv2d_transpose_op.cc
View file @
c9db9ee
This diff is collapsed.
Click to expand it.
Dragon/src/operators/vision/cudnn_lrn_op.cc
View file @
c9db9ee
...
@@ -6,9 +6,9 @@ namespace dragon {
...
@@ -6,9 +6,9 @@ namespace dragon {
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
CuDNNLRNOp
<
Context
>::
RunWithType
()
{
void
CuDNNLRNOp
<
Context
>::
RunWithType
()
{
if
(
this
->
data_format
==
"NCHW"
)
{
cudnnSetTensorDesc
<
T
>
(
&
input_desc
,
&
input
(
0
));
cudnnSetTensorDesc
<
T
>
(
&
input_desc
,
&
input
(
0
));
cudnnSetTensorDesc
<
T
>
(
&
output_desc
,
output
(
0
));
cudnnSetTensorDesc
<
T
>
(
&
output_desc
,
output
(
0
));
auto
*
Xdata
=
input
(
0
).
template
data
<
T
,
Context
>
();
auto
*
Xdata
=
input
(
0
).
template
data
<
T
,
Context
>
();
auto
*
Ydata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
Ydata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
CUDNN_CHECK
(
cudnnLRNCrossChannelForward
(
cudnn_handle
(),
CUDNN_CHECK
(
cudnnLRNCrossChannelForward
(
cudnn_handle
(),
...
@@ -16,20 +16,23 @@ void CuDNNLRNOp<Context>::RunWithType() {
...
@@ -16,20 +16,23 @@ void CuDNNLRNOp<Context>::RunWithType() {
CUDNN_LRN_CROSS_CHANNEL_DIM1
,
CUDNN_LRN_CROSS_CHANNEL_DIM1
,
CUDNNType
<
T
>::
one
,
input_desc
,
Xdata
,
CUDNNType
<
T
>::
one
,
input_desc
,
Xdata
,
CUDNNType
<
T
>::
zero
,
output_desc
,
Ydata
));
CUDNNType
<
T
>::
zero
,
output_desc
,
Ydata
));
}
else
LOG
(
FATAL
)
<<
"Unknown data format: "
<<
this
->
data_format
;
}
}
template
<
class
Context
>
template
<
class
Context
>
void
CuDNNLRNOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNLRNOp
<
Context
>::
RunOnDevice
()
{
output
(
0
)
->
ReshapeLike
(
input
(
0
));
output
(
0
)
->
ReshapeLike
(
input
(
0
));
if
(
this
->
mode
==
ACROSS_CHANNELS
)
{
if
(
this
->
mode
==
"ACROSS_CHANNELS"
)
{
if
(
input
(
0
).
template
IsType
<
float
>
())
RunWithType
<
float
>
();
if
(
input
(
0
).
template
IsType
<
float
>
())
RunWithType
<
float
>
();
#ifdef WITH_CUDA_FP16
#ifdef WITH_CUDA_FP16
else
if
(
input
(
0
).
template
IsType
<
float16
>
())
RunWithType
<
float16
>
();
else
if
(
input
(
0
).
template
IsType
<
float16
>
())
RunWithType
<
float16
>
();
#endif
#endif
else
LOG
(
FATAL
)
<<
"Unsupported input types."
;
else
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
else
{
}
else
if
(
this
->
mode
==
"WITHIN_CHANNEL"
)
{
LRNOp
<
Context
>::
RunOnDevice
();
LRNOp
<
Context
>::
RunOnDevice
();
}
else
{
LOG
(
FATAL
)
<<
"Unsupported lrn mode: "
<<
this
->
mode
;
}
}
}
}
...
@@ -37,6 +40,7 @@ DEPLOY_CUDNN(LRN);
...
@@ -37,6 +40,7 @@ DEPLOY_CUDNN(LRN);
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
CuDNNLRNGradientOp
<
Context
>::
RunWithType
()
{
void
CuDNNLRNGradientOp
<
Context
>::
RunWithType
()
{
if
(
this
->
data_format
==
"NCHW"
)
{
cudnnSetTensorDesc
<
T
>
(
&
input_desc
,
&
input
(
-
1
));
cudnnSetTensorDesc
<
T
>
(
&
input_desc
,
&
input
(
-
1
));
cudnnSetTensorDesc
<
T
>
(
&
output_desc
,
output
(
0
));
cudnnSetTensorDesc
<
T
>
(
&
output_desc
,
output
(
0
));
...
@@ -51,20 +55,23 @@ void CuDNNLRNGradientOp<Context>::RunWithType() {
...
@@ -51,20 +55,23 @@ void CuDNNLRNGradientOp<Context>::RunWithType() {
input_desc
,
dYdata
,
input_desc
,
dYdata
,
output_desc
,
Xdata
,
output_desc
,
Xdata
,
CUDNNType
<
T
>::
zero
,
output_desc
,
dXdata
));
CUDNNType
<
T
>::
zero
,
output_desc
,
dXdata
));
}
else
LOG
(
FATAL
)
<<
"Unknown data format: "
<<
this
->
data_format
;
}
}
template
<
class
Context
>
template
<
class
Context
>
void
CuDNNLRNGradientOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNLRNGradientOp
<
Context
>::
RunOnDevice
()
{
output
(
0
)
->
ReshapeLike
(
input
(
0
));
output
(
0
)
->
ReshapeLike
(
input
(
0
));
if
(
this
->
mode
==
ACROSS_CHANNELS
)
{
if
(
this
->
mode
==
"ACROSS_CHANNELS"
)
{
if
(
input
(
0
).
template
IsType
<
float
>
())
RunWithType
<
float
>
();
if
(
input
(
0
).
template
IsType
<
float
>
())
RunWithType
<
float
>
();
#ifdef WITH_CUDA_FP16
#ifdef WITH_CUDA_FP16
else
if
(
input
(
0
).
template
IsType
<
float16
>
())
RunWithType
<
float16
>
();
else
if
(
input
(
0
).
template
IsType
<
float16
>
())
RunWithType
<
float16
>
();
#endif
#endif
else
LOG
(
FATAL
)
<<
"Unsupported input types."
;
else
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
else
{
}
else
if
(
this
->
mode
==
"WITHIN_CHANNEL"
)
{
LRNGradientOp
<
Context
>::
RunOnDevice
();
LRNGradientOp
<
Context
>::
RunOnDevice
();
}
else
{
LOG
(
FATAL
)
<<
"Unsupported lrn mode: "
<<
this
->
mode
;
}
}
}
}
...
...
Dragon/src/operators/vision/lrn_op.cc
View file @
c9db9ee
...
@@ -45,15 +45,16 @@ template <class Context> template <typename T>
...
@@ -45,15 +45,16 @@ template <class Context> template <typename T>
void
LRNOp
<
Context
>::
PoolRunWithType
()
{
void
LRNOp
<
Context
>::
PoolRunWithType
()
{
pool_out
=
ws
()
->
CreateTensor
(
"/mnt/"
+
anchor
()
+
"/pool_out"
);
pool_out
=
ws
()
->
CreateTensor
(
"/mnt/"
+
anchor
()
+
"/pool_out"
);
if
(
!
pool_op
)
{
if
(
!
pool_op
)
{
Argument
ks
,
s
,
p
,
m
ode
;
Argument
ks
,
s
,
p
,
m
,
df
;
ks
.
set_name
(
"kernel_size"
);
ks
.
add_ints
(
local_size
);
ks
.
set_name
(
"kernel_size"
);
ks
.
add_ints
(
local_size
);
s
.
set_name
(
"stride"
);
s
.
add_ints
(
1
);
s
.
set_name
(
"stride"
);
s
.
add_ints
(
1
);
p
.
set_name
(
"pad"
);
p
.
add_ints
((
local_size
-
1
)
/
2
);
p
.
set_name
(
"pad"
);
p
.
add_ints
((
local_size
-
1
)
/
2
);
mode
.
set_name
(
"mode"
);
mode
.
set_s
(
"AVG"
);
m
.
set_name
(
"mode"
);
m
.
set_s
(
"AVG"
);
OperatorDef
pool_op_def
=
MakeOperatorDef
(
"Pooling"
,
""
,
df
.
set_name
(
"data_format"
);
df
.
set_s
(
data_format
);
OperatorDef
pool_op_def
=
MakeOperatorDef
(
"Pooling2d"
,
""
,
vector
<
string
>
({
sqr_out
->
name
()
}),
vector
<
string
>
({
sqr_out
->
name
()
}),
vector
<
string
>
({
pool_out
->
name
()
}),
vector
<
string
>
({
pool_out
->
name
()
}),
vector
<
Argument
>
({
ks
,
s
,
p
,
m
ode
}));
vector
<
Argument
>
({
ks
,
s
,
p
,
m
,
df
}));
if
(
this
->
op_def
().
has_device_option
())
if
(
this
->
op_def
().
has_device_option
())
pool_op_def
.
mutable_device_option
()
->
CopyFrom
(
this
->
op_def
().
device_option
());
pool_op_def
.
mutable_device_option
()
->
CopyFrom
(
this
->
op_def
().
device_option
());
pool_op
.
reset
(
CreateOperator
(
pool_op_def
,
ws
()));
pool_op
.
reset
(
CreateOperator
(
pool_op_def
,
ws
()));
...
@@ -99,12 +100,11 @@ void LRNOp<Context>::ProdRunWithType() {
...
@@ -99,12 +100,11 @@ void LRNOp<Context>::ProdRunWithType() {
template
<
class
Context
>
template
<
class
Context
>
void
LRNOp
<
Context
>::
RunOnDevice
()
{
void
LRNOp
<
Context
>::
RunOnDevice
()
{
if
(
mode
==
ACROSS_CHANNELS
)
{
if
(
mode
==
"ACROSS_CHANNELS"
)
{
if
(
input
(
0
).
template
IsType
<
float
>
())
{
if
(
input
(
0
).
template
IsType
<
float
>
())
{
AcrossRunWithType
<
float
>
();
AcrossRunWithType
<
float
>
();
}
else
{
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
}
else
{
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
}
}
else
if
(
mode
==
"WITHIN_CHANNEL"
)
{
else
{
if
(
input
(
0
).
template
IsType
<
float
>
())
{
if
(
input
(
0
).
template
IsType
<
float
>
())
{
SplitRunWithType
<
float
>
();
SplitRunWithType
<
float
>
();
SquareRunWithType
<
float
>
();
SquareRunWithType
<
float
>
();
...
@@ -112,6 +112,8 @@ void LRNOp<Context>::RunOnDevice() {
...
@@ -112,6 +112,8 @@ void LRNOp<Context>::RunOnDevice() {
PowRunWithType
<
float
>
();
PowRunWithType
<
float
>
();
ProdRunWithType
<
float
>
();
ProdRunWithType
<
float
>
();
}
else
{
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
}
else
{
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
}
else
{
LOG
(
FATAL
)
<<
"Unsupported lrn mode: "
<<
mode
;
}
}
}
}
...
@@ -173,17 +175,18 @@ template <class Context> template <typename T>
...
@@ -173,17 +175,18 @@ template <class Context> template <typename T>
void
LRNGradientOp
<
Context
>::
PoolRunWithType
()
{
void
LRNGradientOp
<
Context
>::
PoolRunWithType
()
{
sqr_out
=
ws
()
->
GetTensor
(
"/mnt/"
+
anchor
()
+
"/sqr_out"
);
sqr_out
=
ws
()
->
GetTensor
(
"/mnt/"
+
anchor
()
+
"/sqr_out"
);
if
(
!
pool_op
)
{
if
(
!
pool_op
)
{
Argument
ks
,
s
,
p
,
m
ode
;
Argument
ks
,
s
,
p
,
m
,
df
;
ks
.
set_name
(
"kernel_size"
);
ks
.
add_ints
(
local_size
);
ks
.
set_name
(
"kernel_size"
);
ks
.
add_ints
(
local_size
);
s
.
set_name
(
"stride"
);
s
.
add_ints
(
1
);
s
.
set_name
(
"stride"
);
s
.
add_ints
(
1
);
p
.
set_name
(
"pad"
);
p
.
add_ints
((
local_size
-
1
)
/
2
);
p
.
set_name
(
"pad"
);
p
.
add_ints
((
local_size
-
1
)
/
2
);
mode
.
set_name
(
"mode"
);
mode
.
set_s
(
"AVG"
);
m
.
set_name
(
"mode"
);
m
.
set_s
(
"AVG"
);
OperatorDef
pool_op_def
=
MakeOperatorDef
(
"PoolingGradient"
,
""
,
df
.
set_name
(
"data_format"
);
df
.
set_s
(
data_format
);
OperatorDef
pool_op_def
=
MakeOperatorDef
(
"Pooling2dGradient"
,
""
,
vector
<
string
>
({
sqr_out
->
name
(),
vector
<
string
>
({
sqr_out
->
name
(),
pool_out
->
name
(),
pool_out
->
name
(),
pool_out
->
name
()
+
"_grad"
}),
pool_out
->
name
()
+
"_grad"
}),
vector
<
string
>
({
sqr_out
->
name
()
+
"_grad"
}),
vector
<
string
>
({
sqr_out
->
name
()
+
"_grad"
}),
vector
<
Argument
>
({
ks
,
s
,
p
,
m
ode
}));
vector
<
Argument
>
({
ks
,
s
,
p
,
m
,
df
}));
if
(
this
->
op_def
().
has_device_option
())
if
(
this
->
op_def
().
has_device_option
())
pool_op_def
.
mutable_device_option
()
->
CopyFrom
(
this
->
op_def
().
device_option
());
pool_op_def
.
mutable_device_option
()
->
CopyFrom
(
this
->
op_def
().
device_option
());
pool_op
.
reset
(
CreateOperator
(
pool_op_def
,
ws
()));
pool_op
.
reset
(
CreateOperator
(
pool_op_def
,
ws
()));
...
@@ -224,12 +227,11 @@ void LRNGradientOp<Context>::SplitRunWithType() {
...
@@ -224,12 +227,11 @@ void LRNGradientOp<Context>::SplitRunWithType() {
template
<
class
Context
>
template
<
class
Context
>
void
LRNGradientOp
<
Context
>::
RunOnDevice
()
{
void
LRNGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
mode
==
ACROSS_CHANNELS
)
{
if
(
mode
==
"ACROSS_CHANNELS"
)
{
if
(
input
(
0
).
template
IsType
<
float
>
())
{
if
(
input
(
0
).
template
IsType
<
float
>
())
{
AcrossRunWithType
<
float
>
();
AcrossRunWithType
<
float
>
();
}
else
{
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
}
else
{
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
}
}
else
if
(
mode
==
"WITHIN_CHANNEL"
)
{
else
{
if
(
input
(
0
).
template
IsType
<
float
>
())
{
if
(
input
(
0
).
template
IsType
<
float
>
())
{
ProdRunWithType
<
float
>
();
ProdRunWithType
<
float
>
();
PowRunWithType
<
float
>
();
PowRunWithType
<
float
>
();
...
@@ -237,6 +239,8 @@ void LRNGradientOp<Context>::RunOnDevice() {
...
@@ -237,6 +239,8 @@ void LRNGradientOp<Context>::RunOnDevice() {
SquareRunWithType
<
float
>
();
SquareRunWithType
<
float
>
();
SplitRunWithType
<
float
>
();
SplitRunWithType
<
float
>
();
}
else
{
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
}
else
{
LOG
(
FATAL
)
<<
"Unsupported input types."
;
}
}
else
{
LOG
(
FATAL
)
<<
"Unsupported lrn mode: "
<<
mode
;
}
}
}
}
...
...
Dragon/src/utils/cudnn_device.cc
View file @
c9db9ee
...
@@ -69,6 +69,34 @@ void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc,
...
@@ -69,6 +69,34 @@ void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
cudnnSetTensor4dDescWithGroup
(
cudnnTensorDescriptor_t
*
desc
,
const
string
&
data_format
,
const
vector
<
TIndex
>&
dims
,
const
TIndex
group
)
{
if
(
data_format
==
"NCHW"
)
{
CUDNN_CHECK
(
cudnnSetTensor4dDescriptorEx
(
*
desc
,
CUDNNType
<
T
>::
type
,
dims
[
0
],
dims
[
1
]
/
group
,
dims
[
2
],
dims
[
3
],
dims
[
1
]
*
dims
[
2
]
*
dims
[
3
],
dims
[
2
]
*
dims
[
3
],
dims
[
3
],
1
));
}
else
if
(
data_format
==
"NHWC"
)
{
CUDNN_CHECK
(
cudnnSetTensor4dDescriptorEx
(
*
desc
,
CUDNNType
<
T
>::
type
,
dims
[
0
],
dims
[
3
]
/
group
,
dims
[
1
],
dims
[
2
],
dims
[
1
]
*
dims
[
2
]
*
dims
[
3
],
1
,
dims
[
2
]
*
dims
[
3
],
dims
[
3
]));
}
else
LOG
(
FATAL
)
<<
"Unknown data format: "
<<
data_format
;
}
template
<
typename
T
>
void
cudnnSetTensor5dDesc
(
cudnnTensorDescriptor_t
*
desc
,
void
cudnnSetTensor5dDesc
(
cudnnTensorDescriptor_t
*
desc
,
const
string
&
data_format
,
const
string
&
data_format
,
const
vector
<
TIndex
>&
dims
)
{
const
vector
<
TIndex
>&
dims
)
{
...
@@ -169,6 +197,7 @@ template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<T
...
@@ -169,6 +197,7 @@ template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<T
template
void
cudnnSetTensor4dDesc
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor4dDesc
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor5dDesc
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor5dDesc
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor3dDesc
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor3dDesc
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor4dDescWithGroup
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
,
const
TIndex
);
template
void
cudnnSetTensorDesc
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
vector
<
TIndex
>&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensorDesc
<
float
>
(
cudnnTensorDescriptor_t
*
,
const
vector
<
TIndex
>&
,
const
vector
<
TIndex
>&
);
...
@@ -180,6 +209,7 @@ template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<
...
@@ -180,6 +209,7 @@ template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<
template
void
cudnnSetTensor4dDesc
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor4dDesc
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor5dDesc
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor5dDesc
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor3dDesc
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor3dDesc
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor4dDescWithGroup
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
,
const
TIndex
);
template
void
cudnnSetTensorDesc
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
vector
<
TIndex
>&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensorDesc
<
double
>
(
cudnnTensorDescriptor_t
*
,
const
vector
<
TIndex
>&
,
const
vector
<
TIndex
>&
);
...
@@ -192,6 +222,7 @@ template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector
...
@@ -192,6 +222,7 @@ template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector
template
void
cudnnSetTensor4dDesc
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor4dDesc
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor5dDesc
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor5dDesc
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor3dDesc
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor3dDesc
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensor4dDescWithGroup
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
string
&
,
const
vector
<
TIndex
>&
,
const
TIndex
);
template
void
cudnnSetTensorDesc
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
vector
<
TIndex
>&
,
const
vector
<
TIndex
>&
);
template
void
cudnnSetTensorDesc
<
float16
>
(
cudnnTensorDescriptor_t
*
,
const
vector
<
TIndex
>&
,
const
vector
<
TIndex
>&
);
#endif
#endif
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment