Commit 8651e1b5 by Ting PAN

speedup io

1 parent ee893b1b
This directory holds (*after you download them*):
- msmpi.dll / mpiexec.exe / smpd.exe (for ``mpi``, Windows Only)
- cudnn64_*.dll (For ``cudnn``, Windows Only)
- libopenblas.dll / libquadmath-0.dll / libgfortran-3.dll / libgcc_s_seh-1.dll (For ``cblas``, Windows Only)
This directory holds (*after you download them*):
- mpi/*.h (for ``mpi``, Windows/Linux)
- google/protobuf/*.h (For ``google protobuf``, Windows Only)
- cudnn.h (For ``cudnn``, Windows Only)
- cblas.h and relevent header files (For ``cblas``, Windows/Linux)
- getopt.h and unistd.h (For ``platform-relevent`` header files, Windows Only)
This directory holds (*after you download them*):
- msmpi.lib/libmpi.so (for ``mpi``, Windows/Linux)
- libprotobuf.lib (For ``google protobuf``, Windows Only)
- cudnn.lib (For ``cudnn``, Windows Only)
- libopenblas.lib (For ``cblas``, Windows Only)
- python27.lib (For ``python27``, Windows Only)
...@@ -17,7 +17,7 @@ option(WITH_MPI_CUDA "Set to ON to use MPI_CUDA_AWARE" OFF) ...@@ -17,7 +17,7 @@ option(WITH_MPI_CUDA "Set to ON to use MPI_CUDA_AWARE" OFF)
option(WITH_CUDA_FP16 "Set to ON to use FP16" ON) option(WITH_CUDA_FP16 "Set to ON to use FP16" ON)
# set your 3rdparty # set your 3rdparty
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/3rdparty) set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
# set your py27 # set your py27
set(PYTHON_DIR /usr/include/python2.7) # prefer set(PYTHON_DIR /usr/include/python2.7) # prefer
......
...@@ -19,7 +19,8 @@ class DataReader(Process): ...@@ -19,7 +19,8 @@ class DataReader(Process):
self._source = GetProperty(kwargs, 'source', '') self._source = GetProperty(kwargs, 'source', '')
self._use_shuffle = GetProperty(kwargs, 'shuffle', False) self._use_shuffle = GetProperty(kwargs, 'shuffle', False)
self._use_step = GetProperty(kwargs, 'node_step', False) self._use_step = GetProperty(kwargs, 'node_step', False)
self._chunk_size = GetProperty(kwargs, 'chunk_size', 4) # >=4MB self._num_chunks = GetProperty(kwargs, 'num_chunks', 2048)
self._chunk_size = GetProperty(kwargs, 'chunk_size', -1)
self._num_parts = 1 self._num_parts = 1
self._part_idx = 0 self._part_idx = 0
...@@ -91,6 +92,12 @@ class DataReader(Process): ...@@ -91,6 +92,12 @@ class DataReader(Process):
self._db_size = int(self._db.get('size')) self._db_size = int(self._db.get('size'))
self._db_zfill = int(self._db.get('zfill')) self._db_zfill = int(self._db.get('zfill'))
self._epoch_size = self._db_size / self._num_parts + 1 self._epoch_size = self._db_size / self._num_parts + 1
# search a optimal chunk size by chunks
if self._chunk_size == -1:
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20)))
min_chunk_size = 1
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2
self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 / self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20))) (self._num_parts * self._chunk_size << 20)))
self._chunk_size = self._db_size / self._num_shuffle_parts / self._num_parts + 1 self._chunk_size = self._db_size / self._num_shuffle_parts / self._num_parts + 1
......
...@@ -20,17 +20,26 @@ class DataBatch(object): ...@@ -20,17 +20,26 @@ class DataBatch(object):
"""DataBatch use Triple-Buffering to speed up""" """DataBatch use Triple-Buffering to speed up"""
# init mpi
global_rank = 0; local_rank = 0; group_size = 1
if mpi.is_init():
idx, group = mpi.allow_parallel()
if idx != -1: # data parallel
global_rank = mpi.rank()
group_size = len(group)
for i, node in enumerate(group):
if global_rank == node: local_rank = i
kwargs['group_size'] = group_size
# configuration # configuration
self._prefetch = GetProperty(kwargs, 'prefetch', 10) self._prefetch = GetProperty(kwargs, 'prefetch', 40)
self._num_readers = GetProperty(kwargs, 'num_readers', 1) self._num_readers = GetProperty(kwargs, 'num_readers', 1)
self._num_transformers = GetProperty(kwargs, 'num_transformers', -1) self._num_transformers = GetProperty(kwargs, 'num_transformers', -1)
self._num_fetchers = GetProperty(kwargs, 'num_fetchers', 3)
# default policy # default policy
if self._num_transformers == -1: if self._num_transformers == -1:
self._num_transformers = 1 self._num_transformers = 1
# add 1 transformer for random crop
if GetProperty(kwargs, 'crop_size', 0) > 0:
self._num_transformers += 1
# add 1 transformer for color augmentation # add 1 transformer for color augmentation
if GetProperty(kwargs, 'color_augmentation', False): if GetProperty(kwargs, 'color_augmentation', False):
self._num_transformers += 1 self._num_transformers += 1
...@@ -40,23 +49,15 @@ class DataBatch(object): ...@@ -40,23 +49,15 @@ class DataBatch(object):
self._num_transformers +=1 self._num_transformers +=1
self._batch_size = GetProperty(kwargs, 'batch_size', 100) self._batch_size = GetProperty(kwargs, 'batch_size', 100)
self._partition = GetProperty(kwargs, 'partition', False)
if self._partition:
self._batch_size = int(self._batch_size / kwargs['group_size'])
# init queues # init queues
self.Q_level_1 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q_level_1 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_2 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q_level_2 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_3 = Queue(self._prefetch * self._num_readers) self.Q_level_3 = Queue(self._prefetch * self._num_readers)
# init mpi
global_rank = 0; local_rank = 0; group_size = 1
if mpi.is_init():
idx, group = mpi.allow_parallel()
if idx != -1: # data parallel
global_rank = mpi.rank()
group_size = len(group)
for i, node in enumerate(group):
if global_rank == node: local_rank = i
kwargs['group_size'] = group_size
# init readers # init readers
self._readers = [] self._readers = []
for i in xrange(self._num_readers): for i in xrange(self._num_readers):
...@@ -88,11 +89,15 @@ class DataBatch(object): ...@@ -88,11 +89,15 @@ class DataBatch(object):
self._transformers.append(transformer) self._transformers.append(transformer)
time.sleep(0.1) time.sleep(0.1)
# init blob fetcher # init blob fetchers
self._fetcher = BlobFetcher(**kwargs) self._fetchers = []
self._fetcher.Q_in = self.Q_level_2 for i in xrange(self._num_fetchers):
self._fetcher.Q_out = self.Q_level_3 fetcher = BlobFetcher(**kwargs)
self._fetcher.start() fetcher.Q_in = self.Q_level_2
fetcher.Q_out = self.Q_level_3
fetcher.start()
self._fetchers.append(fetcher)
time.sleep(0.1)
#self.echo() #self.echo()
...@@ -103,8 +108,9 @@ class DataBatch(object): ...@@ -103,8 +108,9 @@ class DataBatch(object):
def echo(self): def echo(self):
print '---------------------------------------------------------' print '---------------------------------------------------------'
print 'BatchReader, Using config:' print 'BatchReader, Using config:'
params = {'num_readers': self._num_readers, params = {'prefetching': self._prefetch,
'num_readers': self._num_readers,
'num_transformers': self._num_transformers, 'num_transformers': self._num_transformers,
'num_prefetching': self._prefetch} 'num_fetchers': self._num_fetchers}
pprint.pprint(params) pprint.pprint(params)
print '---------------------------------------------------------' print '---------------------------------------------------------'
...@@ -676,7 +676,7 @@ message DataParameter { ...@@ -676,7 +676,7 @@ message DataParameter {
optional bool force_encoded_color = 9 [default = false]; optional bool force_encoded_color = 9 [default = false];
// Prefetch queue (Number of batches to prefetch to host memory, increase if // Prefetch queue (Number of batches to prefetch to host memory, increase if
// data access bandwidth varies). // data access bandwidth varies).
optional uint32 prefetch = 10 [default = 10]; optional uint32 prefetch = 10 [default = 40];
} }
message DropoutParameter { message DropoutParameter {
......
...@@ -65,9 +65,13 @@ ...@@ -65,9 +65,13 @@
- Run 3rdparty/setup_mpi.sh - Run 3rdparty/setup_mpi.sh
```Shell ```Shell
sudo ./setup_mpi.sh ./setup_mpi.sh
``` ```
- Install
```Shell
sudo cp openmpi/install/bin/mpirun /usr/bin
```
#### Windows: #### Windows:
- We use Microsoft MPI which can perfectly run at lastest Windows10 - We use Microsoft MPI which can perfectly run at lastest Windows10
- Microsoft MPI is intergrated into 3rdparty and you should do nothing - Microsoft MPI is intergrated into 3rdparty and you should do nothing
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!