Commit 4dd97da7 by Ting PAN

Add simple serving script

1 parent e3b9b641
...@@ -9,7 +9,7 @@ The torch-style codes help us to simplify the hierarchical pipeline of modern de ...@@ -9,7 +9,7 @@ The torch-style codes help us to simplify the hierarchical pipeline of modern de
## Requirements ## Requirements
seeta-dragon >= 0.3.0.dev20201014 seeta-dragon >= 0.3.0.dev20201024
## Installation ## Installation
......
...@@ -47,7 +47,7 @@ class AnchorTarget(object): ...@@ -47,7 +47,7 @@ class AnchorTarget(object):
if self.num_strides > 1 if self.num_strides > 1
else np.array(self.scales))) else np.array(self.scales)))
# Plan the maximum shifted anchor layout # Plan the maximum shifted anchor layout
max_size = cfg.TRAIN.MAX_SIZE max_size = max(cfg.TRAIN.MAX_SIZE, max(cfg.TRAIN.SCALES))
if cfg.MODEL.COARSEST_STRIDE > 0: if cfg.MODEL.COARSEST_STRIDE > 0:
stride = float(cfg.MODEL.COARSEST_STRIDE) stride = float(cfg.MODEL.COARSEST_STRIDE)
max_size = int(math.ceil(max_size / stride) * stride) max_size = int(math.ceil(max_size / stride) * stride)
......
...@@ -45,7 +45,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -45,7 +45,7 @@ class DataTransformer(multiprocessing.Process):
self.q_in = self.q_out = None self.q_in = self.q_out = None
self.daemon = True self.daemon = True
def get_boxes(self, example, im_scale, flipped): def get_boxes(self, example, im_scale, im_offset, flipped):
objects, num_objects = example.objects, 0 objects, num_objects = example.objects, 0
height, width = example.height, example.width height, width = example.height, example.width
if not self._use_diff: if not self._use_diff:
...@@ -78,6 +78,14 @@ class DataTransformer(multiprocessing.Process): ...@@ -78,6 +78,14 @@ class DataTransformer(multiprocessing.Process):
# Scale the boxes to the detecting scale. # Scale the boxes to the detecting scale.
boxes *= im_scale boxes *= im_scale
# Offset the boxes to align the cropping.
if im_offset is not None:
boxes[:, 0::2] += im_offset[1]
boxes[:, 1::2] += im_offset[0]
boxes[:, :] = np.minimum(
np.maximum(boxes[:, :], 0),
[im_offset[2][1] - 1, im_offset[2][0] - 1] * 2)
# Attach the classes. # Attach the classes.
gt_boxes = np.empty((num_objects, 5), dtype=np.float32) gt_boxes = np.empty((num_objects, 5), dtype=np.float32)
gt_boxes[:, :4], gt_boxes[:, 4] = boxes, gt_classes gt_boxes[:, :4], gt_boxes[:, 4] = boxes, gt_classes
...@@ -88,9 +96,10 @@ class DataTransformer(multiprocessing.Process): ...@@ -88,9 +96,10 @@ class DataTransformer(multiprocessing.Process):
example = Example(example) example = Example(example)
# Resize. # Resize.
target_size = npr.choice(self._scales)
img, im_scale = image_util.resize_image_with_target_size( img, im_scale = image_util.resize_image_with_target_size(
example.image, example.image,
target_size=npr.choice(self._scales), target_size=target_size,
max_size=self._max_size, max_size=self._max_size,
random_scales=self._random_scales, random_scales=self._random_scales,
) )
...@@ -101,12 +110,18 @@ class DataTransformer(multiprocessing.Process): ...@@ -101,12 +110,18 @@ class DataTransformer(multiprocessing.Process):
img = img[:, ::-1] img = img[:, ::-1]
flipped = True flipped = True
# Crop or Pad.
im_offset = None
if self._max_size == 0:
img, im_offset = image_util.get_image_with_target_size(
img, target_size)
# Distort. # Distort.
if self._use_distort: if self._use_distort:
img = image_util.distort_image(img) img = image_util.distort_image(img)
# Boxes. # Boxes.
boxes = self.get_boxes(example, im_scale, flipped) boxes = self.get_boxes(example, im_scale, im_offset, flipped)
# Standard outputs. # Standard outputs.
outputs = {'image': img, outputs = {'image': img,
......
...@@ -37,7 +37,8 @@ class Proposal(object): ...@@ -37,7 +37,8 @@ class Proposal(object):
self.defaults = collections.OrderedDict([ self.defaults = collections.OrderedDict([
('rois', np.array([[-1, 0, 0, 1, 1]], 'float32'))]) ('rois', np.array([[-1, 0, 0, 1, 1]], 'float32'))])
self.bbox_transform_clip = \ self.bbox_transform_clip = \
np.log(cfg.TRAIN.MAX_SIZE / min(self.strides)) np.log(max(cfg.TRAIN.MAX_SIZE,
max(cfg.TRAIN.SCALES)) / min(self.strides))
# Generate base anchors # Generate base anchors
self.base_anchors = [] self.base_anchors = []
for i in range(self.num_strides): for i in range(self.num_strides):
......
...@@ -52,7 +52,7 @@ def ims_detect(detector, raw_images, timer=None): ...@@ -52,7 +52,7 @@ def ims_detect(detector, raw_images, timer=None):
images, images_info = get_data(raw_images) images, images_info = get_data(raw_images)
timer.tic() if timer else timer timer.tic() if timer else timer
# Do forward # Do forward.
inputs = {'image': torch.from_numpy(images), inputs = {'image': torch.from_numpy(images),
'im_info': torch.from_numpy(images_info)} 'im_info': torch.from_numpy(images_info)}
if not hasattr(detector, 'script_forward'): if not hasattr(detector, 'script_forward'):
...@@ -65,7 +65,7 @@ def ims_detect(detector, raw_images, timer=None): ...@@ -65,7 +65,7 @@ def ims_detect(detector, raw_images, timer=None):
outputs = detector.script_forward(inputs['image'], inputs['im_info']) outputs = detector.script_forward(inputs['image'], inputs['im_info'])
outputs = dict((k, outputs[k].numpy()) for k in outputs.keys()) outputs = dict((k, outputs[k].numpy()) for k in outputs.keys())
# Decode results # Decode results.
batch_pred = box_util.bbox_transform_inv( batch_pred = box_util.bbox_transform_inv(
outputs['rois'][:, 1:5], outputs['rois'][:, 1:5],
outputs['bbox_pred'], outputs['bbox_pred'],
...@@ -79,44 +79,18 @@ def ims_detect(detector, raw_images, timer=None): ...@@ -79,44 +79,18 @@ def ims_detect(detector, raw_images, timer=None):
results[ii][0].append(outputs['cls_prob'][inds]) results[ii][0].append(outputs['cls_prob'][inds])
results[ii][1].append(boxes) results[ii][1].append(boxes)
# Merge from multiple scales # Merge from multiple scales.
ret = [(np.vstack(s), np.vstack(b)) for s, b in results] ret = [(np.vstack(s), np.vstack(b)) for s, b in results]
timer.toc() if timer else timer timer.toc() if timer else timer
return ret return ret
def test_net(weights, q_in, q_out, device, root_logger=True): def get_detections(outputs):
"""Test a network trained with Faster R-CNN algorithm.""" """Return the categorical detections from outputs."""
cfg.GPU_ID = device scores, boxes = outputs
num_classes = len(cfg.MODEL.CLASSES)
logger.set_root_logger(root_logger)
detector = new_detector(device, weights)
must_stop = False
timers = time_util.new_timers('im_detect_bbox', 'misc')
empty_detections = np.zeros((0, 5), 'float32')
while True:
if must_stop:
break
indices, raw_images = [], []
for _ in range(cfg.TEST.IMS_PER_BATCH):
i, raw_image = q_in.get()
if i < 0:
must_stop = True
break
indices.append(i)
raw_images.append(raw_image)
if len(raw_images) == 0:
continue
results = ims_detect(detector, raw_images, timers['im_detect_bbox'])
for i, (scores, boxes) in enumerate(results):
timers['misc'].tic()
boxes_this_image = [[]] boxes_this_image = [[]]
for j in range(1, num_classes): empty_detections = np.zeros((0, 5), 'float32')
for j in range(1, len(cfg.MODEL.CLASSES)):
inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0] inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0]
if len(inds) == 0: if len(inds) == 0:
boxes_this_image.append(empty_detections) boxes_this_image.append(empty_detections)
...@@ -140,8 +114,40 @@ def test_net(weights, q_in, q_out, device, root_logger=True): ...@@ -140,8 +114,40 @@ def test_net(weights, q_in, q_out, device, root_logger=True):
) )
cls_detections = cls_detections[keep, :] cls_detections = cls_detections[keep, :]
boxes_this_image.append(cls_detections) boxes_this_image.append(cls_detections)
timers['misc'].toc() return [boxes_this_image]
def test_net(weights, q_in, q_out, device, root_logger=True):
"""Test a network trained with Faster R-CNN algorithm."""
cfg.GPU_ID = device
logger.set_root_logger(root_logger)
detector = new_detector(device, weights)
timers = time_util.new_timers('im_detect_bbox', 'misc')
must_stop = False
while not must_stop:
indices, raw_images = [], []
for _ in range(cfg.TEST.IMS_PER_BATCH):
i, raw_image = q_in.get()
if i < 0:
must_stop = True
break
indices.append(i)
raw_images.append(raw_image)
if len(raw_images) == 0:
continue
# Detect on specific scales.
all_outputs = ims_detect(
detector=detector,
raw_images=raw_images,
timer=timers['im_detect_bbox'],
)
# Post-processing.
for i, outputs in enumerate(all_outputs):
with timers['misc'].tic_and_toc():
boxes_this_image, = get_detections(outputs)
q_out.put(( q_out.put((
indices[i], indices[i],
dict([('im_detect', timers['im_detect_bbox'].average_time), dict([('im_detect', timers['im_detect_bbox'].average_time),
......
...@@ -46,7 +46,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -46,7 +46,7 @@ class DataTransformer(multiprocessing.Process):
self.q_in = self.q_out = None self.q_in = self.q_out = None
self.daemon = True self.daemon = True
def get_boxes_and_segms(self, example, im_scale, flipped): def get_boxes_and_segms(self, example, im_scale, im_offset, flipped):
objects, num_objects = example.objects, 0 objects, num_objects = example.objects, 0
height, width = example.height, example.width height, width = example.height, example.width
if not self._use_diff: if not self._use_diff:
...@@ -90,6 +90,11 @@ class DataTransformer(multiprocessing.Process): ...@@ -90,6 +90,11 @@ class DataTransformer(multiprocessing.Process):
# Scale the boxes to the detecting scale. # Scale the boxes to the detecting scale.
boxes *= im_scale boxes *= im_scale
# Offset the boxes to align the cropping.
if im_offset is not None:
if min(im_offset[:2]) < 0:
raise ValueError('RandomCrop with mask is not supported.')
# Attach the classes and mask flags. # Attach the classes and mask flags.
gt_boxes = np.empty((num_objects, 6), dtype=np.float32) gt_boxes = np.empty((num_objects, 6), dtype=np.float32)
gt_boxes[:, :4], gt_boxes[:, 4] = boxes, gt_classes gt_boxes[:, :4], gt_boxes[:, 4] = boxes, gt_classes
...@@ -101,9 +106,10 @@ class DataTransformer(multiprocessing.Process): ...@@ -101,9 +106,10 @@ class DataTransformer(multiprocessing.Process):
example = Example(example) example = Example(example)
# Resize. # Resize.
target_size = npr.choice(self._scales)
img, im_scale = image_util.resize_image_with_target_size( img, im_scale = image_util.resize_image_with_target_size(
example.image, example.image,
target_size=npr.choice(self._scales), target_size=target_size,
max_size=self._max_size, max_size=self._max_size,
random_scales=self._random_scales, random_scales=self._random_scales,
) )
...@@ -114,12 +120,18 @@ class DataTransformer(multiprocessing.Process): ...@@ -114,12 +120,18 @@ class DataTransformer(multiprocessing.Process):
img = img[:, ::-1] img = img[:, ::-1]
flipped = True flipped = True
# Crop or Pad.
im_offset = None
if self._max_size == 0:
img, im_offset = image_util.get_image_with_target_size(
img, target_size)
# Distort. # Distort.
if self._use_distort: if self._use_distort:
img = image_util.distort_image(img) img = image_util.distort_image(img)
# Boxes and segmentations. # Boxes and segmentations.
boxes, segms = self.get_boxes_and_segms(example, im_scale, flipped) boxes, segms = self.get_boxes_and_segms(example, im_scale, im_offset, flipped)
# Standard outputs. # Standard outputs.
outputs = {'image': img, outputs = {'image': img,
......
...@@ -118,41 +118,14 @@ def mask_detect(detector, rois): ...@@ -118,41 +118,14 @@ def mask_detect(detector, rois):
return detector.rcnn.sigmoid(mask_pred).numpy().copy() return detector.rcnn.sigmoid(mask_pred).numpy().copy()
def test_net(weights, q_in, q_out, device, root_logger=True): def get_detections(outputs):
"""Test a network trained with Mask R-CNN algorithm.""" """Return the categorical detections from outputs."""
cfg.GPU_ID = device scores, boxes, batch_inds, im_scales = outputs
num_classes = len(cfg.MODEL.CLASSES)
logger.set_root_logger(root_logger)
detector = new_detector(device, weights)
must_stop = False
timers = time_util.new_timers('im_detect_bbox', 'im_detect_mask', 'misc')
empty_detections = np.zeros((0, 5), 'float32')
empty_rois = np.zeros((0, 6), 'float32')
while True:
if must_stop:
break
indices, raw_images = [], []
for _ in range(cfg.TEST.IMS_PER_BATCH):
i, raw_image = q_in.get()
if i < 0:
must_stop = True
break
indices.append(i)
raw_images.append(raw_image)
if len(raw_images) == 0:
continue
results = ims_detect(detector, raw_images, timers['im_detect_bbox'])
for i, (scores, boxes, batch_inds, im_scales) in enumerate(results):
timers['misc'].tic()
rois_this_image = [] rois_this_image = []
boxes_this_image = [[]] boxes_this_image = [[]]
masks_this_image = [[]] empty_detections = np.zeros((0, 5), 'float32')
for j in range(1, num_classes): empty_rois = np.zeros((0, 6), 'float32')
for j in range(1, len(cfg.MODEL.CLASSES)):
inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0] inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0]
if len(inds) == 0: if len(inds) == 0:
boxes_this_image.append(empty_detections) boxes_this_image.append(empty_detections)
...@@ -179,31 +152,64 @@ def test_net(weights, q_in, q_out, device, root_logger=True): ...@@ -179,31 +152,64 @@ def test_net(weights, q_in, q_out, device, root_logger=True):
cls_detections = cls_detections[keep, :] cls_detections = cls_detections[keep, :]
cls_batch_inds = cls_batch_inds[keep] cls_batch_inds = cls_batch_inds[keep]
boxes_this_image.append(cls_detections) boxes_this_image.append(cls_detections)
rois_this_image.append( rois_this_image.append(np.hstack((
np.hstack((
cls_batch_inds, cls_batch_inds,
cls_detections[:, :4] * im_scales[cls_batch_inds], cls_detections[:, :4] * im_scales[cls_batch_inds],
np.ones((len(keep), 1)) * (j - 1), np.ones((len(keep), 1)) * (j - 1))))
))) return [boxes_this_image, rois_this_image]
mask_rois = np.concatenate(rois_this_image)
timers['misc'].toc()
def test_net(weights, q_in, q_out, device, root_logger=True):
"""Test a network trained with Mask R-CNN algorithm."""
cfg.GPU_ID = device
num_classes = len(cfg.MODEL.CLASSES)
logger.set_root_logger(root_logger)
detector = new_detector(device, weights)
timers = time_util.new_timers('im_detect_bbox', 'im_detect_mask', 'misc')
must_stop = False
while not must_stop:
# Wait inputs.
indices, raw_images = [], []
for _ in range(cfg.TEST.IMS_PER_BATCH):
i, raw_image = q_in.get()
if i < 0:
must_stop = True
break
indices.append(i)
raw_images.append(raw_image)
if len(raw_images) == 0:
continue
# Detect on specific scales.
all_outputs = ims_detect(
detector=detector,
raw_images=raw_images,
timer=timers['im_detect_bbox'],
)
# Post-processing.
for i, outputs in enumerate(all_outputs):
segms_this_image = [[]]
with timers['misc'].tic_and_toc():
boxes_this_image, rois_this_image = get_detections(outputs)
mask_rois = np.concatenate(rois_this_image)
if len(mask_rois) > 0: if len(mask_rois) > 0:
k = 0 k = 0
timers['im_detect_mask'].tic() timers['im_detect_mask'].tic()
mask_pred = mask_detect(detector, mask_rois) mask_pred = mask_detect(detector, mask_rois)
for j in range(1, num_classes): for j in range(1, num_classes):
num_pred = len(boxes_this_image[j]) num_pred = len(boxes_this_image[j])
cls_masks = mask_pred[k:k + num_pred] cls_segms = mask_pred[k:k + num_pred]
masks_this_image.append(cls_masks) segms_this_image.append(cls_segms)
k += num_pred k += num_pred
timers['im_detect_mask'].toc() timers['im_detect_mask'].toc()
q_out.put(( q_out.put((
indices[i], indices[i],
dict([('im_detect', (timers['im_detect_bbox'].average_time + dict([('im_detect', (timers['im_detect_bbox'].average_time +
timers['im_detect_mask'].average_time)), timers['im_detect_mask'].average_time)),
('misc', timers['misc'].average_time)]), ('misc', timers['misc'].average_time)]),
dict([('boxes', boxes_this_image), dict([('boxes', boxes_this_image),
('masks', masks_this_image)]), ('masks', segms_this_image)]),
)) ))
...@@ -48,9 +48,7 @@ class AnchorTarget(object): ...@@ -48,9 +48,7 @@ class AnchorTarget(object):
ratios=self.ratios, ratios=self.ratios,
sizes=sizes)) sizes=sizes))
# Plan the maximum anchor layout # Plan the maximum anchor layout
max_size = cfg.TRAIN.MAX_SIZE max_size = max(cfg.TRAIN.MAX_SIZE, max(cfg.TRAIN.SCALES))
if max_size == 0:
max_size = cfg.TRAIN.SCALES[0]
if cfg.MODEL.COARSEST_STRIDE > 0: if cfg.MODEL.COARSEST_STRIDE > 0:
stride = float(cfg.MODEL.COARSEST_STRIDE) stride = float(cfg.MODEL.COARSEST_STRIDE)
max_size = int(math.ceil(max_size / stride) * stride) max_size = int(math.ceil(max_size / stride) * stride)
......
...@@ -56,8 +56,6 @@ def ims_detect(detector, raw_images, timer=None): ...@@ -56,8 +56,6 @@ def ims_detect(detector, raw_images, timer=None):
# Do Forward # Do Forward
inputs = {'image': torch.from_numpy(images), inputs = {'image': torch.from_numpy(images),
'im_info': torch.from_numpy(images_info)} 'im_info': torch.from_numpy(images_info)}
# with torch.no_grad():
# outputs = detector.forward(inputs)
if not hasattr(detector, 'script_forward'): if not hasattr(detector, 'script_forward'):
def script_forward(self, image, im_info): def script_forward(self, image, im_info):
return self.forward({'image': image, 'im_info': im_info}) return self.forward({'image': image, 'im_info': im_info})
...@@ -65,10 +63,8 @@ def ims_detect(detector, raw_images, timer=None): ...@@ -65,10 +63,8 @@ def ims_detect(detector, raw_images, timer=None):
func=types.MethodType(script_forward, detector), func=types.MethodType(script_forward, detector),
example_inputs=[inputs['image'], inputs['im_info']], example_inputs=[inputs['image'], inputs['im_info']],
) )
outputs = detector.script_forward(inputs['image'], inputs['im_info']) outputs = detector.script_forward(inputs['image'], inputs['im_info'])
outputs = dict((k, outputs[k].numpy()) for k in outputs.keys()) outputs = dict((k, outputs[k].numpy()) for k in outputs.keys())
timer.toc() if timer else timer
# Decode results # Decode results
detections = outputs['detections'] detections = outputs['detections']
...@@ -83,48 +79,20 @@ def ims_detect(detector, raw_images, timer=None): ...@@ -83,48 +79,20 @@ def ims_detect(detector, raw_images, timer=None):
return ret return ret
def test_net(weights, q_in, q_out, device, root_logger=True): def get_detections(outputs):
"""Test a network trained with RetinaNet algorithm.""" """Return the categorical detections from outputs."""
cfg.GPU_ID = device
num_classes = len(cfg.MODEL.CLASSES) num_classes = len(cfg.MODEL.CLASSES)
logger.set_root_logger(root_logger)
detector = new_detector(device, weights)
must_stop = False
timers = time_util.new_timers('im_detect_bbox', 'misc')
empty_detections = np.zeros((0, 5), 'float32')
while True:
if must_stop:
break
indices, raw_images = [], []
for _ in range(cfg.TEST.IMS_PER_BATCH):
i, raw_image = q_in.get()
if i < 0:
must_stop = True
break
indices.append(i)
raw_images.append(raw_image)
if len(raw_images) == 0:
continue
# Run detecting on specific scales
results = ims_detect(detector, raw_images, timers['im_detect_bbox'])
# Post-processing
for i, detections in enumerate(results):
timers['misc'].tic()
boxes_this_image = [[]] boxes_this_image = [[]]
# Detection format: (x1, y1, x2, y2, score, cls) raw_detections = outputs
detections = np.array(detections) empty_detections = np.zeros((0, 5), 'float32')
for j in range(1, num_classes): for j in range(1, num_classes):
cls_indices = np.where(detections[:, 5].astype(np.int32) == j)[0] cls_indices = np.where(
raw_detections[:, 5].astype(np.int32) == j)[0]
if len(cls_indices) == 0: if len(cls_indices) == 0:
boxes_this_image.append(empty_detections) boxes_this_image.append(empty_detections)
continue continue
cls_boxes = detections[cls_indices, :4] cls_boxes = raw_detections[cls_indices, :4]
cls_scores = detections[cls_indices, 4] cls_scores = raw_detections[cls_indices, 4]
cls_detections = np.hstack(( cls_detections = np.hstack((
cls_boxes, cls_scores[:, np.newaxis])) \ cls_boxes, cls_scores[:, np.newaxis])) \
.astype(np.float32, copy=False) .astype(np.float32, copy=False)
...@@ -142,8 +110,37 @@ def test_net(weights, q_in, q_out, device, root_logger=True): ...@@ -142,8 +110,37 @@ def test_net(weights, q_in, q_out, device, root_logger=True):
) )
cls_detections = cls_detections[keep, :] cls_detections = cls_detections[keep, :]
boxes_this_image.append(cls_detections) boxes_this_image.append(cls_detections)
timers['misc'].toc() return [boxes_this_image]
def test_net(weights, q_in, q_out, device, root_logger=True):
"""Test a network trained with RetinaNet algorithm."""
cfg.GPU_ID = device
logger.set_root_logger(root_logger)
detector = new_detector(device, weights)
timers = time_util.new_timers('im_detect_bbox', 'misc')
must_stop = False
while not must_stop:
# Wait inputs.
indices, raw_images = [], []
for _ in range(cfg.TEST.IMS_PER_BATCH):
i, raw_image = q_in.get()
if i < 0:
must_stop = True
break
indices.append(i)
raw_images.append(raw_image)
if len(raw_images) == 0:
continue
# Detect on specific scales.
all_outputs = ims_detect(detector, raw_images, timers['im_detect_bbox'])
# Post-processing.
for i, outputs in enumerate(all_outputs):
with timers['misc'].tic_and_toc():
boxes_this_image, = get_detections(outputs)
q_out.put(( q_out.put((
indices[i], indices[i],
dict([('im_detect', timers['im_detect_bbox'].average_time), dict([('im_detect', timers['im_detect_bbox'].average_time),
......
...@@ -41,9 +41,10 @@ def get_data(raw_images): ...@@ -41,9 +41,10 @@ def get_data(raw_images):
return images_wide, image_scales_wide return images_wide, image_scales_wide
def ims_detect(detector, raw_images): def ims_detect(detector, raw_images, timer=None):
"""Detect images at single or multiple scales.""" """Detect images at single or multiple scales."""
images, image_scales = get_data(raw_images) images, image_scales = get_data(raw_images)
timer.tic() if timer else timer
# Do forward # Do forward
inputs = {'image': torch.from_numpy(images)} inputs = {'image': torch.from_numpy(images)}
...@@ -55,6 +56,7 @@ def ims_detect(detector, raw_images): ...@@ -55,6 +56,7 @@ def ims_detect(detector, raw_images):
example_inputs=[inputs['image']], example_inputs=[inputs['image']],
) )
outputs = detector.script_forward(inputs['image']) outputs = detector.script_forward(inputs['image'])
timer.toc() if timer else timer
# Decode results # Decode results
batch_pred = outputs['bbox_pred'].numpy() batch_pred = outputs['bbox_pred'].numpy()
...@@ -71,42 +73,16 @@ def ims_detect(detector, raw_images): ...@@ -71,42 +73,16 @@ def ims_detect(detector, raw_images):
results[i // len(cfg.TEST.SCALES)][1].append(boxes) results[i // len(cfg.TEST.SCALES)][1].append(boxes)
# Merge from multiple scales # Merge from multiple scales
return [(np.vstack(s), np.vstack(b)) for s, b in results] ret = [(np.vstack(s), np.vstack(b)) for s, b in results]
timer.toc() if timer else timer
return ret
def test_net(weights, q_in, q_out, device, root_logger=True): def get_detections(outputs):
"""Test a network trained with SSD algorithm.""" """Return the categorical detections from outputs."""
cfg.GPU_ID = device scores, boxes = outputs
num_classes = len(cfg.MODEL.CLASSES)
logger.set_root_logger(root_logger)
detector = new_detector(device, weights)
must_stop = False
_t = time_util.new_timers('im_detect', 'misc')
while True:
if must_stop:
break
indices, raw_images = [], []
for _ in range(cfg.TEST.IMS_PER_BATCH):
i, raw_image = q_in.get()
if i < 0:
must_stop = True
break
indices.append(i)
raw_images.append(raw_image)
if len(raw_images) == 0:
continue
with _t['im_detect'].tic_and_toc():
results = ims_detect(detector, raw_images)
for i, (scores, boxes) in enumerate(results):
_t['misc'].tic()
boxes_this_image = [[]] boxes_this_image = [[]]
# Detection format: (score...), (x1, y1, x2, y2) for j in range(1, len(cfg.MODEL.CLASSES)):
for j in range(1, num_classes):
inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0] inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0]
cls_scores = scores[inds, j] cls_scores = scores[inds, j]
cls_boxes = boxes[inds] cls_boxes = boxes[inds]
...@@ -130,11 +106,44 @@ def test_net(weights, q_in, q_out, device, root_logger=True): ...@@ -130,11 +106,44 @@ def test_net(weights, q_in, q_out, device, root_logger=True):
) )
cls_detections = cls_detections[keep, :] cls_detections = cls_detections[keep, :]
boxes_this_image.append(cls_detections) boxes_this_image.append(cls_detections)
_t['misc'].toc() return [boxes_this_image]
def test_net(weights, q_in, q_out, device, root_logger=True):
"""Test a network trained with SSD algorithm."""
cfg.GPU_ID = device
logger.set_root_logger(root_logger)
detector = new_detector(device, weights)
timers = time_util.new_timers('im_detect_bbox', 'misc')
must_stop = False
while not must_stop:
# Wait inputs.
indices, raw_images = [], []
for _ in range(cfg.TEST.IMS_PER_BATCH):
i, raw_image = q_in.get()
if i < 0:
must_stop = True
break
indices.append(i)
raw_images.append(raw_image)
if len(raw_images) == 0:
continue
# Detect on specific scales.
all_outputs = ims_detect(
detector=detector,
raw_images=raw_images,
timer=timers['im_detect_bbox'],
)
# Post-processing.
for i, outputs in enumerate(all_outputs):
with timers['misc'].tic_and_toc():
boxes_this_image, = get_detections(outputs)
q_out.put(( q_out.put((
indices[i], indices[i],
dict([('im_detect', _t['im_detect'].average_time), dict([('im_detect', timers['im_detect_bbox'].average_time),
('misc', _t['misc'].average_time)]), ('misc', timers['misc'].average_time)]),
dict([('boxes', boxes_this_image)]), dict([('boxes', boxes_this_image)]),
)) ))
...@@ -27,14 +27,14 @@ from seetadet.utils.pycocotools.cocoeval import COCOeval ...@@ -27,14 +27,14 @@ from seetadet.utils.pycocotools.cocoeval import COCOeval
class COCOEvaluator(object): class COCOEvaluator(object):
"""Evaluator for MS COCO dataset."""
def __init__(self, imdb, ann_file=None): def __init__(self, imdb, ann_file=None):
self.imdb = imdb self.imdb = imdb
if ann_file is not None and \ if ann_file is not None and os.path.exists(ann_file):
os.path.exists(ann_file):
self.coco = COCO(ann_file) self.coco = COCO(ann_file)
cats = self.coco.loadCats(self.coco.getCatIds()) cats = self.coco.loadCats(self.coco.getCatIds())
self.class_to_cat_id = dict( self.class_to_cat_id = dict(zip([c['name'] for c in cats],
zip([c['name'] for c in cats],
self.coco.getCatIds())) self.coco.getCatIds()))
else: else:
self.coco = None self.coco = None
...@@ -43,22 +43,22 @@ class COCOEvaluator(object): ...@@ -43,22 +43,22 @@ class COCOEvaluator(object):
def bbox_results_one_category(self, boxes, cat_id, gt_recs): def bbox_results_one_category(self, boxes, cat_id, gt_recs):
ix, results = 0, [] ix, results = 0, []
for image_name, rec in gt_recs.items(): for image_name, rec in gt_recs.items():
dets = boxes[ix] detections = boxes[ix]
ix += 1 ix += 1
if isinstance(dets, list) and len(dets) == 0: if isinstance(detections, list) and len(detections) == 0:
continue continue
dets = dets.astype('float64') detections = detections.astype('float64')
scores = dets[:, -1] scores = detections[:, -1]
xs = dets[:, 0] xs = detections[:, 0]
ys = dets[:, 1] ys = detections[:, 1]
ws = dets[:, 2] - xs + 1 ws = detections[:, 2] - xs + 1
hs = dets[:, 3] - ys + 1 hs = detections[:, 3] - ys + 1
results.extend([{ results.extend([{
'image_id': self.get_image_id(image_name), 'image_id': self.get_image_id(image_name),
'category_id': cat_id, 'category_id': cat_id,
'bbox': [xs[k], ys[k], ws[k], hs[k]], 'bbox': [xs[k], ys[k], ws[k], hs[k]],
'score': scores[k], 'score': scores[k],
} for k in range(dets.shape[0])]) } for k in range(detections.shape[0])])
return results return results
def do_bbox_eval(self, res_file): def do_bbox_eval(self, res_file):
...@@ -109,7 +109,7 @@ class COCOEvaluator(object): ...@@ -109,7 +109,7 @@ class COCOEvaluator(object):
return image_name return image_name
def get_results_file(self, results_folder, type='bbox'): def get_results_file(self, results_folder, type='bbox'):
# experiments/model_id/results/detections_taas_<comp_id>.json # experiments/model_id/results/detections.json
filename = self.get_prefix(type) + self.imdb.comp_id + '.json' filename = self.get_prefix(type) + self.imdb.comp_id + '.json'
if not os.path.exists(results_folder): if not os.path.exists(results_folder):
os.makedirs(results_folder) os.makedirs(results_folder)
......
...@@ -46,7 +46,7 @@ class Example(object): ...@@ -46,7 +46,7 @@ class Example(object):
@property @property
def image(self): def image(self):
"""Return the image data. """Return the image array.
Returns Returns
------- -------
...@@ -77,7 +77,7 @@ class Example(object): ...@@ -77,7 +77,7 @@ class Example(object):
Returns Returns
------- -------
Sequence[Dict] Sequence[dict]
The objects. The objects.
""" """
......
...@@ -21,6 +21,8 @@ from seetadet.utils.env import pickle ...@@ -21,6 +21,8 @@ from seetadet.utils.env import pickle
class VOCEvaluator(object): class VOCEvaluator(object):
"""Evaluator for PASCAL VOC dataset."""
def __init__(self, imdb): def __init__(self, imdb):
self.imdb = imdb self.imdb = imdb
......
...@@ -8,20 +8,24 @@ ...@@ -8,20 +8,24 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Modeling utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Backbones # Backbone
import seetadet.modeling.airnet import seetadet.modeling.airnet
import seetadet.modeling.mobilenet import seetadet.modeling.mobilenet_v2
import seetadet.modeling.mobilenet_v3
import seetadet.modeling.resnet import seetadet.modeling.resnet
import seetadet.modeling.vgg import seetadet.modeling.vgg
# Custom modules # FeatureEnhancer
from seetadet.modeling.fast_rcnn import FastRCNN
from seetadet.modeling.fpn import FPN from seetadet.modeling.fpn import FPN
# RoIHead
from seetadet.modeling.fast_rcnn import FastRCNN
from seetadet.modeling.mask_rcnn import MaskRCNN from seetadet.modeling.mask_rcnn import MaskRCNN
from seetadet.modeling.retinanet import RetinaNet from seetadet.modeling.retinanet import RetinaNet
from seetadet.modeling.rpn import RPN from seetadet.modeling.rpn import RPN
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""AirNet backbone."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Generic detector."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""EfficientNet backbone."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from seetadet.core import registry
from seetadet.modeling.mobilenet_v3 import conv_triplet
from seetadet.modeling.mobilenet_v3 import conv_quintet
from seetadet.modules import init
from seetadet.modules import nn
class SqueezeExcite(nn.Module):
"""Squeeze-excite attention module."""
def __init__(self, dim_in, dim_squeeze, squeeze_ratio=0.25):
super(SqueezeExcite, self).__init__()
dim = int(dim_squeeze * squeeze_ratio)
self.layers = nn.Sequential(nn.AvgPool2d(-1, global_pooling=True),
nn.Conv2d(dim_in, dim, kernel_size=1),
nn.Swish(),
nn.Conv2d(dim, dim_in, kernel_size=1),
nn.Sigmoid(True))
def forward(self, x):
return x * self.layers(x)
class InvertedResidual(nn.Module):
"""Invert residual block."""
def __init__(
self,
dim_in,
dim_out,
kernel_size=3,
expand_ratio=3,
stride=1,
activation=None,
squeeze_excite=0,
):
super(InvertedResidual, self).__init__()
self.stride = stride
self.apply_residual = stride == 1 and dim_in == dim_out
self.dim = dim = int(round(dim_in * expand_ratio))
self.endpoint = None # Expansion feature
layers = []
if expand_ratio != 1:
layers.append(nn.Sequential(*conv_triplet(
dim_in, dim, activation=activation)))
expansion_transform = None
if squeeze_excite > 0:
expansion_transform = SqueezeExcite(dim, dim_in)
quintet = conv_quintet(dim, dim_out,
kernel_size=kernel_size,
stride=stride,
activation=activation,
expansion_transform=expansion_transform)
layers.append(nn.Sequential(*quintet[:3]))
layers.extend(quintet[3:])
self.conv = nn.Sequential(*layers)
def forward(self, x):
out = self.conv[0](x)
self.endpoint = out if self.stride == 2 else None
for layer in self.conv[1:]:
out = layer(out)
if self.apply_residual:
out += x
return out
class NASMobileNet(nn.Module):
"""NAS variant of mobilenet class."""
def __init__(self, arch, preset):
super(NASMobileNet, self).__init__()
# Hand-craft configurations.
repeats, strides, out_channels, def_blocks = preset
assert sum(repeats) == len(arch), 'Bad architecture.'
self.feature_dims = collections.OrderedDict()
# Stem.
features = [nn.Sequential(
*conv_triplet(
dim_in=3,
dim_out=out_channels[0],
kernel_size=3,
stride=2,
activation=nn.Swish(),
))]
# Blocks.
dim_in, stride_out = out_channels[0], 2
for repeat, dim_out, stride in \
zip(repeats, out_channels[1:], strides):
stride_out *= stride
for i in range(repeat):
stride = stride if i == 0 else 1
idx = arch[len(features) - 1]
if def_blocks is None:
block = functools.partial(
InvertedResidual,
kernel_size=(idx // 100) % 10,
expand_ratio=int(idx / 1000.) / 10,
squeeze_excite=idx % 10)
else:
block = def_blocks[idx]
features.append(block(
dim_in, dim_out,
stride=stride,
activation=nn.Swish()))
dim_in = dim_out
if stride == 2:
self.feature_dims[id(features[-1])] = features[-1].dim
features.append(nn.Sequential(
*conv_triplet(
dim_in=dim_in,
dim_out=out_channels[-1],
kernel_size=1,
stride=1,
activation=nn.Swish())))
self.feature_dims[id(features[-1])] = out_channels[-1]
self.features = nn.Sequential(*features)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
def forward(self, x):
outputs = []
for i, layer in enumerate(self.features):
x = layer(x)
if self.feature_dims.get(id(layer)):
outputs.append(getattr(layer, 'endpoint', x))
return outputs
class ModelSetting(object):
"""Hand-craft model setting."""
# Default NASBlocks definition.
# We use the following hash method:
# ef * 10000 + kernel_size * 100 + se * 1
# e.g., ef=4.0, ks=3, se=True, with index 40301
DEFAULT_NAS_BLOCKS_DEF = None
EFFICIENT = (
[1, 2, 2, 3, 3, 4, 1],
[1, 2, 2, 2, 1, 2, 1],
[32, 16, 24, 40, 80, 112, 192, 320, 1280],
DEFAULT_NAS_BLOCKS_DEF,
)
@registry.backbone.register('efficient_b0')
def efficient_b0():
return NASMobileNet([10301,
60301, 60301,
60501, 60501,
60301, 60301, 60301,
60501, 60501, 60501,
60501, 60501, 60501, 60501,
60301], ModelSetting.EFFICIENT)
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""FastRCNN head."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""FPN feature enhancer."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -21,7 +22,7 @@ HIGHEST_BACKBONE_LVL = 5 # E.g., "conv5"-like level ...@@ -21,7 +22,7 @@ HIGHEST_BACKBONE_LVL = 5 # E.g., "conv5"-like level
class FPN(nn.Module): class FPN(nn.Module):
"""Feature Pyramid Networks for R-CNN and RetinaNet.""" """Feature Pyramid Networks to enhance input features."""
def __init__(self, feature_dims): def __init__(self, feature_dims):
super(FPN, self).__init__() super(FPN, self).__init__()
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""MaskRCNN head."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""MobileNetV2 backbone."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -48,11 +49,14 @@ def conv_quintet(dim_in, dim_out, kernel_size, stride): ...@@ -48,11 +49,14 @@ def conv_quintet(dim_in, dim_out, kernel_size, stride):
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
"""The invert residual block.""" """Invert residual block."""
def __init__(self, dim_in, dim_out, kernel_size=3, expand_ratio=3, stride=1): def __init__(self, dim_in, dim_out, kernel_size=3, expand_ratio=3, stride=1):
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
self.stride = stride
self.apply_residual = stride == 1 and dim_in == dim_out
self.dim = dim = int(round(dim_in * expand_ratio)) self.dim = dim = int(round(dim_in * expand_ratio))
self.endpoint = None # Expansion feature
layers = [] layers = []
if expand_ratio != 1: if expand_ratio != 1:
layers.append(nn.Sequential(*conv_triplet(dim_in, dim))) layers.append(nn.Sequential(*conv_triplet(dim_in, dim)))
...@@ -60,13 +64,10 @@ class InvertedResidual(nn.Module): ...@@ -60,13 +64,10 @@ class InvertedResidual(nn.Module):
layers.append(nn.Sequential(*quintet[:3])) layers.append(nn.Sequential(*quintet[:3]))
layers.extend(quintet[3:]) layers.extend(quintet[3:])
self.conv = nn.Sequential(*layers) self.conv = nn.Sequential(*layers)
self.stride = stride
self.apply_residual = stride == 1 and dim_in == dim_out
self.feature = None
def forward(self, x): def forward(self, x):
out = self.conv[0](x) out = self.conv[0](x)
self.feature = out if self.stride == 2 else None self.endpoint = out if self.stride == 2 else None
for layer in self.conv[1:]: for layer in self.conv[1:]:
out = layer(out) out = layer(out)
if self.apply_residual: if self.apply_residual:
...@@ -75,37 +76,32 @@ class InvertedResidual(nn.Module): ...@@ -75,37 +76,32 @@ class InvertedResidual(nn.Module):
class NASMobileNet(nn.Module): class NASMobileNet(nn.Module):
"""The NAS variant of mobilenet series.""" """NAS variant of mobilenet class."""
# Pre-defined conv blocks
blocks = {
0: functools.partial(InvertedResidual, kernel_size=3, expand_ratio=3),
1: functools.partial(InvertedResidual, kernel_size=3, expand_ratio=6),
2: functools.partial(InvertedResidual, kernel_size=5, expand_ratio=3),
3: functools.partial(InvertedResidual, kernel_size=5, expand_ratio=6),
4: functools.partial(InvertedResidual, kernel_size=7, expand_ratio=3),
5: functools.partial(InvertedResidual, kernel_size=7, expand_ratio=6),
6: nn.Identity,
}
def __init__(self, arch, preset): def __init__(self, arch, preset):
super(NASMobileNet, self).__init__() super(NASMobileNet, self).__init__()
# Hand-craft configurations # Hand-craft configurations
repeats, strides, out_channels = preset repeats, strides, out_channels, def_blocks = preset
assert sum(repeats) == len(arch) assert sum(repeats) == len(arch), 'Bad architecture.'
self.feature_dims = collections.OrderedDict() self.feature_dims = collections.OrderedDict()
# Stem # Stem.
features = [nn.Sequential(*conv_triplet(3, out_channels[0], 3, stride=2)), features = [nn.Sequential(
InvertedResidual(*out_channels[:2], 3, 1)] *conv_triplet(
dim_in=3,
# Body dim_out=out_channels[0],
dim_in = out_channels[1] kernel_size=3,
stride=2,
))]
# Blocks.
dim_in, dim_out = out_channels[:2]
features.append(InvertedResidual(dim_in, dim_out, 3, 1))
for repeat, dim_out, stride in \ for repeat, dim_out, stride in \
zip(repeats, out_channels[2:], strides): zip(repeats, out_channels[2:], strides):
for i in range(repeat): for i in range(repeat):
stride = stride if i == 0 else 1 stride = stride if i == 0 else 1
block = self.blocks[arch[len(features) - 2]] block = def_blocks[arch[len(features) - 2]]
features.append(block(dim_in, dim_out, stride=stride)) features.append(block(dim_in, dim_out, stride=stride))
dim_in = dim_out dim_in = dim_out
if stride == 2: if stride == 2:
...@@ -125,32 +121,44 @@ class NASMobileNet(nn.Module): ...@@ -125,32 +121,44 @@ class NASMobileNet(nn.Module):
for layer in self.features: for layer in self.features:
x = layer(x) x = layer(x)
if self.feature_dims.get(id(layer)): if self.feature_dims.get(id(layer)):
if hasattr(layer, 'feature'): outputs.append(getattr(layer, 'endpoint', x))
outputs.append(layer.feature)
else:
outputs.append(x)
return outputs return outputs
class ModelSetting(object): class ModelSetting(object):
"""Hand-craft model setting.""" """Hand-craft model setting."""
# Default NASBlocks definition.
# See ProxyLessNAS (arxiv.1812.00332) for details.
DEFAULT_NAS_BLOCKS_DEF = {
0: functools.partial(InvertedResidual, kernel_size=3, expand_ratio=3),
1: functools.partial(InvertedResidual, kernel_size=3, expand_ratio=6),
2: functools.partial(InvertedResidual, kernel_size=5, expand_ratio=3),
3: functools.partial(InvertedResidual, kernel_size=5, expand_ratio=6),
4: functools.partial(InvertedResidual, kernel_size=7, expand_ratio=3),
5: functools.partial(InvertedResidual, kernel_size=7, expand_ratio=6),
6: nn.Identity,
}
V2 = ( V2 = (
[2, 3, 4, 3, 3, 1], [2, 3, 4, 3, 3, 1],
[2, 2, 2, 1, 2, 1], [2, 2, 2, 1, 2, 1],
[32, 16, 24, 32, 64, 96, 160, 320, 1280], [32, 16, 24, 32, 64, 96, 160, 320, 1280],
DEFAULT_NAS_BLOCKS_DEF,
) )
PROXYLESS_MOBILE = ( PROXYLESS_MOBILE = (
[4, 4, 4, 4, 4, 1], [4, 4, 4, 4, 4, 1],
[2, 2, 2, 1, 2, 1], [2, 2, 2, 1, 2, 1],
[32, 16, 32, 40, 80, 96, 192, 320, 1280], [32, 16, 32, 40, 80, 96, 192, 320, 1280],
DEFAULT_NAS_BLOCKS_DEF,
) )
PROXYLESS_GPU = ( PROXYLESS_GPU = (
[4, 4, 4, 4, 4, 1], [4, 4, 4, 4, 4, 1],
[2, 2, 2, 1, 2, 1], [2, 2, 2, 1, 2, 1],
[40, 24, 32, 56, 112, 128, 256, 432, 1280], [40, 24, 32, 56, 112, 128, 256, 432, 1280],
DEFAULT_NAS_BLOCKS_DEF,
) )
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""MobileNetV3 backbone."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from seetadet.core import registry
from seetadet.core.config import cfg
from seetadet.modules import init
from seetadet.modules import nn
def make_divisible(v, divisor=8):
"""Return the divisible value."""
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def conv_triplet(dim_in, dim_out, kernel_size=1, stride=1, activation=None):
"""Return a convolution triplet."""
return [nn.Conv2d(dim_in, dim_out,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
bias=False),
nn.get_norm(cfg.MODEL.BACKBONE_NORM, dim_out),
nn.ReLU(True) if activation is None else activation]
def conv_quintet(
dim_in,
dim_out,
kernel_size,
stride,
activation=None,
expansion_transform=None,
):
"""Return a convolution quintet."""
layers = [nn.Conv2d(dim_in, dim_in,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=dim_in,
bias=False),
nn.get_norm(cfg.MODEL.BACKBONE_NORM, dim_in),
nn.ReLU(True) if activation is None else activation]
if expansion_transform is not None:
layers += [expansion_transform]
layers += [nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False),
nn.get_norm(cfg.MODEL.BACKBONE_NORM, dim_out)]
return layers
class SqueezeExcite(nn.Module):
"""Squeeze-excite attention module."""
def __init__(self, dim_in, squeeze_ratio=0.25):
super(SqueezeExcite, self).__init__()
dim = make_divisible(dim_in * squeeze_ratio)
print(dim)
self.layers = nn.Sequential(nn.AvgPool2d(-1, global_pooling=True),
nn.Conv2d(dim_in, dim, kernel_size=1),
nn.ReLU(True),
nn.Conv2d(dim, dim_in, kernel_size=1),
nn.Hardsigmoid(True))
def forward(self, x):
return x * self.layers(x)
class InvertedResidual(nn.Module):
"""Invert residual block."""
def __init__(
self,
dim_in,
dim_out,
kernel_size=3,
expand_ratio=3,
stride=1,
activation=None,
squeeze_excite=0,
):
super(InvertedResidual, self).__init__()
self.stride = stride
self.apply_residual = stride == 1 and dim_in == dim_out
self.dim = dim = int(round(dim_in * expand_ratio))
self.endpoint = None # Expansion feature
layers = []
if expand_ratio != 1:
layers.append(nn.Sequential(*conv_triplet(
dim_in, dim, activation=activation)))
expansion_transform = None
if squeeze_excite > 0:
expansion_transform = SqueezeExcite(dim)
quintet = conv_quintet(dim, dim_out,
kernel_size=kernel_size,
stride=stride,
activation=activation,
expansion_transform=expansion_transform)
layers.append(nn.Sequential(*quintet[:3]))
layers.extend(quintet[3:])
self.conv = nn.Sequential(*layers)
def forward(self, x):
out = self.conv[0](x)
self.endpoint = out if self.stride == 2 else None
for layer in self.conv[1:]:
out = layer(out)
if self.apply_residual:
out += x
return out
class NASMobileNet(nn.Module):
"""The NAS variant of mobilenet class."""
def __init__(self, arch, preset):
super(NASMobileNet, self).__init__()
# Hand-craft configurations.
repeats, strides, out_channels, def_blocks = preset
assert sum(repeats) == len(arch), 'Bad architecture.'
self.feature_dims = collections.OrderedDict()
# Stem.
features = [nn.Sequential(
*conv_triplet(
dim_in=3,
dim_out=out_channels[0],
kernel_size=3,
stride=2,
activation=nn.Hardswish(),
))]
# Blocks.
dim_in, stride_out = out_channels[0], 2
for repeat, dim_out, stride in \
zip(repeats, out_channels[1:], strides):
stride_out *= stride
for i in range(repeat):
stride = stride if i == 0 else 1
idx = arch[len(features) - 1]
if def_blocks is None:
block = functools.partial(
InvertedResidual,
kernel_size=(idx // 100) % 10,
expand_ratio=int(idx / 1000.) / 10,
squeeze_excite=idx % 10)
else:
block = def_blocks[idx]
features.append(block(
dim_in, dim_out,
stride=stride,
activation=nn.Hardswish()
if stride_out > 8 else nn.ReLU(True)))
dim_in = dim_out
if stride == 2:
self.feature_dims[id(features[-1])] = features[-1].dim
features.append(nn.Sequential(
*conv_triplet(
dim_in=dim_in,
dim_out=out_channels[-1],
kernel_size=1,
stride=1,
activation=nn.Hardswish())))
self.feature_dims[id(features[-1])] = out_channels[-1]
self.features = nn.Sequential(*features)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
def forward(self, x):
outputs = []
for i, layer in enumerate(self.features):
x = layer(x)
if self.feature_dims.get(id(layer)):
outputs.append(getattr(layer, 'endpoint', x))
return outputs
class ModelSetting(object):
"""Hand-craft model setting."""
# Default NASBlocks definition.
# We use the following hash method:
# ef * 10000 + kernel_size * 100 + se * 1
# e.g., ef=4.0, ks=3, se=True, with index 40301
DEFAULT_NAS_BLOCKS_DEF = None
V3 = (
[1, 2, 3, 4, 2, 3],
[1, 2, 2, 2, 1, 2],
[16, 16, 24, 40, 80, 112, 160, 960],
DEFAULT_NAS_BLOCKS_DEF,
)
@registry.backbone.register('mobilenet_v3')
def mobilenet_v3():
return NASMobileNet([10300,
40300, 30300,
30501, 30501, 30501,
60300, 25300, 23300, 23300,
60301, 60301,
60501, 60501, 60501], ModelSetting.V3)
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""ResNet backbone."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""RetinaNet head."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""RPN head."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""SSD head."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""VGGNet backbone."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -295,6 +295,8 @@ Conv2d = nn.Conv2d ...@@ -295,6 +295,8 @@ Conv2d = nn.Conv2d
ConvTranspose2d = nn.ConvTranspose2d ConvTranspose2d = nn.ConvTranspose2d
DepthwiseConv2d = nn.DepthwiseConv2d DepthwiseConv2d = nn.DepthwiseConv2d
DropBlock2d = nn.DropBlock2d DropBlock2d = nn.DropBlock2d
Hardsigmoid = nn.Hardsigmoid
Hardswish = nn.Hardswish
Linear = nn.Linear Linear = nn.Linear
MaxPool2d = nn.MaxPool2d MaxPool2d = nn.MaxPool2d
Module = nn.Module Module = nn.Module
...@@ -302,4 +304,5 @@ ModuleList = nn.ModuleList ...@@ -302,4 +304,5 @@ ModuleList = nn.ModuleList
Sequential = nn.Sequential Sequential = nn.Sequential
Sigmoid = nn.Sigmoid Sigmoid = nn.Sigmoid
Softmax = nn.Softmax Softmax = nn.Softmax
Swish = nn.Swish
upsample = nn.functional.upsample upsample = nn.functional.upsample
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""ONNX utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -82,11 +82,11 @@ def resize_image_with_target_size( ...@@ -82,11 +82,11 @@ def resize_image_with_target_size(
): ):
"""Resize an image with the target size.""" """Resize an image with the target size."""
im_shape = img.shape im_shape = img.shape
max_size = max_size if max_size > 0 else target_size
# Scale along the shortest side # Scale along the shortest side
im_size_min = np.min(im_shape[:2]) im_size_min = np.min(im_shape[:2])
im_size_max = np.max(im_shape[:2]) im_size_max = np.max(im_shape[:2])
im_scale = float(target_size) / float(im_size_min) im_scale = float(target_size) / float(im_size_min)
if max_size > 0:
# Prevent the biggest axis from being more than MAX_SIZE # Prevent the biggest axis from being more than MAX_SIZE
if np.round(im_scale * im_size_max) > max_size: if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max) im_scale = float(max_size) / float(im_size_max)
......
...@@ -7,10 +7,6 @@ ...@@ -7,10 +7,6 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/fast_rcnn/nms_wrapper.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
...@@ -27,35 +23,35 @@ except ImportError: ...@@ -27,35 +23,35 @@ except ImportError:
cpu_nms = cpu_soft_nms = print cpu_nms = cpu_soft_nms = print
def gpu_nms(dets, thresh): def gpu_nms(detections, thresh):
"""Filter out the detections using GPU-NMS.""" """Filter out the detections using GPU-NMS."""
if dets.shape[0] == 0: if detections.shape[0] == 0:
return [] return []
scores = dets[:, 4] scores = detections[:, 4]
order = scores.argsort()[::-1] order = scores.argsort()[::-1]
sorted_dets = env.new_tensor(dets[order, :]) sorted_detections = env.new_tensor(detections[order, :])
keep = det.nms(sorted_dets, iou_threshold=thresh).numpy() keep = det.nms(sorted_detections, iou_threshold=thresh).numpy()
return order[keep] return order[keep]
def nms(dets, thresh): def nms(detections, thresh):
"""Filter out the detections using NMS.""" """Filter out the detections using NMS."""
if dets.shape[0] == 0: if detections.shape[0] == 0:
return [] return []
if cpu_nms is print: if cpu_nms is print:
raise ImportError('Failed to load <cython_nms> library.') raise ImportError('Failed to load <cython_nms> library.')
return cpu_nms(dets, thresh) return cpu_nms(detections, thresh)
def soft_nms( def soft_nms(
dets, detections,
thresh, thresh,
method='linear', method='linear',
sigma=0.5, sigma=0.5,
score_thresh=0.001, score_thresh=0.001,
): ):
"""Filter out the detections using Soft-NMS.""" """Filter out the detections using Soft-NMS."""
if dets.shape[0] == 0: if detections.shape[0] == 0:
return [] return []
if cpu_soft_nms is print: if cpu_soft_nms is print:
raise ImportError('Failed to load <cython_nms> library.') raise ImportError('Failed to load <cython_nms> library.')
...@@ -63,7 +59,7 @@ def soft_nms( ...@@ -63,7 +59,7 @@ def soft_nms(
if method not in methods: if method not in methods:
raise ValueError('Unknown soft nms method:', method) raise ValueError('Unknown soft nms method:', method)
return cpu_soft_nms( return cpu_soft_nms(
dets, detections,
thresh, thresh,
methods[method], methods[method],
sigma, sigma,
......
...@@ -57,7 +57,7 @@ def benchmark_flops(module, normalizer=1e6): ...@@ -57,7 +57,7 @@ def benchmark_flops(module, normalizer=1e6):
if original_training: if original_training:
module.eval() module.eval()
with torch.no_grad(): with torch.no_grad():
module() module.benchmark()
if original_training: if original_training:
module.train() module.train()
return collect_flops(module, normalizer) return collect_flops(module, normalizer)
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Export a detection network into the onnx model."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -27,7 +28,7 @@ from seetadet.utils import logger ...@@ -27,7 +28,7 @@ from seetadet.utils import logger
def parse_args(): def parse_args():
"""Parse arguments""" """Parse arguments."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Export a detection network into the onnx model') description='Export a detection network into the onnx model')
parser.add_argument( parser.add_argument(
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Train a detection network with mpi utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Deploy a detection network for serving."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import importlib
import os
import threading
import argparse
import cv2
import dragon
import flask
import kpl_helper
import numpy as np
import pprint
from seetadet.core.config import cfg
from seetadet.core.coordinator import Coordinator
from seetadet.modeling.detector import new_detector
from seetadet.utils import logger
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Deploy a detection network for serving')
parser.add_argument(
'--cfg',
dest='cfg_file',
default=None,
help='config file')
parser.add_argument(
'--exp_dir',
default='',
help='experiment dir')
parser.add_argument(
'--model_dir',
default='',
help='final model dir')
parser.add_argument(
'--iter',
type=int,
default=None,
help='test checkpoint of given step')
parser.add_argument(
'--port',
type=int,
default=5050,
help='listening port')
return parser.parse_args()
def get_image(base64_str):
try:
image_bytes = base64.b64decode(base64_str)
img = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(img, cv2.IMREAD_COLOR)
return img
except Exception as e:
logger.info('Decode base64 image failed. detail: ' + str(e))
return None
def get_objects(boxes_this_image):
objects = []
for j, name in enumerate(cfg.MODEL.CLASSES):
if name == '__background__':
continue
detections = boxes_this_image[j]
return_inds = np.where(detections[:, 4] > cfg.VIS_TH)[0]
for det in detections[return_inds]:
objects.append({
'score': float(det[4]),
'name': name,
'xmin': int(det[0]),
'ymin': int(det[1]),
'xmax': int(det[2]),
'ymax': int(det[3])
})
logger.info('Detect objects: ' + str(objects))
return objects
class Wrapper(object):
"""Inference wrapper."""
def __init__(self, args):
if args.model_dir:
Coordinator(args.cfg_file, exp_dir=args.model_dir)
checkpoint = os.path.join(args.model_dir, 'model_final.pkl')
else:
coordinator = Coordinator(args.cfg_file, exp_dir=args.exp_dir)
checkpoint, _ = coordinator.checkpoint(args.iter, wait=False)
logger.info('Load model from: ' + checkpoint)
self.test_module = 'seetadet.algo.%s.test' % cfg.MODEL.TYPE
self.test_module = importlib.import_module(self.test_module)
self.detector = new_detector(cfg.GPU_ID, checkpoint)
self.lock = threading.RLock()
def do_inference(self, img):
compute_fn = getattr(self.test_module, 'ims_detect')
process_fn = getattr(self.test_module, 'get_detections')
try:
self.lock.acquire()
outputs = compute_fn(self.detector, [img])[0]
finally:
self.lock.release()
outputs = process_fn(outputs)
return outputs[0]
if __name__ == '__main__':
os.environ['FLASK_ENV'] = 'production'
args = parse_args()
logger.info('Called with args:\n' + str(args))
logger.info('Using config:\n' + pprint.pformat(cfg))
app = flask.Flask('SeetaDet')
workspace = dragon.Workspace()
with workspace.as_default():
wrapper = Wrapper(args)
@app.route("/", methods=['POST'])
def infer():
try:
req = flask.request.get_json(force=True)
base64_str = req['base64_image']
except KeyError:
print('Not found base64 image.')
return flask.abort(400)
response = kpl_helper.deploy.RectangleBoxObjectDetectionResponse(0, 0, 0)
base64_str = base64_str.split(",")[-1]
img = get_image(base64_str)
if not isinstance(img, np.ndarray):
return flask.jsonify(response.dumps())
response.height, response.width, response.depth = img.shape
with workspace.as_default():
detections = wrapper.do_inference(img)
objects = get_objects(detections)
for obj in objects:
response.add_object(obj['name'],
obj['xmin'],
obj['ymin'],
obj['xmax'],
obj['ymax'],
obj['score'])
return flask.jsonify(response.dumps())
app.run(host="0.0.0.0", port=args.port)
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Test a detection network."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -28,9 +29,9 @@ from seetadet.utils import logger ...@@ -28,9 +29,9 @@ from seetadet.utils import logger
def parse_args(): def parse_args():
"""Parse arguments""" """Parse arguments."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Test a detection network with a specified checkpoint') description='Test a detection network')
parser.add_argument( parser.add_argument(
'--cfg', '--cfg',
dest='cfg_file', dest='cfg_file',
...@@ -81,8 +82,7 @@ def parse_args(): ...@@ -81,8 +82,7 @@ def parse_args():
if len(sys.argv) == 1: if len(sys.argv) == 1:
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
args = parser.parse_args() return parser.parse_args()
return args
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Train a detection network."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!