Commit b1c4e901 by Ting PAN

Simplify I/O settings

1 parent bf361560
...@@ -40,25 +40,19 @@ class DataBatch(object): ...@@ -40,25 +40,19 @@ class DataBatch(object):
---------- ----------
source : str source : str
The path of database. The path of database.
multiple_nodes: boolean shuffle : bool, optional, default=False
Whether to split data for multiple parallel nodes. Default is ``False``. Whether to shuffle the data.
shuffle : boolean num_chunks : int, optional, default=2048
Whether to shuffle the data. Default is ``False``. The number of chunks to split.
num_chunks : int batch_size : int, optional, default=128
The number of chunks to split. Default is ``2048``. The size of a mini-batch.
chunk_size : int prefetch : int, optional, default=5
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``). The prefetch count.
batch_size : int
The size of a training batch.
partition : boolean
Whether to partition batch. Default is ``False``.
prefetch : int
The prefetch count. Default is ``5``.
""" """
super(DataBatch, self).__init__() super(DataBatch, self).__init__()
# Init mpi # Init mpi
global_rank = 0; local_rank = 0; group_size = 1 global_rank, local_rank, group_size = 0, 0, 1
if mpi.Is_Init(): if mpi.Is_Init():
idx, group = mpi.AllowParallel() idx, group = mpi.AllowParallel()
if idx != -1: # DataParallel if idx != -1: # DataParallel
...@@ -70,6 +64,7 @@ class DataBatch(object): ...@@ -70,6 +64,7 @@ class DataBatch(object):
# Configuration # Configuration
self._prefetch = kwargs.get('prefetch', 5) self._prefetch = kwargs.get('prefetch', 5)
self._batch_size = kwargs.get('batch_size', 2)
self._num_readers = kwargs.get('num_readers', 1) self._num_readers = kwargs.get('num_readers', 1)
self._num_transformers = kwargs.get('num_transformers', -1) self._num_transformers = kwargs.get('num_transformers', -1)
self._max_transformers = kwargs.get('max_transformers', 3) self._max_transformers = kwargs.get('max_transformers', 3)
...@@ -81,37 +76,28 @@ class DataBatch(object): ...@@ -81,37 +76,28 @@ class DataBatch(object):
# Add 1 transformer for color augmentation # Add 1 transformer for color augmentation
if cfg.TRAIN.COLOR_JITTERING: if cfg.TRAIN.COLOR_JITTERING:
self._num_transformers += 1 self._num_transformers += 1
self._num_transformers = min(self._num_transformers, self._max_transformers) self._num_transformers = min(
self._num_transformers, self._max_transformers)
self._batch_size = kwargs.get('batch_size', 100)
self._partition = kwargs.get('partition', False)
if self._partition:
self._batch_size = int(self._batch_size / kwargs['group_size'])
# Init queues # Init queues
self.Q_level_1 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q1 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q1_level_2 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q21 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q2_level_2 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q22 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_3 = Queue(self._prefetch * self._num_readers) self.Q3 = Queue(self._prefetch * self._num_readers)
# Init readers # Init readers
self._readers = [] self._readers = []
for i in range(self._num_readers): for i in range(self._num_readers):
self._readers.append(DataReader(**kwargs)) self._readers.append(DataReader(**kwargs))
self._readers[-1].Q_out = self.Q_level_1 self._readers[-1].Q_out = self.Q1
for i in range(self._num_readers): for i in range(self._num_readers):
num_parts = self._num_readers part_idx, num_parts = i, self._num_readers
part_idx = i
if self._readers[i]._multiple_nodes or \
self._readers[i]._use_shuffle:
num_parts *= group_size num_parts *= group_size
part_idx += local_rank * self._num_readers part_idx += local_rank * self._num_readers
self._readers[i]._num_parts = num_parts self._readers[i]._num_parts = num_parts
self._readers[i]._part_idx = part_idx self._readers[i]._part_idx = part_idx
self._readers[i]._random_seed += part_idx self._readers[i]._rng_seed += part_idx
self._readers[i].start() self._readers[i].start()
time.sleep(0.1) time.sleep(0.1)
...@@ -119,10 +105,10 @@ class DataBatch(object): ...@@ -119,10 +105,10 @@ class DataBatch(object):
self._transformers = [] self._transformers = []
for i in range(self._num_transformers): for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs) transformer = DataTransformer(**kwargs)
transformer._random_seed += (i + local_rank * self._num_transformers) transformer._rng_seed += (i + local_rank * self._num_transformers)
transformer.Q_in = self.Q_level_1 transformer.Q_in = self.Q1
transformer.Q1_out = self.Q1_level_2 transformer.Q1_out = self.Q21
transformer.Q2_out = self.Q2_level_2 transformer.Q2_out = self.Q22
transformer.start() transformer.start()
self._transformers.append(transformer) self._transformers.append(transformer)
time.sleep(0.1) time.sleep(0.1)
...@@ -131,9 +117,9 @@ class DataBatch(object): ...@@ -131,9 +117,9 @@ class DataBatch(object):
self._fetchers = [] self._fetchers = []
for i in range(self._num_fetchers): for i in range(self._num_fetchers):
fetcher = BlobFetcher(**kwargs) fetcher = BlobFetcher(**kwargs)
fetcher.Q1_in = self.Q1_level_2 fetcher.Q1_in = self.Q21
fetcher.Q2_in = self.Q2_level_2 fetcher.Q2_in = self.Q22
fetcher.Q_out = self.Q_level_3 fetcher.Q_out = self.Q3
fetcher.start() fetcher.start()
self._fetchers.append(fetcher) self._fetchers.append(fetcher)
time.sleep(0.1) time.sleep(0.1)
...@@ -163,7 +149,7 @@ class DataBatch(object): ...@@ -163,7 +149,7 @@ class DataBatch(object):
The batch dict. The batch dict.
""" """
return self.Q_level_3.get() return self.Q3.get()
def echo(self): def echo(self):
"""Print I/O Information. """Print I/O Information.
......
...@@ -14,15 +14,14 @@ from __future__ import division ...@@ -14,15 +14,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import numpy as np import numpy
import numpy.random as npr import multiprocessing
from multiprocessing import Process
import dragon.config as config from dragon import config as _cfg
from dragon.tools.db import LMDB from dragon.tools import db as _db
class DataReader(Process): class DataReader(multiprocessing.Process):
"""DataReader is deployed to queue encoded str from `LMDB`_. """DataReader is deployed to queue encoded str from `LMDB`_.
It is supported to adaptively partition and shuffle records over all distributed nodes. It is supported to adaptively partition and shuffle records over all distributed nodes.
...@@ -35,28 +34,19 @@ class DataReader(Process): ...@@ -35,28 +34,19 @@ class DataReader(Process):
---------- ----------
source : str source : str
The path of database. The path of database.
multiple_nodes: boolean shuffle : bool, optional, default=False
Whether to split data for multiple parallel nodes. Default is ``False``. Whether to shuffle the data.
shuffle : boolean num_chunks : int, optional, default=2048
Whether to shuffle the data. Default is ``False``. The number of chunks to split.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
""" """
super(DataReader, self).__init__() super(DataReader, self).__init__()
self._source = kwargs.get('source', '') self._source = kwargs.get('source', '')
self._multiple_nodes = kwargs.get('multiple_nodes', False)
self._use_shuffle = kwargs.get('shuffle', False) self._use_shuffle = kwargs.get('shuffle', False)
self._use_instance_chunk = kwargs.get('instance_chunk', False)
self._num_chunks = kwargs.get('num_chunks', 2048) self._num_chunks = kwargs.get('num_chunks', 2048)
self._chunk_size = kwargs.get('chunk_size', -1)
self._part_idx, self._num_parts = 0, 1 self._part_idx, self._num_parts = 0, 1
self._cur_idx, self._cur_chunk_idx = 0, 0 self._cursor, self._chunk_cursor = 0, 0
self._random_seed = config.GetRandomSeed() self._rng_seed = _cfg.GetRandomSeed()
self.Q_out = None self.Q_out = None
self.daemon = True self.daemon = True
...@@ -71,13 +61,13 @@ class DataReader(Process): ...@@ -71,13 +61,13 @@ class DataReader(Process):
""" """
return self._db.value() return self._db.value()
def redirect(self, target_idx): def redirect(self, target):
"""Redirect to the target position. """Redirect to the target position.
Parameters Parameters
---------- ----------
target_idx : int target : int
The key of instance in ``LMDB``. The key of the record.
Returns Returns
------- -------
...@@ -85,17 +75,17 @@ class DataReader(Process): ...@@ -85,17 +75,17 @@ class DataReader(Process):
Notes Notes
----- -----
The redirection reopens the ``LMDB``. The redirection reopens the database.
You can drop caches by ``echo 3 > /proc/sys/vm/drop_caches``. You can drop caches by ``echo 3 > /proc/sys/vm/drop_caches``.
This will disturb getting stuck when ``Database Size`` >> ``RAM Size``. This will disturb getting stuck when *Database Size* >> *RAM Size*.
""" """
self._db.close() self._db.close()
self._db.open(self._source) self._db.open(self._source)
self._cur_idx = target_idx self._cursor = target
self._db.set(str(self._cur_idx).zfill(self._zfill)) self._db.set(str(target).zfill(self._zfill))
def reset(self): def reset(self):
"""Reset the cursor and environment. """Reset the cursor and environment.
...@@ -105,19 +95,18 @@ class DataReader(Process): ...@@ -105,19 +95,18 @@ class DataReader(Process):
None None
""" """
if self._multiple_nodes or self._use_shuffle: if self._num_parts > 1 or self._use_shuffle:
if self._use_shuffle: self._perm = npr.permutation(self._num_shuffle_parts) self._chunk_cursor = 0
self._cur_chunk_idx = 0 self._part_idx = (self._part_idx + 1) % self._num_parts
self._start_idx = int(self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx]) if self._use_shuffle: self._perm = numpy.random.permutation(self._perm_size)
self._start_idx = int(self._start_idx * self._chunk_size) self._head = self._part_idx * self._perm_size + self._perm[self._chunk_cursor]
if self._start_idx >= self._num_entries: self.next_chunk() self._tail = self._head * self._chunk_size
self._end_idx = self._start_idx + self._chunk_size if self._head >= self._num_entries: self.next_chunk()
self._end_idx = min(self._num_entries, self._end_idx) self._tail = self._head + self._chunk_size
self._tail = min(self._num_entries, self._tail)
else: else:
self._start_idx = 0 self._head, self._tail = 0, self._num_entries
self._end_idx = self._num_entries self.redirect(self._head)
self.redirect(self._start_idx)
def next_record(self): def next_record(self):
"""Step the cursor of records. """Step the cursor of records.
...@@ -127,8 +116,8 @@ class DataReader(Process): ...@@ -127,8 +116,8 @@ class DataReader(Process):
None None
""" """
self._cur_idx += 1
self._db.next() self._db.next()
self._cursor += 1
def next_chunk(self): def next_chunk(self):
"""Step the cursor of shuffling chunks. """Step the cursor of shuffling chunks.
...@@ -138,16 +127,17 @@ class DataReader(Process): ...@@ -138,16 +127,17 @@ class DataReader(Process):
None None
""" """
self._cur_chunk_idx += 1 self._chunk_cursor += 1
if self._cur_chunk_idx >= self._num_shuffle_parts: self.reset() if self._chunk_cursor >= self._perm_size: self.reset()
else: else:
self._start_idx = self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx] self._head = self._part_idx * self._perm_size + self._perm[self._chunk_cursor]
self._start_idx = self._start_idx * self._chunk_size self._head = self._head * self._chunk_size
if self._start_idx >= self._num_entries: self.next_chunk() if self._head >= self._num_entries:
self.next_chunk()
else: else:
self._end_idx = self._start_idx + self._chunk_size self._tail = self._head + self._chunk_size
self._end_idx = min(self._num_entries, self._end_idx) self._tail = min(self._num_entries, self._tail)
self.redirect(self._start_idx) self.redirect(self._head)
def run(self): def run(self):
"""Start the process. """Start the process.
...@@ -157,44 +147,42 @@ class DataReader(Process): ...@@ -157,44 +147,42 @@ class DataReader(Process):
None None
""" """
# fix seed # Fix seed
npr.seed(self._random_seed) numpy.random.seed(self._rng_seed)
# init db # Init db
self._db = LMDB() self._db = _db.LMDB()
self._db.open(self._source) self._db.open(self._source)
self._zfill = self._db.zfill() self._zfill = self._db.zfill()
self._num_entries = self._db.num_entries() self._num_entries = self._db.num_entries()
self._epoch_size = int(self._num_entries/ self._num_parts + 1)
epoch_size = self._num_entries // self._num_parts + 1
if self._use_shuffle: if self._use_shuffle:
if self._chunk_size == 1: if self._num_chunks <= 0:
# Each chunk has at most 1 record [For Fully Shuffle] # Each chunk has at most 1 record (Record-Wise)
self._chunk_size, self._num_shuffle_parts = \ self._chunk_size, self._perm_size = 1, epoch_size
1, int(self._num_entries / self._num_parts) + 1
else: else:
if self._use_shuffle and self._chunk_size == -1: # Search a optimal chunk size (Chunk-Wise)
# Search a optimal chunk size by chunks [For Chunk Shuffle] min_size, max_size = \
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20))) 1, self._db._total_size * 1.0 \
min_chunk_size = 1 / (self._num_chunks * (1 << 20))
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2 while min_size * 2 < max_size: min_size *= 2
self._chunk_size = min_chunk_size self._perm_size = int(math.ceil(
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 / self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20))) (self._num_parts * min_size << 20)))
self._chunk_size = int(self._num_entries / self._num_shuffle_parts / self._num_parts + 1) self._chunk_size = int(
limit = (self._num_parts - 0.5) * self._num_shuffle_parts * self._chunk_size self._num_entries * 1.0 /
(self._perm_size * self._num_parts) + 1)
limit = (self._num_parts - 0.5) * self._perm_size * self._chunk_size
if self._num_entries <= limit: if self._num_entries <= limit:
# Roll back to fully shuffle # Roll back to Record-Wise shuffle
self._chunk_size, self._num_shuffle_parts = \ self._chunk_size, self._perm_size = 1, epoch_size
1, int(self._num_entries / self._num_parts) + 1
else: else:
# Each chunk has at most K records [For Multiple Nodes] # One chunk has at most K records
# Note that if ``shuffle`` and ``multiple_nodes`` are all ``False``, self._chunk_size, self._perm_size = epoch_size, 1
# ``chunk_size`` and ``num_shuffle_parts`` are meaningless
self._chunk_size = int(self._num_entries / self._num_parts) + 1
self._num_shuffle_parts = 1
self._perm = np.arange(self._num_shuffle_parts) self._perm = numpy.arange(self._perm_size)
# Init env # Init env
self.reset() self.reset()
...@@ -203,7 +191,7 @@ class DataReader(Process): ...@@ -203,7 +191,7 @@ class DataReader(Process):
while True: while True:
self.Q_out.put(self.element()) self.Q_out.put(self.element())
self.next_record() self.next_record()
if self._cur_idx >= self._end_idx: if self._cursor >= self._tail:
if self._multiple_nodes or \ if self._num_parts > 1 or self._use_shuffle:
self._use_shuffle: self.next_chunk() self.next_chunk()
else: self.reset() else: self.reset()
\ No newline at end of file
...@@ -35,7 +35,7 @@ import lib.utils.logger as logger ...@@ -35,7 +35,7 @@ import lib.utils.logger as logger
class DataTransformer(Process): class DataTransformer(Process):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DataTransformer, self).__init__() super(DataTransformer, self).__init__()
self._random_seed = cfg.RNG_SEED self._rng_seed = cfg.RNG_SEED
self._use_flipped = cfg.TRAIN.USE_FLIPPED self._use_flipped = cfg.TRAIN.USE_FLIPPED
self._use_diff = cfg.TRAIN.USE_DIFF self._use_diff = cfg.TRAIN.USE_DIFF
self._classes = kwargs.get('classes', ('__background__',)) self._classes = kwargs.get('classes', ('__background__',))
...@@ -164,7 +164,7 @@ class DataTransformer(Process): ...@@ -164,7 +164,7 @@ class DataTransformer(Process):
return im, im_scale, gt_boxes return im, im_scale, gt_boxes
def run(self): def run(self):
npr.seed(self._random_seed) npr.seed(self._rng_seed)
while True: while True:
serialized = self.Q_in.get() serialized = self.Q_in.get()
data = self.get(serialized) data = self.get(serialized)
......
...@@ -28,8 +28,7 @@ class DataLayer(torch.nn.Module): ...@@ -28,8 +28,7 @@ class DataLayer(torch.nn.Module):
'source': database.source, 'source': database.source,
'classes': database.classes, 'classes': database.classes,
'shuffle': cfg.TRAIN.USE_SHUFFLE, 'shuffle': cfg.TRAIN.USE_SHUFFLE,
'multiple_nodes': True, 'num_chunks': 0, # Record-Wise Shuffle
'chunk_size': 1, # Valid if using shuffle
'batch_size': cfg.TRAIN.IMS_PER_BATCH * 2, 'batch_size': cfg.TRAIN.IMS_PER_BATCH * 2,
}) })
......
...@@ -39,25 +39,19 @@ class DataBatch(object): ...@@ -39,25 +39,19 @@ class DataBatch(object):
---------- ----------
source : str source : str
The path of database. The path of database.
multiple_nodes: boolean shuffle : bool, optional, default=False
Whether to split data for multiple parallel nodes. Default is ``False``. Whether to shuffle the data.
shuffle : boolean num_chunks : int, optional, default=2048
Whether to shuffle the data. Default is ``False``. The number of chunks to split.
num_chunks : int batch_size : int, optional, default=128
The number of chunks to split. Default is ``2048``. The size of a mini-batch.
chunk_size : int prefetch : int, optional, default=5
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``). The prefetch count.
batch_size : int
The size of a training batch.
partition : boolean
Whether to partition batch. Default is ``False``.
prefetch : int
The prefetch count. Default is ``5``.
""" """
super(DataBatch, self).__init__() super(DataBatch, self).__init__()
# Init mpi # Init mpi
global_rank = 0; local_rank = 0; group_size = 1 global_rank, local_rank, group_size = 0, 0, 1
if mpi.Is_Init(): if mpi.Is_Init():
idx, group = mpi.AllowParallel() idx, group = mpi.AllowParallel()
if idx != -1: # DataParallel if idx != -1: # DataParallel
...@@ -69,6 +63,7 @@ class DataBatch(object): ...@@ -69,6 +63,7 @@ class DataBatch(object):
# Configuration # Configuration
self._prefetch = kwargs.get('prefetch', 5) self._prefetch = kwargs.get('prefetch', 5)
self._batch_size = kwargs.get('batch_size', 32)
self._num_readers = kwargs.get( 'num_readers', 1) self._num_readers = kwargs.get( 'num_readers', 1)
self._num_transformers = kwargs.get('num_transformers', -1) self._num_transformers = kwargs.get('num_transformers', -1)
self._max_transformers = kwargs.get('max_transformers', 3) self._max_transformers = kwargs.get('max_transformers', 3)
...@@ -77,43 +72,27 @@ class DataBatch(object): ...@@ -77,43 +72,27 @@ class DataBatch(object):
# Io-Aware Policy # Io-Aware Policy
if self._num_transformers == -1: if self._num_transformers == -1:
self._num_transformers = 3 self._num_transformers = 3
# Add 1 transformer for color augmentation self._num_transformers = min(
if kwargs.get('color_augmentation', False): self._num_transformers, self._max_transformers)
self._num_transformers += 1
# Add 1 transformer for random scale
if kwargs.get('max_random_scale', 1.0) - \
kwargs.get('min_random_scale', 1.0) != 0:
self._num_transformers += 1
self._num_transformers = min(self._num_transformers, self._max_transformers)
self._batch_size = kwargs.get('batch_size', 100)
self._partition = kwargs.get('partition', False)
if self._partition:
self._batch_size = int(self._batch_size / kwargs['group_size'])
# Init queues # Init queues
self.Q_level_1 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q1 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_2 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q2 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_3 = Queue(self._prefetch * self._num_readers) self.Q3 = Queue(self._prefetch * self._num_readers)
# Init readers # Init readers
self._readers = [] self._readers = []
for i in range(self._num_readers): for i in range(self._num_readers):
self._readers.append(DataReader(**kwargs)) self._readers.append(DataReader(**kwargs))
self._readers[-1].Q_out = self.Q_level_1 self._readers[-1].Q_out = self.Q1
for i in range(self._num_readers): for i in range(self._num_readers):
num_parts = self._num_readers part_idx, num_parts = i, self._num_readers
part_idx = i
if self._readers[i]._multiple_nodes or \
self._readers[i]._use_shuffle:
num_parts *= group_size num_parts *= group_size
part_idx += local_rank * self._num_readers part_idx += local_rank * self._num_readers
self._readers[i]._num_parts = num_parts self._readers[i]._num_parts = num_parts
self._readers[i]._part_idx = part_idx self._readers[i]._part_idx = part_idx
self._readers[i]._random_seed += part_idx self._readers[i]._rng_seed += part_idx
self._readers[i].start() self._readers[i].start()
time.sleep(0.1) time.sleep(0.1)
...@@ -121,9 +100,9 @@ class DataBatch(object): ...@@ -121,9 +100,9 @@ class DataBatch(object):
self._transformers = [] self._transformers = []
for i in range(self._num_transformers): for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs) transformer = DataTransformer(**kwargs)
transformer._random_seed += (i + local_rank * self._num_transformers) transformer._rng_seed += (i + local_rank * self._num_transformers)
transformer.Q_in = self.Q_level_1 transformer.Q_in = self.Q1
transformer.Q_out = self.Q_level_2 transformer.Q_out = self.Q2
transformer.start() transformer.start()
self._transformers.append(transformer) self._transformers.append(transformer)
time.sleep(0.1) time.sleep(0.1)
...@@ -132,8 +111,8 @@ class DataBatch(object): ...@@ -132,8 +111,8 @@ class DataBatch(object):
self._fetchers = [] self._fetchers = []
for i in range(self._num_fetchers): for i in range(self._num_fetchers):
fetcher = BlobFetcher(**kwargs) fetcher = BlobFetcher(**kwargs)
fetcher.Q_in = self.Q_level_2 fetcher.Q_in = self.Q2
fetcher.Q_out = self.Q_level_3 fetcher.Q_out = self.Q3
fetcher.start() fetcher.start()
self._fetchers.append(fetcher) self._fetchers.append(fetcher)
time.sleep(0.1) time.sleep(0.1)
...@@ -163,7 +142,7 @@ class DataBatch(object): ...@@ -163,7 +142,7 @@ class DataBatch(object):
The batch dict. The batch dict.
""" """
return self.Q_level_3.get() return self.Q3.get()
def echo(self): def echo(self):
"""Print I/O Information. """Print I/O Information.
......
...@@ -31,7 +31,7 @@ class DataTransformer(Process): ...@@ -31,7 +31,7 @@ class DataTransformer(Process):
self._expander = Expander() self._expander = Expander()
self._sampler = Sampler(cfg.SSD.SAMPLERS) self._sampler = Sampler(cfg.SSD.SAMPLERS)
self._resizer = Resizer() self._resizer = Resizer()
self._random_seed = cfg.RNG_SEED self._rng_seed = cfg.RNG_SEED
self._mirror = cfg.TRAIN.USE_FLIPPED self._mirror = cfg.TRAIN.USE_FLIPPED
self._use_diff = cfg.TRAIN.USE_DIFF self._use_diff = cfg.TRAIN.USE_DIFF
self._classes = kwargs.get('classes', ('__background__',)) self._classes = kwargs.get('classes', ('__background__',))
...@@ -112,7 +112,7 @@ class DataTransformer(Process): ...@@ -112,7 +112,7 @@ class DataTransformer(Process):
return im, gt_boxes return im, gt_boxes
def run(self): def run(self):
npr.seed(self._random_seed) npr.seed(self._rng_seed)
while True: while True:
serialized = self.Q_in.get() serialized = self.Q_in.get()
im, gt_boxes = self.get(serialized) im, gt_boxes = self.get(serialized)
......
...@@ -28,7 +28,7 @@ class DataLayer(torch.nn.Module): ...@@ -28,7 +28,7 @@ class DataLayer(torch.nn.Module):
'source': database.source, 'source': database.source,
'classes': database.classes, 'classes': database.classes,
'shuffle': cfg.TRAIN.USE_SHUFFLE, 'shuffle': cfg.TRAIN.USE_SHUFFLE,
'multiple_nodes': True, 'num_chunks': 2048, # Chunk-Wise Shuffle
'batch_size': cfg.TRAIN.IMS_PER_BATCH * 2, 'batch_size': cfg.TRAIN.IMS_PER_BATCH * 2,
}) })
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!