Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit 007d9c21
authored
Aug 21, 2017
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add NCCL support for synchronous distributed training
1 parent
2f685b88
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
87 additions
and
36 deletions
Dragon/CMakeLists.txt
Dragon/include/operators/update/async_update_op.h
Dragon/include/operators/update/update_op_base.h
Dragon/include/utils/cuda_device.h
Dragon/src/operators/mpi/mpi_broadcast_op.cc
Dragon/src/operators/mpi/mpi_gather_op.cc
Dragon/src/operators/update/async_update_op.cc
Dragon/src/operators/update/update_op_base.cc
README.md
Dragon/CMakeLists.txt
View file @
007d9c2
...
@@ -14,7 +14,8 @@ option(WITH_CUDNN "Set ON to use CUDNN" OFF)
...
@@ -14,7 +14,8 @@ option(WITH_CUDNN "Set ON to use CUDNN" OFF)
option
(
WITH_BLAS
"Set ON to use BLAS"
OFF
)
option
(
WITH_BLAS
"Set ON to use BLAS"
OFF
)
option
(
WITH_SSE
"Set ON to use SSE 4.1"
ON
)
option
(
WITH_SSE
"Set ON to use SSE 4.1"
ON
)
option
(
WITH_MPI
"Set ON to use MPI"
OFF
)
option
(
WITH_MPI
"Set ON to use MPI"
OFF
)
option
(
WITH_MPI_CUDA
"Set ON to use MPI_CUDA_AWARE"
OFF
)
option
(
WITH_MPI_CUDA
"Set ON to use MPI-CUDA"
OFF
)
option
(
WITH_MPI_NCCL
"Set ON to use MPI-NCCL"
OFF
)
option
(
WITH_CUDA_FP16
"Set ON to use FP16"
ON
)
option
(
WITH_CUDA_FP16
"Set ON to use FP16"
ON
)
# set your 3rdparty
# set your 3rdparty
...
@@ -24,7 +25,7 @@ set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
...
@@ -24,7 +25,7 @@ set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
set
(
PYTHON_DIR /usr/include/python2.7
)
# prefer
set
(
PYTHON_DIR /usr/include/python2.7
)
# prefer
#set(PYTHON_DIR /usr/include/python3.x) # optional, set specific version
#set(PYTHON_DIR /usr/include/python3.x) # optional, set specific version
#set(ANACONDA_DIR /xxx/anaconda) # optional, root folder of anaconda, preset for 2.7, 3.5, and 3.6
#set(ANACONDA_DIR /xxx/anaconda) # optional, root folder of anaconda, preset for 2.7, 3.5, and 3.6
set
(
NUMPY_DIR /xxx/numpy
)
# require, root folder of numpy package
set
(
NUMPY_DIR /xxx/numpy
)
# require
d
, root folder of numpy package
# set CUDA compiling architecture
# set CUDA compiling architecture
set
(
CUDA_ARCH -gencode arch=compute_20,code=sm_20
set
(
CUDA_ARCH -gencode arch=compute_20,code=sm_20
...
@@ -61,7 +62,7 @@ set(CUDA_ARCH -gencode arch=compute_20,code=sm_20
...
@@ -61,7 +62,7 @@ set(CUDA_ARCH -gencode arch=compute_20,code=sm_20
# ---[ Dependencies
# ---[ Dependencies
if
(
WITH_CUDA
)
if
(
WITH_CUDA
)
FIND_PACKAGE
(
CUDA REQUIRED
)
FIND_PACKAGE
(
CUDA REQUIRED
)
endif
()
endif
()
set
(
CMAKE_CXX_STANDARD 11
)
set
(
CMAKE_CXX_STANDARD 11
)
...
@@ -129,9 +130,13 @@ if (WITH_MPI)
...
@@ -129,9 +130,13 @@ if (WITH_MPI)
message
(
STATUS
"Use MPI [Optional]"
)
message
(
STATUS
"Use MPI [Optional]"
)
endif
()
endif
()
if
(
WITH_MPI_CUDA
)
if
(
WITH_MPI_CUDA
)
ADD_DEFINITIONS
(
-DWITH_
CUDA_AWARE
)
ADD_DEFINITIONS
(
-DWITH_
MPI_CUDA
)
message
(
STATUS
"Use MPI-CUDA [Optional]"
)
message
(
STATUS
"Use MPI-CUDA [Optional]"
)
endif
()
endif
()
if
(
WITH_MPI_NCCL
)
ADD_DEFINITIONS
(
-DWITH_MPI_NCCL
)
message
(
STATUS
"Use MPI-NCCL [Optional]"
)
endif
()
if
(
WITH_CUDA_FP16
)
if
(
WITH_CUDA_FP16
)
ADD_DEFINITIONS
(
-DWITH_CUDA_FP16
)
ADD_DEFINITIONS
(
-DWITH_CUDA_FP16
)
message
(
STATUS
"Use CUDA FP16 [Optional]"
)
message
(
STATUS
"Use CUDA FP16 [Optional]"
)
...
...
Dragon/include/operators/update/async_update_op.h
View file @
007d9c2
...
@@ -37,7 +37,7 @@ class AsyncUpdateOp final: public UpdateOpBase<Context> {
...
@@ -37,7 +37,7 @@ class AsyncUpdateOp final: public UpdateOpBase<Context> {
Map
<
int
,
int
>
local_timestamp
;
Map
<
int
,
int
>
local_timestamp
;
std
::
unique_ptr
<
std
::
thread
>
thread
;
std
::
unique_ptr
<
std
::
thread
>
thread
;
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
cudaStream_t
stream
;
cudaStream_t
stream
;
cublasHandle_t
handle
;
cublasHandle_t
handle
;
#endif
#endif
...
...
Dragon/include/operators/update/update_op_base.h
View file @
007d9c2
...
@@ -46,7 +46,12 @@ class UpdateOpBase : public Operator<Context> {
...
@@ -46,7 +46,12 @@ class UpdateOpBase : public Operator<Context> {
#ifdef WITH_MPI
#ifdef WITH_MPI
MPI_Comm
comm
;
MPI_Comm
comm
;
MPI_Group
group
;
MPI_Group
group
;
#endif // WITH_MPI
#endif // WITH_MPI
#ifdef WITH_MPI_NCCL
ncclComm_t
nccl_comm
;
cudaStream_t
stream
;
#endif // WITH_MPI_NCCL
};
};
...
...
Dragon/include/utils/cuda_device.h
View file @
007d9c2
...
@@ -14,6 +14,10 @@
...
@@ -14,6 +14,10 @@
#include <curand.h>
#include <curand.h>
#include <cuda.h>
#include <cuda.h>
#ifdef WITH_MPI_NCCL
#include <nccl/nccl.h>
#endif // WITH_MPI_NCCL
#include "core/common.h"
#include "core/common.h"
namespace
dragon
{
namespace
dragon
{
...
@@ -25,19 +29,27 @@ static const int CUDA_NUM_THREADS = 1024;
...
@@ -25,19 +29,27 @@ static const int CUDA_NUM_THREADS = 1024;
do { \
do { \
cudaError_t error = condition; \
cudaError_t error = condition; \
CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
} while (0)
} while (0)
#define CUBLAS_CHECK(condition) \
#define CUBLAS_CHECK(condition) \
do { \
do { \
cublasStatus_t status = condition; \
cublasStatus_t status = condition; \
CHECK_EQ(status, CUBLAS_STATUS_SUCCESS); \
CHECK_EQ(status, CUBLAS_STATUS_SUCCESS); \
} while (0)
} while (0)
#define CURAND_CHECK(condition) \
#define CURAND_CHECK(condition) \
do { \
do { \
curandStatus_t status = condition; \
curandStatus_t status = condition; \
CHECK_EQ(status, CURAND_STATUS_SUCCESS); \
CHECK_EQ(status, CURAND_STATUS_SUCCESS); \
} while (0)
} while (0)
#ifdef WITH_MPI_NCCL
#define NCCL_CHECK(condition) \
do { \
ncclResult_t status = condition; \
CHECK_EQ(status, ncclSuccess) << " " << ncclGetErrorString(status); \
} while (0)
#endif // WITH_MPI_NCCL
#define CUDA_KERNEL_LOOP(i, n) \
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
...
...
Dragon/src/operators/mpi/mpi_broadcast_op.cc
View file @
007d9c2
...
@@ -8,7 +8,7 @@ namespace dragon {
...
@@ -8,7 +8,7 @@ namespace dragon {
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
MPIBroadcastOp
<
Context
>::
RunWithType
()
{
void
MPIBroadcastOp
<
Context
>::
RunWithType
()
{
if
(
this
->
comm_rank
==
this
->
comm_root
)
{
if
(
this
->
comm_rank
==
this
->
comm_root
)
{
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
Xdata
=
input
(
0
).
template
mutable_data
<
T
,
Context
>
();
auto
*
Xdata
=
input
(
0
).
template
mutable_data
<
T
,
Context
>
();
#else
#else
auto
*
Xdata
=
input
(
0
).
template
mutable_data
<
T
,
CPUContext
>
();
auto
*
Xdata
=
input
(
0
).
template
mutable_data
<
T
,
CPUContext
>
();
...
@@ -16,7 +16,7 @@ void MPIBroadcastOp<Context>::RunWithType() {
...
@@ -16,7 +16,7 @@ void MPIBroadcastOp<Context>::RunWithType() {
MPI_Bcast
(
Xdata
,
input
(
0
).
count
(),
MPI_FLOAT
,
this
->
comm_root
,
this
->
comm
);
MPI_Bcast
(
Xdata
,
input
(
0
).
count
(),
MPI_FLOAT
,
this
->
comm_root
,
this
->
comm
);
output
(
0
)
->
Share
(
input
(
0
));
output
(
0
)
->
Share
(
input
(
0
));
}
else
{
}
else
{
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
Ydata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
Ydata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
#else
#else
auto
*
Ydata
=
output
(
0
)
->
template
mutable_data
<
T
,
CPUContext
>
();
auto
*
Ydata
=
output
(
0
)
->
template
mutable_data
<
T
,
CPUContext
>
();
...
@@ -59,7 +59,7 @@ OPERATOR_SCHEMA(MPIBroadcast).NumInputs(1).NumOutputs(1);
...
@@ -59,7 +59,7 @@ OPERATOR_SCHEMA(MPIBroadcast).NumInputs(1).NumOutputs(1);
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
MPIBroadcastGradientOp
<
Context
>::
RunWithType
()
{
void
MPIBroadcastGradientOp
<
Context
>::
RunWithType
()
{
if
(
this
->
comm_rank
==
this
->
comm_root
)
{
if
(
this
->
comm_rank
==
this
->
comm_root
)
{
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
dYdata
=
input
(
-
1
).
template
mutable_data
<
T
,
Context
>
();
auto
*
dYdata
=
input
(
-
1
).
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
ctx
().
template
Copy
<
T
,
Context
,
Context
>
(
output
(
0
)
->
count
(),
dXdata
,
dYdata
);
ctx
().
template
Copy
<
T
,
Context
,
Context
>
(
output
(
0
)
->
count
(),
dXdata
,
dYdata
);
...
@@ -72,7 +72,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
...
@@ -72,7 +72,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
for
(
int
i
=
0
;
i
<
this
->
comm_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
this
->
comm_size
;
i
++
)
{
if
(
i
==
this
->
comm_root
)
continue
;
if
(
i
==
this
->
comm_root
)
continue
;
MPI_Recv
(
dYdata
,
output
(
0
)
->
count
(),
MPI_FLOAT
,
i
,
0
,
this
->
comm
,
MPI_STATUS_IGNORE
);
MPI_Recv
(
dYdata
,
output
(
0
)
->
count
(),
MPI_FLOAT
,
i
,
0
,
this
->
comm
,
MPI_STATUS_IGNORE
);
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
math
::
Add
<
T
,
Context
>
(
output
(
0
)
->
count
(),
dYdata
,
dXdata
,
dXdata
);
math
::
Add
<
T
,
Context
>
(
output
(
0
)
->
count
(),
dYdata
,
dXdata
,
dXdata
);
#else
#else
math
::
Add
<
T
,
CPUContext
>
(
output
(
0
)
->
count
(),
dYdata
,
dXdata
,
dXdata
);
math
::
Add
<
T
,
CPUContext
>
(
output
(
0
)
->
count
(),
dYdata
,
dXdata
,
dXdata
);
...
@@ -80,7 +80,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
...
@@ -80,7 +80,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
}
}
}
}
else
{
else
{
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
dYdata
=
input
(
-
1
).
template
data
<
T
,
Context
>
();
auto
*
dYdata
=
input
(
-
1
).
template
data
<
T
,
Context
>
();
#else
#else
auto
*
dYdata
=
input
(
-
1
).
template
data
<
T
,
CPUContext
>
();
auto
*
dYdata
=
input
(
-
1
).
template
data
<
T
,
CPUContext
>
();
...
...
Dragon/src/operators/mpi/mpi_gather_op.cc
View file @
007d9c2
...
@@ -11,7 +11,7 @@ void MPIGatherOp<Context>::RunWithType() {
...
@@ -11,7 +11,7 @@ void MPIGatherOp<Context>::RunWithType() {
output
(
this
->
comm_rank
)
->
Share
(
input
(
0
));
output
(
this
->
comm_rank
)
->
Share
(
input
(
0
));
for
(
int
i
=
0
;
i
<
this
->
comm_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
this
->
comm_size
;
i
++
)
{
if
(
i
==
this
->
comm_root
)
continue
;
if
(
i
==
this
->
comm_root
)
continue
;
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
Ydata
=
output
(
i
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
Ydata
=
output
(
i
)
->
template
mutable_data
<
T
,
Context
>
();
#else
#else
auto
*
Ydata
=
output
(
i
)
->
template
mutable_data
<
T
,
CPUContext
>
();
auto
*
Ydata
=
output
(
i
)
->
template
mutable_data
<
T
,
CPUContext
>
();
...
@@ -20,7 +20,7 @@ void MPIGatherOp<Context>::RunWithType() {
...
@@ -20,7 +20,7 @@ void MPIGatherOp<Context>::RunWithType() {
}
}
}
}
else
{
else
{
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
Xdata
=
input
(
0
).
template
data
<
T
,
Context
>
();
auto
*
Xdata
=
input
(
0
).
template
data
<
T
,
Context
>
();
#else
#else
auto
*
Xdata
=
input
(
0
).
template
data
<
T
,
CPUContext
>
();
auto
*
Xdata
=
input
(
0
).
template
data
<
T
,
CPUContext
>
();
...
@@ -53,7 +53,7 @@ void MPIGatherGradientOp<Context>::RunWithType() {
...
@@ -53,7 +53,7 @@ void MPIGatherGradientOp<Context>::RunWithType() {
output
(
0
)
->
Share
(
input
(
this
->
comm_rank
+
1
));
output
(
0
)
->
Share
(
input
(
this
->
comm_rank
+
1
));
for
(
int
i
=
0
;
i
<
this
->
comm_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
this
->
comm_size
;
i
++
)
{
if
(
i
==
this
->
comm_root
)
continue
;
if
(
i
==
this
->
comm_root
)
continue
;
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
dYdata
=
input
(
this
->
comm_rank
+
1
).
template
data
<
T
,
Context
>
();
auto
*
dYdata
=
input
(
this
->
comm_rank
+
1
).
template
data
<
T
,
Context
>
();
#else
#else
auto
*
dYdata
=
input
(
this
->
comm_rank
+
1
).
template
data
<
T
,
CPUContext
>
();
auto
*
dYdata
=
input
(
this
->
comm_rank
+
1
).
template
data
<
T
,
CPUContext
>
();
...
@@ -62,7 +62,7 @@ void MPIGatherGradientOp<Context>::RunWithType() {
...
@@ -62,7 +62,7 @@ void MPIGatherGradientOp<Context>::RunWithType() {
}
}
}
}
else
{
else
{
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
dXdata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
#else
#else
auto
*
dXdata
=
output
(
0
)
->
template
mutable_data
<
T
,
CPUContext
>
();
auto
*
dXdata
=
output
(
0
)
->
template
mutable_data
<
T
,
CPUContext
>
();
...
...
Dragon/src/operators/update/async_update_op.cc
View file @
007d9c2
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
#ifdef WITH_MPI
#ifdef WITH_MPI
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
#include <cublas_v2.h>
#include <cublas_v2.h>
#endif // WITH_
CUDA_AWARE
#endif // WITH_
MPI_CUDA
namespace
dragon
{
namespace
dragon
{
...
@@ -62,7 +62,7 @@ AsyncUpdateOp<Context>::AsyncUpdateOp(const OperatorDef& op_def, Workspace* ws)
...
@@ -62,7 +62,7 @@ AsyncUpdateOp<Context>::AsyncUpdateOp(const OperatorDef& op_def, Workspace* ws)
}
}
// create independent stream for thread if using cuda-aware
// create independent stream for thread if using cuda-aware
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
cudaStreamCreate
(
&
stream
);
cudaStreamCreate
(
&
stream
);
cublasCreate_v2
(
&
handle
);
cublasCreate_v2
(
&
handle
);
cublasSetStream
(
handle
,
stream
);
cublasSetStream
(
handle
,
stream
);
...
@@ -78,7 +78,7 @@ void AsyncUpdateOp<Context>::RootRunWithType() {
...
@@ -78,7 +78,7 @@ void AsyncUpdateOp<Context>::RootRunWithType() {
if
(
mode
!=
"Async_No_Lock"
)
ws
()
->
LockTensor
(
output
(
i
)
->
name
());
if
(
mode
!=
"Async_No_Lock"
)
ws
()
->
LockTensor
(
output
(
i
)
->
name
());
int
delay
=
GetDelay
(
i
);
UpdateTimestamp
(
i
);
int
delay
=
GetDelay
(
i
);
UpdateTimestamp
(
i
);
math
::
Axpy
<
T
,
Context
>
(
input
(
i
).
count
(),
-
1.0
/
delay
,
dXdata
,
Xdata
);
math
::
Axpy
<
T
,
Context
>
(
input
(
i
).
count
(),
-
1.0
/
delay
,
dXdata
,
Xdata
);
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
cudaStreamSynchronize
(
cudaStreamDefault
);
cudaStreamSynchronize
(
cudaStreamDefault
);
#endif
#endif
if
(
mode
!=
"Async_No_Lock"
)
ws
()
->
UnlockTensor
(
output
(
i
)
->
name
());
if
(
mode
!=
"Async_No_Lock"
)
ws
()
->
UnlockTensor
(
output
(
i
)
->
name
());
...
@@ -108,14 +108,14 @@ void AsyncUpdateOp<Context>::ThreadRunWithType() {
...
@@ -108,14 +108,14 @@ void AsyncUpdateOp<Context>::ThreadRunWithType() {
Tensor
*
X
=
ws
()
->
GetTensor
(
tags
[
status
.
MPI_TAG
]);
Tensor
*
X
=
ws
()
->
GetTensor
(
tags
[
status
.
MPI_TAG
]);
if
(
X
->
count
()
==
0
)
continue
;
// wait for server
if
(
X
->
count
()
==
0
)
continue
;
// wait for server
recv_buffer
->
ReshapeLike
(
*
X
);
recv_buffer
->
ReshapeLike
(
*
X
);
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
Bdata
=
recv_buffer
->
template
mutable_data
<
T
,
Context
>
();
auto
*
Bdata
=
recv_buffer
->
template
mutable_data
<
T
,
Context
>
();
#else
#else
auto
*
Bdata
=
recv_buffer
->
template
mutable_data
<
T
,
CPUContext
>
();
auto
*
Bdata
=
recv_buffer
->
template
mutable_data
<
T
,
CPUContext
>
();
#endif
#endif
MPI_Recv
(
Bdata
,
X
->
count
(),
MPI_FLOAT
,
status
.
MPI_SOURCE
,
status
.
MPI_TAG
,
this
->
comm
,
MPI_STATUS_IGNORE
);
MPI_Recv
(
Bdata
,
X
->
count
(),
MPI_FLOAT
,
status
.
MPI_SOURCE
,
status
.
MPI_TAG
,
this
->
comm
,
MPI_STATUS_IGNORE
);
// update
// update
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
Xdata
=
X
->
template
mutable_data
<
T
,
Context
>
();
auto
*
Xdata
=
X
->
template
mutable_data
<
T
,
Context
>
();
if
(
mode
!=
"Async_No_Lock"
)
ws
()
->
LockTensor
(
output
(
status
.
MPI_TAG
)
->
name
());
if
(
mode
!=
"Async_No_Lock"
)
ws
()
->
LockTensor
(
output
(
status
.
MPI_TAG
)
->
name
());
int
delay
=
GetDelay
(
status
.
MPI_TAG
);
int
delay
=
GetDelay
(
status
.
MPI_TAG
);
...
...
Dragon/src/operators/update/update_op_base.cc
View file @
007d9c2
...
@@ -29,13 +29,38 @@ void UpdateOpBase<Context>::InitMPI() {
...
@@ -29,13 +29,38 @@ void UpdateOpBase<Context>::InitMPI() {
MPI_Comm_group
(
MPI_COMM_WORLD
,
&
world_group
);
MPI_Comm_group
(
MPI_COMM_WORLD
,
&
world_group
);
MPI_Group_translate_ranks
(
world_group
,
1
,
&
world_root
,
group
,
&
comm_root
);
MPI_Group_translate_ranks
(
world_group
,
1
,
&
world_root
,
group
,
&
comm_root
);
CHECK
(
comm_root
!=
MPI_UNDEFINED
)
<<
"MPI root is not included in layer group."
;
CHECK
(
comm_root
!=
MPI_UNDEFINED
)
<<
"MPI root is not included in layer group."
;
#endif
#endif // WITH_MPI
#ifdef WITH_MPI_NCCL
ncclUniqueId
id
;
if
(
comm_rank
==
comm_root
)
ncclGetUniqueId
(
&
id
);
MPI_Bcast
((
void
*
)
&
id
,
sizeof
(
id
),
MPI_BYTE
,
0
,
comm
);
ctx
().
SwitchToDevice
();
NCCL_CHECK
(
ncclCommInitRank
(
&
nccl_comm
,
comm_size
,
id
,
comm_rank
));
CUDA_CHECK
(
cudaStreamCreate
(
&
stream
));
#endif // WITH_MPI_NCCL
}
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
UpdateOpBase
<
Context
>::
ReduceRunWithType
()
{
void
UpdateOpBase
<
Context
>::
ReduceRunWithType
()
{
#ifdef WITH_MPI
if
(
TypeMeta
::
Id
<
Context
>
()
==
TypeMeta
::
Id
<
CUDAContext
>
())
{
#ifdef WITH_MPI_NCCL
TIndex
count
=
input
(
0
).
count
();
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
Context
>
();
NCCL_CHECK
(
ncclAllReduce
((
const
void
*
)
dXdata
,
(
void
*
)
dXdata
,
count
,
ncclFloat
,
ncclSum
,
nccl_comm
,
stream
));
CUDA_CHECK
(
cudaStreamSynchronize
(
stream
));
math
::
Scal
<
T
,
Context
>
(
count
,
T
(
1.0
/
comm_size
),
dXdata
);
return
;
#endif
}
#ifdef WITH_MPI // WITH_MPI
MPI_Request
recv_req
;
MPI_Request
recv_req
;
TIndex
count
=
input
(
0
).
count
();
TIndex
count
=
input
(
0
).
count
();
TIndex
segment_size
=
count
/
comm_size
;
TIndex
segment_size
=
count
/
comm_size
;
...
@@ -48,13 +73,13 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
...
@@ -48,13 +73,13 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
segment_ends
[
i
]
=
segment_sizes
[
i
]
+
segment_ends
[
i
-
1
];
segment_ends
[
i
]
=
segment_sizes
[
i
]
+
segment_ends
[
i
-
1
];
buffer
=
ws
()
->
GetBuffer
();
buffer
=
ws
()
->
GetBuffer
();
buffer
->
Reshape
(
vector
<
TIndex
>
(
1
,
segment_sizes
[
0
]));
buffer
->
Reshape
(
vector
<
TIndex
>
(
1
,
segment_sizes
[
0
]));
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
Bdata
=
buffer
->
mutable_data
<
T
,
Context
>
();
auto
*
Bdata
=
buffer
->
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
Context
>
();
#else
#else
auto
*
Bdata
=
buffer
->
mutable_data
<
T
,
CPUContext
>
();
auto
*
Bdata
=
buffer
->
mutable_data
<
T
,
CPUContext
>
();
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
CPUContext
>
();
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
CPUContext
>
();
#endif // WITH_
CUDA_AWARE
#endif // WITH_
MPI_CUDA
int
recv_from
=
(
comm_rank
-
1
+
comm_size
)
%
comm_size
;
int
recv_from
=
(
comm_rank
-
1
+
comm_size
)
%
comm_size
;
int
send_to
=
(
comm_rank
+
1
)
%
comm_size
;
int
send_to
=
(
comm_rank
+
1
)
%
comm_size
;
...
@@ -72,14 +97,14 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
...
@@ -72,14 +97,14 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
auto
*
segment_update
=
&
(
dXdata
[
segment_ends
[
recv_chunk
]
-
auto
*
segment_update
=
&
(
dXdata
[
segment_ends
[
recv_chunk
]
-
segment_sizes
[
recv_chunk
]]);
segment_sizes
[
recv_chunk
]]);
MPI_Wait
(
&
recv_req
,
MPI_STATUS_IGNORE
);
MPI_Wait
(
&
recv_req
,
MPI_STATUS_IGNORE
);
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
math
::
Axpy
<
T
,
Context
>
(
segment_sizes
[
recv_chunk
],
math
::
Axpy
<
T
,
Context
>
(
segment_sizes
[
recv_chunk
],
1.0
,
Bdata
,
segment_update
);
1.0
,
Bdata
,
segment_update
);
cudaStreamSynchronize
(
cudaStreamDefault
);
cudaStreamSynchronize
(
cudaStreamDefault
);
#else
#else
math
::
Axpy
<
T
,
CPUContext
>
(
segment_sizes
[
recv_chunk
],
math
::
Axpy
<
T
,
CPUContext
>
(
segment_sizes
[
recv_chunk
],
1.0
,
Bdata
,
segment_update
);
1.0
,
Bdata
,
segment_update
);
#endif // WITH_
CUDA_AWARE
#endif // WITH_
MPI_CUDA
}
}
ws
()
->
ReleaseBuffer
(
buffer
);
ws
()
->
ReleaseBuffer
(
buffer
);
...
@@ -99,11 +124,11 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
...
@@ -99,11 +124,11 @@ void UpdateOpBase<Context>::ReduceRunWithType() {
// ave-normalize
// ave-normalize
if
(
comm_size
>
1
)
{
if
(
comm_size
>
1
)
{
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
math
::
Scal
<
T
,
Context
>
(
count
,
T
(
1.0
/
comm_size
),
dXdata
);
math
::
Scal
<
T
,
Context
>
(
count
,
T
(
1.0
/
comm_size
),
dXdata
);
#else
#else
math
::
Scal
<
T
,
CPUContext
>
(
count
,
T
(
1.0
/
comm_size
),
dXdata
);
math
::
Scal
<
T
,
CPUContext
>
(
count
,
T
(
1.0
/
comm_size
),
dXdata
);
#endif // WITH_
CUDA_AWARE
#endif // WITH_
MPI_CUDA
}
}
#endif // WITH_MPI
#endif // WITH_MPI
}
}
...
@@ -159,11 +184,11 @@ void UpdateOpBase<Context>::UpdateRunWithType() {
...
@@ -159,11 +184,11 @@ void UpdateOpBase<Context>::UpdateRunWithType() {
}
}
CHECK
(
async_tag
!=
-
1
);
CHECK
(
async_tag
!=
-
1
);
}
}
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
Context
>
();
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
Context
>
();
#else
#else
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
CPUContext
>
();
auto
*
dXdata
=
input
(
0
).
template
mutable_data
<
T
,
CPUContext
>
();
#endif // WITH_
CUDA_AWARE
#endif // WITH_
MPI_CUDA
MPI_Send
(
dXdata
,
input
(
0
).
count
(),
MPI_FLOAT
,
this
->
comm_root
,
async_tag
,
this
->
comm
);
MPI_Send
(
dXdata
,
input
(
0
).
count
(),
MPI_FLOAT
,
this
->
comm_root
,
async_tag
,
this
->
comm
);
#endif // WITH_MPI
#endif // WITH_MPI
}
}
...
@@ -173,12 +198,12 @@ template <class Context> template <typename T>
...
@@ -173,12 +198,12 @@ template <class Context> template <typename T>
void
UpdateOpBase
<
Context
>::
RecvRunWithType
()
{
void
UpdateOpBase
<
Context
>::
RecvRunWithType
()
{
#ifdef WITH_MPI
#ifdef WITH_MPI
if
(
comm_rank
!=
comm_root
)
{
if
(
comm_rank
!=
comm_root
)
{
#ifdef WITH_
CUDA_AWARE
#ifdef WITH_
MPI_CUDA
auto
*
Xdata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
Xdata
=
output
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
#else
#else
auto
*
Xdata
=
output
(
0
)
->
template
mutable_data
<
T
,
CPUContext
>
();
auto
*
Xdata
=
output
(
0
)
->
template
mutable_data
<
T
,
CPUContext
>
();
#endif // WITH_
CUDA_AWARE
#endif // WITH_
MPI_CUDA
MPI_Recv
(
Xdata
,
output
(
0
)
->
count
(),
MPI_FLOAT
,
MPI_Recv
(
Xdata
,
output
(
0
)
->
count
(),
MPI_FLOAT
,
this
->
comm_root
,
async_tag
,
this
->
comm
,
MPI_STATUS_IGNORE
);
this
->
comm_root
,
async_tag
,
this
->
comm
,
MPI_STATUS_IGNORE
);
}
}
...
...
README.md
View file @
007d9c2
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
2.
CUDA
[
Optional
]
2.
CUDA
[
Optional
]
3.
CUDNN
[
Optional
]
3.
CUDNN
[
Optional
]
4.
OpenMPI
[
Optional
]
4.
OpenMPI
[
Optional
]
5.
NCCL
[
Optional
]
-----
-----
### Installation
### Installation
...
@@ -16,6 +17,8 @@
...
@@ -16,6 +17,8 @@
(Optional) Download and install [CUDNN](https://developer.nvidia.com/cudnn)
(Optional) Download and install [CUDNN](https://developer.nvidia.com/cudnn)
(Optional, Linux Only) Download and install [NCCL](https://developer.nvidia.com/nccl)
3.
(Optional) Download 3rdparty.zip and unzip to Dragon/3rdparty (Out of source code dir)
3.
(Optional) Download 3rdparty.zip and unzip to Dragon/3rdparty (Out of source code dir)
[*Win64-VS2013*](https://pan.baidu.com/s/1miGAZl2) (OpenBLAS / Protobuf2.6 for VS2013 / CUDNN v7 / Microsoft MPI)
[*Win64-VS2013*](https://pan.baidu.com/s/1miGAZl2) (OpenBLAS / Protobuf2.6 for VS2013 / CUDNN v7 / Microsoft MPI)
...
@@ -42,6 +45,7 @@
...
@@ -42,6 +45,7 @@
-
Set CUDA compiling architectures if necessary
-
Set CUDA compiling architectures if necessary
-
GCC version(4.8+, 5.0-) should add
``-std=c++11``
to
``CUDA_NVCC_FLAGS``
, if
``nullptr``
is not found
-
GCC version(4.8+, 5.0-) should add
``-std=c++11``
to
``CUDA_NVCC_FLAGS``
, if
``nullptr``
is not found
-
We pre-generated files under
``Dragon/src/protos``
with protobuf-2.6, run
``protoc``
by yourself if higher are required
-
We pre-generated files under
``Dragon/src/protos``
with protobuf-2.6, run
``protoc``
by yourself if higher are required
-
OpenMPI can take NCCL and our CUDA impl at the same time, prefer not to use NCCL(
*memory inefficient*
)
6.
Environment Variables
6.
Environment Variables
### Linux(Only for OpenMPI):
### Linux(Only for OpenMPI):
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment