Commit c020594c by Ting PAN

Support rotated boxes for SSD

1 parent 406662ad
------------------------------------------------------------------------
The list of most significant changes made over time in SeetaDet.
SeetaDet 0.2.1 (20191017)
Dragon Minimum Required (Version 0.3.0.dev20191017)
Changes:
Preview Features:
- Rotated boxes and FPN support for SSD.
- Frozen the graph to speed up inference.
Bugs fixed:
- None
------------------------------------------------------------------------
SeetaDet 0.2.0 (20190929)
Dragon Minimum Required (Version 0.3.0.dev20190929)
......
......@@ -67,10 +67,10 @@ python export.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR> --iter <ITERATION>
| :------: | :------: |
| [VGG16.SSD](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/VGG16.SSD.pth)| SSD |
| [VGG16.RCNN](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/VGG16.RCNN.pth)| R-CNN |
| [R-18.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-18.Affine.pth)| R-CNN, RetinaNet |
| [R-34.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-34.Affine.pth)| R-CNN, RetinaNet |
| [R-50.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-50.Affine.pth)| R-CNN, RetinaNet |
| [R-101.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-101.Affine.pth)| R-CNN, RetinaNet |
| [R-18.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-18.Affine.pth)| R-CNN, RetinaNet, SSD |
| [R-34.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-34.Affine.pth)| R-CNN, RetinaNet, SSD |
| [R-50.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-50.Affine.pth)| R-CNN, RetinaNet, SSD |
| [R-101.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-101.Affine.pth)| R-CNN, RetinaNet, SSD |
| [AirNet.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/AirNet.Affine.pth)| R-CNN, RetinaNet, SSD |
## References
......
......@@ -5,6 +5,7 @@ rm -r build install *.c *.cpp
# compile cython modules
python setup.py build_ext --inplace
g++ -o ../lib/utils/ctypes_rbox.so -shared -fPIC -O2 rbox.cc -fopenmp
# compile cuda modules
cd build && cmake .. && make install && cd ..
......
......@@ -224,7 +224,7 @@ __C.MODEL.FOCAL_LOSS_GAMMA = 2.0
# Stride of the coarsest Feature level
# This is needed so the input can be padded properly
__C.MODEL.COARSEST_STRIDE = -1
__C.MODEL.COARSEST_STRIDE = 32
###########################################
......@@ -269,6 +269,9 @@ __C.RETINANET.ANCHOR_SCALE = 4
# NOTE: this doesn't include the last conv for logits
__C.RETINANET.NUM_CONVS = 4
# Weight for bbox regression loss
__C.RETINANET.BBOX_REG_WEIGHT = 1.
# During inference, #locs to select based on cls score before NMS is performed
__C.RETINANET.PRE_NMS_TOP_N = 5000
......@@ -359,6 +362,13 @@ __C.SSD = edict()
# Whether to enable FPN enhancement?
__C.SSD.FPN_ON = False
# Convolutions to use in the cls and bbox tower
# NOTE: this doesn't include the last conv for logits
__C.SSD.NUM_CONVS = 0
# Weight for bbox regression loss
__C.SSD.BBOX_REG_WEIGHT = 1.
__C.SSD.MULTIBOX = edict()
# MultiBox configs
__C.SSD.MULTIBOX.STRIDES = []
......@@ -523,7 +533,7 @@ __C.PIXEL_MEANS = [102., 115., 122.]
__C.BBOX_REG_WEIGHTS = (10., 10., 5., 5.)
# Default weights on (dx, dy, dw, dh, da) for normalizing rbox regression targets
__C.RBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0, 10.0)
__C.RBOX_REG_WEIGHTS = (10.0, 10.0, 5., 5., 10.)
# Prior prob for the positives at the beginning of training.
# This is used to set the bias init for the logits layer
......
......@@ -153,11 +153,13 @@ class TaaS(imdb):
if len(dets) == 0:
continue
for k in range(dets.shape[0]):
f.write(
'{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'
content = '{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}' \
.format(image_id, dets[k, -1],
dets[k, 0] + 1, dets[k, 1] + 1,
dets[k, 2] + 1, dets[k, 3] + 1))
dets[k, 2] + 1, dets[k, 3] + 1)
if dets.shape[1] == 6:
content += ' {:.2f}'.format(dets[k, 4])
f.write(content + '\n')
def _write_voc_segm_results(self, all_boxes, all_masks, output_dir):
for cls_inds, cls in enumerate(self.classes):
......
......@@ -27,16 +27,13 @@ except:
from lib.core.config import cfg
from lib.pycocotools.mask_utils import mask_rle2im
from lib.utils import rotated_boxes
from lib.utils.boxes import expand_boxes
from lib.utils.mask import mask_overlap
def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
"""Compute VOC AP given precision and recall."""
if use_07_metric:
# 11 point metric
ap = 0.
......@@ -47,20 +44,20 @@ def voc_ap(rec, prec, use_07_metric=False):
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
# Correct AP calculation
# First append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
# Compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# To calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
# And sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
......@@ -72,19 +69,14 @@ def voc_bbox_eval(
IoU=0.5,
use_07_metric=False,
):
class_recs = {}
n_pos = 0
class_recs, n_pos = {}, 0
for image_name, rec in gt_recs.items():
R = [obj for obj in rec['objects'] if obj['name'] == cls_name]
bbox = np.array([x['bbox'] for x in R])
difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
diff = np.array([x['difficult'] for x in R]).astype(np.bool)
det = [False] * len(R)
n_pos = n_pos + sum(~difficult)
class_recs[image_name] = {
'bbox': bbox,
'difficult': difficult,
'det': det
}
n_pos = n_pos + sum(~diff)
class_recs[image_name] = {'bbox': bbox, 'difficult': diff, 'det': det}
# Read detections
with open(det_file, 'r') as f:
......@@ -107,50 +99,53 @@ def voc_bbox_eval(
# Go down detections and mark TPs and FPs
nd = len(image_ids)
tp, fp = np.zeros(nd), np.zeros(nd)
def overlaps4(bb, BBGT):
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1., 0.)
ih = np.maximum(iymax - iymin + 1., 0.)
inters = iw * ih
uni = ((bb[2] - bb[0] + 1.) *
(bb[3] - bb[1] + 1.) +
(BBGT[:, 2] - BBGT[:, 0] + 1.) *
(BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
return inters / uni
def overlaps5(bb, BBGT):
return rotated_boxes.bbox_overlaps(bb.reshape((1, 5)), BBGT)[0]
for d in range(nd):
R = class_recs[image_ids[d]]
bb = BB[d, :].astype(float)
ovmax, jmax = -np.inf, 0
ov_max, j_max = -np.inf, 0
BBGT = R['bbox'].astype(float)
if BBGT.size > 0:
# Compute overlaps intersection
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1., 0.)
ih = np.maximum(iymax - iymin + 1., 0.)
inters = iw * ih
# Union
uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
(BBGT[:, 2] - BBGT[:, 0] + 1.) *
(BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
overlaps = inters / uni
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
if ovmax > IoU:
if not R['difficult'][jmax]:
if not R['det'][jmax]:
overlaps = overlaps4(bb, BBGT) \
if len(bb) == 4 else overlaps5(bb, BBGT)
ov_max = np.max(overlaps)
j_max = np.argmax(overlaps)
if ov_max > IoU:
if not R['difficult'][j_max]:
if not R['det'][j_max]:
tp[d] = 1.
R['det'][jmax] = 1
R['det'][j_max] = 1
else:
fp[d] = 1.
else:
fp[d] = 1.
# compute precision recall
# Compute precision recall
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / float(n_pos)
# avoid divide by zero in case the first detection matches a difficult
# ground truth
# Avoid divide by zero in case the first detection
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = voc_ap(rec, prec, use_07_metric)
return rec, prec, ap
......
......@@ -18,12 +18,12 @@ import numpy.random as npr
import dragon.vm.torch as torch
from lib.core.config import cfg
from lib.faster_rcnn.generate_anchors import generate_anchors
from lib.utils import logger
from lib.utils.blob import blob_to_tensor
from lib.utils.boxes import bbox_overlaps
from lib.utils.boxes import bbox_transform
from lib.utils.boxes import dismantle_gt_boxes
from lib.utils.cython_bbox import bbox_overlaps
from lib.faster_rcnn.generate_anchors import generate_anchors
class AnchorTargetLayer(torch.nn.Module):
......@@ -116,10 +116,7 @@ class AnchorTargetLayer(torch.nn.Module):
labels.fill(-1)
# Overlaps between the anchors and the gt boxes
overlaps = bbox_overlaps(
np.ascontiguousarray(anchors, dtype=np.float),
np.ascontiguousarray(gt_boxes, dtype=np.float),
)
overlaps = bbox_overlaps(anchors, gt_boxes)
argmax_overlaps = overlaps.argmax(axis=1)
max_overlaps = overlaps[np.arange(num_inside), argmax_overlaps]
gt_argmax_overlaps = overlaps.argmax(axis=0)
......
......@@ -72,7 +72,7 @@ class DataBatch(mp.Process):
super(DataBatch, self).__init__()
# Distributed settings
rank, group_size = 0, 1
process_group = dragon.distributed.get_default_process_group()
process_group = dragon.distributed.get_group()
if process_group is not None and kwargs.get(
'phase', 'TRAIN') == 'TRAIN':
group_size = process_group.size
......
......@@ -101,11 +101,18 @@ class DataTransformer(multiprocessing.Process):
def get_annotations(cls, example):
objects = []
for ix, obj in enumerate(example['object']):
objects.append({
'name': obj['name'],
'difficult': obj.get('difficult', 0),
'bbox': [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax']],
})
if 'xmin' in obj:
objects.append({
'name': obj['name'],
'difficult': obj.get('difficult', 0),
'bbox': [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax']],
})
else:
objects.append({
'name': obj['name'],
'difficult': obj.get('difficult', 0),
'bbox': obj['bbox'],
})
return example['id'], objects
def get(self, example):
......
......@@ -19,9 +19,9 @@ import numpy.random as npr
from lib.core.config import cfg
from lib.utils.blob import blob_to_tensor
from lib.utils.boxes import bbox_overlaps
from lib.utils.boxes import bbox_transform
from lib.utils.boxes import dismantle_gt_boxes
from lib.utils.cython_bbox import bbox_overlaps
class ProposalTargetLayer(torch.nn.Module):
......@@ -124,10 +124,7 @@ def _sample_rois(
num_classes,
):
"""Generate a random sample of RoIs."""
overlaps = bbox_overlaps(
np.ascontiguousarray(all_rois[:, 1:5], dtype=np.float),
np.ascontiguousarray(gt_boxes[:, :4], dtype=np.float),
)
overlaps = bbox_overlaps(all_rois[:, 1:5], gt_boxes[:, :4])
gt_assignment = overlaps.argmax(axis=1)
max_overlaps = overlaps.max(axis=1)
labels = gt_boxes[gt_assignment, 4]
......
......@@ -20,40 +20,51 @@ from lib.core.config import cfg
from lib.nms.nms_wrapper import nms
from lib.nms.nms_wrapper import soft_nms
from lib.utils.blob import im_list_to_blob
from lib.utils.blob import tensor_to_blob
from lib.utils.boxes import bbox_transform_inv
from lib.utils.boxes import clip_tiled_boxes
from lib.utils.image import scale_image
from lib.utils.timer import Timer
from lib.utils.graph import FrozenGraph
from lib.utils.vis import vis_one_image
def im_detect(detector, raw_image):
"""Detect a image, with single or multiple scales."""
# Prepare images
ims, ims_scale = scale_image(raw_image)
# Prepare blobs
blobs = {'data': im_list_to_blob(ims)}
blobs['ims_info'] = np.array([
list(blobs['data'].shape[1:3]) + [im_scale]
for im_scale in ims_scale], dtype=np.float32)
blobs['data'] = torch.from_numpy(blobs['data'])
for im_scale in ims_scale
], dtype=np.float32)
# Do Forward
with torch.no_grad():
outputs = detector.forward(inputs=blobs)
if not hasattr(detector, 'frozen_graph'):
inputs = {
'data': torch.from_numpy(blobs['data']),
'ims_info': torch.from_numpy(blobs['ims_info']),
}
with torch.no_grad():
with torch.jit.Recorder(retain_ops=True):
outputs = detector.forward(inputs)
detector.frozen_graph = FrozenGraph(
{'data': inputs['data'],
'ims_info': inputs['ims_info']},
{'rois': outputs['rois'],
'cls_prob': outputs['cls_prob'],
'bbox_pred': outputs['bbox_pred']},
)
outputs = detector.frozen_graph(**blobs)
# Decode results
batch_rois = tensor_to_blob(outputs['rois'])
batch_scores = tensor_to_blob(outputs['cls_prob'])
batch_deltas = tensor_to_blob(outputs['bbox_pred'])
batch_rois = outputs['rois']
batch_scores = outputs['cls_prob']
batch_deltas = outputs['bbox_pred']
batch_boxes = bbox_transform_inv(
boxes=batch_rois[:, 1:5],
deltas=batch_deltas,
weights=cfg.BBOX_REG_WEIGHTS,
batch_rois[:, 1:5],
batch_deltas,
cfg.BBOX_REG_WEIGHTS,
)
scores_wide, boxes_wide = [], []
......
......@@ -22,9 +22,9 @@ from lib.core.config import cfg
from lib.faster_rcnn.generate_anchors import generate_anchors
from lib.utils import logger
from lib.utils.blob import blob_to_tensor
from lib.utils.boxes import bbox_overlaps
from lib.utils.boxes import bbox_transform
from lib.utils.boxes import dismantle_gt_boxes
from lib.utils.cython_bbox import bbox_overlaps
class AnchorTargetLayer(torch.nn.Module):
......@@ -123,10 +123,7 @@ class AnchorTargetLayer(torch.nn.Module):
labels.fill(-1)
# Overlaps between the anchors and the gt boxes
overlaps = bbox_overlaps(
np.ascontiguousarray(anchors, dtype=np.float),
np.ascontiguousarray(gt_boxes, dtype=np.float),
)
overlaps = bbox_overlaps(anchors, gt_boxes)
argmax_overlaps = overlaps.argmax(axis=1)
max_overlaps = overlaps[np.arange(num_inside), argmax_overlaps]
......@@ -164,10 +161,10 @@ class AnchorTargetLayer(torch.nn.Module):
bbox_targets = np.zeros((num_inside, 4), dtype=np.float32)
bbox_targets[fg_inds, :] = bbox_transform(
anchors[fg_inds, :],
gt_boxes[argmax_overlaps[fg_inds], 0:4],
gt_boxes[argmax_overlaps[fg_inds], :4],
)
bbox_inside_weights = np.zeros((num_inside, 4), dtype=np.float32)
bbox_inside_weights[labels == 1, :] = np.array((1.0, 1.0, 1.0, 1.0))
bbox_inside_weights[labels == 1, :] = np.array((1., 1., 1., 1.))
bbox_outside_weights = np.zeros((num_inside, 4), dtype=np.float32)
bbox_outside_weights[labels == 1, :] = np.ones((1, 4)) / cfg.TRAIN.RPN_BATCHSIZE
bbox_outside_weights[labels == 0, :] = np.ones((1, 4)) / cfg.TRAIN.RPN_BATCHSIZE
......
......@@ -19,9 +19,9 @@ import dragon.vm.torch as torch
from lib.core.config import cfg
from lib.utils.blob import blob_to_tensor
from lib.utils.boxes import bbox_overlaps
from lib.utils.boxes import bbox_transform
from lib.utils.boxes import dismantle_gt_boxes
from lib.utils.cython_bbox import bbox_overlaps
class ProposalTargetLayer(torch.nn.Module):
......@@ -160,9 +160,7 @@ def _map_rois_to_fpn_levels(rois, k_min, k_max):
def _sample_rois(all_rois, gt_boxes, fg_rois_per_image, rois_per_image, num_classes):
"""Sample a batch of RoIs comprising foreground and background examples."""
# overlaps: (rois x gt_boxes)
overlaps = bbox_overlaps(
np.ascontiguousarray(all_rois[:, 1:5], dtype=np.float),
np.ascontiguousarray(gt_boxes[:, :4], dtype=np.float))
overlaps = bbox_overlaps(all_rois[:, 1:5], gt_boxes[:, :4])
gt_assignment = overlaps.argmax(axis=1)
max_overlaps = overlaps.max(axis=1)
labels = gt_boxes[gt_assignment, 4]
......
......@@ -41,7 +41,6 @@ class Detector(torch.nn.Module):
model = cfg.MODEL.TYPE
backbone = cfg.MODEL.BACKBONE.lower().split('.')
body, modules = backbone[0], backbone[1:]
self.recorder = None
# + Data Loader
self.data_layer = importlib.import_module(
......
......@@ -34,7 +34,7 @@ class FPN(torch.nn.Module):
for lvl in range(cfg.FPN.RPN_MIN_LEVEL, HIGHEST_BACKBONE_LVL + 1):
self.C.append(conv1x1(feature_dims[lvl - 1], cfg.FPN.DIM, bias=True))
self.P.append(conv3x3(cfg.FPN.DIM, cfg.FPN.DIM, bias=True))
if 'retinanet' in cfg.MODEL.TYPE:
if 'retinanet' in cfg.MODEL.TYPE or 'ssd' in cfg.MODEL.TYPE:
for lvl in range(HIGHEST_BACKBONE_LVL + 1, cfg.FPN.RPN_MAX_LEVEL + 1):
dim_in = feature_dims[-1] if lvl == HIGHEST_BACKBONE_LVL + 1 else cfg.FPN.DIM
self.P.append(conv3x3(dim_in, cfg.FPN.DIM, stride=2, bias=True))
......@@ -64,7 +64,7 @@ class FPN(torch.nn.Module):
for i in range(HIGHEST_BACKBONE_LVL - 1, min_lvl - 1, -1):
lateral_output = self.C[i - min_lvl](features[i - 1])
upscale_output = torch.vision.ops.nn_resize(
fpn_input, dsize=lateral_output.shape[-2:])
fpn_input, dsize=None, fx=2., fy=2.)
fpn_input = lateral_output.__iadd__(upscale_output)
outputs.insert(0, self.P[i - min_lvl](fpn_input))
return outputs
......@@ -83,7 +83,7 @@ class FPN(torch.nn.Module):
for i in range(HIGHEST_BACKBONE_LVL - 1, min_lvl - 1, -1):
lateral_output = self.C[i - min_lvl](features[i - 1])
upscale_output = torch.vision.ops.nn_resize(
fpn_input, dsize=lateral_output.shape[-2:])
fpn_input, dsize=None, fx=2., fy=2.)
fpn_input = lateral_output.__iadd__(upscale_output)
outputs.insert(0, self.P[i - min_lvl](fpn_input))
return outputs
......
......@@ -32,17 +32,34 @@ class SSD(torch.nn.Module):
# SSD outputs #
########################################
self.cls_conv = torch.nn.ModuleList(
conv3x3(feature_dims[0], feature_dims[0], bias=True)
for _ in range(cfg.SSD.NUM_CONVS)
)
self.bbox_conv = torch.nn.ModuleList(
conv3x3(feature_dims[0], feature_dims[0], bias=True)
for _ in range(cfg.SSD.NUM_CONVS)
)
self.cls_score = torch.nn.ModuleList()
self.bbox_pred = torch.nn.ModuleList()
self.softmax = torch.nn.Softmax(dim=2)
self.relu = torch.nn.ReLU(inplace=True)
C = cfg.MODEL.NUM_CLASSES
self.box_dim = len(cfg.BBOX_REG_WEIGHTS)
if len(feature_dims) == 1 and \
len(feature_dims) != len(cfg.SSD.MULTIBOX.STRIDES):
feature_dims = feature_dims * len(cfg.SSD.MULTIBOX.STRIDES)
feature_dims = list(filter(None, feature_dims))
for i, dim_in in enumerate(feature_dims):
A = len(cfg.SSD.MULTIBOX.ASPECT_RATIOS[i]) + 1
if self.box_dim == 5 and \
len(cfg.SSD.MULTIBOX.ASPECT_ANGLES) > 0:
A *= len(cfg.SSD.MULTIBOX.ASPECT_ANGLES)
self.cls_score.append(conv3x3(dim_in, A * C, bias=True))
self.bbox_pred.append(conv3x3(dim_in, A * 4, bias=True))
self.bbox_pred.append(conv3x3(dim_in, A * self.box_dim, bias=True))
self.prior_box_layer = PriorBoxLayer()
......@@ -58,12 +75,20 @@ class SSD(torch.nn.Module):
self.reset_parameters()
def reset_parameters(self):
# Careful Initialization
# Weight ~ Normal(0, 0.001)
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal_(m.weight, std=0.001)
torch.nn.init.constant_(m.bias, 0)
if cfg.SSD.NUM_CONVS > 0:
# Initialization following the RPN
# Weight ~ Normal(0, 0.01)
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal_(m.weight, std=0.01)
torch.nn.init.constant_(m.bias, 0)
else:
# Careful Initialization
# Weight ~ Normal(0, 0.001)
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal_(m.weight, std=0.001)
torch.nn.init.constant_(m.bias, 0)
def compute_outputs(self, features):
"""Compute the SSD logits.
......@@ -77,18 +102,22 @@ class SSD(torch.nn.Module):
# Compute logits
cls_score_wide, bbox_pred_wide = [], []
for i, feature in enumerate(features):
cls_x, bbox_x = feature, feature
for j in range(cfg.SSD.NUM_CONVS):
cls_x = self.relu(self.cls_conv[j](cls_x))
bbox_x = self.relu(self.bbox_conv[j](bbox_x))
cls_score_wide.append(
self.cls_score[i](feature)
self.cls_score[i](cls_x)
.permute(0, 2, 3, 1).view(0, -1))
bbox_pred_wide.append(
self.bbox_pred[i](feature)
self.bbox_pred[i](bbox_x)
.permute(0, 2, 3, 1).view(0, -1))
# Concat them if necessary
return \
torch.cat(cls_score_wide, dim=1) \
.view(0, -1, cfg.MODEL.NUM_CLASSES), \
torch.cat(bbox_pred_wide, dim=1).view(0, -1, 4)
torch.cat(bbox_pred_wide, dim=1).view(0, -1, self.box_dim)
def compute_losses(
self,
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
from lib.core.config import cfg
from lib.utils import logger
from lib.utils import rotated_boxes
try:
from lib.nms.cpu_nms import cpu_nms, cpu_soft_nms
......@@ -35,6 +36,8 @@ def nms(detections, thresh, force_cpu=False):
"""Perform either CPU or GPU Hard-NMS."""
if detections.shape[0] == 0:
return []
if detections.shape[1] == 6:
return rotated_boxes.nms(detections, thresh)
if cfg.USE_GPU_NMS and not force_cpu:
return gpu_nms(detections, thresh, device_id=cfg.GPU_ID)
else:
......
......@@ -71,7 +71,7 @@ class RetinaNetDecoder(torch.nn.Module):
features=features,
cls_prob=cls_prob,
bbox_pred=bbox_pred,
ims_info=blob_to_tensor(ims_info, enforce_cpu=True),
ims_info=ims_info,
strides=self.strides,
ratios=[float(e) for e in cfg.RETINANET.ASPECT_RATIOS],
scales=self.scales,
......@@ -94,7 +94,7 @@ class RPNDecoder(torch.nn.Module):
features=features,
cls_prob=cls_prob,
bbox_pred=bbox_pred,
ims_info=blob_to_tensor(ims_info, enforce_cpu=True),
ims_info=ims_info,
num_outputs=self.K,
strides=cfg.RPN.STRIDES,
ratios=[float(e) for e in cfg.RPN.ASPECT_RATIOS],
......
......@@ -20,9 +20,9 @@ from lib.core.config import cfg
from lib.faster_rcnn.generate_anchors import generate_anchors_v2
from lib.utils import logger
from lib.utils.blob import blob_to_tensor
from lib.utils.boxes import bbox_overlaps
from lib.utils.boxes import bbox_transform
from lib.utils.boxes import dismantle_gt_boxes
from lib.utils.cython_bbox import bbox_overlaps
class AnchorTargetLayer(torch.nn.Module):
......@@ -104,10 +104,7 @@ class AnchorTargetLayer(torch.nn.Module):
labels.fill(-1)
# Overlaps between the anchors and the gt boxes
overlaps = bbox_overlaps(
np.ascontiguousarray(anchors, dtype=np.float),
np.ascontiguousarray(gt_boxes, dtype=np.float),
)
overlaps = bbox_overlaps(anchors, gt_boxes)
argmax_overlaps = overlaps.argmax(axis=1)
max_overlaps = overlaps[np.arange(num_inside), argmax_overlaps]
......@@ -133,8 +130,9 @@ class AnchorTargetLayer(torch.nn.Module):
bbox_inside_weights = np.zeros((num_inside, 4), dtype=np.float32)
bbox_inside_weights[fg_inds, :] = np.array((1., 1., 1., 1.))
bbox_reg_weight = float(cfg.RETINANET.BBOX_REG_WEIGHT)
bbox_outside_weights = np.zeros((num_inside, 4), dtype=np.float32)
bbox_outside_weights[fg_inds, :] = np.ones((1, 4)) / max(len(fg_inds), 1)
bbox_outside_weights[fg_inds, :] = bbox_reg_weight / max(len(fg_inds), 1)
labels_wide[ix, inds_inside] = labels
bbox_targets_wide[ix, inds_inside] = bbox_targets
......
......@@ -20,7 +20,7 @@ from lib.core.config import cfg
from lib.nms.nms_wrapper import nms
from lib.nms.nms_wrapper import soft_nms
from lib.utils.blob import im_list_to_blob
from lib.utils.blob import tensor_to_blob
from lib.utils.graph import FrozenGraph
from lib.utils.image import scale_image
from lib.utils.timer import Timer
from lib.utils.vis import vis_one_image
......@@ -28,28 +28,35 @@ from lib.utils.vis import vis_one_image
def im_detect(detector, raw_image):
"""Detect a image, with single or multiple scales."""
# Prepare images
ims, ims_scale = scale_image(raw_image)
# Prepare blobs
blobs = {'data': im_list_to_blob(ims)}
blobs['ims_info'] = np.array([
list(blobs['data'].shape[1:3]) + [im_scale]
for im_scale in ims_scale], dtype=np.float32,
)
blobs['data'] = torch.from_numpy(blobs['data'])
for im_scale in ims_scale
], dtype=np.float32)
# Do Forward
with torch.no_grad():
outputs = detector.forward(inputs=blobs)
# Unpack results
return tensor_to_blob(outputs['detections'])[:, 1:]
if not hasattr(detector, 'frozen_graph'):
inputs = {
'data': torch.from_numpy(blobs['data']),
'ims_info': torch.from_numpy(blobs['ims_info']),
}
with torch.no_grad():
with torch.jit.Recorder(retain_ops=True):
outputs = detector.forward(inputs)
detector.frozen_graph = FrozenGraph(
{'data': inputs['data'],
'ims_info': inputs['ims_info']},
{'detections': outputs['detections']},
)
outputs = detector.frozen_graph(**blobs)
return outputs['detections'][:, 1:]
def ims_detect(detector, raw_images):
"""Detect images, with single or multiple scales."""
# Prepare images
ims, ims_scale = scale_image(raw_images[0])
num_scales = len(ims_scale)
ims_shape = [im.shape for im in raw_images]
......@@ -62,16 +69,27 @@ def ims_detect(detector, raw_images):
blobs = {'data': im_list_to_blob(ims)}
blobs['ims_info'] = np.array([
list(blobs['data'].shape[1:3]) + [im_scale]
for im_scale in ims_scale], dtype=np.float32,
)
blobs['data'] = torch.from_numpy(blobs['data'])
for im_scale in ims_scale
], dtype=np.float32)
# Do Forward
with torch.no_grad():
outputs = detector.forward(inputs=blobs)
if not hasattr(detector, 'frozen_graph'):
inputs = {
'data': torch.from_numpy(blobs['data']),
'ims_info': torch.from_numpy(blobs['ims_info']),
}
with torch.no_grad():
with torch.jit.Recorder(retain_ops=True):
outputs = detector.forward(inputs)
detector.frozen_graph = FrozenGraph(
{'data': inputs['data'],
'ims_info': inputs['ims_info']},
{'detections': outputs['detections']},
)
outputs = detector.frozen_graph(**blobs)
# Unpack results
results = tensor_to_blob(outputs['detections'])
results = outputs['detections']
detections_wide = [[] for _ in range(len(ims_shape))]
for i in range(len(ims)):
......@@ -86,7 +104,7 @@ def ims_detect(detector, raw_images):
return detections_wide
def test_net(net, server):
def test_net(detector, server):
# Load settings
classes = server.classes
num_images = server.num_images
......@@ -107,9 +125,9 @@ def test_net(net, server):
# Run detecting on specific scales
_t['im_detect'].tic()
if cfg.TEST.IMS_PER_BATCH > 1:
results = ims_detect(net, raw_images)
results = ims_detect(detector, raw_images)
else:
results = [im_detect(net, raw_images[0])]
results = [im_detect(detector, raw_images[0])]
_t['im_detect'].toc()
# Post-Processing
......
......@@ -71,7 +71,7 @@ class DataBatch(mp.Process):
super(DataBatch, self).__init__()
# Distributed settings
rank, group_size = 0, 1
process_group = dragon.distributed.get_default_process_group()
process_group = dragon.distributed.get_group()
if process_group is not None and kwargs.get(
'phase', 'TRAIN') == 'TRAIN':
group_size = process_group.size
......
......@@ -20,6 +20,7 @@ import numpy as np
from lib.core.config import cfg
from lib.ssd import transforms
from lib.utils import rotated_boxes
from lib.utils.boxes import flip_boxes
......@@ -42,7 +43,7 @@ class DataTransformer(multiprocessing.Process):
self.daemon = True
def make_roi_dict(self, example, flip=False):
n_objects = 0
n_objects, box_dim = 0, len(cfg.BBOX_REG_WEIGHTS)
if not self._use_diff:
for obj in example['object']:
if obj.get('difficult', 0) == 0:
......@@ -54,8 +55,8 @@ class DataTransformer(multiprocessing.Process):
'width': example['width'],
'height': example['height'],
'gt_classes': np.zeros((n_objects,), 'int32'),
'boxes': np.zeros((n_objects, 4), 'float32'),
'normalized_boxes': np.zeros((n_objects, 4), 'float32'),
'boxes': np.zeros((n_objects, box_dim), 'float32'),
'normalized_boxes': np.zeros((n_objects, box_dim), 'float32'),
}
# Filter the difficult instances
......@@ -64,12 +65,32 @@ class DataTransformer(multiprocessing.Process):
if not self._use_diff and \
obj.get('difficult', 0) > 0:
continue
roi_dict['boxes'][object_idx, :] = [
max(0, obj['xmin']),
max(0, obj['ymin']),
min(obj['xmax'], example['width'] - 1),
min(obj['ymax'], example['height'] - 1),
]
if box_dim == 4:
roi_dict['boxes'][object_idx, :] = [
max(0, obj['xmin']),
max(0, obj['ymin']),
min(obj['xmax'], example['width'] - 1),
min(obj['ymax'], example['height'] - 1),
]
elif box_dim == 5:
if 'bbox' in obj:
roi_dict['boxes'][object_idx, :] = [
max(0, obj['bbox'][0]),
max(0, obj['bbox'][1]),
min(obj['bbox'][2], example['width'] - 1),
min(obj['bbox'][3], example['height'] - 1),
rotated_boxes.clip_angle(obj['bbox'][4]),
]
else:
roi_dict['boxes'][object_idx, :] = \
rotated_boxes.canonicalize(
[obj['x1'], obj['y1'],
obj['x2'], obj['y2'],
obj['x3'], obj['y3'],
obj['x4'], obj['y4']]
)
else:
raise ValueError('Excepted box4d or box5d.')
roi_dict['gt_classes'][object_idx] = \
self._class_to_ind[obj['name']]
object_idx += 1
......@@ -78,8 +99,10 @@ class DataTransformer(multiprocessing.Process):
roi_dict['boxes'] = flip_boxes(
roi_dict['boxes'], roi_dict['width'])
roi_dict['boxes'][:, 0::2] /= roi_dict['width']
roi_dict['boxes'][:, 1::2] /= roi_dict['height']
roi_dict['boxes'][:, 0] /= roi_dict['width']
roi_dict['boxes'][:, 1] /= roi_dict['height']
roi_dict['boxes'][:, 2] /= roi_dict['width']
roi_dict['boxes'][:, 3] /= roi_dict['height']
return roi_dict
......@@ -99,8 +122,10 @@ class DataTransformer(multiprocessing.Process):
# Post-Process for gt boxes
# Shape like: [num_objects, {x1, y1, x2, y2, cls}]
gt_boxes = np.empty((len(roi_dict['gt_classes']), 5), 'float32')
gt_boxes[:, :4], gt_boxes[:, 4] = roi_dict['boxes'], roi_dict['gt_classes']
box_dim = roi_dict['boxes'].shape[1]
gt_boxes = np.empty((roi_dict['gt_classes'].size, box_dim + 1), 'float32')
gt_boxes[:, :box_dim], gt_boxes[:, box_dim] = \
roi_dict['boxes'], roi_dict['gt_classes']
# Distort => Expand => Sample => Resize
img, gt_boxes = self._image_aug(img, gt_boxes)
......
......@@ -16,28 +16,68 @@ from __future__ import print_function
import numpy as np
def generate_anchors(min_sizes, max_sizes, ratios):
def generate_anchors(min_sizes, max_sizes, ratios, angles=()):
"""
Generate anchor (reference) windows by enumerating
aspect ratios, min_sizes, max_sizes wrt a reference ctr (x, y, w, h).
"""
if len(angles) > 0:
return generate_rotated_anchors(
min_sizes, max_sizes, ratios, angles)
total_anchors = []
for idx, min_size in enumerate(min_sizes):
# Note that SSD assume it is a ctr-anchor
base_anchor = np.array([0, 0, min_size, min_size])
anchors = _ratio_enum(base_anchor, ratios)
anchors = _ratio_enum(base_anchor, ratios, _mkanchors)
if len(max_sizes) > 0:
max_size = max_sizes[idx]
_anchors = anchors[0].reshape((1, 4))
_anchors = np.vstack([_anchors, _max_size_enum(
base_anchor, min_size, max_size)])
_anchors = np.vstack([
_anchors,
_max_size_enum(
base_anchor,
min_size,
max_size,
_mkanchors,
)])
anchors = np.vstack([_anchors, anchors[1:]])
total_anchors.append(anchors)
return np.vstack(total_anchors)
def generate_rotated_anchors(min_sizes, max_sizes, ratios, angles):
"""
Generate anchor (reference) windows by enumerating
aspect ratios, min_sizes, max_sizes wrt a reference ctr (x, y, w, h).
"""
total_anchors = []
for angle in angles:
for idx, min_size in enumerate(min_sizes):
angle_array = np.ones((len(ratios), 1)) * angle
# Note that SSD assume it is a ctr-anchor
base_anchor = np.array([0, 0, min_size, min_size])
anchors = _ratio_enum(base_anchor, ratios, _mkanchors_v2)
if len(max_sizes) > 0:
max_size = max_sizes[idx]
_anchors = anchors[0].reshape((1, 4))
_anchors = np.vstack([
_anchors,
_max_size_enum(
base_anchor,
min_size,
max_size,
_mkanchors_v2,
)])
anchors = np.vstack([_anchors, anchors[1:]])
angle_array = np.vstack((angle_array, angle))
anchors = np.hstack((anchors, angle_array))
total_anchors.append(anchors)
return np.vstack(total_anchors)
def _whctrs(anchor):
"""Return width, height, x center, and y center for an anchor (window)."""
w, h = anchor[2], anchor[3]
......@@ -46,37 +86,43 @@ def _whctrs(anchor):
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
Given a vector of widths (ws) and heights (hs) around a center
ws, hs = ws[:, np.newaxis], hs[:, np.newaxis]
return np.hstack((
x_ctr - 0.5 * ws,
y_ctr - 0.5 * hs,
x_ctr + 0.5 * ws,
y_ctr + 0.5 * hs,
))
def _mkanchors_v2(ws, hs, x_ctr, y_ctr):
"""Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * ws,
y_ctr - 0.5 * hs,
x_ctr + 0.5 * ws,
y_ctr + 0.5 * hs))
return anchors
ws, hs = ws[:, np.newaxis], hs[:, np.newaxis]
return np.hstack((0 * (ws) + x_ctr, 0 * (hs) + y_ctr, ws, hs))
def _ratio_enum(anchor, ratios):
def _ratio_enum(anchor, ratios, make_fn):
"""Enumerate a set of anchors for each aspect ratio wrt an anchor."""
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
hs = np.round(np.sqrt(size_ratios))
ws = np.round(hs * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
return make_fn(ws, hs, x_ctr, y_ctr)
def _max_size_enum(base_anchor, min_size, max_size):
def _max_size_enum(base_anchor, min_size, max_size, make_fn):
"""Enumerate a anchor for max_size wrt base_anchor."""
w, h, x_ctr, y_ctr = _whctrs(base_anchor)
ws = hs = np.sqrt([min_size * max_size])
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
return make_fn(ws, hs, x_ctr, y_ctr)
if __name__ == '__main__':
print(generate_anchors(min_sizes=[30], max_sizes=[60], ratios=[1, 0.5, 2]))
print(generate_anchors(min_sizes=[30], max_sizes=[60], ratios=[1]))
print(generate_rotated_anchors(min_sizes=[30], max_sizes=[60], ratios=[1], angles=[1]))
......@@ -18,9 +18,9 @@ import dragon.vm.torch as torch
from lib.core.config import cfg
from lib.utils.blob import blob_to_tensor
from lib.utils.boxes import bbox_overlaps
from lib.utils.boxes import bbox_transform
from lib.utils.boxes import dismantle_gt_boxes
from lib.utils.cython_bbox import bbox_overlaps
class MultiBoxMatchLayer(torch.nn.Module):
......@@ -30,12 +30,12 @@ class MultiBoxMatchLayer(torch.nn.Module):
def forward(self, prior_boxes, gt_boxes):
num_images = cfg.TRAIN.IMS_PER_BATCH
gt_boxes_wide = dismantle_gt_boxes(gt_boxes, num_images)
num_priors = len(prior_boxes)
num_priors, box_dim = prior_boxes.shape[:]
# Do matching between prior boxes and gt boxes
match_inds_wide = -np.ones((num_images, num_priors), dtype=np.int32)
match_labels_wide = np.zeros(match_inds_wide.shape, dtype=np.int64)
max_overlaps_wide = np.zeros(match_inds_wide.shape, dtype=np.float32)
match_inds_wide = -np.ones((num_images, num_priors), 'int32')
match_labels_wide = np.zeros(match_inds_wide.shape, 'int64')
max_overlaps_wide = np.zeros(match_inds_wide.shape, 'float32')
for ix in range(num_images):
# GT boxes (x1, y1, x2, y2, label)
......@@ -44,26 +44,24 @@ class MultiBoxMatchLayer(torch.nn.Module):
continue
# Compute the overlaps between prior boxes and gt boxes
overlaps = bbox_overlaps(
np.ascontiguousarray(prior_boxes, dtype=np.float),
np.ascontiguousarray(gt_boxes, dtype=np.float))
overlaps = bbox_overlaps(prior_boxes, gt_boxes)
argmax_overlaps = overlaps.argmax(axis=1)
max_overlaps = overlaps[np.arange(num_priors), argmax_overlaps]
max_overlaps_wide[ix] = max_overlaps
# Bipartite matching & assignments
# Bipartite matching and assignments
bipartite_inds = overlaps.argmax(axis=0)
class_assignment = gt_boxes[:, 4]
class_assignment = gt_boxes[:, -1]
match_inds_wide[ix][bipartite_inds] = np.arange(
gt_boxes.shape[0], dtype=np.int32)
match_labels_wide[ix][bipartite_inds] = class_assignment
# Per prediction matching & assignments
# Per prediction matching and assignments
# Note that SSD match each prior box for only once
# We simply implement it by clobbering the assignments matched in bipartite
per_inds = np.where(max_overlaps >= cfg.TRAIN.FG_THRESH)[0]
gt_assignment = argmax_overlaps[per_inds]
class_assignment = gt_boxes[gt_assignment, 4]
class_assignment = gt_boxes[gt_assignment, -1]
match_inds_wide[ix][per_inds] = gt_assignment
match_labels_wide[ix][per_inds] = class_assignment
......@@ -78,24 +76,30 @@ class MultiBoxTargetLayer(torch.nn.Module):
def __init__(self):
super(MultiBoxTargetLayer, self).__init__()
def forward(self, match_inds, match_labels, prior_boxes, gt_boxes):
def forward(
self,
match_inds,
match_labels,
prior_boxes,
gt_boxes,
):
num_images = cfg.TRAIN.IMS_PER_BATCH
# GT assignments between default boxes and gt boxes
match_inds_wide = match_inds
# Matched labels (After hard mining possibly)
match_labels_wide = match_labels
num_priors = len(prior_boxes)
num_priors, box_dim = prior_boxes.shape[:]
gt_boxes_wide = dismantle_gt_boxes(gt_boxes, num_images)
bbox_targets_wide = np.zeros((num_images, num_priors, 4), dtype=np.float32)
bbox_inside_weights_wide = np.zeros(bbox_targets_wide.shape, dtype=np.float32)
bbox_outside_weights_wide = np.zeros(bbox_targets_wide.shape, dtype=np.float32)
bbox_targets_wide = np.zeros((num_images, num_priors, box_dim), 'float32')
bbox_inside_weights_wide = np.zeros(bbox_targets_wide.shape, 'float32')
bbox_outside_weights_wide = np.zeros(bbox_targets_wide.shape, 'float32')
# Number of matched boxes(#positive)
# We divide it by num of images, as SmoothLLLoss will divide it also
n_pos = max(len(np.where(match_labels_wide > 0)[0]), 1)
bbox_normalization = n_pos / num_images
n_pos = float(max(len(np.where(match_labels_wide > 0)[0]), 1))
bbox_reg_weight = cfg.SSD.BBOX_REG_WEIGHT * num_images / n_pos
for ix in range(num_images):
gt_boxes = gt_boxes_wide[ix]
......@@ -113,8 +117,8 @@ class MultiBoxTargetLayer(torch.nn.Module):
# Assign targets & inside weights & outside weights
bbox_targets_wide[ix][ex_inds] = bbox_transform(
ex_rois, gt_rois, cfg.BBOX_REG_WEIGHTS)
bbox_inside_weights_wide[ix, :] = (1.0, 1.0, 1.0, 1.0)
bbox_outside_weights_wide[ix][ex_inds] = 1.0 / bbox_normalization
bbox_inside_weights_wide[ix, :] = 1.
bbox_outside_weights_wide[ix][ex_inds] = bbox_reg_weight
return {
'bbox_targets': blob_to_tensor(bbox_targets_wide),
......
......@@ -34,7 +34,7 @@ class PriorBoxLayer(torch.nn.Module):
len(min_sizes), len(max_sizes)))
self.strides = cfg.SSD.MULTIBOX.STRIDES
aspect_ratios = cfg.SSD.MULTIBOX.ASPECT_RATIOS
self.num_anchors = len(min_sizes) * len(aspect_ratios) + len(max_sizes)
aspect_angles = cfg.SSD.MULTIBOX.ASPECT_ANGLES
self.base_anchors = []
for i in range(len(min_sizes)):
self.base_anchors.append(
......@@ -44,6 +44,7 @@ class PriorBoxLayer(torch.nn.Module):
max_sizes[i] if isinstance(
max_sizes[i], (list, tuple)) else [max_sizes[i]],
aspect_ratios[i],
aspect_angles,
)
)
......@@ -55,17 +56,34 @@ class PriorBoxLayer(torch.nn.Module):
shift_x = (np.arange(0, width) + 0.5) * self.strides[i]
shift_y = (np.arange(0, height) + 0.5) * self.strides[i]
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
shifts = np.vstack((shift_x.ravel(), shift_y.ravel(),
shift_x.ravel(), shift_y.ravel())).transpose()
# 2. Apply anchors on base grids
# Add A anchors (1, A, 4) to
# cell K shifts (K, 1, 4) to get
# shift anchors (K, A, 4)
# Reshape to (K * A, 4) shifted anchors
A = self.base_anchors[i].shape[0]
D = self.base_anchors[i].shape[1]
if D == 4:
shifts = np.vstack((
shift_x.ravel(),
shift_y.ravel(),
shift_x.ravel(),
shift_y.ravel())
).transpose()
elif D == 5:
shifts = np.vstack((
shift_x.ravel(),
shift_y.ravel(),
shift_x.ravel() * 0,
shift_y.ravel() * 0,
shift_y.ravel() * 0)
).transpose()
else:
raise ValueError('Excepted anchor4d or anchor5d.')
K = shifts.shape[0] # K = map_h * map_w
anchors = (self.base_anchors[i].reshape((1, A, 4)) +
shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
anchors = anchors.reshape((K * A, 4)).astype(np.float32)
anchors = (self.base_anchors[i].reshape((1, A, D)) +
shifts.reshape((1, K, D)).transpose((1, 0, 2)))
anchors = anchors.reshape((K * A, D)).astype(np.float32)
all_anchors.append(anchors)
return np.concatenate(all_anchors, axis=0)
......@@ -20,10 +20,10 @@ import numpy as np
from lib.core.config import cfg
from lib.nms.nms_wrapper import nms
from lib.nms.nms_wrapper import soft_nms
from lib.utils.blob import tensor_to_blob
from lib.utils.boxes import bbox_transform_inv
from lib.utils.boxes import clip_tiled_boxes
from lib.utils.boxes import clip_boxes
from lib.utils.timer import Timer
from lib.utils.graph import FrozenGraph
from lib.utils.vis import vis_one_image
......@@ -41,33 +41,40 @@ def get_images(ims):
def ims_detect(detector, ims):
"""Detect images, with the single scale."""
# Prepare blobs
data, im_scales = get_images(ims)
data = torch.from_numpy(data).cuda(cfg.GPU_ID)
# Do Forward
with torch.no_grad():
outputs = detector.forward(inputs={'data': data})
if not hasattr(detector, 'frozen_graph'):
image = torch.from_numpy(data)
with torch.no_grad():
with torch.jit.Recorder(retain_ops=True):
outputs = detector.forward(inputs={'data': image})
detector.frozen_graph = FrozenGraph(
{'data': image},
{'cls_prob': outputs['cls_prob'],
'bbox_pred': outputs['bbox_pred']},
{'prior_boxes': outputs['prior_boxes']},
)
outputs = detector.frozen_graph(data=data)
# Decode results
batch_boxes = []
scores = tensor_to_blob(outputs['cls_prob'])
prior_boxes = tensor_to_blob(outputs['prior_boxes'])
box_deltas = tensor_to_blob(outputs['bbox_pred'])
for i in range(box_deltas.shape[0]):
for i in range(len(im_scales)):
boxes = bbox_transform_inv(
boxes=prior_boxes,
deltas=box_deltas[i],
weights=cfg.BBOX_REG_WEIGHTS,
outputs['prior_boxes'],
outputs['bbox_pred'][i],
cfg.BBOX_REG_WEIGHTS,
)
boxes[:, 0::2] /= im_scales[i][1]
boxes[:, 1::2] /= im_scales[i][0]
batch_boxes.append(clip_tiled_boxes(boxes, ims[i].shape))
boxes[:, 0] /= im_scales[i][1]
boxes[:, 1] /= im_scales[i][0]
boxes[:, 2] /= im_scales[i][1]
boxes[:, 3] /= im_scales[i][0]
batch_boxes.append(clip_boxes(boxes, ims[i].shape))
return scores, batch_boxes
return outputs['cls_prob'], batch_boxes
def test_net(net, server):
def test_net(detector, server):
# Load settings
classes = server.classes
num_images = server.num_images
......@@ -87,7 +94,7 @@ def test_net(net, server):
raw_images.append(raw_image)
_t['im_detect'].tic()
batch_scores, batch_boxes = ims_detect(net, raw_images)
batch_scores, batch_boxes = ims_detect(detector, raw_images)
_t['im_detect'].toc()
_t['misc'].tic()
......@@ -129,7 +136,7 @@ def test_net(net, server):
classes,
boxes_this_image,
thresh=cfg.VIS_TH,
box_alpha=1.0,
box_alpha=1.,
show_class=True,
filename=server.get_save_filename(image_ids[item_idx]),
)
......
......@@ -19,6 +19,9 @@ from __future__ import print_function
import numpy as np
from lib.utils import cython_bbox
from lib.utils import rotated_boxes
def intersection(boxes1, boxes2):
"""Compute pairwise intersection areas between boxes.
......@@ -104,15 +107,29 @@ def ioa2(boxes1, boxes2):
return intersect / areas
def bbox_overlaps(boxes1, boxes2):
"""Compute the overlaps between two group of boxes."""
if boxes1.shape[1] == 5:
return rotated_boxes.bbox_overlaps(boxes1, boxes2)
return cython_bbox.bbox_overlaps(
np.ascontiguousarray(boxes1, dtype=np.float),
np.ascontiguousarray(boxes2, dtype=np.float),
)
def bbox_transform(ex_rois, gt_rois, weights=(1., 1., 1., 1.)):
"""Transform the boxes to the regression targets."""
ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0
ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0
if len(weights) == 5:
# Transform the rotated boxes
return rotated_boxes.bbox_transform(ex_rois, gt_rois, weights)
ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.
ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.
ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths
ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights
gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0
gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0
gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.
gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.
gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths
gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights
......@@ -122,22 +139,22 @@ def bbox_transform(ex_rois, gt_rois, weights=(1., 1., 1., 1.)):
targets_dw = ww * np.log(gt_widths / ex_widths)
targets_dh = wh * np.log(gt_heights / ex_heights)
targets = np.vstack(
(targets_dx, targets_dy,
targets_dw, targets_dh)).transpose()
return targets
return np.vstack((targets_dx, targets_dy, targets_dw, targets_dh)).transpose()
def bbox_transform_inv(boxes, deltas, weights=(1., 1., 1., 1.)):
"""Decode the final boxes according to the deltas."""
if len(weights) == 5:
# Decode the rotated boxes
return rotated_boxes.bbox_transform_inv(boxes, deltas, weights)
if boxes.shape[0] == 0:
return np.zeros((0, deltas.shape[1]), dtype=deltas.dtype)
boxes = boxes.astype(deltas.dtype, copy=False)
widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
widths = boxes[:, 2] - boxes[:, 0] + 1.
heights = boxes[:, 3] - boxes[:, 1] + 1.
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
......@@ -170,6 +187,20 @@ def boxes_area(boxes):
return areas
def clip_boxes(boxes, im_shape):
if boxes.shape[1] == 5:
return rotated_boxes.clip_boxes(boxes, im_shape)
# x1 >= 0
boxes[:, 0] = np.maximum(np.minimum(boxes[:, 0], im_shape[1] - 1), 0)
# y1 >= 0
boxes[:, 1] = np.maximum(np.minimum(boxes[:, 1], im_shape[0] - 1), 0)
# x2 < im_shape[1]
boxes[:, 2] = np.maximum(np.minimum(boxes[:, 2], im_shape[1] - 1), 0)
# y2 < im_shape[0]
boxes[:, 3] = np.maximum(np.minimum(boxes[:, 3], im_shape[0] - 1), 0)
return boxes
def clip_tiled_boxes(boxes, im_shape):
# x1 >= 0
boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
......@@ -203,6 +234,8 @@ def expand_boxes(boxes, scale):
def flip_boxes(boxes, width):
"""Flip the boxes horizontally."""
if boxes.shape[1] == 5:
return rotated_boxes.flip_boxes(boxes, width)
flip_boxes = boxes.copy()
old_x1 = boxes[:, 0].copy()
old_x2 = boxes[:, 2].copy()
......@@ -224,5 +257,5 @@ def dismantle_gt_boxes(gt_boxes, num_images):
return [
gt_boxes[
np.where(gt_boxes[:, -1].astype(np.int32) == ix)[0]
] for ix in range(num_images)
][:, :-1] for ix in range(num_images)
]
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import dragon
from dragon.core.framework import tensor_util
from dragon.vm.torch.jit.recorder import get_default_recorder
class FrozenGraph(object):
"""Simple sequential graph to accelerate inference.
The frozen graph reduces the overhead of python functions
under eager execution. Such cost will be at least 15ms
for common backbones, which limits to about 60FPS.
For more details, see the eager mechanism of Dragon.
"""
def __init__(self, inputs, outputs, constants=None):
def canonicalize(input_dict):
if input_dict is None:
return {}
for k, v in input_dict.items():
input_dict[k] = v.name if hasattr(v, 'name') else v
return input_dict
self._inputs = canonicalize(inputs)
self._outputs = canonicalize(outputs)
self._constants = canonicalize(constants)
self._graph = dragon.Workspace() \
.merge_from(dragon.workspace.get_default())
self._tape = get_default_recorder()
def forward(self, **kwargs):
# Assign inputs
for name, tensor in self._inputs.items():
value = kwargs.get(name, None)
tensor_util.set_array(tensor, value)
# Replay the tape
self._tape.replay()
# Collect outputs
# 1) Target results
# 2) Constant values
outputs = collections.OrderedDict()
for name, tensor in self._outputs.items():
outputs[name] = tensor_util.to_array(tensor, True)
for name, value in self._constants.items():
outputs[name] = value
return outputs
def __call__(self, **kwargs):
with self._graph.as_default():
return self.forward(**kwargs)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ctypes import *
import os.path as osp
import numpy as np
class LibRotatedBoxes(object):
def __init__(self):
self._nms = cdll.LoadLibrary(
osp.join(osp.split(
osp.abspath(__file__))[0],
"ctypes_rbox.so")
).NMS
self._overlaps = cdll.LoadLibrary(
osp.join(osp.split(
osp.abspath(__file__))[0],
"ctypes_rbox.so")
).Overlaps
self._nms.argtypes = (
POINTER(c_double),
POINTER(c_int),
POINTER(c_double),
POINTER(c_int),
c_double,
)
self._overlaps.argtypes = \
(POINTER(c_double),
POINTER(c_double),
POINTER(c_int),
POINTER(c_double)
)
self._nms.restype = None
self._overlaps.restype = None
def nms(self, dets, thresh):
"""CPU Hard-NMS.
Parameters
----------
dets: (N, 6) ndarray of double [cx, cy, w, h, a, scores]
thresh : float
"""
assert dets.shape[1] == 6
order = dets[:, 5].argsort()[::-1]
sorted_dets = dets[order, :]
N = sorted_dets.shape[0]
num_ctypes = c_int(N)
thresh = c_double(thresh)
pred_boxes = sorted_dets[:, 0:-1].flatten()
pred_scores = sorted_dets[:, -1:].flatten()
indices = np.zeros(N)
_boxes = np.ascontiguousarray(pred_boxes, dtype=np.double)
_scores = np.ascontiguousarray(pred_scores, dtype=np.double)
_inds = np.ascontiguousarray(indices, dtype=np.int32)
boxes_ctypes_ptr = _boxes.ctypes.data_as(POINTER(c_double))
scores_ctypes_ptr = _scores.ctypes.data_as(POINTER(c_double))
inds_ctypes_ptr = _inds.ctypes.data_as(POINTER(c_int32))
self._nms(
boxes_ctypes_ptr,
inds_ctypes_ptr,
scores_ctypes_ptr,
byref(num_ctypes),
thresh,
)
keep_indices = np.ctypeslib.as_array(
(c_int32 * num_ctypes.value).from_address(
addressof(inds_ctypes_ptr.contents)))
return list(order[keep_indices.astype(np.int32)])
def overlaps(self, boxes, query_boxes):
"""Computer overlaps between boxes and query boxes.
Parameters
----------
boxes: (N, 5) ndarray of double [cx, cy, w, h, a]
query_boxes: (K, 6) ndarray of double [cx, cy, w, h, a, cls]
Returns
-------
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
assert boxes.shape[1] == 5
if query_boxes.shape[1] == 6:
query_boxes = query_boxes[:, :-1]
N = boxes.shape[0]
K = query_boxes.shape[0]
num_ctypes = (c_int * 2)()
num_ctypes[0] = N
num_ctypes[1] = K
num_ctypes_ptr = cast(num_ctypes, POINTER(c_int))
_boxes = boxes.flatten()
_query_boxes = query_boxes.flatten()
_areas = np.zeros((N, K), dtype=np.double).flatten()
_boxes = np.ascontiguousarray(_boxes, dtype=np.double)
_query_boxes = np.ascontiguousarray(_query_boxes, dtype=np.double)
_areas = np.ascontiguousarray(_areas, dtype=np.double)
boxes_ctypes_ptr = _boxes.ctypes.data_as(POINTER(c_double))
query_boxes_ctypes_ptr = _query_boxes.ctypes.data_as(POINTER(c_double))
areas_ctypes_ptr = _areas.ctypes.data_as(POINTER(c_double))
self._overlaps(
boxes_ctypes_ptr,
query_boxes_ctypes_ptr,
num_ctypes_ptr,
areas_ctypes_ptr,
)
area = np.ctypeslib.as_array(
(c_double * K * N).from_address(
addressof(areas_ctypes_ptr.contents)
)
)
rarea = np.nan_to_num(area.astype(np.float32))
return rarea
libc = LibRotatedBoxes()
def bbox_overlaps(boxes1, boxes2):
"""Compute the overlaps between two group of boxes."""
return libc.overlaps(boxes1, boxes2)
def bbox_transform(ex_rois, gt_rois, weights=(1., 1., 1., 1., 1.)):
"""Transform the boxes to the regression targets."""
ex_ctr_x = ex_rois[:, 0]
ex_ctr_y = ex_rois[:, 1]
ex_widths = ex_rois[:, 2]
ex_heights = ex_rois[:, 3]
ex_angles = ex_rois[:, 4]
gt_ctr_x = gt_rois[:, 0]
gt_ctr_y = gt_rois[:, 1]
gt_widths = gt_rois[:, 2]
gt_heights = gt_rois[:, 3]
gt_angles = gt_rois[:, 4]
wx, wy, ww, wh, wa = weights
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = ww * np.log(gt_widths / ex_widths)
targets_dh = wh * np.log(gt_heights / ex_heights)
targets_da = wa * np.sin(np.radians(gt_angles - ex_angles))
return np.vstack((
targets_dx,
targets_dy,
targets_dw,
targets_dh,
targets_da,
)).transpose()
def bbox_transform_inv(boxes, deltas, weights=(1., 1., 1., 1., 1.)):
"""Decode the final boxes according to the deltas."""
if boxes.shape[0] == 0:
return np.zeros((0, deltas.shape[1]), dtype=deltas.dtype)
boxes = boxes.astype(deltas.dtype, copy=False)
ctr_x = boxes[:, 0]
ctr_y = boxes[:, 1]
widths = boxes[:, 2]
heights = boxes[:, 3]
angles = boxes[:, 4:5]
wx, wy, ww, wh, wa = weights
dx = deltas[:, 0::5] / wx
dy = deltas[:, 1::5] / wy
dw = deltas[:, 2::5] / ww
dh = deltas[:, 3::5] / wh
da = deltas[:, 4::5] / wa
pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
pred_w = np.exp(dw) * widths[:, np.newaxis]
pred_h = np.exp(dh) * heights[:, np.newaxis]
da = np.minimum(np.maximum(da, -1), 1)
pred_a = np.rad2deg(np.arcsin(da)) + angles
pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype)
pred_boxes[:, 0::5] = pred_ctr_x # x_ctr
pred_boxes[:, 1::5] = pred_ctr_y # y_ctr
pred_boxes[:, 2::5] = pred_w # w
pred_boxes[:, 3::5] = pred_h # h
pred_boxes[:, 4::5] = pred_a # angle
return pred_boxes
def canonicalize(values):
def poly8_to_poly5(values):
pt1, pt2 = values[0:2], values[2:4]
pt3, pt4 = values[4:6], values[6:8]
edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[1] - pt2[1]) * (pt1[1] - pt2[1]))
edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[1] - pt3[1]) * (pt2[1] - pt3[1]))
angle, width, height = 0, 0, 0
if edge1 > edge2:
width = edge1
height = edge2
if pt1[0] - pt2[0] != 0:
angle = -np.arctan(float(pt1[1] - pt2[1]) / float(pt1[0] - pt2[0])) / 3.1415926 * 180
else:
angle = 90.
elif edge2 >= edge1:
width = edge2
height = edge1
if pt2[0] - pt3[0] != 0:
angle = -np.arctan(float(pt2[1] - pt3[1]) / float(pt2[0] - pt3[0])) / 3.1415926 * 180
else:
angle = 90.
if angle < -45.:
angle = angle + 180.
x_ctr = (pt1[0] + pt3[0]) / 2.
y_ctr = (pt1[1] + pt3[1]) / 2.
return x_ctr, y_ctr, width, height, angle
if len(values) == 8:
return poly8_to_poly5(values)
return values
def clip_angle(d):
while d < 0:
d += 360
while d >= 360:
d -= 360
return d
def clip_boxes(boxes, im_shape):
# ctr_x >= 0
boxes[:, 0] = np.maximum(np.minimum(boxes[:, 0], im_shape[1] - 1), 0)
# ctr_y >= 0
boxes[:, 1] = np.maximum(np.minimum(boxes[:, 1], im_shape[0] - 1), 0)
# w < im_shape[1]
boxes[:, 2] = np.maximum(np.minimum(boxes[:, 2], im_shape[1] - 1), 0)
# h < im_shape[0]
boxes[:, 3] = np.maximum(np.minimum(boxes[:, 3], im_shape[0] - 1), 0)
# 0 < a < 360
boxes[:, 4] = np.maximum(np.minimum(boxes[:, 4], 359), 0)
return boxes
def clip_tiled_boxes(boxes, im_shape):
# ctr_x >= 0
boxes[:, 0::5] = np.maximum(np.minimum(boxes[:, 0::5], im_shape[1] - 1), 0)
# ctr_y >= 0
boxes[:, 1::5] = np.maximum(np.minimum(boxes[:, 1::5], im_shape[0] - 1), 0)
# w < im_shape[1]
boxes[:, 2::5] = np.maximum(np.minimum(boxes[:, 2::5], im_shape[1] - 1), 0)
# h < im_shape[0]
boxes[:, 3::5] = np.maximum(np.minimum(boxes[:, 3::5], im_shape[0] - 1), 0)
# 0 < a < 360
boxes[:, 4::5] = np.maximum(np.minimum(boxes[:, 4::5], 359), 0)
return boxes
def flip_boxes(boxes, width):
ca = np.vectorize(clip_angle)
flip_boxes = boxes.copy()
old_cx = boxes[:, 0].copy()
old_a = boxes[:, 4].copy()
flip_boxes[:, 0] = width - old_cx - 1
flip_boxes[:, 4] = ca(180 - old_a)
return flip_boxes
def nms(dets, thresh):
return libc.nms(dets, thresh)
if __name__ == "__main__":
prior_boxes = np.array([[4, 4, 5, 5, 90], [4, 4, 15, 15, 90]], dtype=np.double)
gt_boxes = np.array([[4, 4, 15, 15, 90, 1]], dtype=np.double)
ov = bbox_overlaps(prior_boxes, gt_boxes)
print(ov)
\ No newline at end of file
......@@ -63,28 +63,7 @@ def kp_connections(keypoints):
return kp_lines
def convert_from_cls_format(cls_boxes, cls_segms, cls_keyps):
"""Convert from the class boxes/segms/keyps format generated by the testing code."""
box_list = [b for b in cls_boxes if len(b) > 0]
if len(box_list) > 0:
boxes = np.concatenate(box_list)
else:
boxes = None
if cls_segms is not None:
segms = [s for slist in cls_segms for s in slist]
else:
segms = None
if cls_keyps is not None:
keyps = [k for klist in cls_keyps for k in klist]
else:
keyps = None
classes = []
for j in range(len(cls_boxes)):
classes += [j] * len(cls_boxes[j])
return boxes, segms, keyps, classes
def convert_from_cls_format_v2(cls_boxes, cls_segms, cls_keyps, class_names):
def convert_from_cls_format(cls_boxes, cls_segms, cls_keyps, class_names):
"""Convert from the class boxes/segms/keyps format generated by the testing code."""
box_list, segm_list = [], []
for j, name in enumerate(class_names):
......@@ -118,6 +97,29 @@ def get_class_string(class_name, score):
return class_name + ' {:0.2f}'.format(score).lstrip('0')
def get_bbox_contours(rotated_box):
def point_rotate(p, c, radian):
x = (p[0] - c[0]) * np.cos(radian) - (p[1] - c[1]) * np.sin(radian) + c[0]
y = (p[0] - c[0]) * np.sin(radian) + (p[1] - c[1]) * np.cos(radian) + c[1]
return x, y
cx, cy, w, h, angle = rotated_box
angle = -angle
x1, y1 = cx - (w / 2), cy - (h / 2)
x2, y2 = x1 + w, y1 + h
radian = np.radians(angle)
r11 = point_rotate((x1, y1), (cx, cy), radian)
r12 = point_rotate((x1, y2), (cx, cy), radian)
r21 = point_rotate((x2, y2), (cx, cy), radian)
r22 = point_rotate((x2, y1), (cx, cy), radian)
quad = np.array([r11, r12, r21, r22, r11])
# Main direction
mside = max(w, h) / 2
x_end = mside * np.cos(radian)
y_end = mside * np.sin(radian)
main_direction = np.array([[cx, cy], [cx + x_end, cy + y_end]])
return quad, main_direction
def get_mask(boxes, segms, im_shape, mask_thresh=0.4):
i, masks = 0, np.zeros(list(im_shape) + [len(boxes)], dtype=np.uint8)
for det, msk in zip(boxes, segms):
......@@ -144,107 +146,6 @@ def get_mask(boxes, segms, im_shape, mask_thresh=0.4):
return masks
def vis_mask(img, mask, col, alpha=0.4, show_border=True, border_thick=1):
"""Visualizes a single binary mask."""
img = img.astype(np.float32)
idx = np.nonzero(mask)
img[idx[0], idx[1], :] *= 1.0 - alpha
img[idx[0], idx[1], :] += alpha * col
if show_border:
_, contours, _ = cv2.findContours(
mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(img, contours, -1, _WHITE, border_thick, cv2.LINE_AA)
return img.astype(np.uint8)
def vis_class(img, pos, class_str, font_scale=0.35):
"""Visualizes the class."""
x0, y0 = int(pos[0]), int(pos[1])
# Compute text size.
txt = class_str
font = cv2.FONT_HERSHEY_SIMPLEX
((txt_w, txt_h), _) = cv2.getTextSize(txt, font, font_scale, 1)
# Place text background.
back_tl = x0, y0 - int(1.3 * txt_h)
back_br = x0 + txt_w, y0
cv2.rectangle(img, back_tl, back_br, _GREEN, -1)
# Show text.
txt_tl = x0, y0 - int(0.3 * txt_h)
cv2.putText(img, txt, txt_tl, font, font_scale, _GRAY, lineType=cv2.LINE_AA)
return img
def vis_bbox(img, bbox, thick=1):
"""Visualizes a bounding box."""
(x0, y0, w, h) = bbox
x1, y1 = int(x0 + w), int(y0 + h)
x0, y0 = int(x0), int(y0)
cv2.rectangle(img, (x0, y0), (x1, y1), _GREEN, thickness=thick)
return img
def vis_one_image_opencv(
im,
class_names,
boxes,
segms=None,
keypoints=None,
thresh=0.9,
kp_thresh=2,
show_box=False,
show_class=False,
):
"""Constructs a numpy array with the detections visualized."""
boxes, segms, keypoints, classes = \
convert_from_cls_format_v2(boxes, segms, keypoints, class_names)
if boxes is None \
or boxes.shape[0] == 0 or \
max(boxes[:, 4]) < thresh:
return im
mask_color_id, masks, color_list = 0, None, colormap()
if segms is not None and len(segms) > 0:
masks = get_mask(boxes, segms, im.shape[0:2])
# Display in largest to smallest order to reduce occlusion
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
sorted_inds = np.argsort(-areas)
for i in sorted_inds:
bbox = boxes[i, :4]
score = boxes[i, -1]
if score < thresh:
continue
# show box (off by default)
if show_box:
im = vis_bbox(
im, (bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]))
# show class (off by default)
if show_class:
class_str = get_class_string(class_names[classes[i]], score)
im = vis_class(im, (bbox[0], bbox[1] - 2), class_str)
# show mask
if segms is not None and len(segms) > i:
color_mask = color_list[mask_color_id % len(color_list), 0:3]
mask_color_id += 1
im = vis_mask(im, masks[..., i], color_mask)
# # show keypoints
# if keypoints is not None and len(keypoints) > i:
# im = vis_keypoints(im, keypoints[i], kp_thresh)
cv2.imshow('Detectron', im)
cv2.waitKey(0)
def vis_one_image(
im,
class_names,
......@@ -260,11 +161,11 @@ def vis_one_image(
):
"""Visual debugging of detections."""
boxes, segms, keypoints, classes = \
convert_from_cls_format_v2(boxes, segms, keypoints, class_names)
convert_from_cls_format(boxes, segms, keypoints, class_names)
if boxes is None \
or boxes.shape[0] == 0 or \
max(boxes[:, 4]) < thresh:
max(boxes[:, -1]) < thresh:
return
im, mask, masks = im[:, :, ::-1], None, None
......@@ -282,35 +183,69 @@ def vis_one_image(
ax.imshow(im)
# Display in largest to smallest order to reduce occlusion
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
if boxes.shape[1] == 5:
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
elif boxes.shape[1] == 6:
areas = boxes[:, 2] * boxes[:, 3]
else:
raise ValueError('Excepted box4d or box5d.')
sorted_inds = np.argsort(-areas)
mask_color_id = 0
for i in sorted_inds:
bbox = boxes[i, :4]
bbox = boxes[i, :-1]
score = boxes[i, -1]
if score < thresh:
continue
# show box (off by default)
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1],
fill=False, edgecolor='g',
linewidth=1.0, alpha=box_alpha))
# Show box
if bbox.size == 4:
ax.add_patch(
plt.Rectangle(
(bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1],
fill=False,
edgecolor='g',
linewidth=1.,
alpha=box_alpha,
)
)
elif bbox.size == 5:
quad, md = get_bbox_contours(bbox)
ax.add_patch(
Polygon(
quad,
fill=False,
edgecolor='g',
linewidth=1.,
alpha=box_alpha,
)
)
ax.add_patch(
plt.arrow(
md[0, 0],
md[0, 1],
md[1, 0] - md[0, 0],
md[1, 1] - md[0, 1],
width=2,
color='g',
alpha=box_alpha,
)
)
# Show class
if show_class:
ax.text(
bbox[0], bbox[1] - 2,
get_class_string(class_names[classes[i]], score),
fontsize=11,
family='serif',
bbox=dict(
facecolor='g', alpha=0.4, pad=0, edgecolor='none'),
color='white')
bbox=dict(facecolor='g', alpha=0.4, pad=0, edgecolor='none'),
color='white',
)
# show mask
# Show mask
if segms is not None and len(segms) > i:
img = np.ones(im.shape)
color_mask = color_list[mask_color_id % len(color_list), 0:3]
......@@ -327,12 +262,14 @@ def vis_one_image(
e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
for c in contour:
polygon = Polygon(
ax.add_patch(Polygon(
c.reshape((-1, 2)),
fill=True, facecolor=color_mask,
edgecolor='w', linewidth=1.2,
alpha=0.5)
ax.add_patch(polygon)
fill=True,
facecolor=color_mask,
edgecolor='w',
linewidth=1.2,
alpha=0.5,
))
if filename is not None:
fig.savefig(filename, dpi=dpi)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!