Commit 406662ad by Ting PAN

Comment I/O prefetch detailedly

1 parent d5f7d2d9
...@@ -121,7 +121,7 @@ class DataBatch(mp.Process): ...@@ -121,7 +121,7 @@ class DataBatch(mp.Process):
self._transformers = [] self._transformers = []
for i in range(self._num_transformers): for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs) transformer = DataTransformer(**kwargs)
transformer._rng_seed += (i + rank * self._num_transformers) transformer._seed += (i + rank * self._num_transformers)
transformer.q_in = self.Q1 transformer.q_in = self.Q1
transformer.q1_out, transformer.q2_out = self.Q21, self.Q22 transformer.q1_out, transformer.q2_out = self.Q21, self.Q22
transformer.start() transformer.start()
...@@ -175,10 +175,15 @@ class DataBatch(mp.Process): ...@@ -175,10 +175,15 @@ class DataBatch(mp.Process):
'gt_boxes': np.concatenate(all_boxes, axis=0), 'gt_boxes': np.concatenate(all_boxes, axis=0),
} }
# Two queues to implement aspect-grouping
# This is necessary to reduce the gpu memory
# from fetching a huge square batch blob
q1, q2 = self.Q21, self.Q22 q1, q2 = self.Q21, self.Q22
# Main prefetch loop
while True: while True:
if q1.qsize() >= cfg.TRAIN.IMS_PER_BATCH: if q1.qsize() >= cfg.TRAIN.IMS_PER_BATCH:
self.Q3.put(produce(q1)) self.Q3.put(produce(q1))
elif q2.qsize() >= cfg.TRAIN.IMS_PER_BATCH: elif q2.qsize() >= cfg.TRAIN.IMS_PER_BATCH:
self.Q3.put(produce(q2)) self.Q3.put(produce(q2))
q1, q2 = q2, q1 # Sample two queues uniformly q1, q2 = q2, q1 # Uniform sampling trick
...@@ -26,13 +26,12 @@ from lib.utils.boxes import flip_boxes ...@@ -26,13 +26,12 @@ from lib.utils.boxes import flip_boxes
class DataTransformer(multiprocessing.Process): class DataTransformer(multiprocessing.Process):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DataTransformer, self).__init__() super(DataTransformer, self).__init__()
self._rng_seed = cfg.RNG_SEED self._seed = cfg.RNG_SEED
self._use_flipped = cfg.TRAIN.USE_FLIPPED self._use_flipped = cfg.TRAIN.USE_FLIPPED
self._use_diff = cfg.TRAIN.USE_DIFF self._use_diff = cfg.TRAIN.USE_DIFF
self._classes = kwargs.get('classes', ('__background__',)) self._classes = kwargs.get('classes', ('__background__',))
self._num_classes = len(self._classes) self._num_classes = len(self._classes)
self._class_to_ind = dict(zip(self._classes, range(self._num_classes))) self._class_to_ind = dict(zip(self._classes, range(self._num_classes)))
self._queues = []
self.q_in = self.q1_out = self.q2_out = None self.q_in = self.q1_out = self.q2_out = None
self.daemon = True self.daemon = True
...@@ -147,7 +146,10 @@ class DataTransformer(multiprocessing.Process): ...@@ -147,7 +146,10 @@ class DataTransformer(multiprocessing.Process):
return im, im_scale, gt_boxes return im, im_scale, gt_boxes
def run(self): def run(self):
np.random.seed(self._rng_seed) # Fix the process-local random seed
np.random.seed(self._seed)
# Main prefetch loop
while True: while True:
outputs = self.get(self.q_in.get()) outputs = self.get(self.q_in.get())
if len(outputs[2]) < 1: if len(outputs[2]) < 1:
......
...@@ -66,19 +66,6 @@ class RetinaNetDecoder(torch.nn.Module): ...@@ -66,19 +66,6 @@ class RetinaNetDecoder(torch.nn.Module):
(2 ** (octave / float(scales_per_octave))) (2 ** (octave / float(scales_per_octave)))
for octave in range(scales_per_octave)] for octave in range(scales_per_octave)]
def register_operator(self):
return {
'op_type': 'Proposal',
'arguments': {
'det_type': 'RETINANET',
'strides': self.strides,
'scales': self.scales,
'ratios': [float(e) for e in cfg.RETINANET.ASPECT_RATIOS],
'pre_nms_top_n': cfg.RETINANET.PRE_NMS_TOP_N,
'score_thresh': cfg.TEST.SCORE_THRESH,
}
}
def forward(self, features, cls_prob, bbox_pred, ims_info): def forward(self, features, cls_prob, bbox_pred, ims_info):
return F.decode_retinanet( return F.decode_retinanet(
features=features, features=features,
......
...@@ -115,7 +115,7 @@ class DataBatch(mp.Process): ...@@ -115,7 +115,7 @@ class DataBatch(mp.Process):
self._transformers = [] self._transformers = []
for i in range(self._num_transformers): for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs) transformer = DataTransformer(**kwargs)
transformer._rng_seed += (i + rank * self._num_transformers) transformer._seed += (i + rank * self._num_transformers)
transformer.q_in, transformer.q_out = self.Q1, self.Q2 transformer.q_in, transformer.q_out = self.Q1, self.Q2
transformer.start() transformer.start()
self._transformers.append(transformer) self._transformers.append(transformer)
...@@ -159,6 +159,7 @@ class DataBatch(mp.Process): ...@@ -159,6 +159,7 @@ class DataBatch(mp.Process):
cfg.SSD.RESIZE.WIDTH, 3, cfg.SSD.RESIZE.WIDTH, 3,
) )
# Main prefetch loop
while True: while True:
boxes_to_pack = [] boxes_to_pack = []
image_batch = np.zeros(image_batch_shape, 'uint8') image_batch = np.zeros(image_batch_shape, 'uint8')
......
...@@ -26,7 +26,7 @@ from lib.utils.boxes import flip_boxes ...@@ -26,7 +26,7 @@ from lib.utils.boxes import flip_boxes
class DataTransformer(multiprocessing.Process): class DataTransformer(multiprocessing.Process):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DataTransformer, self).__init__() super(DataTransformer, self).__init__()
self._rng_seed = cfg.RNG_SEED self._seed = cfg.RNG_SEED
self._mirror = cfg.TRAIN.USE_FLIPPED self._mirror = cfg.TRAIN.USE_FLIPPED
self._use_diff = cfg.TRAIN.USE_DIFF self._use_diff = cfg.TRAIN.USE_DIFF
self._classes = kwargs.get('classes', ('__background__',)) self._classes = kwargs.get('classes', ('__background__',))
...@@ -114,7 +114,10 @@ class DataTransformer(multiprocessing.Process): ...@@ -114,7 +114,10 @@ class DataTransformer(multiprocessing.Process):
return img, gt_boxes return img, gt_boxes
def run(self): def run(self):
np.random.seed(self._rng_seed) # Fix the process-local random seed
np.random.seed(self._seed)
# Main prefetch loop
while True: while True:
outputs = self.get(self.q_in.get()) outputs = self.get(self.q_in.get())
if len(outputs[1]) < 1: if len(outputs[1]) < 1:
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!