Commit 89f1ee28 by Ting PAN

Update README.md

1 parent 19c489b6
...@@ -3,9 +3,7 @@ The list of most significant changes made over time in SeetaDet. ...@@ -3,9 +3,7 @@ The list of most significant changes made over time in SeetaDet.
SeetaDet 0.1.0 (20190311) SeetaDet 0.1.0 (20190311)
Recommended docker for Dragon: Dragon Minimum Required (Version 0.3.0.0)
seetaresearch/dragon:0.3.0.0-rc4-cuda9.1-ubuntu16.04
Changes: Changes:
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
## WHAT's SeetaDet? ## WHAT's SeetaDet?
SeetaDet contains many useful object detectors, including R-CNN series, SSD, SeetaDet contains many useful object detectors, including R-CNN series, SSD,
and the recent RetinaNet. We have achieved the same or higher performance than and the recent RetinaNet.
the baseline reported by the original paper.
We have achieved the same or higher performance than the baseline reported by the original paper.
This repository is based on our [Dragon](https://github.com/seetaresearch/Dragon), This repository is based on our [Dragon](https://github.com/seetaresearch/Dragon),
while the style of codes is PyTorch. The torch-style codes help us to simplify the while the style of codes is PyTorch.
hierarchical pipeline of modern detection.
The torch-style codes help us to simplify the hierarchical pipeline of modern detection.
## Installation ## Installation
...@@ -22,10 +24,41 @@ pip install opencv-python Pillow ...@@ -22,10 +24,41 @@ pip install opencv-python Pillow
#### 2. Compile the C Extensions #### 2. Compile the C Extensions
```bash ```bash
cd SeeTADet/compile cd SeetaDet/compile
bash ./make.sh bash ./make.sh
``` ```
## Quick Start
#### Train a detection model
```bash
cd SeetaDet/tools
python train.py --cfg <MODEL_YAML>
```
We have provided the default YAML examples into ``SeetaDet/configs``.
#### Test a detection model
```bash
cd SeetaDet/tools
python test.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR> --iter <ITERATION>
```
Or
```bash
cd SeetaDet/tools
python test_all.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR>
```
#### Export a detection model to ONNX
```bash
cd SeetaDet/tools
python export.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR> --iter <ITERATION>
```
## Resources ## Resources
#### Pre-trained ImageNet models #### Pre-trained ImageNet models
......
...@@ -32,7 +32,7 @@ FRCNN: ...@@ -32,7 +32,7 @@ FRCNN:
ROI_XFORM_METHOD: RoIAlign ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7 ROI_XFORM_RESOLUTION: 7
TRAIN: TRAIN:
WEIGHTS: '../data/imagenet_models/R-101.Affine.pth' WEIGHTS: '/data/models/imagenet/R-101.Affine.pth'
DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb' DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb'
IMS_PER_BATCH: 2 IMS_PER_BATCH: 2
USE_DIFF: False # Do not use crowd objects USE_DIFF: False # Do not use crowd objects
......
...@@ -32,7 +32,7 @@ FRCNN: ...@@ -32,7 +32,7 @@ FRCNN:
ROI_XFORM_METHOD: RoIAlign ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7 ROI_XFORM_RESOLUTION: 7
TRAIN: TRAIN:
WEIGHTS: '../data/imagenet_models/R-101.Affine.pth' WEIGHTS: '/data/models/imagenet/R-101.Affine.pth'
DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb' DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb'
IMS_PER_BATCH: 2 IMS_PER_BATCH: 2
USE_DIFF: False # Do not use crowd objects USE_DIFF: False # Do not use crowd objects
......
...@@ -23,7 +23,7 @@ FRCNN: ...@@ -23,7 +23,7 @@ FRCNN:
ROI_XFORM_METHOD: RoIAlign ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7 ROI_XFORM_RESOLUTION: 7
TRAIN: TRAIN:
WEIGHTS: '../data/imagenet_models/R-50.Affine.pth' WEIGHTS: '/data/models/imagenet/R-50.Affine.pth'
DATABASE: 'taas:/data/voc_0712_trainval_lmdb' DATABASE: 'taas:/data/voc_0712_trainval_lmdb'
IMS_PER_BATCH: 2 IMS_PER_BATCH: 2
BATCH_SIZE: 128 BATCH_SIZE: 128
......
...@@ -28,7 +28,7 @@ FRCNN: ...@@ -28,7 +28,7 @@ FRCNN:
ROI_XFORM_RESOLUTION: 7 ROI_XFORM_RESOLUTION: 7
MLP_HEAD_DIM: 4096 MLP_HEAD_DIM: 4096
TRAIN: TRAIN:
WEIGHTS: '../data/imagenet_models/VGG16.RCNN.pth' WEIGHTS: '/data/models/imagenet/VGG16.RCNN.pth'
DATABASE: 'taas:/data/voc_0712_trainval_lmdb' DATABASE: 'taas:/data/voc_0712_trainval_lmdb'
RPN_MIN_SIZE: 16 RPN_MIN_SIZE: 16
IMS_PER_BATCH: 2 IMS_PER_BATCH: 2
......
...@@ -32,7 +32,7 @@ FPN: ...@@ -32,7 +32,7 @@ FPN:
RPN_MIN_LEVEL: 3 RPN_MIN_LEVEL: 3
RPN_MAX_LEVEL: 7 RPN_MAX_LEVEL: 7
TRAIN: TRAIN:
WEIGHTS: '../data/imagenet_models/R-50.Affine.pth' WEIGHTS: '/data/models/imagenet/R-50.Affine.pth'
DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb' DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb'
IMS_PER_BATCH: 8 IMS_PER_BATCH: 8
SCALES: [400] SCALES: [400]
......
...@@ -35,7 +35,7 @@ DROPBLOCK: ...@@ -35,7 +35,7 @@ DROPBLOCK:
DROP_ON: True DROP_ON: True
DECREMENT: 0.000005 # * 20000 = 0.1 DECREMENT: 0.000005 # * 20000 = 0.1
TRAIN: TRAIN:
WEIGHTS: '../data/imagenet_models/R-50.Affine.pth' WEIGHTS: '/data/models/imagenet/R-50.Affine.pth'
DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb' DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb'
IMS_PER_BATCH: 8 IMS_PER_BATCH: 8
SCALES: [400] SCALES: [400]
......
...@@ -29,7 +29,7 @@ SSD: ...@@ -29,7 +29,7 @@ SSD:
STRIDES: [8, 16, 32] STRIDES: [8, 16, 32]
ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5], [1, 2, 0.5]] ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5], [1, 2, 0.5]]
TRAIN: TRAIN:
WEIGHTS: '../data/imagenet_models/AirNet.SSD.pth' WEIGHTS: '/data/models/imagenet/AirNet.SSD.pth'
DATABASE: 'taas:/data/voc_0712_trainval_lmdb' DATABASE: 'taas:/data/voc_0712_trainval_lmdb'
IMS_PER_BATCH: 32 IMS_PER_BATCH: 32
TEST: TEST:
......
...@@ -32,7 +32,7 @@ SSD: ...@@ -32,7 +32,7 @@ SSD:
ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5, 3, 0.33], [1, 2, 0.5, 3, 0.33], ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5, 3, 0.33], [1, 2, 0.5, 3, 0.33],
[1, 2, 0.5, 3, 0.33], [1, 2, 0.5], [1, 2, 0.5]] [1, 2, 0.5, 3, 0.33], [1, 2, 0.5], [1, 2, 0.5]]
TRAIN: TRAIN:
WEIGHTS: '../data/imagenet_models/VGG16.SSD.pth' WEIGHTS: '/data/models/imagenet/VGG16.SSD.pth'
DATABASE: 'taas:/data/voc_0712_trainval_lmdb' DATABASE: 'taas:/data/voc_0712_trainval_lmdb'
IMS_PER_BATCH: 32 IMS_PER_BATCH: 32
TEST: TEST:
......
...@@ -13,48 +13,27 @@ ...@@ -13,48 +13,27 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from lib.datasets.pascal_voc import pascal_voc from lib.datasets.taas import TaaS
from lib.datasets.coco import coco
from lib.datasets.taas import taas
__sets = {}
# pascal voc # TaaS DataSet
for year in ['2007', '2012', '0712']: _GLOBAL_DATA_SETS = {'taas': lambda source: TaaS(source)}
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
# coco 2014
for year in ['2014']:
for split in ['train', 'val', 'trainval35k', 'minival', 'valminusminival']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# coco 2015 & 2017
for year in ['2015', '2017']:
for split in ['test', 'test-dev']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# taas
__sets['taas'] = (lambda source: taas(source))
def get_imdb(name): def get_imdb(name):
"""Get an imdb (image database) by name.""" """Get an imdb (image database) by name."""
keys = name.split(':') keys = name.split(':')
if len(keys) == 2: if len(keys) >= 2:
cls, source = keys cls, source = keys[0], ':'.join(keys[1:])
if cls not in __sets: if cls not in _GLOBAL_DATA_SETS:
raise KeyError('Unknown dataset: {}'.format(cls)) raise KeyError('Unknown dataset: {}'.format(cls))
return __sets[cls](source) return _GLOBAL_DATA_SETS[cls](source)
elif len(keys) == 1: elif len(keys) == 1:
return __sets[name]() return _GLOBAL_DATA_SETS[name]()
else: else:
raise ValueError('Illegal format of image database: {}'.format(name)) raise ValueError('Illegal format of image database: {}'.format(name))
def list_imdbs(): def list_imdbs():
"""List all registered imdbs.""" """List all registered imdbs."""
return __sets.keys() return _GLOBAL_DATA_SETS.keys()
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import numpy as np
import uuid
try:
import cPickle
except:
import pickle as cPickle
from .imdb import imdb
from .voc_eval import voc_bbox_eval, voc_segm_eval
class pascal_voc(imdb):
def __init__(self, image_set, year, name='voc'):
imdb.__init__(self, name + '_' + year + '_' + image_set)
self._year = year
self._image_set = image_set
self._classes = ('__background__', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
self._salt = str(uuid.uuid4())
self.config = {'cleanup': True, 'use_salt': True}
def _get_comp_id(self):
return '_' + self._salt if self.config['use_salt'] else ''
def _get_prefix(self, type='bbox'):
if type == 'bbox': return 'detections_'
elif type == 'segm': return 'segmentations_'
elif type == 'kpt': return 'keypoints_'
return ''
def _get_voc_results_T(self, results_folder, type='bbox'):
# experiments/model_id/results/detections_voc_2007_test_<comp_id>_aeroplane.txt
filename = self._get_prefix(type) + self._name + self._get_comp_id() + '_{:s}.txt'
if not os.path.exists(results_folder): os.makedirs(results_folder)
return os.path.join(results_folder, filename)
def _write_voc_bbox_results(self, all_boxes, gt_recs, output_dir):
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__': continue
print('Writing {} VOC format bbox results'.format(cls))
filename = self._get_voc_results_T(output_dir).format(cls)
with open(filename, 'wt') as f:
ix = 0
for image_id, rec in gt_recs.items():
dets = all_boxes[cls_ind][ix]; ix += 1
if dets == []: continue
for k in range(dets.shape[0]):
f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
format(image_id, dets[k, -1],
dets[k, 0] + 1, dets[k, 1] + 1,
dets[k, 2] + 1, dets[k, 3] + 1))
def _write_seg_results_file(self, all_boxes, all_masks):
for cls_inds, cls in enumerate(self.classes):
if cls == '__background__': continue
print('Writing {} VOC results file'.format(cls))
results_folder = os.path.join(self._devkit_path, 'results', 'seg')
if not os.path.exists(results_folder): os.makedirs(results_folder)
det_filename = os.path.join(results_folder, cls + '_det.pkl')
seg_filename = os.path.join(results_folder, cls + '_seg.pkl')
with open(det_filename, 'wb') as f:
cPickle.dump(all_boxes[cls_inds], f, cPickle.HIGHEST_PROTOCOL)
with open(seg_filename, 'wb') as f:
cPickle.dump(all_masks[cls_inds], f, cPickle.HIGHEST_PROTOCOL)
def _do_voc_bbox_eval(self, gt_recs, output_dir):
aps = []
# The PASCAL VOC metric changed in 2010
use_07_metric = True if int(self._year) < 2010 else False
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No') + '\n')
for i, cls in enumerate(self._classes):
if cls == '__background__':
continue
det_file = self._get_voc_results_T(output_dir).format(cls)
rec, prec, ap = voc_bbox_eval(det_file, gt_recs, cls,
IoU=0.5, use_07_metric=use_07_metric)
aps += [ap]
print('AP for {} = {:.4f}'.format(cls, ap))
print('Mean AP = {:.4f}\n'.format(np.mean(aps)))
def _do_voc_segm_eval(self, imagenames, output_dir):
aps = []
# define this as true according to SDS's evaluation protocol
use_07_metric = True
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
print('~~~~~~ Evaluation use min overlap = 0.5 ~~~~~~')
for i, cls in enumerate(self.classes):
if cls == '__background__':
continue
det_file = os.path.join(output_dir, 'bbox_' + cls + '.pkl')
seg_file = os.path.join(output_dir, 'segm_' + cls + '.pkl')
mask_file = os.path.join(self.cache_path, self.name + '.pkl')
ap = seg_eval_v2(det_file, seg_file, mask_file, imagenames, cls,
ovthresh=0.5, use_07_metric=use_07_metric)
aps += [ap]
print('AP for {} = {:.2f}'.format(cls, ap))
print('Mean AP@0.5 = {:.2f}'.format(np.mean(aps)))
print('~~~~~~ Evaluation use min overlap = 0.7 ~~~~~~')
aps = []
for i, cls in enumerate(self.classes):
if cls == '__background__':
continue
det_file = os.path.join(output_dir, 'bbox_' + cls + '.pkl')
seg_file = os.path.join(output_dir, 'segm_' + cls + '.pkl')
mask_file = os.path.join(self.cache_path, self.name + '.pkl')
ap = seg_eval_v2(det_file, seg_file, mask_file, imagenames, cls,
ovthresh=0.7, use_07_metric=use_07_metric)
aps += [ap]
print('AP for {} = {:.2f}'.format(cls, ap))
print('Mean AP@0.7 = {:.2f}'.format(np.mean(aps)))
def evaluate_detections(self, all_boxes, gt_recs, output_dir):
self._write_voc_bbox_results(all_boxes, gt_recs, output_dir)
self._do_voc_bbox_eval(gt_recs, output_dir)
if self.config['cleanup']:
for cls in self._classes:
if cls == '__background__': continue
filename = self._get_voc_results_T(output_dir).format(cls)
os.remove(filename)
def competition_mode(self, on):
if on:
self.config['use_salt'] = False
self.config['cleanup'] = False
else:
self.config['use_salt'] = True
self.config['cleanup'] = True
\ No newline at end of file
...@@ -35,7 +35,7 @@ from lib.utils import boxes as box_utils ...@@ -35,7 +35,7 @@ from lib.utils import boxes as box_utils
from lib.pycocotools.mask import encode as encode_masks from lib.pycocotools.mask import encode as encode_masks
class taas(imdb): class TaaS(imdb):
def __init__(self, source): def __init__(self, source):
imdb.__init__(self, 'taas') imdb.__init__(self, 'taas')
self._classes = cfg.MODEL.CLASSES self._classes = cfg.MODEL.CLASSES
...@@ -151,6 +151,40 @@ class taas(imdb): ...@@ -151,6 +151,40 @@ class taas(imdb):
############################################## ##############################################
# # # #
# ROT #
# #
##############################################
def _write_voc_rbox_results(self, all_boxes, gt_recs, output_dir):
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__': continue
print('Writing {} VOC format rbox results'.format(cls))
filename = self._get_voc_results_T(output_dir).format(cls)
with open(filename, 'wt') as f:
ix = 0
for image_id, rec in gt_recs.items():
dets = all_boxes[cls_ind][ix]; ix += 1
if dets == []: continue
for k in range(dets.shape[0]):
f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f} {:.2f}\n'.
format(image_id, dets[k, -1],
dets[k, 0] + 1, dets[k, 1] + 1,
dets[k, 2] + 1, dets[k, 3] + 1, dets[k, 4]))
def _do_voc_rbox_eval(self, gt_recs, output_dir, IoU=0.5, use_07_metric=True):
aps = []
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
for i, cls in enumerate(self._classes):
if cls == '__background__': continue
det_file = self._get_voc_results_T(output_dir).format(cls)
rec, prec, ap = voc_rbox_eval(det_file, gt_recs, cls,
IoU=IoU, use_07_metric=use_07_metric)
if ap > 0: aps += [ap]
print('AP for {} = {:.4f}'.format(cls, ap))
print('Mean AP = {:.4f}\n'.format(np.mean(aps)))
##############################################
# #
# COCO # # COCO #
# # # #
############################################## ##############################################
...@@ -398,6 +432,15 @@ class taas(imdb): ...@@ -398,6 +432,15 @@ class taas(imdb):
print('~~~~~~ Evaluation IoU@0.7 ~~~~~~') print('~~~~~~ Evaluation IoU@0.7 ~~~~~~')
self._do_voc_bbox_eval(gt_recs, output_dir, self._do_voc_bbox_eval(gt_recs, output_dir,
IoU=0.7, use_07_metric='2007' in protocol) IoU=0.7, use_07_metric='2007' in protocol)
elif 'rot' in protocol:
self._write_voc_rbox_results(all_boxes, gt_recs, output_dir)
if not 'wo' in protocol:
print('\n~~~~~~ Evaluation IoU@0.5 ~~~~~~')
self._do_voc_rbox_eval(gt_recs, output_dir,
IoU=0.5, use_07_metric='2007' in protocol)
print('~~~~~~ Evaluation IoU@0.7 ~~~~~~')
self._do_voc_rbox_eval(gt_recs, output_dir,
IoU=0.7, use_07_metric='2007' in protocol)
elif 'coco' in protocol: elif 'coco' in protocol:
from lib.pycocotools.coco import COCO from lib.pycocotools.coco import COCO
if os.path.exists(cfg.TEST.JSON_FILE): if os.path.exists(cfg.TEST.JSON_FILE):
......
...@@ -30,11 +30,6 @@ try: ...@@ -30,11 +30,6 @@ try:
except ImportError as e: except ImportError as e:
print('Failed to import gpu nms. Error: {0}'.format(str(e))) print('Failed to import gpu nms. Error: {0}'.format(str(e)))
try:
from lib.utils.rboxes import RNMSWrapper
except ImportError as e:
print('Failed to import rnms. Error: {0}'.format(str(e)))
def nms(detections, thresh, force_cpu=False): def nms(detections, thresh, force_cpu=False):
"""Perform either CPU or GPU Hard-NMS.""" """Perform either CPU or GPU Hard-NMS."""
...@@ -62,20 +57,4 @@ def soft_nms( ...@@ -62,20 +57,4 @@ def soft_nms(
methods[method], methods[method],
sigma, sigma,
score_thresh, score_thresh,
) )
\ No newline at end of file
def rnms(detections, thresh):
"""Perform CPU Hard-NMS on rotated boxes.
Parameters
----------
detections : numpy.ndarray
(N, 6) of double [cx, cy, w, h, a, scores]
thresh : float
The nms thresh.
"""
if detections.shape[0] == 0: return []
wrapper = RNMSWrapper()
return wrapper.nms(detections, thresh)
\ No newline at end of file
...@@ -13,8 +13,7 @@ from __future__ import absolute_import ...@@ -13,8 +13,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import PIL.Image import cv2
import numpy as np
import numpy.random as npr import numpy.random as npr
from lib.core.config import cfg from lib.core.config import cfg
...@@ -25,17 +24,17 @@ class Resizer(object): ...@@ -25,17 +24,17 @@ class Resizer(object):
self._re_height = cfg.SSD.RESIZE.HEIGHT self._re_height = cfg.SSD.RESIZE.HEIGHT
self._re_width = cfg.SSD.RESIZE.WIDTH self._re_width = cfg.SSD.RESIZE.WIDTH
interp_list = { interp_list = {
'LINEAR': PIL.Image.BILINEAR, 'LINEAR': cv2.INTER_LINEAR,
'AREA': PIL.Image.BILINEAR, 'AREA': cv2.INTER_AREA,
'NEAREST': PIL.Image.NEAREST, 'NEAREST': cv2.INTER_NEAREST,
'CUBIC': PIL.Image.CUBIC, 'CUBIC': cv2.INTER_CUBIC,
'LANCZOS4': PIL.Image.LANCZOS, 'LANCZOS4': cv2.INTER_LANCZOS4,
} }
interp_mode = cfg.SSD.RESIZE.INTERP_MODE interp_mode = cfg.SSD.RESIZE.INTERP_MODE
self._interp_mode = [interp_list[key] for key in interp_mode] self._interp_mode = [interp_list[key] for key in interp_mode]
def resize_image(self, im): def resize_image(self, im):
rand = npr.randint(0, len(self._interp_mode)) rand = npr.randint(0, len(self._interp_mode))
im = PIL.Image.fromarray(im) return cv2.resize(
im = im.resize((self._re_width, self._re_height), self._interp_mode[rand]) im, (self._re_width, self._re_height),
return np.array(im) interpolation=self._interp_mode[rand])
\ No newline at end of file \ No newline at end of file
...@@ -13,7 +13,6 @@ from __future__ import absolute_import ...@@ -13,7 +13,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys
import cv2 import cv2
import numpy as np import numpy as np
...@@ -24,11 +23,9 @@ from lib.core.config import cfg ...@@ -24,11 +23,9 @@ from lib.core.config import cfg
def resize_image(im, fx, fy): def resize_image(im, fx, fy):
im_shape = im.shape return cv2.resize(
im = PIL.Image.fromarray(im) im, None, fx=fx, fy=fy,
size = (int(np.ceil(im_shape[1] * fx)), int(np.ceil(im_shape[0] * fy))) interpolation=cv2.INTER_LINEAR)
im = im.resize(size, PIL.Image.BILINEAR)
return np.array(im)
# Faster and robust resizing than OpenCV methods # Faster and robust resizing than OpenCV methods
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!