Commit 80267d8f by Ting PAN

Use sequential sampling as the default shuffle policy

Summary:
This commit reimplements the default shuffle policy of data reader with
sequential sampling (be consistent with DALI) instead of chunk permutation (MXNet solution).
Sequential sampling is tuned by argument ``initial_fill`` only, and works both for HDD and SSD.
1 parent cca00c0d
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""BBox ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Builtin ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Decoder ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Generic ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Image ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -300,8 +301,8 @@ class RandomBBoxCrop(object): ...@@ -300,8 +301,8 @@ class RandomBBoxCrop(object):
aspect_ratio=(0.5, 2.0), aspect_ratio=(0.5, 2.0),
thresholds=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9), thresholds=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9),
allow_no_crop=True, allow_no_crop=True,
ltrb=True,
num_attempts=10, num_attempts=10,
bbox_layout=None,
**kwargs **kwargs
): ):
"""Create a ``RandomBBoxCrop`` operator. """Create a ``RandomBBoxCrop`` operator.
...@@ -316,10 +317,10 @@ class RandomBBoxCrop(object): ...@@ -316,10 +317,10 @@ class RandomBBoxCrop(object):
The minimum IoU(s) to satisfy. The minimum IoU(s) to satisfy.
allow_no_crop : bool, optional, default=True allow_no_crop : bool, optional, default=True
**True** to include the no-cropping as a option. **True** to include the no-cropping as a option.
ltrb : bool, optional, default=True
Indicate the bbox is ``ltrb`` or ``xywh`` format.
num_attempts : int, optional, default=10 num_attempts : int, optional, default=10
The max number of sampling trails. The max number of sampling trails.
bbox_layout : str, optional
The optional bbox layout.
Returns Returns
------- -------
...@@ -332,8 +333,8 @@ class RandomBBoxCrop(object): ...@@ -332,8 +333,8 @@ class RandomBBoxCrop(object):
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
thresholds=thresholds, thresholds=thresholds,
allow_no_crop=allow_no_crop, allow_no_crop=allow_no_crop,
ltrb=ltrb,
num_attempts=num_attempts, num_attempts=num_attempts,
bbox_layout=bbox_layout,
device='cpu', device='cpu',
**kwargs **kwargs
) )
...@@ -359,7 +360,9 @@ class RandomResizedCrop(object): ...@@ -359,7 +360,9 @@ class RandomResizedCrop(object):
def __new__( def __new__(
cls, cls,
size, size,
interp_type='LINEAR', interp_type=None,
mag_filter=None,
min_filter=None,
random_area=(0.08, 1.), random_area=(0.08, 1.),
random_aspect_ratio=(0.75, 1.33), random_aspect_ratio=(0.75, 1.33),
num_attempts=10, num_attempts=10,
...@@ -371,8 +374,12 @@ class RandomResizedCrop(object): ...@@ -371,8 +374,12 @@ class RandomResizedCrop(object):
---------- ----------
size : Union[int, Sequence[int]] size : Union[int, Sequence[int]]
The output image size. The output image size.
interp_type : {'NN', 'LINEAR', 'TRIANGULAR', 'CUBIC', 'GAUSSIAN', 'LANCZOS3'}, optional interp_type : str, optional
The interpolation method. The interpolation for both up and down sampling.
mag_filter : str, optional, default='LINEAR'
The interpolation for up sampling.
min_filter : str, optional, default='TRIANGULAR'
The interpolation for down sampling.
random_area : Sequence[float], optional, default=(0.08, 1.) random_area : Sequence[float], optional, default=(0.08, 1.)
The range of scale for sampling. The range of scale for sampling.
random_aspect_ratio : Sequence[float], optional, default=(0.75, 1.33) random_aspect_ratio : Sequence[float], optional, default=(0.75, 1.33)
...@@ -388,9 +395,15 @@ class RandomResizedCrop(object): ...@@ -388,9 +395,15 @@ class RandomResizedCrop(object):
""" """
if isinstance(interp_type, six.string_types): if isinstance(interp_type, six.string_types):
interp_type = getattr(types, 'INTERP_' + interp_type.upper()) interp_type = getattr(types, 'INTERP_' + interp_type.upper())
if isinstance(mag_filter, six.string_types):
mag_filter = getattr(types, 'INTERP_' + mag_filter.upper())
if isinstance(min_filter, six.string_types):
min_filter = getattr(types, 'INTERP_' + min_filter.upper())
return ops.RandomResizedCrop( return ops.RandomResizedCrop(
size=size, size=size,
interp_type=interp_type, interp_type=interp_type,
mag_filter=mag_filter,
min_filter=min_filter,
random_area=random_area, random_area=random_area,
random_aspect_ratio=random_aspect_ratio, random_aspect_ratio=random_aspect_ratio,
num_attempts=num_attempts, num_attempts=num_attempts,
...@@ -425,6 +438,8 @@ class Resize(object): ...@@ -425,6 +438,8 @@ class Resize(object):
resize_longer=None, resize_longer=None,
max_size=None, max_size=None,
interp_type='LINEAR', interp_type='LINEAR',
mag_filter=None,
min_filter=None,
**kwargs **kwargs
): ):
"""Create a ``Resize`` operator. """Create a ``Resize`` operator.
...@@ -441,12 +456,20 @@ class Resize(object): ...@@ -441,12 +456,20 @@ class Resize(object):
Resize along the longer side. Resize along the longer side.
max_size : int, optional, default=0 max_size : int, optional, default=0
The limited size for ``resize_shorter``. The limited size for ``resize_shorter``.
interp_type : {'NN', 'LINEAR', 'TRIANGULAR', 'CUBIC', 'GAUSSIAN', 'LANCZOS3'}, optional interp_type : str, optional
The interpolation method. The interpolation for both up and down sampling.
mag_filter : str, optional, default='LINEAR'
The interpolation for up sampling.
min_filter : str, optional, default='TRIANGULAR'
The interpolation for down sampling.
""" """
if isinstance(interp_type, six.string_types): if isinstance(interp_type, six.string_types):
interp_type = getattr(types, 'INTERP_' + interp_type.upper()) interp_type = getattr(types, 'INTERP_' + interp_type.upper())
if isinstance(mag_filter, six.string_types):
mag_filter = getattr(types, 'INTERP_' + mag_filter.upper())
if isinstance(min_filter, six.string_types):
min_filter = getattr(types, 'INTERP_' + min_filter.upper())
return ops.Resize( return ops.Resize(
resize_x=resize_x, resize_x=resize_x,
resize_y=resize_y, resize_y=resize_y,
...@@ -454,6 +477,8 @@ class Resize(object): ...@@ -454,6 +477,8 @@ class Resize(object):
resize_longer=resize_longer, resize_longer=resize_longer,
max_size=max_size, max_size=max_size,
interp_type=interp_type, interp_type=interp_type,
mag_filter=mag_filter,
min_filter=min_filter,
device=context.get_device_type(), device=context.get_device_type(),
**kwargs **kwargs
) )
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Random ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Reader ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -38,6 +39,7 @@ class KPLRecordReader(object): ...@@ -38,6 +39,7 @@ class KPLRecordReader(object):
```python ```python
class MyPipeline(dali.Pipeline): class MyPipeline(dali.Pipeline):
def __init__(): def __init__():
super(MyPipeline, self).__init__() super(MyPipeline, self).__init__()
# Assume the we have the following data: # Assume the we have the following data:
...@@ -48,11 +50,11 @@ class KPLRecordReader(object): ...@@ -48,11 +50,11 @@ class KPLRecordReader(object):
path='/data' path='/data'
features=('image', 'label'), features=('image', 'label'),
pipeline=self, pipeline=self,
# Shuffle globally within specified number of chunks # Shuffle locally in the next ``initial_fill`` examples
# once an epoch is finished # It turns to be weak with the decreasing of ``initial_fill``
shuffle_after_epoch=True, # and disabled if ``initial_fill`` is set to **1**
# Set **0** to split each example as a chunk random_shuffle=True,
shuffle_chunks=0, initial_fill=1024,
) )
def iter_step(self): def iter_step(self):
...@@ -71,8 +73,8 @@ class KPLRecordReader(object): ...@@ -71,8 +73,8 @@ class KPLRecordReader(object):
pipeline, pipeline,
shard_id=0, shard_id=0,
num_shards=1, num_shards=1,
shuffle_after_epoch=False, random_shuffle=False,
shuffle_chunks=0, initial_fill=1024,
**kwargs **kwargs
): ):
"""Create a ``KPLRecordReader``. """Create a ``KPLRecordReader``.
...@@ -81,14 +83,18 @@ class KPLRecordReader(object): ...@@ -81,14 +83,18 @@ class KPLRecordReader(object):
---------- ----------
path : str path : str
The folder of record files. The folder of record files.
features : Sequence[str], required
The name of features to extract.
pipeline : nvidia.dali.Pipeline, required
The pipeline to connect to.
shard_id : int, optional, default=0 shard_id : int, optional, default=0
The index of specific shard. The index of partition to read.
num_shards : int, optional, default=1 num_shards : int, optional, default=1
The total number of shards. The total number of partitions over dataset.
shuffle_after_epoch : bool, optional, default=False random_shuffle : bool, optional, default=False
**True** to shuffle examples once an epoch is finished. Whether to shuffle the data.
shuffle_chunks : int, optional, default=0 initial_fill : int, optional, default=1024
The number of chunks to shuffle. The length of sampling sequence for shuffle.
""" """
self._pipe = pipeline self._pipe = pipeline
...@@ -99,8 +105,8 @@ class KPLRecordReader(object): ...@@ -99,8 +105,8 @@ class KPLRecordReader(object):
source=path, source=path,
part_idx=shard_id, part_idx=shard_id,
num_parts=num_shards, num_parts=num_shards,
shuffle=shuffle_after_epoch, shuffle=random_shuffle,
num_chunks=shuffle_chunks, initial_fill=initial_fill,
**kwargs **kwargs
) )
self._buffer = self._reader.q_out = mp.Queue( self._buffer = self._reader.q_out = mp.Queue(
...@@ -197,13 +203,13 @@ class TFRecordReader(object): ...@@ -197,13 +203,13 @@ class TFRecordReader(object):
path : str path : str
The folder of record files. The folder of record files.
shard_id : int, optional, default=0 shard_id : int, optional, default=0
The index of specific shard. The index of partition to read.
num_shards : int, optional, default=1 num_shards : int, optional, default=1
The total number of shards. The total number of partitions over dataset.
random_shuffle : bool, optional, default=False random_shuffle : bool, optional, default=False
**True** to shuffle examples in a sequence. Whether to shuffle the data.
initial_fill : int, optional, default=1024 initial_fill : int, optional, default=1024
The length of sequence for shuffle. The length of sampling sequence for shuffle.
Returns Returns
------- -------
......
...@@ -14,10 +14,6 @@ before_first ...@@ -14,10 +14,6 @@ before_first
############ ############
.. automethod:: dragon.io.DataReader.before_first .. automethod:: dragon.io.DataReader.before_first
next_chunk
##########
.. automethod:: dragon.io.DataReader.next_chunk
next_example next_example
############ ############
.. automethod:: dragon.io.DataReader.next_example .. automethod:: dragon.io.DataReader.next_example
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <array> #include <array>
#include <climits> #include <climits>
#include <cmath> #include <cmath>
#include <cstring>
#include <ctime> #include <ctime>
#include <functional> #include <functional>
#include <map> #include <map>
......
...@@ -30,39 +30,39 @@ class DataReader(multiprocessing.Process): ...@@ -30,39 +30,39 @@ class DataReader(multiprocessing.Process):
simple_reader = DataReader(dataset=dataset, source=path) simple_reader = DataReader(dataset=dataset, source=path)
``` ```
Partition are available over distributed nodes: Shuffle is supported to randomly sampling into a sequence buffer:
```python ```python
distributed_reader = DataReader( shuffle_reader = DataReader(
dataset=dataset, dataset=dataset,
source=path, source=path,
part_idx=rank, shuffle=True,
num_parts=num_ranks, # It is recommended to set a buffer size larger than
# the batch size to make batches of single node more diverse.
# Default value 1024 is sufficient for most case.
initial_fill=1024,
) )
``` ```
There are two shuffle schemes: Partition are available over distributed nodes:
```python ```python
# Recommendation: SSD or dataset is tiny distributed_reader = DataReader(
example_wise_shuffle_reader = DataReader(
dataset=dataset,
source=path,
shuffle=True,
num_chunks=0, # Set to the number of examples
)
# Recommendation: HDD or dataset is huge
chunk_wise_shuffle_reader = DataReader(
dataset=dataset, dataset=dataset,
source=path, source=path,
shuffle=True, part_idx=rank,
num_chunks=2048, num_parts=world_size,
) )
``` ```
""" """
class PartBoundaries(object):
"""Record the boundary of current part."""
def __init__(self, start, end):
self.start, self.end = start, end
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Create a ``DataReader``. """Create a ``DataReader``.
...@@ -72,14 +72,14 @@ class DataReader(multiprocessing.Process): ...@@ -72,14 +72,14 @@ class DataReader(multiprocessing.Process):
The dataset class to load examples. The dataset class to load examples.
source : str source : str
The path of data source. The path of data source.
shuffle : bool, optional, default=False
Whether to shuffle the data.r
num_chunks : int, optional, default=0
The number of chunks to split.
num_parts : int, optional, default=1
The number of partitions over dataset.
part_idx : int, optional, default=0 part_idx : int, optional, default=0
The index of current partition. The index of partition to read.
num_parts : int, optional, default=1
The total number of partitions over dataset.
shuffle : bool, optional, default=False
Whether to shuffle the data.
initial_fill : int, optional, default=1024
The length of sampling sequence for shuffle.
seed : int, optional seed : int, optional
The random seed to use instead. The random seed to use instead.
...@@ -87,64 +87,68 @@ class DataReader(multiprocessing.Process): ...@@ -87,64 +87,68 @@ class DataReader(multiprocessing.Process):
super(DataReader, self).__init__() super(DataReader, self).__init__()
self._dataset = kwargs.get('dataset', None) self._dataset = kwargs.get('dataset', None)
self._source = kwargs.get('source', '') self._source = kwargs.get('source', '')
self._shuffle = kwargs.get('shuffle', False)
self._num_chunks = kwargs.get('num_chunks', 0)
self._num_parts = kwargs.get('num_parts', 1)
self._part_idx = kwargs.get('part_idx', 0) self._part_idx = kwargs.get('part_idx', 0)
self._num_parts = kwargs.get('num_parts', 1)
self._shuffle = kwargs.get('shuffle', False)
self._initial_fill = kwargs.get('initial_fill', 1024) if self._shuffle else 1
self._seed = kwargs.get('seed', config.config().random_seed) self._seed = kwargs.get('seed', config.config().random_seed)
self._begin, self._end = 0, 0 self._first, self._cursor, self._last = 0, 0, 0
self._perm_size, self._perm = 1, None self._part_size = 0
self._chunk_size, self._num_examples = 1, 1 self._num_examples = 0
self._example_cursor, self._chunk_cursor = 0, 0 self._example_buffer = []
self._parts = []
self.q_out = None self.q_out = None
self.daemon = True self.daemon = True
def before_first(self): def before_first(self):
"""Move the cursor before begin.""" """Move the cursor before begin."""
self._example_cursor = self._begin self._cursor = self._first
self._dataset.redirect(self._begin) self._dataset.redirect(self._first)
def next_example(self): def next_example(self):
"""Return the next example.""" """Return the next example."""
self._example_cursor += 1 self._cursor += 1
return self._dataset.get() return self._dataset.get()
def next_chunk(self): def reset(self, stick_to_part=False):
"""Select the next chunk."""
self._chunk_cursor += 1
if self._chunk_cursor >= self._perm_size:
self.reset()
else:
chunk_idx = self._part_idx * self._perm_size + int(self._perm[self._chunk_cursor])
self._begin = chunk_idx * self._chunk_size
if self._begin >= self._num_examples:
self.next_chunk()
else:
self._end = min(self._begin + self._chunk_size, self._num_examples)
self.before_first()
def reset(self):
"""Reset the environment of dataset.""" """Reset the environment of dataset."""
if self._num_parts > 1 or self._shuffle: # Redirect to the adjacent part if available.
self._chunk_cursor = -1 if not stick_to_part:
self._part_idx = (self._part_idx + 1) % self._num_parts self._part_idx = (self._part_idx + 1) % self._num_parts
if self._shuffle: self._first = self._part_idx * self._part_size
self._perm = numpy.random.permutation(self._perm_size) self._last = min(self._first + self._part_size, self._num_examples)
self.next_chunk()
else:
self._begin, self._end = 0, self._num_examples
self.before_first() self.before_first()
# Use the new boundaries to avoid sampling duplicates
# when buffer size is greater than dataset size.
counter = self._parts[-1].end
self._parts.append(DataReader.PartBoundaries(counter, counter))
def run(self): def run(self):
"""Start the process.""" """Start the process."""
self._init_dataset() self._init_dataset()
# Persist a loop to read examples. # Persist a loop to read examples.
while True: while True:
self.q_out.put(self.next_example()) # Pop the depleted part if necessary
if self._example_cursor >= self._end: if self._parts[0].start == self._parts[0].end:
if self._num_parts > 1 or self._shuffle: self._parts.pop(0)
self.next_chunk() offset = 0
else: if self._shuffle:
# Sample a random offset if shuffle required.
offset = self._parts[0].end - self._parts[0].start
offset = int(numpy.random.uniform(high=offset))
# Choose a loaded example from the buffer.
i = self._parts[0].start % len(self._example_buffer)
j = (self._parts[0].start + offset) % len(self._example_buffer)
self.q_out.put(self._example_buffer[j])
self._example_buffer[j] = self._example_buffer[i]
# Load and push back a new example into the buffer.
k = self._parts[-1].end % len(self._example_buffer)
self._example_buffer[k] = self.next_example()
# Increase the part boundaries
self._parts[-1].end += 1
self._parts[0].start += 1
# Reset the cursor if necessary
if self._cursor >= self._last:
self.reset() self.reset()
def _init_dataset(self): def _init_dataset(self):
...@@ -154,24 +158,16 @@ class DataReader(multiprocessing.Process): ...@@ -154,24 +158,16 @@ class DataReader(multiprocessing.Process):
# Instantiate the dataset here to avoid a fork of process. # Instantiate the dataset here to avoid a fork of process.
# Fork will somehow fail if dataset is implemented in C/C++. # Fork will somehow fail if dataset is implemented in C/C++.
self._dataset = self._dataset(self._source) self._dataset = self._dataset(self._source)
self._num_examples = self._dataset.size
# Determine the chunk scheme on different settings.
def div_up(a, b):
return (a + b - 1) // b
if self._shuffle: # Determine the part specification.
if self._num_chunks <= 0: self._num_examples = self._dataset.size
# Each chunk has at most 1 example (ExampleWise). self._part_size = (self._num_examples + self._num_parts - 1) // self._num_parts
self._perm_size = div_up(self._num_examples, self._num_parts) self._parts.append(DataReader.PartBoundaries(0, 0))
else:
# Each chunk has several examples (ChunkWise). # Fill the initial buffer to support random sampling.
self._perm_size = div_up(self._num_chunks, self._num_parts) self.reset(stick_to_part=True)
self._chunk_size = div_up(self._num_examples, self._num_chunks) for i in range(self._initial_fill):
else: self._example_buffer.append(self.next_example())
# Each chunk has the examples of whole shard (ShardWise). self._parts[-1].end += 1
self._chunk_size = div_up(self._num_examples, self._num_parts) if self._cursor >= self._last:
# Reset the layout of permutation.
self._perm = numpy.arange(self._perm_size)
self.reset() self.reset()
...@@ -35,8 +35,7 @@ class Registry(object): ...@@ -35,8 +35,7 @@ class Registry(object):
if not self.has(name): if not self.has(name):
raise KeyError( raise KeyError(
"`%s` is not registered in <%s>." "`%s` is not registered in <%s>."
% (name, self._name) % (name, self._name))
)
return self._registry[name] return self._registry[name]
def has(self, name): def has(self, name):
......
...@@ -74,8 +74,8 @@ class DataIterator(object): ...@@ -74,8 +74,8 @@ class DataIterator(object):
The path of data source. The path of data source.
shuffle : bool, optional, default=False shuffle : bool, optional, default=False
Whether to shuffle the data. Whether to shuffle the data.
num_chunks : int, optional, default=0 initial_fill : int, optional, default=1024
The number of chunks to split. The length of sampling sequence for shuffle.
resize : int, optional, default=0 resize : int, optional, default=0
The size for the shortest edge. The size for the shortest edge.
padding : int, optional, default=0 padding : int, optional, default=0
...@@ -94,7 +94,7 @@ class DataIterator(object): ...@@ -94,7 +94,7 @@ class DataIterator(object):
The range of scales to sample a crop randomly. The range of scales to sample a crop randomly.
random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33) random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33)
The range of aspect ratios to sample a crop randomly. The range of aspect ratios to sample a crop randomly.
augment_color : bool, optional, default=False distort_color : bool, optional, default=False
Whether to apply color distortion. Whether to apply color distortion.
inverse_color : bool, option, default=False inverse_color : bool, option, default=False
Whether to inverse channels for color images. Whether to inverse channels for color images.
......
...@@ -56,7 +56,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -56,7 +56,7 @@ class DataTransformer(multiprocessing.Process):
The range of scales to sample a crop randomly. The range of scales to sample a crop randomly.
random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33) random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33)
The range of aspect ratios to sample a crop randomly. The range of aspect ratios to sample a crop randomly.
augment_color : bool, optional, default=False distort_color : bool, optional, default=False
Whether to apply color distortion. Whether to apply color distortion.
inverse_color : bool, option, default=False inverse_color : bool, option, default=False
Whether to inverse channels for color images. Whether to inverse channels for color images.
...@@ -76,7 +76,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -76,7 +76,7 @@ class DataTransformer(multiprocessing.Process):
self._mirror = kwargs.get('mirror', False) self._mirror = kwargs.get('mirror', False)
self._random_scales = kwargs.get('random_scales', (0.08, 1.)) self._random_scales = kwargs.get('random_scales', (0.08, 1.))
self._random_ratios = kwargs.get('random_aspect_ratios', (3. / 4., 4. / 3.)) self._random_ratios = kwargs.get('random_aspect_ratios', (3. / 4., 4. / 3.))
self._augment_color = kwargs.get('augment_color', False) self._distort_color = kwargs.get('distort_color', False)
self._inverse_color = kwargs.get('inverse_color', False) self._inverse_color = kwargs.get('inverse_color', False)
self._phase = kwargs.get('phase', 'TRAIN') self._phase = kwargs.get('phase', 'TRAIN')
self._seed = kwargs.get('seed', config.config().random_seed) self._seed = kwargs.get('seed', config.config().random_seed)
...@@ -84,7 +84,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -84,7 +84,7 @@ class DataTransformer(multiprocessing.Process):
self.daemon = True self.daemon = True
if cv2 is None: if cv2 is None:
raise ImportError('Failed to import package <cv2>.') raise ImportError('Failed to import package <cv2>.')
if self._augment_color and PIL is None: if self._distort_color and PIL is None:
raise ImportError('Failed to import package <PIL>.') raise ImportError('Failed to import package <PIL>.')
def get(self, example): def get(self, example):
...@@ -105,7 +105,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -105,7 +105,7 @@ class DataTransformer(multiprocessing.Process):
""" """
# Decode. # Decode.
img = numpy.frombuffer(example['data'], numpy.uint8) img = numpy.frombuffer(example['data'], numpy.uint8)
if example['encoded'] > 0: if example.get('encoded', 0) > 0:
img = cv2.imdecode(img, 1) img = cv2.imdecode(img, 1)
else: else:
img = img.reshape(example['shape']) img = img.reshape(example['shape'])
...@@ -117,13 +117,11 @@ class DataTransformer(multiprocessing.Process): ...@@ -117,13 +117,11 @@ class DataTransformer(multiprocessing.Process):
pass pass
else: else:
if w < h: if w < h:
ow, oh = size, size * h // w ow, oh, im_scale = size, size * h // w, float(size) / w
else: else:
oh, ow = size, size * w // h oh, ow, im_scale = size, size * w // h, float(size) / h
img = cv2.resize( interp = cv2.INTER_AREA if im_scale < 1 else cv2.INTER_LINEAR
img, (ow, oh), img = cv2.resize(img, (ow, oh), interpolation=interp)
interpolation=cv2.INTER_LINEAR,
)
# Padding. # Padding.
if self._padding > 0: if self._padding > 0:
...@@ -181,7 +179,9 @@ class DataTransformer(multiprocessing.Process): ...@@ -181,7 +179,9 @@ class DataTransformer(multiprocessing.Process):
j = (width - w) // 2 j = (width - w) // 2
img = img[i:i + h, j:j + w, :] img = img[i:i + h, j:j + w, :]
new_size = (self._random_crop_size, self._random_crop_size) new_size = (self._random_crop_size, self._random_crop_size)
img = cv2.resize(img, new_size, interpolation=cv2.INTER_LINEAR) min_scale = self._random_crop_size / max(img.shape[:2])
interp = cv2.INTER_AREA if min_scale < 1 else cv2.INTER_LINEAR
img = cv2.resize(img, new_size, interpolation=interp)
# CutOut. # CutOut.
if self._cutout_size > 0: if self._cutout_size > 0:
...@@ -199,8 +199,8 @@ class DataTransformer(multiprocessing.Process): ...@@ -199,8 +199,8 @@ class DataTransformer(multiprocessing.Process):
if numpy.random.randint(0, 2) > 0: if numpy.random.randint(0, 2) > 0:
img = img[:, ::-1, :] img = img[:, ::-1, :]
# Color augmentation. # Color distortion.
if self._augment_color: if self._distort_color:
img = PIL.Image.fromarray(img) img = PIL.Image.fromarray(img)
transforms = [ transforms = [
PIL.ImageEnhance.Brightness, PIL.ImageEnhance.Brightness,
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#ifndef DRAGON_UTILS_CAST_H_ #ifndef DRAGON_UTILS_CAST_H_
#define DRAGON_UTILS_CAST_H_ #define DRAGON_UTILS_CAST_H_
#include <cstring>
#include "dragon/core/types.h" #include "dragon/core/types.h"
#include "dragon/utils/device/common_cuda.h" #include "dragon/utils/device/common_cuda.h"
......
...@@ -118,7 +118,7 @@ def load_and_assign_pkl_dict(name, module, skip=False): ...@@ -118,7 +118,7 @@ def load_and_assign_pkl_dict(name, module, skip=False):
value_dict = six.moves.pickle.load(f) value_dict = six.moves.pickle.load(f)
except UnicodeDecodeError: except UnicodeDecodeError:
with open(name, 'rb') as f: with open(name, 'rb') as f:
value_dict = six.moves.pickle.load(f, encoding='iso-8859-1') value_dict = six.moves.pickle.load(f, encoding='bytes')
weight_dict = {w.name: w for w in module.weights} weight_dict = {w.name: w for w in module.weights}
return _assign_weights_from_dict(weight_dict, value_dict, skip=skip) return _assign_weights_from_dict(weight_dict, value_dict, skip=skip)
......
...@@ -66,7 +66,7 @@ def load(f, pickle_module=PICKLE_MODULE): ...@@ -66,7 +66,7 @@ def load(f, pickle_module=PICKLE_MODULE):
f, 'rb', lambda f: pickle_module.load(f)) f, 'rb', lambda f: pickle_module.load(f))
except UnicodeDecodeError: except UnicodeDecodeError:
return _with_file_like( return _with_file_like(
f, 'rb', lambda f: pickle_module.load(f, encoding='iso-8859-1')) f, 'rb', lambda f: pickle_module.load(f, encoding='bytes'))
def _save_dict(obj): def _save_dict(obj):
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!