Commit f8359d17 by Ting PAN

Adapt to SeetaRecord

1 parent ca255ea0
Showing with 1441 additions and 2218 deletions
...@@ -47,4 +47,4 @@ __pycache__ ...@@ -47,4 +47,4 @@ __pycache__
.idea .idea
# OSX dir files # OSX dir files
.DS_Store .DS_Store
\ No newline at end of file
------------------------------------------------------------------------ ------------------------------------------------------------------------
The list of most significant changes made over time in SeetaDet. The list of most significant changes made over time in SeetaDet.
SeetaDet 0.2.0 (20190929)
Dragon Minimum Required (Version 0.3.0.dev20190929)
Changes:
Preview Features:
- Use SeetaRecord instead of LMDB.
- Flatten the implementation of layers.
Bugs fixed:
- None
------------------------------------------------------------------------
SeetaDet 0.1.2 (20190723) SeetaDet 0.1.2 (20190723)
Dragon Minimum Required (Version 0.3.0.0) Dragon Minimum Required (Version 0.3.0.0)
......
#!/bin/sh
# delete cache # delete cache
rm -r build install *.c *.cpp rm -r build install *.c *.cpp
# compile proto files
protoc -I ../lib/proto --python_out=../lib/proto ../lib/proto/anno.proto
# compile cython modules # compile cython modules
python setup.py build_ext --inplace python setup.py build_ext --inplace
# compile cuda modules # compile cuda modules
cd build cd build && cmake .. && make install && cd ..
cmake .. && make install && cd ..
# setup # setup
cp -r install/lib ../ cp -r install/lib ../
...@@ -32,15 +32,15 @@ FRCNN: ...@@ -32,15 +32,15 @@ FRCNN:
ROI_XFORM_METHOD: RoIAlign ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7 ROI_XFORM_RESOLUTION: 7
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/R-101.Affine.pth' WEIGHTS: '/model/R-101.Affine.pth'
DATABASE: '/data/coco_2014_trainval35k_lmdb' DATABASE: '/data/coco_2014_trainval35k'
IMS_PER_BATCH: 2 IMS_PER_BATCH: 2
USE_DIFF: False # Do not use crowd objects USE_DIFF: False # Do not use crowd objects
BATCH_SIZE: 512 BATCH_SIZE: 512
SCALES: [800] SCALES: [800]
MAX_SIZE: 1333 MAX_SIZE: 1333
TEST: TEST:
DATABASE: '/data/coco_2014_minival_lmdb' DATABASE: '/data/coco_2014_minival'
JSON_FILE: '/data/instances_minival2014.json' JSON_FILE: '/data/instances_minival2014.json'
PROTOCOL: 'coco' PROTOCOL: 'coco'
RPN_POST_NMS_TOP_N: 1000 RPN_POST_NMS_TOP_N: 1000
......
...@@ -32,15 +32,15 @@ FRCNN: ...@@ -32,15 +32,15 @@ FRCNN:
ROI_XFORM_METHOD: RoIAlign ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7 ROI_XFORM_RESOLUTION: 7
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/R-101.Affine.pth' WEIGHTS: '/model/R-101.Affine.pth'
DATABASE: '/data/coco_2014_trainval35k_lmdb' DATABASE: '/data/coco_2014_trainval35k'
IMS_PER_BATCH: 2 IMS_PER_BATCH: 2
USE_DIFF: False # Do not use crowd objects USE_DIFF: False # Do not use crowd objects
BATCH_SIZE: 512 BATCH_SIZE: 512
SCALES: [800] SCALES: [800]
MAX_SIZE: 1333 MAX_SIZE: 1333
TEST: TEST:
DATABASE: '/data/coco_2014_minival_lmdb' DATABASE: '/data/coco_2014_minival'
JSON_FILE: '/data/instances_minival2014.json' JSON_FILE: '/data/instances_minival2014.json'
PROTOCOL: 'coco' PROTOCOL: 'coco'
RPN_POST_NMS_TOP_N: 1000 RPN_POST_NMS_TOP_N: 1000
......
...@@ -23,14 +23,14 @@ FRCNN: ...@@ -23,14 +23,14 @@ FRCNN:
ROI_XFORM_METHOD: RoIAlign ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7 ROI_XFORM_RESOLUTION: 7
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/R-50.Affine.pth' WEIGHTS: '/model/R-50.Affine.pth'
DATABASE: '/data/voc_0712_trainval_lmdb' DATABASE: '/data/voc_0712_trainval'
IMS_PER_BATCH: 2 IMS_PER_BATCH: 2
BATCH_SIZE: 128 BATCH_SIZE: 128
SCALES: [600] SCALES: [600]
MAX_SIZE: 1000 MAX_SIZE: 1000
TEST: TEST:
DATABASE: '/data/voc_2007_test_lmdb' DATABASE: '/data/voc_2007_test'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco' PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
RPN_POST_NMS_TOP_N: 1000 RPN_POST_NMS_TOP_N: 1000
SCALES: [600] SCALES: [600]
......
...@@ -28,15 +28,15 @@ FRCNN: ...@@ -28,15 +28,15 @@ FRCNN:
ROI_XFORM_RESOLUTION: 7 ROI_XFORM_RESOLUTION: 7
MLP_HEAD_DIM: 4096 MLP_HEAD_DIM: 4096
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/VGG16.RCNN.pth' WEIGHTS: '/model/VGG16.RCNN.pth'
DATABASE: '/data/voc_0712_trainval_lmdb' DATABASE: '/data/voc_0712_trainval'
RPN_MIN_SIZE: 16 RPN_MIN_SIZE: 16
IMS_PER_BATCH: 2 IMS_PER_BATCH: 2
BATCH_SIZE: 128 BATCH_SIZE: 128
SCALES: [600] SCALES: [600]
MAX_SIZE: 1000 MAX_SIZE: 1000
TEST: TEST:
DATABASE: '/data/voc_2007_test_lmdb' DATABASE: '/data/voc_2007_test'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco' PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
RPN_MIN_SIZE: 16 RPN_MIN_SIZE: 16
RPN_POST_NMS_TOP_N: 300 RPN_POST_NMS_TOP_N: 300
......
...@@ -32,13 +32,13 @@ FPN: ...@@ -32,13 +32,13 @@ FPN:
RPN_MIN_LEVEL: 3 RPN_MIN_LEVEL: 3
RPN_MAX_LEVEL: 7 RPN_MAX_LEVEL: 7
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/R-50.Affine.pth' WEIGHTS: '/model/R-50.Affine.pth'
DATABASE: '/data/coco_2014_trainval35k_lmdb' DATABASE: '/data/coco_2014_trainval35k'
IMS_PER_BATCH: 8 IMS_PER_BATCH: 8
SCALES: [400] SCALES: [400]
MAX_SIZE: 666 MAX_SIZE: 666
TEST: TEST:
DATABASE: '/data/coco_2014_minival_lmdb' DATABASE: '/data/coco_2014_minival'
JSON_FILE: '/data/instances_minival2014.json' JSON_FILE: '/data/instances_minival2014.json'
PROTOCOL: 'coco' PROTOCOL: 'coco'
IMS_PER_BATCH: 1 IMS_PER_BATCH: 1
......
...@@ -36,8 +36,8 @@ DROPBLOCK: ...@@ -36,8 +36,8 @@ DROPBLOCK:
DROP_ON: True DROP_ON: True
DECREMENT: 0.000005 # * 20000 = 0.1 DECREMENT: 0.000005 # * 20000 = 0.1
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/R-50.Affine.pth' WEIGHTS: '/model/R-50.Affine.pth'
DATABASE: '/data/coco_2014_trainval35k_lmdb' DATABASE: '/data/coco_2014_trainval35k'
IMS_PER_BATCH: 8 IMS_PER_BATCH: 8
SCALES: [400] SCALES: [400]
MAX_SIZE: 666 MAX_SIZE: 666
...@@ -45,7 +45,7 @@ TRAIN: ...@@ -45,7 +45,7 @@ TRAIN:
COLOR_JITTERING: True COLOR_JITTERING: True
SCALE_RANGE: [0.75, 1.33] SCALE_RANGE: [0.75, 1.33]
TEST: TEST:
DATABASE: '/data/coco_2014_minival_lmdb' DATABASE: '/data/coco_2014_minival'
JSON_FILE: '/data/instances_minival2014.json' JSON_FILE: '/data/instances_minival2014.json'
PROTOCOL: 'coco' PROTOCOL: 'coco'
IMS_PER_BATCH: 1 IMS_PER_BATCH: 1
......
...@@ -23,8 +23,8 @@ FPN: ...@@ -23,8 +23,8 @@ FPN:
RPN_MIN_LEVEL: 3 RPN_MIN_LEVEL: 3
RPN_MAX_LEVEL: 7 RPN_MAX_LEVEL: 7
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/AirNet.Affine.pth' WEIGHTS: '/model/AirNet.Affine.pth'
DATABASE: '/data/voc_0712_trainval_lmdb' DATABASE: '/data/voc_0712_trainval'
IMS_PER_BATCH: 32 IMS_PER_BATCH: 32
SCALES: [300] SCALES: [300]
MAX_SIZE: 500 MAX_SIZE: 500
...@@ -32,7 +32,7 @@ TRAIN: ...@@ -32,7 +32,7 @@ TRAIN:
SCALE_JITTERING: True SCALE_JITTERING: True
COLOR_JITTERING: True COLOR_JITTERING: True
TEST: TEST:
DATABASE: '/data/voc_2007_test_lmdb' DATABASE: '/data/voc_2007_test'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco' PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
IMS_PER_BATCH: 1 IMS_PER_BATCH: 1
SCALES: [300] SCALES: [300]
......
...@@ -24,8 +24,8 @@ FPN: ...@@ -24,8 +24,8 @@ FPN:
RPN_MIN_LEVEL: 3 RPN_MIN_LEVEL: 3
RPN_MAX_LEVEL: 7 RPN_MAX_LEVEL: 7
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/R-18.Affine.pth' WEIGHTS: '/model/R-18.Affine.pth'
DATABASE: '/data/voc_0712_trainval_lmdb' DATABASE: '/data/voc_0712_trainval'
IMS_PER_BATCH: 32 IMS_PER_BATCH: 32
SCALES: [300] SCALES: [300]
MAX_SIZE: 500 MAX_SIZE: 500
...@@ -33,7 +33,7 @@ TRAIN: ...@@ -33,7 +33,7 @@ TRAIN:
SCALE_JITTERING: True SCALE_JITTERING: True
COLOR_JITTERING: True COLOR_JITTERING: True
TEST: TEST:
DATABASE: '/data/voc_2007_test_lmdb' DATABASE: '/data/voc_2007_test'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco' PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
IMS_PER_BATCH: 1 IMS_PER_BATCH: 1
SCALES: [300] SCALES: [300]
......
...@@ -24,8 +24,8 @@ FPN: ...@@ -24,8 +24,8 @@ FPN:
RPN_MIN_LEVEL: 3 RPN_MIN_LEVEL: 3
RPN_MAX_LEVEL: 7 RPN_MAX_LEVEL: 7
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/R-34.Affine.pth' WEIGHTS: '/model/R-34.Affine.pth'
DATABASE: '/data/voc_0712_trainval_lmdb' DATABASE: '/data/voc_0712_trainval'
IMS_PER_BATCH: 32 IMS_PER_BATCH: 32
SCALES: [300] SCALES: [300]
MAX_SIZE: 500 MAX_SIZE: 500
...@@ -33,7 +33,7 @@ TRAIN: ...@@ -33,7 +33,7 @@ TRAIN:
SCALE_JITTERING: True SCALE_JITTERING: True
COLOR_JITTERING: True COLOR_JITTERING: True
TEST: TEST:
DATABASE: '/data/voc_2007_test_lmdb' DATABASE: '/data/voc_2007_test'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco' PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
IMS_PER_BATCH: 1 IMS_PER_BATCH: 1
SCALES: [300] SCALES: [300]
......
...@@ -29,11 +29,11 @@ SSD: ...@@ -29,11 +29,11 @@ SSD:
STRIDES: [8, 16, 32] STRIDES: [8, 16, 32]
ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5], [1, 2, 0.5]] ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5], [1, 2, 0.5]]
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/AirNet.Affine.pth' WEIGHTS: '/model/AirNet.Affine.pth'
DATABASE: '/data/voc_0712_trainval_lmdb' DATABASE: '/data/voc_0712_trainval'
IMS_PER_BATCH: 32 IMS_PER_BATCH: 32
TEST: TEST:
DATABASE: '/data/voc_2007_test_lmdb' DATABASE: '/data/voc_2007_test'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco' PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
IMS_PER_BATCH: 8 IMS_PER_BATCH: 8
NMS_TOP_K: 400 NMS_TOP_K: 400
......
...@@ -32,11 +32,11 @@ SSD: ...@@ -32,11 +32,11 @@ SSD:
ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5, 3, 0.33], [1, 2, 0.5, 3, 0.33], ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5, 3, 0.33], [1, 2, 0.5, 3, 0.33],
[1, 2, 0.5, 3, 0.33], [1, 2, 0.5], [1, 2, 0.5]] [1, 2, 0.5, 3, 0.33], [1, 2, 0.5], [1, 2, 0.5]]
TRAIN: TRAIN:
WEIGHTS: '/data/models/imagenet/VGG16.SSD.pth' WEIGHTS: '/model/VGG16.SSD.pth'
DATABASE: '/data/voc_0712_trainval_lmdb' DATABASE: '/data/voc_0712_trainval'
IMS_PER_BATCH: 32 IMS_PER_BATCH: 32
TEST: TEST:
DATABASE: '/data/voc_2007_test_lmdb' DATABASE: '/data/voc_2007_test'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco' PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
IMS_PER_BATCH: 8 IMS_PER_BATCH: 8
NMS_TOP_K: 400 NMS_TOP_K: 400
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os.path as osp
sys.path.insert(0, '../../../')
from database.frcnn.utils.make_from_xml import make_db
if __name__ == '__main__':
VOC_ROOT_DIR = '/home/workspace/datasets/VOC'
# train database: voc_2007_trainval + voc_2012_trainval
make_db(database_file=osp.join(VOC_ROOT_DIR, 'cache/voc_0712_trainval_lmdb'),
images_path=[osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/JPEGImages'),
osp.join(VOC_ROOT_DIR, 'VOCdevkit2012/VOC2012/JPEGImages')],
annotations_path=[osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/Annotations'),
osp.join(VOC_ROOT_DIR, 'VOCdevkit2012/VOC2012/Annotations')],
imagesets_path=[osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/ImageSets/Main'),
osp.join(VOC_ROOT_DIR, 'VOCdevkit2012/VOC2012/ImageSets/Main')],
splits=['trainval', 'trainval'])
# test database: voc_2007_test
make_db(database_file=osp.join(VOC_ROOT_DIR, 'cache/voc_2007_test_lmdb'),
images_path=osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/JPEGImages'),
annotations_path=osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/Annotations'),
imagesets_path=osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/ImageSets/Main'),
splits=['test'])
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import cv2
from dragon.tools.db import LMDB
sys.path.insert(0, '../../..')
from lib.proto import anno_pb2 as pb
ZFILL = 8
ENCODE_QUALITY = 95
def set_zfill(value):
global ZFILL
ZFILL = value
def set_quality(value):
global ENCODE_QUALITY
ENCODE_QUALITY = value
def make_datum(image_id, image_file, objects):
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = ENCODE_QUALITY != 100
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), ENCODE_QUALITY])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = image_id
for ix, obj in enumerate(objects):
anno = pb.Annotation()
anno.x1, anno.y1, anno.x2, anno.y2 = obj['bbox']
anno.name = obj['name']
anno.difficult = obj['difficult']
anno_datum.annotation.add().CopyFrom(anno)
return anno_datum
def make_db(database_file, images_path, gt_recs, ext='.png'):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
total_line = len(gt_recs)
start_time = time.time()
zfill_flag = '{0:0%d}' % (ZFILL)
for image_id, objects in gt_recs.items():
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, total_line, now_time - start_time))
db.commit()
image_file = os.path.join(images_path, image_id + ext)
datum = make_datum(image_id, image_file, objects)
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import cv2
import xml.etree.ElementTree as ET
from dragon.tools.db import LMDB
sys.path.insert(0, '../../..')
from lib.proto import anno_pb2 as pb
ZFILL = 8
ENCODE_QUALITY = 95
def set_zfill(value):
global ZFILL
ZFILL = value
def set_quality(value):
global ENCODE_QUALITY
ENCODE_QUALITY = value
def make_datum(image_file, xml_file):
tree = ET.parse(xml_file)
filename = os.path.split(xml_file)[-1]
objs = tree.findall('object')
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
if im is None or im.shape[0] == 0 or im.shape[1] == 0:
print("XML have not objects ignored: ", xml_file)
return None
datum.height, datum.width, datum.channels = im.shape
datum.encoded = ENCODE_QUALITY != 100
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), ENCODE_QUALITY])
if im is None or im.shape[0] == 0 or im.shape[1] == 0:
print("XML have not objects ignored: ", xml_file)
return None
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = filename.split('.')[0]
if len(objs) == 0:
return None
for ix, obj in enumerate(objs):
anno = pb.Annotation()
bbox = obj.find('bndbox')
x1 = float(bbox.find('xmin').text)
y1 = float(bbox.find('ymin').text)
x2 = float(bbox.find('xmax').text)
y2 = float(bbox.find('ymax').text)
cls = obj.find('name').text.strip()
anno.x1, anno.y1, anno.x2, anno.y2 = (x1, y1, x2, y2)
anno.name = cls
class_name_set.add(cls)
anno.difficult = False
if obj.find('difficult') is not None:
anno.difficult = int(obj.find('difficult').text) == 1
anno_datum.annotation.add().CopyFrom(anno)
return anno_datum
def make_db(
database_file,
images_path,
annotations_path,
imagesets_path,
splits,
):
if os.path.isdir(database_file) is True:
print('Warning: The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
if not isinstance(images_path, list):
images_path = [images_path]
if not isinstance(annotations_path, list):
annotations_path = [annotations_path]
if not isinstance(imagesets_path, list):
imagesets_path = [imagesets_path]
assert len(splits) == len(imagesets_path)
assert len(splits) == len(images_path)
assert len(splits) == len(annotations_path)
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=1000)
db.open(database_file, mode='w')
count = 0
total_line = 0
start_time = time.time()
zfill_flag = '{0:0%d}' % ZFILL
for db_idx, split in enumerate(splits):
split_file = os.path.join(imagesets_path[db_idx], split + '.txt')
assert os.path.exists(split_file)
with open(split_file, 'r') as f:
lines = f.readlines()
total_line += len(lines)
for line in lines:
filename = line.strip()
image_file = os.path.join(images_path[db_idx], filename + '.jpg')
xml_file = os.path.join(annotations_path[db_idx], filename + '.xml')
datum = make_datum(image_file, xml_file)
if datum is not None:
count += 1
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
if count % 1000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, total_line, now_time - start_time))
db.commit()
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import shutil import shutil
import time import time
import numpy as np import numpy as np
from lib.core.config import cfg from lib.core.config import cfg
from lib.core.config import cfg_from_file from lib.core.config import cfg_from_file
class Coordinator(object): class Coordinator(object):
"""Coordinator is a simple tool to manage the """Coordinator is a simple tool to manage the
unique experiments from the YAML configurations. unique experiments from the YAML configurations.
""" """
def __init__(self, cfg_file, exp_dir=None): def __init__(self, cfg_file, exp_dir=None):
# Override the default configs # Override the default configs
cfg_from_file(cfg_file) cfg_from_file(cfg_file)
if cfg.EXP_DIR != '': if cfg.EXP_DIR != '':
exp_dir = cfg.EXP_DIR exp_dir = cfg.EXP_DIR
if exp_dir is None: if exp_dir is None:
model_id = time.strftime( model_id = time.strftime(
'%Y%m%d_%H%M%S', time.localtime(time.time())) '%Y%m%d_%H%M%S', time.localtime(time.time()))
self.experiment_dir = '../experiments/{}'.format(model_id) self.experiment_dir = '../experiments/{}'.format(model_id)
if not os.path.exists(self.experiment_dir): if not os.path.exists(self.experiment_dir):
os.makedirs(self.experiment_dir) os.makedirs(self.experiment_dir)
else: else:
if not os.path.exists(exp_dir): if not os.path.exists(exp_dir):
raise ValueError('ExperimentDir({}) does not exist.'.format(exp_dir)) raise ValueError('ExperimentDir({}) does not exist.'.format(exp_dir))
self.experiment_dir = exp_dir self.experiment_dir = exp_dir
def _path_at(self, file, auto_create=True): def _path_at(self, file, auto_create=True):
path = os.path.abspath(os.path.join(self.experiment_dir, file)) path = os.path.abspath(os.path.join(self.experiment_dir, file))
if auto_create and not os.path.exists(path): if auto_create and not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
return path return path
def checkpoints_dir(self): def checkpoints_dir(self):
return self._path_at('checkpoints') return self._path_at('checkpoints')
def exports_dir(self): def exports_dir(self):
return self._path_at('exports') return self._path_at('exports')
def results_dir(self, checkpoint=None): def results_dir(self, checkpoint=None):
sub_dir = os.path.splitext(os.path.basename(checkpoint))[0] if checkpoint else '' sub_dir = os.path.splitext(os.path.basename(checkpoint))[0] if checkpoint else ''
return self._path_at(os.path.join('results', sub_dir)) return self._path_at(os.path.join('results', sub_dir))
def checkpoint(self, global_step=None, wait=True): def checkpoint(self, global_step=None, wait=True):
def locate(): def locate():
files = os.listdir(self.checkpoints_dir()) files = os.listdir(self.checkpoints_dir())
steps = [] steps = []
for ix, file in enumerate(files): for ix, file in enumerate(files):
step = int(file.split('_iter_')[-1].split('.')[0]) step = int(file.split('_iter_')[-1].split('.')[0])
if global_step == step: if global_step == step:
return os.path.join(self.checkpoints_dir(), files[ix]), step return os.path.join(self.checkpoints_dir(), files[ix]), step
steps.append(step) steps.append(step)
if global_step is None: if global_step is None:
if len(files) == 0: if len(files) == 0:
return None, 0 return None, 0
last_idx = int(np.argmax(steps)) last_idx = int(np.argmax(steps))
last_step = steps[last_idx] last_step = steps[last_idx]
return os.path.join(self.checkpoints_dir(), files[last_idx]), last_step return os.path.join(self.checkpoints_dir(), files[last_idx]), last_step
return None, 0 return None, 0
result = locate() result = locate()
while result[0] is None and wait: while result[0] is None and wait:
print('\rWaiting for step_{}.checkpoint to exist...'.format(global_step), end='') print('\rWaiting for step_{}.checkpoint to exist...'.format(global_step), end='')
time.sleep(10) time.sleep(10)
result = locate() result = locate()
return result return result
def delete_experiment(self): def delete_experiment(self):
if os.path.exists(self.experiment_dir): if os.path.exists(self.experiment_dir):
shutil.rmtree(self.experiment_dir) shutil.rmtree(self.experiment_dir)
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import collections
import cv2 import multiprocessing as mp
from multiprocessing import Queue import os
from collections import OrderedDict
import cv2
from lib.core.config import cfg import dragon
from lib.datasets.factory import get_imdb
# All detectors share the same reader/transformer during testing from lib.core.config import cfg
from lib.faster_rcnn.data.data_reader import DataReader from lib.datasets.factory import get_imdb
from lib.faster_rcnn.data.data_transformer import DataTransformer from lib.faster_rcnn.data_transformer import DataTransformer
class TestServer(object): class TestServer(object):
def __init__(self, output_dir): def __init__(self, output_dir):
self.imdb = get_imdb(cfg.TEST.DATABASE) self.imdb = get_imdb(cfg.TEST.DATABASE)
self.imdb.competition_mode(cfg.TEST.COMPETITION_MODE) self.imdb.competition_mode(cfg.TEST.COMPETITION_MODE)
self.num_images, self.num_classes, self.classes = \ self.num_images, self.num_classes, self.classes = \
self.imdb.num_images, self.imdb.num_classes, self.imdb.classes self.imdb.num_images, self.imdb.num_classes, self.imdb.classes
self.data_reader = DataReader(**{'source': self.imdb.source}) self.data_reader = dragon.io.DataReader(
self.data_transformer = DataTransformer() dataset=lambda: dragon.io.SeetaRecordDataset(self.imdb.source))
self.data_reader.q_out = Queue(cfg.TEST.IMS_PER_BATCH) self.data_transformer = DataTransformer()
self.data_reader.start() self.data_reader.q_out = mp.Queue(cfg.TEST.IMS_PER_BATCH)
self.gt_recs = OrderedDict() self.data_reader.start()
self.output_dir = output_dir self.gt_recs = collections.OrderedDict()
if cfg.VIS_ON_FILE: self.output_dir = output_dir
self.vis_dir = os.path.join(self.output_dir, 'vis') if cfg.VIS_ON_FILE:
if not os.path.exists(self.vis_dir): self.vis_dir = os.path.join(self.output_dir, 'vis')
os.makedirs(self.vis_dir) if not os.path.exists(self.vis_dir):
os.makedirs(self.vis_dir)
def set_transformer(self, transformer_cls):
self.data_transformer = transformer_cls() def set_transformer(self, transformer_cls):
self.data_transformer = transformer_cls()
def get_image(self):
serialized = self.data_reader.q_out.get() def get_image(self):
image = self.data_transformer.get_image(serialized) example = self.data_reader.q_out.get()
image_id, objects = self.data_transformer.get_annotations(serialized) image = self.data_transformer.get_image(example)
self.gt_recs[image_id] = { image_id, objects = self.data_transformer.get_annotations(example)
'objects': objects, self.gt_recs[image_id] = {
'width': image.shape[1], 'objects': objects,
'height': image.shape[0], 'width': image.shape[1],
} 'height': image.shape[0],
return image_id, image }
return image_id, image
def get_save_filename(self, image_id, ext='.jpg'):
return os.path.join(self.vis_dir, image_id + ext) \ def get_save_filename(self, image_id, ext='.jpg'):
if cfg.VIS_ON_FILE else None return os.path.join(self.vis_dir, image_id + ext) \
if cfg.VIS_ON_FILE else None
def get_records(self):
if len(self.gt_recs) != self.num_images: def get_records(self):
raise RuntimeError( if len(self.gt_recs) != self.num_images:
'Loading {} records, while {} required.' raise RuntimeError(
.format(len(self.gt_recs), self.num_images), 'Loading {} records, while {} required.'
) .format(len(self.gt_recs), self.num_images),
return self.gt_recs )
return self.gt_recs
def evaluate_detections(self, all_boxes):
self.imdb.evaluate_detections( def evaluate_detections(self, all_boxes):
all_boxes, self.get_records(), self.output_dir) self.imdb.evaluate_detections(
all_boxes,
def evaluate_segmentations(self, all_boxes, all_masks): self.get_records(),
self.imdb.evaluate_segmentations( self.output_dir,
all_boxes, all_masks, self.get_records(), self.output_dir) )
def evaluate_segmentations(self, all_boxes, all_masks):
class InferServer(object): self.imdb.evaluate_segmentations(
def __init__(self, output_dir): all_boxes,
self.images_dir = cfg.TEST.DATABASE all_masks,
self.imdb = get_imdb('taas:/empty') self.get_records(),
self.images = os.listdir(self.images_dir) self.output_dir,
self.num_images, self.num_classes, self.classes = \ )
len(self.images), cfg.MODEL.NUM_CLASSES, cfg.MODEL.CLASSES
self.data_transformer = DataTransformer()
self.gt_recs = OrderedDict() class InferServer(object):
self.output_dir = output_dir def __init__(self, output_dir):
self.image_idx = 0 self.images_dir = cfg.TEST.DATABASE
if cfg.VIS_ON_FILE: self.imdb = get_imdb('taas:/empty')
self.vis_dir = os.path.join(self.output_dir, 'vis') self.images = os.listdir(self.images_dir)
if not os.path.exists(self.vis_dir): self.num_images, self.num_classes, self.classes = \
os.makedirs(self.vis_dir) len(self.images), cfg.MODEL.NUM_CLASSES, cfg.MODEL.CLASSES
self.data_transformer = DataTransformer()
def set_transformer(self, transformer_cls): self.gt_recs = collections.OrderedDict()
self.data_transformer = transformer_cls() self.output_dir = output_dir
self.image_idx = 0
def get_image(self): if cfg.VIS_ON_FILE:
image_name = self.images[self.image_idx] self.vis_dir = os.path.join(self.output_dir, 'vis')
image_id = image_name.split('.')[0] if not os.path.exists(self.vis_dir):
image = cv2.imread(os.path.join(self.images_dir, image_name)) os.makedirs(self.vis_dir)
self.image_idx = (self.image_idx + 1) % self.num_images
self.gt_recs[image_id] = { def set_transformer(self, transformer_cls):
'width': image.shape[1], self.data_transformer = transformer_cls()
'height': image.shape[0],
} def get_image(self):
return image_id, image image_name = self.images[self.image_idx]
image_id = image_name.split('.')[0]
def get_save_filename(self, image_id, ext='.jpg'): image = cv2.imread(os.path.join(self.images_dir, image_name))
return os.path.join(self.vis_dir, image_id + ext) \ self.image_idx = (self.image_idx + 1) % self.num_images
if cfg.VIS_ON_FILE else None self.gt_recs[image_id] = {'width': image.shape[1], 'height': image.shape[0]}
return image_id, image
def get_records(self):
if len(self.gt_recs) != self.num_images: def get_save_filename(self, image_id, ext='.jpg'):
raise RuntimeError( return os.path.join(self.vis_dir, image_id + ext) \
'Loading {} records, while {} required.' if cfg.VIS_ON_FILE else None
.format(len(self.gt_recs), self.num_images),
) def get_records(self):
return self.gt_recs if len(self.gt_recs) != self.num_images:
raise RuntimeError(
def evaluate_detections(self, all_boxes): 'Loading {} records, while {} required.'
self.imdb.evaluate_detections( .format(len(self.gt_recs), self.num_images),
all_boxes, )
self.get_records(), return self.gt_recs
self.output_dir,
) def evaluate_detections(self, all_boxes):
self.imdb.evaluate_detections(
def evaluate_segmentations(self, all_boxes, all_masks): all_boxes,
self.imdb.evaluate_segmentations( self.get_records(),
all_boxes, self.output_dir,
all_masks, )
self.get_records(),
self.output_dir, def evaluate_segmentations(self, all_boxes, all_masks):
) self.imdb.evaluate_segmentations(
all_boxes,
all_masks,
self.get_records(),
self.output_dir,
)
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on: # Codes are based on:
# #
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/factory.py> # <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/factory.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from lib.datasets.taas import TaaS from lib.datasets.taas import TaaS
# TaaS DataSet # TaaS DataSet
_GLOBAL_DATA_SETS = {'taas': lambda source: TaaS(source)} _GLOBAL_DATA_SETS = {'taas': lambda source: TaaS(source)}
def get_imdb(name): def get_imdb(name):
"""Get an imdb (image database) by name.""" """Get an imdb (image database) by name."""
keys = name.split(':') keys = name.split(':')
if len(keys) >= 2: if len(keys) >= 2:
cls, source = keys[0], ':'.join(keys[1:]) cls, source = keys[0], ':'.join(keys[1:])
if cls not in _GLOBAL_DATA_SETS: if cls not in _GLOBAL_DATA_SETS:
raise KeyError('Unknown DataSet: {}'.format(cls)) raise KeyError('Unknown DataSet: {}'.format(cls))
return _GLOBAL_DATA_SETS[cls](source) return _GLOBAL_DATA_SETS[cls](source)
elif os.path.exists(name): elif os.path.exists(name):
return _GLOBAL_DATA_SETS['taas'](name) return _GLOBAL_DATA_SETS['taas'](name)
else: else:
raise ValueError('Illegal Database: {}' + name) raise ValueError('Illegal Database: {}' + name)
def list_imdbs(): def list_imdbs():
"""List all registered imdbs.""" """List all registered imdbs."""
return _GLOBAL_DATA_SETS.keys() return _GLOBAL_DATA_SETS.keys()
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on: # Codes are based on:
# #
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/imdb.py> # <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/imdb.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
import os import os
from dragon.tools.db import LMDB import dragon
from lib.core.config import cfg from lib.core.config import cfg
class imdb(object): class imdb(object):
def __init__(self, name): def __init__(self, name):
self._name = name self._name = name
self._num_classes = 0 self._num_classes = 0
self._classes = [] self._classes = []
@property @property
def name(self): def name(self):
return self._name return self._name
@property @property
def num_classes(self): def num_classes(self):
return len(self._classes) return len(self._classes)
@property @property
def classes(self): def classes(self):
return self._classes return self._classes
@property @property
def cache_path(self): def cache_path(self):
cache_path = os.path.abspath(os.path.join(cfg.DATA_DIR, 'cache')) cache_path = os.path.abspath(os.path.join(cfg.DATA_DIR, 'cache'))
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
os.makedirs(cache_path) os.makedirs(cache_path)
return cache_path return cache_path
@property @property
def source(self): def source(self):
excepted_source = os.path.join(self.cache_path, self.name + '_lmdb') excepted_source = os.path.join(self.cache_path, self.name)
if not os.path.exists(excepted_source): if not os.path.exists(excepted_source):
raise RuntimeError('Excepted LMDB source from: {}, ' raise RuntimeError(
'but it is not existed.'.format(excepted_source)) 'Excepted source from: {}, '
return excepted_source 'but it is not existed.'
.format(excepted_source)
@property )
def num_images(self): return excepted_source
self._db = LMDB()
self._db.open(self.source) @property
num_entries = self._db.num_entries() def num_images(self):
self._db.close() return dragon.io.SeetaRecordDataset(self.source).size
return num_entries
def evaluate_detections(self, all_boxes, gt_recs, output_dir):
def evaluate_detections(self, all_boxes, gt_recs, output_dir): pass
pass
def evaluate_masks(self, all_boxes, all_masks, output_dir):
def evaluate_masks(self, all_boxes, all_masks, output_dir): pass
pass
...@@ -13,7 +13,7 @@ from __future__ import absolute_import ...@@ -13,7 +13,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from lib.faster_rcnn.layers.anchor_target_layer import AnchorTargetLayer from lib.faster_rcnn.anchor_target_layer import AnchorTargetLayer
from lib.faster_rcnn.layers.data_layer import DataLayer from lib.faster_rcnn.data_layer import DataLayer
from lib.faster_rcnn.layers.proposal_layer import ProposalLayer from lib.faster_rcnn.proposal_layer import ProposalLayer
from lib.faster_rcnn.layers.proposal_target_layer import ProposalTargetLayer from lib.faster_rcnn.proposal_target_layer import ProposalTargetLayer
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import numpy as np
from lib.core.config import cfg
from lib.utils.blob import im_list_to_blob
class BlobFetcher(multiprocessing.Process):
def __init__(self, **kwargs):
super(BlobFetcher, self).__init__()
self.q1_in = self.q2_in = self.q_out = None
self.daemon = True
def get(self, Q_in):
processed_ims, ims_info, all_boxes = [], [], []
for ix in range(cfg.TRAIN.IMS_PER_BATCH):
im, im_scale, gt_boxes = Q_in.get()
processed_ims.append(im)
ims_info.append(list(im.shape[0:2]) + [im_scale])
# Encode boxes by adding the idx of images
im_boxes = np.zeros((gt_boxes.shape[0], gt_boxes.shape[1] + 1), dtype=np.float32)
im_boxes[:, 0:gt_boxes.shape[1]] = gt_boxes
im_boxes[:, -1] = ix
all_boxes.append(im_boxes)
return {
'data': im_list_to_blob(processed_ims),
'ims_info': np.array(ims_info, dtype=np.float32),
'gt_boxes': np.concatenate(all_boxes, axis=0),
}
def run(self):
while True:
if self.q1_in.qsize() >= cfg.TRAIN.IMS_PER_BATCH:
self.q_out.put(self.get(self.q1_in))
elif self.q2_in.qsize() >= cfg.TRAIN.IMS_PER_BATCH:
self.q_out.put(self.get(self.q2_in))
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import multiprocessing
import numpy
from dragon.tools import db
from lib.core.config import cfg
class DataReader(multiprocessing.Process):
"""Collect encoded str from `LMDB`_.
Partition and shuffle records over distributed nodes.
Parameters
----------
source : str
The path of database.
shuffle : bool, optional, default=False
Whether to shuffle the data.
num_chunks : int, optional, default=2048
The number of chunks to split.
"""
def __init__(self, **kwargs):
"""Create a DataReader."""
super(DataReader, self).__init__()
self._source = kwargs.get('source', '')
self._use_shuffle = kwargs.get('shuffle', False)
self._num_chunks = kwargs.get('num_chunks', 2048)
self._part_idx, self._num_parts = 0, 1
self._cursor, self._chunk_cursor = 0, 0
self._chunk_size, self._perm_size = 0, 0
self._head, self._tail, self._num_entries = 0, 0, 0
self._db, self._zfill, self._perm = None, None, None
self._rng_seed = cfg.RNG_SEED
self.q_out = None
self.daemon = True
def element(self):
"""Get the value of current record.
Returns
-------
str
The encoded str.
"""
return self._db.value()
def redirect(self, target):
"""Redirect to the target position.
Parameters
----------
target : int
The key of the record.
Notes
-----
The redirection reopens the database.
You can drop caches by ``echo 3 > /proc/sys/vm/drop_caches``.
This will disturb getting stuck when *Database Size* >> *RAM Size*.
"""
self._db.close()
self._db.open(self._source)
self._cursor = target
self._db.set(str(target).zfill(self._zfill))
def reset(self):
"""Reset the cursor and environment."""
if self._num_parts > 1 or self._use_shuffle:
self._chunk_cursor = 0
self._part_idx = (self._part_idx + 1) % self._num_parts
if self._use_shuffle:
self._perm = numpy.random.permutation(self._perm_size)
self._head = self._part_idx * self._perm_size + self._perm[self._chunk_cursor]
self._tail = self._head * self._chunk_size
if self._head >= self._num_entries: self.next_chunk()
self._tail = self._head + self._chunk_size
self._tail = min(self._num_entries, self._tail)
else:
self._head, self._tail = 0, self._num_entries
self.redirect(self._head)
def next_record(self):
"""Step the cursor of records."""
self._db.next()
self._cursor += 1
def next_chunk(self):
"""Step the cursor of chunks."""
self._chunk_cursor += 1
if self._chunk_cursor >= self._perm_size:
self.reset()
else:
self._head = self._part_idx * self._perm_size + self._perm[self._chunk_cursor]
self._head = self._head * self._chunk_size
if self._head >= self._num_entries:
self.next_chunk()
else:
self._tail = self._head + self._chunk_size
self._tail = min(self._num_entries, self._tail)
self.redirect(self._head)
def run(self):
"""Start the process."""
# Fix seed
numpy.random.seed(self._rng_seed)
# Init db
self._db = db.LMDB()
self._db.open(self._source)
self._zfill = self._db.zfill()
self._num_entries = self._db.num_entries()
epoch_size = self._num_entries // self._num_parts + 1
if self._use_shuffle:
if self._num_chunks <= 0:
# Each chunk has at most 1 record (Record-Wise)
self._chunk_size, self._perm_size = 1, epoch_size
else:
# Search a optimal chunk size (Chunk-Wise)
min_size, max_size = \
1, self._db._total_size * 1.0 \
/ (self._num_chunks * (1 << 20))
while min_size * 2 < max_size: min_size *= 2
self._perm_size = int(math.ceil(
self._db._total_size * 1.1 /
(self._num_parts * min_size << 20)))
self._chunk_size = int(
self._num_entries * 1.0 /
(self._perm_size * self._num_parts) + 1)
limit = (self._num_parts - 0.5) * self._perm_size * self._chunk_size
if self._num_entries <= limit:
# Roll back to Record-Wise shuffle
self._chunk_size, self._perm_size = 1, epoch_size
else:
# One chunk has at most K records
self._chunk_size, self._perm_size = epoch_size, 1
self._perm = numpy.arange(self._perm_size)
# Init env
self.reset()
# Run!
while True:
self.q_out.put(self.element())
self.next_record()
if self._cursor >= self._tail:
if self._num_parts > 1 or self._use_shuffle:
self.next_chunk()
else:
self.reset()
...@@ -13,55 +13,70 @@ from __future__ import absolute_import ...@@ -13,55 +13,70 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from multiprocessing import Queue import multiprocessing as mp
import time import time
import dragon import dragon
import pprint import dragon.vm.torch as torch
import numpy as np
from lib.core.config import cfg from lib.core.config import cfg
from lib.faster_rcnn.data.data_reader import DataReader from lib.faster_rcnn.data_transformer import DataTransformer
from lib.faster_rcnn.data.data_transformer import DataTransformer from lib.datasets.factory import get_imdb
from lib.faster_rcnn.data.blob_fetcher import BlobFetcher
from lib.utils import logger from lib.utils import logger
from lib.utils.blob import im_list_to_blob
class DataBatch(object): class DataLayer(torch.nn.Module):
"""DataBatch aims to prefetch data by ``Triple-Buffering``. """Generate a mini-batch of data."""
It takes full advantages of the Process/Thread of Python, def __init__(self):
super(DataLayer, self).__init__()
database = get_imdb(cfg.TRAIN.DATABASE)
self.data_batch = DataBatch(**{
'dataset': lambda: dragon.io.SeetaRecordDataset(database.source),
'classes': database.classes,
'shuffle': cfg.TRAIN.USE_SHUFFLE,
'num_chunks': cfg.TRAIN.NUM_SHUFFLE_CHUNKS,
'batch_size': cfg.TRAIN.IMS_PER_BATCH * 2,
})
def forward(self):
# Get an array blob from the Queue
outputs = self.data_batch.get()
# Zero-Copy the array to tensor
outputs['data'] = torch.from_numpy(outputs['data'])
return outputs
class DataBatch(mp.Process):
"""Prefetch the batch of data."""
which provides remarkable I/O speed up for scalable distributed training.
"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Construct a ``DataBatch``. """Construct a ``DataBatch``.
Parameters Parameters
---------- ----------
source : str dataset : lambda
The path of database. The creator of a dataset.
shuffle : bool, optional, default=False shuffle : bool, optional, default=False
Whether to shuffle the data. Whether to shuffle the data.
num_chunks : int, optional, default=2048 num_chunks : int, optional, default=0
The number of chunks to split. The number of chunks to split.
batch_size : int, optional, default=128 batch_size : int, optional, default=2
The size of a mini-batch. The size of a mini-batch.
prefetch : int, optional, default=5 prefetch : int, optional, default=5
The prefetch count. The prefetch count.
""" """
super(DataBatch, self).__init__() super(DataBatch, self).__init__()
# Init mpi # Distributed settings
global_rank, local_rank, group_size = 0, 0, 1 rank, group_size = 0, 1
if dragon.mpi.is_init(): process_group = dragon.distributed.get_default_process_group()
group = dragon.mpi.is_parallel() if process_group is not None and kwargs.get(
if group is not None: # DataParallel 'phase', 'TRAIN') == 'TRAIN':
global_rank = dragon.mpi.rank() group_size = process_group.size
group_size = len(group) rank = dragon.distributed.get_rank(process_group)
for i, node in enumerate(group):
if global_rank == node:
local_rank = i
kwargs['group_size'] = group_size kwargs['group_size'] = group_size
# Configuration # Configuration
...@@ -71,6 +86,7 @@ class DataBatch(object): ...@@ -71,6 +86,7 @@ class DataBatch(object):
self._num_transformers = kwargs.get('num_transformers', -1) self._num_transformers = kwargs.get('num_transformers', -1)
self._max_transformers = kwargs.get('max_transformers', 3) self._max_transformers = kwargs.get('max_transformers', 3)
self._num_fetchers = kwargs.get('num_fetchers', 1) self._num_fetchers = kwargs.get('num_fetchers', 1)
self.daemon = True
# Io-Aware Policy # Io-Aware Policy
if self._num_transformers == -1: if self._num_transformers == -1:
...@@ -81,66 +97,52 @@ class DataBatch(object): ...@@ -81,66 +97,52 @@ class DataBatch(object):
self._num_transformers = min( self._num_transformers = min(
self._num_transformers, self._max_transformers) self._num_transformers, self._max_transformers)
# Init queues # Initialize queues
self.Q1 = Queue(self._prefetch * self._num_readers * self._batch_size) num_batches = self._prefetch * self._num_readers
self.Q21 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q1 = mp.Queue(num_batches * self._batch_size)
self.Q22 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q21 = mp.Queue(num_batches * self._batch_size)
self.Q3 = Queue(self._prefetch * self._num_readers) self.Q22 = mp.Queue(num_batches * self._batch_size)
self.Q3 = mp.Queue(num_batches)
# Init readers # Initialize readers
self._readers = [] self._readers = []
for i in range(self._num_readers): for i in range(self._num_readers):
self._readers.append(DataReader(**kwargs))
self._readers[-1].q_out = self.Q1
for i in range(self._num_readers):
part_idx, num_parts = i, self._num_readers part_idx, num_parts = i, self._num_readers
num_parts *= group_size num_parts *= group_size
part_idx += local_rank * self._num_readers part_idx += rank * self._num_readers
self._readers[i]._num_parts = num_parts self._readers.append(dragon.io.DataReader(
self._readers[i]._part_idx = part_idx num_parts=num_parts, part_idx=part_idx, **kwargs))
self._readers[i]._rng_seed += part_idx self._readers[i]._seed += part_idx
self._readers[i].q_out = self.Q1
self._readers[i].start() self._readers[i].start()
time.sleep(0.1) time.sleep(0.1)
# Init transformers # Initialize transformers
self._transformers = [] self._transformers = []
for i in range(self._num_transformers): for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs) transformer = DataTransformer(**kwargs)
transformer._rng_seed += (i + local_rank * self._num_transformers) transformer._rng_seed += (i + rank * self._num_transformers)
transformer.q_in = self.Q1 transformer.q_in = self.Q1
transformer.q1_out = self.Q21 transformer.q1_out, transformer.q2_out = self.Q21, self.Q22
transformer.q2_out = self.Q22
transformer.start() transformer.start()
self._transformers.append(transformer) self._transformers.append(transformer)
time.sleep(0.1) time.sleep(0.1)
# Init blob fetchers # Initialize batch-producer
self._fetchers = [] self.start()
for i in range(self._num_fetchers):
fetcher = BlobFetcher(**kwargs)
fetcher.q1_in = self.Q21
fetcher.q2_in = self.Q22
fetcher.q_out = self.Q3
fetcher.start()
self._fetchers.append(fetcher)
time.sleep(0.1)
# Prevent to echo multiple nodes
if local_rank == 0:
self.echo()
# Register cleanup callbacks
def cleanup(): def cleanup():
def terminate(processes): def terminate(processes):
for process in processes: for process in processes:
process.terminate() process.terminate()
process.join() process.join()
terminate(self._fetchers) terminate([self])
logger.info('Terminating BlobFetcher ......') logger.info('Terminate DataBatch.')
terminate(self._transformers) terminate(self._transformers)
logger.info('Terminating DataTransformer ......') logger.info('Terminate DataTransformer.')
terminate(self._readers) terminate(self._readers)
logger.info('Terminating DataReader......') logger.info('Terminate DataReader.')
import atexit import atexit
atexit.register(cleanup) atexit.register(cleanup)
...@@ -156,20 +158,27 @@ class DataBatch(object): ...@@ -156,20 +158,27 @@ class DataBatch(object):
""" """
return self.Q3.get() return self.Q3.get()
def echo(self): def run(self):
"""Print I/O Information. """Start the process to produce batches."""
def produce(q_in):
Returns processed_ims, ims_info, all_boxes = [], [], []
------- for image_index in range(cfg.TRAIN.IMS_PER_BATCH):
None im, im_scale, gt_boxes = q_in.get()
processed_ims.append(im)
""" ims_info.append(list(im.shape[:2]) + [im_scale])
print('---------------------------------------------------------') im_boxes = np.zeros((gt_boxes.shape[0], gt_boxes.shape[1] + 1), 'float32')
print('BatchFetcher({} Threads), Using config:'.format( im_boxes[:, :gt_boxes.shape[1]], im_boxes[:, -1] = gt_boxes, image_index
self._num_readers + self._num_transformers + self._num_fetchers)) all_boxes.append(im_boxes)
params = {'queue_size': self._prefetch, return {
'n_readers': self._num_readers, 'data': im_list_to_blob(processed_ims),
'n_transformers': self._num_transformers, 'ims_info': np.array(ims_info, dtype=np.float32),
'n_fetchers': self._num_fetchers} 'gt_boxes': np.concatenate(all_boxes, axis=0),
pprint.pprint(params) }
print('---------------------------------------------------------')
q1, q2 = self.Q21, self.Q22
while True:
if q1.qsize() >= cfg.TRAIN.IMS_PER_BATCH:
self.Q3.put(produce(q1))
elif q2.qsize() >= cfg.TRAIN.IMS_PER_BATCH:
self.Q3.put(produce(q2))
q1, q2 = q2, q1 # Sample two queues uniformly
...@@ -14,22 +14,13 @@ from __future__ import division ...@@ -14,22 +14,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import multiprocessing import multiprocessing
import numpy as np
import numpy.random as npr
try: import cv2
import cv2 import numpy as np
except ImportError as e:
print('Failed to import cv2. Error: {0}'.format(str(e)))
try:
import PIL.Image
except ImportError as e:
print('Failed to import PIL. Error: {0}'.format(str(e)))
from lib.core.config import cfg from lib.core.config import cfg
from lib.proto import anno_pb2 as pb
from lib.utils import logger
from lib.utils.blob import prep_im_for_blob from lib.utils.blob import prep_im_for_blob
from lib.utils.boxes import flip_boxes
class DataTransformer(multiprocessing.Process): class DataTransformer(multiprocessing.Process):
...@@ -47,44 +38,45 @@ class DataTransformer(multiprocessing.Process): ...@@ -47,44 +38,45 @@ class DataTransformer(multiprocessing.Process):
def make_roi_dict( def make_roi_dict(
self, self,
ann_datum, example,
im_scale, im_scale,
apply_flip=False, apply_flip=False,
offsets=None, offsets=None,
): ):
annotations = ann_datum.annotation
n_objects = 0 n_objects = 0
if not self._use_diff: if not self._use_diff:
for ann in annotations: for obj in example['object']:
if not ann.difficult: if obj.get('difficult', 0) == 0:
n_objects += 1 n_objects += 1
else: else:
n_objects = len(annotations) n_objects = len(example['object'])
roi_dict = { roi_dict = {
'width': ann_datum.datum.width, 'width': example['width'],
'height': ann_datum.datum.height, 'height': example['height'],
'gt_classes': np.zeros((n_objects,), 'int32'), 'gt_classes': np.zeros((n_objects,), 'int32'),
'boxes': np.zeros((n_objects, 4), 'float32'), 'boxes': np.zeros((n_objects, 4), 'float32'),
} }
# Filter the difficult instances # Filter the difficult instances
rec_idx = 0 object_idx = 0
for ann in annotations: for obj in example['object']:
if not self._use_diff and ann.difficult: if not self._use_diff and \
obj.get('difficult', 0) > 0:
continue continue
roi_dict['boxes'][rec_idx, :] = [ roi_dict['boxes'][object_idx, :] = [
max(0, ann.x1), max(0, obj['xmin']),
max(0, ann.y1), max(0, obj['ymin']),
min(ann.x2, ann_datum.datum.width - 1), min(obj['xmax'], example['width'] - 1),
min(ann.y2, ann_datum.datum.height - 1), min(obj['ymax'], example['height'] - 1),
] ]
roi_dict['gt_classes'][rec_idx] = self._class_to_ind[ann.name] roi_dict['gt_classes'][object_idx] = \
rec_idx += 1 self._class_to_ind[obj['name']]
object_idx += 1
# Flip the boxes if necessary # Flip the boxes if necessary
if apply_flip: if apply_flip:
roi_dict['boxes'] = _flip_boxes( roi_dict['boxes'] = flip_boxes(
roi_dict['boxes'], roi_dict['width']) roi_dict['boxes'], roi_dict['width'])
# Scale the boxes to the detecting scale # Scale the boxes to the detecting scale
...@@ -102,50 +94,34 @@ class DataTransformer(multiprocessing.Process): ...@@ -102,50 +94,34 @@ class DataTransformer(multiprocessing.Process):
return roi_dict return roi_dict
@classmethod @classmethod
def get_image(cls, serialized): def get_image(cls, example):
datum = pb.AnnotatedDatum() img = np.frombuffer(example['content'], np.uint8)
datum.ParseFromString(serialized) return cv2.imdecode(img, -1)
datum = datum.datum
im = np.fromstring(datum.data, np.uint8)
return cv2.imdecode(im, -1) if datum.encoded is True else \
im.reshape((datum.height, datum.width, datum.channels))
@classmethod @classmethod
def get_annotations(cls, serialized): def get_annotations(cls, example):
datum = pb.AnnotatedDatum()
datum.ParseFromString(serialized)
filename = datum.filename
annotations = datum.annotation
objects = [] objects = []
for ix, ann in enumerate(annotations): for ix, obj in enumerate(example['object']):
objects.append({ objects.append({
'name': ann.name, 'name': obj['name'],
'difficult': int(ann.difficult), 'difficult': obj.get('difficult', 0),
'bbox': [ann.x1, ann.y1, ann.x2, ann.y2], 'bbox': [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax']],
'mask': ann.mask,
}) })
return filename, objects return example['id'], objects
def get(self, serialized): def get(self, example):
datum = pb.AnnotatedDatum() img = np.frombuffer(example['content'], np.uint8)
datum.ParseFromString(serialized) img = cv2.imdecode(img, -1)
im_datum = datum.datum
im = np.fromstring(im_datum.data, np.uint8)
if im_datum.encoded is True:
im = cv2.imdecode(im, -1)
else:
h, w = im_datum.height, im_datum.width
im = im.reshape((h, w, im_datum.channels))
# Scale # Scale
scale_indices = npr.randint(len(cfg.TRAIN.SCALES)) scale_indices = np.random.randint(len(cfg.TRAIN.SCALES))
target_size = cfg.TRAIN.SCALES[scale_indices] target_size = cfg.TRAIN.SCALES[scale_indices]
im, im_scale, jitter = prep_im_for_blob(im, target_size, cfg.TRAIN.MAX_SIZE) im, im_scale, jitter = prep_im_for_blob(img, target_size, cfg.TRAIN.MAX_SIZE)
# Flip # Flip
apply_flip = False apply_flip = False
if self._use_flipped: if self._use_flipped:
if npr.randint(0, 2) > 0: if np.random.randint(2) > 0:
im = im[:, ::-1, :] im = im[:, ::-1, :]
apply_flip = True apply_flip = True
...@@ -160,8 +136,8 @@ class DataTransformer(multiprocessing.Process): ...@@ -160,8 +136,8 @@ class DataTransformer(multiprocessing.Process):
# To a square (target_size, target_size) # To a square (target_size, target_size)
im, offsets = _get_image_with_target_size([target_size] * 2, im) im, offsets = _get_image_with_target_size([target_size] * 2, im)
# Datum -> RoIDict # Example -> RoIDict
roi_dict = self.make_roi_dict(datum, im_scale, apply_flip, offsets) roi_dict = self.make_roi_dict(example, im_scale, apply_flip, offsets)
# Post-Process for gt boxes # Post-Process for gt boxes
# Shape like: [num_objects, {x1, y1, x2, y2, cls}] # Shape like: [num_objects, {x1, y1, x2, y2, cls}]
...@@ -171,29 +147,16 @@ class DataTransformer(multiprocessing.Process): ...@@ -171,29 +147,16 @@ class DataTransformer(multiprocessing.Process):
return im, im_scale, gt_boxes return im, im_scale, gt_boxes
def run(self): def run(self):
npr.seed(self._rng_seed) np.random.seed(self._rng_seed)
while True: while True:
serialized = self.q_in.get() outputs = self.get(self.q_in.get())
data = self.get(serialized) if len(outputs[2]) < 1:
# Ensure that there should be at least 1 ground-truth continue # Ignore the non-object image
if len(data[2]) < 1: aspect_ratio = float(outputs[0].shape[0]) / outputs[0].shape[1]
continue if aspect_ratio > 1.:
aspect_ratio = float(data[0].shape[0]) / data[0].shape[1] self.q1_out.put(outputs)
if aspect_ratio > 1.0:
self.q1_out.put(data)
else: else:
self.q2_out.put(data) self.q2_out.put(outputs)
def _flip_boxes(boxes, width):
flip_boxes = boxes.copy()
old_x1 = boxes[:, 0].copy()
old_x2 = boxes[:, 2].copy()
flip_boxes[:, 0] = width - old_x2 - 1
flip_boxes[:, 2] = width - old_x1 - 1
if not (flip_boxes[:, 2] >= flip_boxes[:, 0]).all():
logger.fatal('Encounter invalid coordinates after flipping boxes.')
return flip_boxes
def _get_image_with_target_size(target_size, img): def _get_image_with_target_size(target_size, img):
......
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on: # Codes are based on:
# #
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py> # <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
import numpy as np from __future__ import absolute_import
from __future__ import division
# Verify that we compute the same anchors as Shaoqing's matlab implementation: from __future__ import print_function
#
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat import numpy as np
# >> anchors
# # Verify that we compute the same anchors as Shaoqing's matlab implementation:
# anchors = #
# # >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat
# -83 -39 100 56 # >> anchors
# -175 -87 192 104 #
# -359 -183 376 200 # anchors =
# -55 -55 72 72 #
# -119 -119 136 136 # -83 -39 100 56
# -247 -247 264 264 # -175 -87 192 104
# -35 -79 52 96 # -359 -183 376 200
# -79 -167 96 184 # -55 -55 72 72
# -167 -343 184 360 # -119 -119 136 136
# -247 -247 264 264
# array([[ -83., -39., 100., 56.], # -35 -79 52 96
# [-175., -87., 192., 104.], # -79 -167 96 184
# [-359., -183., 376., 200.], # -167 -343 184 360
# [ -55., -55., 72., 72.],
# [-119., -119., 136., 136.], # array([[ -83., -39., 100., 56.],
# [-247., -247., 264., 264.], # [-175., -87., 192., 104.],
# [ -35., -79., 52., 96.], # [-359., -183., 376., 200.],
# [ -79., -167., 96., 184.], # [ -55., -55., 72., 72.],
# [-167., -343., 184., 360.]]) # [-119., -119., 136., 136.],
# [-247., -247., 264., 264.],
# [ -35., -79., 52., 96.],
def generate_anchors( # [ -79., -167., 96., 184.],
base_size=16, # [-167., -343., 184., 360.]])
ratios=(0.5, 1, 2),
scales=2**np.arange(3, 6),
): def generate_anchors(
""" base_size=16,
Generate anchor (reference) windows by enumerating aspect ratios X ratios=(0.5, 1, 2),
scales wrt a reference (0, 0, 15, 15) window. scales=2**np.arange(3, 6),
""" ):
base_anchor = np.array([1, 1, base_size, base_size]) - 1 """
ratio_anchors = _ratio_enum(base_anchor, ratios) Generate anchor (reference) windows by enumerating aspect ratios X
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales) scales wrt a reference (0, 0, 15, 15) window.
for i in range(ratio_anchors.shape[0])]) """
return anchors base_anchor = np.array([1, 1, base_size, base_size]) - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
def generate_anchors_v2( for i in range(ratio_anchors.shape[0])])
stride=16, return anchors
ratios=(0.5, 1, 2),
sizes=(32, 64, 128, 256, 512),
): def generate_anchors_v2(
""" stride=16,
Generates a matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors ratios=(0.5, 1, 2),
are centered on stride / 2, have (approximate) sqrt areas of the specified sizes=(32, 64, 128, 256, 512),
sizes, and aspect ratios as given. ):
""" """
return generate_anchors( Generates a matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
base_size=stride, are centered on stride / 2, have (approximate) sqrt areas of the specified
ratios=ratios, sizes, and aspect ratios as given.
scales=np.array(sizes, dtype=np.float) / stride, """
) return generate_anchors(
base_size=stride,
ratios=ratios,
def _whctrs(anchor): scales=np.array(sizes, dtype=np.float) / stride,
"""Return width, height, x center, and y center for an anchor (window).""" )
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1) def _whctrs(anchor):
y_ctr = anchor[1] + 0.5 * (h - 1) """Return width, height, x center, and y center for an anchor (window)."""
return w, h, x_ctr, y_ctr w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
def _mkanchors(ws, hs, x_ctr, y_ctr): y_ctr = anchor[1] + 0.5 * (h - 1)
""" return w, h, x_ctr, y_ctr
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
""" def _mkanchors(ws, hs, x_ctr, y_ctr):
ws = ws[:, np.newaxis] """
hs = hs[:, np.newaxis] Given a vector of widths (ws) and heights (hs) around a center
anchors = np.hstack((x_ctr - 0.5 * (ws - 1), (x_ctr, y_ctr), output a set of anchors (windows).
y_ctr - 0.5 * (hs - 1), """
x_ctr + 0.5 * (ws - 1), ws = ws[:, np.newaxis]
y_ctr + 0.5 * (hs - 1))) hs = hs[:, np.newaxis]
return anchors anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
def _ratio_enum(anchor, ratios): y_ctr + 0.5 * (hs - 1)))
"""Enumerate a set of anchors for each aspect ratio wrt an anchor.""" return anchors
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios def _ratio_enum(anchor, ratios):
ws = np.round(np.sqrt(size_ratios)) """Enumerate a set of anchors for each aspect ratio wrt an anchor."""
hs = np.round(ws * ratios) w, h, x_ctr, y_ctr = _whctrs(anchor)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) size = w * h
return anchors size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
def _scale_enum(anchor, scales): anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
"""Enumerate a set of anchors for each scale wrt an anchor.""" return anchors
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales def _scale_enum(anchor, scales):
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) """Enumerate a set of anchors for each scale wrt an anchor."""
return anchors w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
if __name__ == '__main__': anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
print(generate_anchors()) return anchors
if __name__ == '__main__':
print(generate_anchors())
# --------------------------------------------------------
# Mask R-CNN @ Detectron
# Copyright (c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dragon.vm.torch as torch
from lib.core.config import cfg
from lib.datasets.factory import get_imdb
from lib.faster_rcnn.data.data_batch import DataBatch
class DataLayer(torch.nn.Module):
def __init__(self):
super(DataLayer, self).__init__()
database = get_imdb(cfg.TRAIN.DATABASE)
self.data_batch = DataBatch(**{
'source': database.source,
'classes': database.classes,
'shuffle': cfg.TRAIN.USE_SHUFFLE,
'num_chunks': 0, # Record-Wise Shuffle
'batch_size': cfg.TRAIN.IMS_PER_BATCH * 2,
})
def forward(self):
# Get an array blob from the Queue
outputs = self.data_batch.get()
# Zero-Copy the array to tensor
outputs['data'] = torch.from_numpy(outputs['data'])
return outputs
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import dragon.vm.torch as torch import dragon.vm.torch as torch
import numpy as np import numpy as np
from lib.core.config import cfg from lib.core.config import cfg
from lib.nms.nms_wrapper import nms from lib.nms.nms_wrapper import nms
from lib.nms.nms_wrapper import soft_nms from lib.nms.nms_wrapper import soft_nms
from lib.utils.blob import im_list_to_blob from lib.utils.blob import im_list_to_blob
from lib.utils.blob import tensor_to_blob from lib.utils.blob import tensor_to_blob
from lib.utils.boxes import bbox_transform_inv from lib.utils.boxes import bbox_transform_inv
from lib.utils.boxes import clip_tiled_boxes from lib.utils.boxes import clip_tiled_boxes
from lib.utils.image import scale_image from lib.utils.image import scale_image
from lib.utils.timer import Timer from lib.utils.timer import Timer
from lib.utils.vis import vis_one_image from lib.utils.vis import vis_one_image
def im_detect(detector, raw_image): def im_detect(detector, raw_image):
"""Detect a image, with single or multiple scales.""" """Detect a image, with single or multiple scales."""
# Prepare images # Prepare images
ims, ims_scale = scale_image(raw_image) ims, ims_scale = scale_image(raw_image)
# Prepare blobs # Prepare blobs
blobs = {'data': im_list_to_blob(ims)} blobs = {'data': im_list_to_blob(ims)}
blobs['ims_info'] = np.array([ blobs['ims_info'] = np.array([
list(blobs['data'].shape[1:3]) + [im_scale] list(blobs['data'].shape[1:3]) + [im_scale]
for im_scale in ims_scale], dtype=np.float32) for im_scale in ims_scale], dtype=np.float32)
blobs['data'] = torch.from_numpy(blobs['data'])
blobs['data'] = torch.from_numpy(blobs['data'])
# Do Forward
with torch.no_grad(): # Do Forward
outputs = detector.forward(inputs=blobs) with torch.no_grad():
outputs = detector.forward(inputs=blobs)
# Decode results
batch_rois = tensor_to_blob(outputs['rois']) # Decode results
batch_scores = tensor_to_blob(outputs['cls_prob']) batch_rois = tensor_to_blob(outputs['rois'])
batch_deltas = tensor_to_blob(outputs['bbox_pred']) batch_scores = tensor_to_blob(outputs['cls_prob'])
batch_deltas = tensor_to_blob(outputs['bbox_pred'])
batch_boxes = bbox_transform_inv(
boxes=batch_rois[:, 1:5], batch_boxes = bbox_transform_inv(
deltas=batch_deltas, boxes=batch_rois[:, 1:5],
weights=cfg.BBOX_REG_WEIGHTS, deltas=batch_deltas,
) weights=cfg.BBOX_REG_WEIGHTS,
)
scores_wide, boxes_wide = [], []
scores_wide, boxes_wide = [], []
for im_idx in range(len(ims)):
indices = np.where(batch_rois[:, 0].astype(np.int32) == im_idx)[0] for im_idx in range(len(ims)):
boxes = batch_boxes[indices] indices = np.where(batch_rois[:, 0].astype(np.int32) == im_idx)[0]
boxes /= ims_scale[im_idx] boxes = batch_boxes[indices]
clip_tiled_boxes(boxes, raw_image.shape) boxes /= ims_scale[im_idx]
scores_wide.append(batch_scores[indices]) clip_tiled_boxes(boxes, raw_image.shape)
boxes_wide.append(boxes) scores_wide.append(batch_scores[indices])
boxes_wide.append(boxes)
return (np.vstack(scores_wide), np.vstack(boxes_wide)) \
if len(scores_wide) > 1 else (scores_wide[0], boxes_wide[0]) return (np.vstack(scores_wide), np.vstack(boxes_wide)) \
if len(scores_wide) > 1 else (scores_wide[0], boxes_wide[0])
def test_net(detector, server):
# Load settings def test_net(detector, server):
classes = server.classes # Load settings
num_images = server.num_images classes = server.classes
num_classes = server.num_classes num_images = server.num_images
all_boxes = [[[] for _ in range(num_images)] for _ in range(num_classes)] num_classes = server.num_classes
all_boxes = [[[] for _ in range(num_images)] for _ in range(num_classes)]
_t = {'im_detect': Timer(), 'misc': Timer()}
_t = {'im_detect': Timer(), 'misc': Timer()}
for i in range(num_images):
image_id, raw_image = server.get_image() for i in range(num_images):
image_id, raw_image = server.get_image()
_t['im_detect'].tic()
scores, boxes = im_detect(detector, raw_image) _t['im_detect'].tic()
_t['im_detect'].toc() scores, boxes = im_detect(detector, raw_image)
_t['im_detect'].toc()
_t['misc'].tic()
boxes_this_image = [[]] _t['misc'].tic()
for j in range(1, num_classes): boxes_this_image = [[]]
inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0] for j in range(1, num_classes):
cls_scores = scores[inds, j] inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0]
cls_boxes = boxes[inds, j*4:(j+1)*4] cls_scores = scores[inds, j]
cls_detections = np.hstack( cls_boxes = boxes[inds, j*4:(j+1)*4]
(cls_boxes, cls_scores[:, np.newaxis]) cls_detections = np.hstack(
).astype(np.float32, copy=False) (cls_boxes, cls_scores[:, np.newaxis])
if cfg.TEST.USE_SOFT_NMS: ).astype(np.float32, copy=False)
keep = soft_nms( if cfg.TEST.USE_SOFT_NMS:
cls_detections, cfg.TEST.NMS, keep = soft_nms(
method=cfg.TEST.SOFT_NMS_METHOD, cls_detections, cfg.TEST.NMS,
sigma=cfg.TEST.SOFT_NMS_SIGMA, method=cfg.TEST.SOFT_NMS_METHOD,
) sigma=cfg.TEST.SOFT_NMS_SIGMA,
else: )
keep = nms(cls_detections, cfg.TEST.NMS, force_cpu=True) else:
cls_detections = cls_detections[keep, :] keep = nms(cls_detections, cfg.TEST.NMS, force_cpu=True)
all_boxes[j][i] = cls_detections cls_detections = cls_detections[keep, :]
boxes_this_image.append(cls_detections) all_boxes[j][i] = cls_detections
boxes_this_image.append(cls_detections)
if cfg.VIS or cfg.VIS_ON_FILE:
vis_one_image( if cfg.VIS or cfg.VIS_ON_FILE:
raw_image, classes, boxes_this_image, vis_one_image(
thresh=cfg.VIS_TH, box_alpha=1.0, show_class=True, raw_image, classes, boxes_this_image,
filename=server.get_save_filename(image_id), thresh=cfg.VIS_TH, box_alpha=1.0, show_class=True,
) filename=server.get_save_filename(image_id),
)
# Limit to max_per_image detections *over all classes*
if cfg.TEST.DETECTIONS_PER_IM > 0: # Limit to max_per_image detections *over all classes*
image_scores = [] if cfg.TEST.DETECTIONS_PER_IM > 0:
for j in range(1, num_classes): image_scores = []
if len(all_boxes[j][i]) < 1: continue for j in range(1, num_classes):
image_scores.append(all_boxes[j][i][:, -1]) if len(all_boxes[j][i]) < 1: continue
if len(image_scores) > 0: image_scores.append(all_boxes[j][i][:, -1])
image_scores = np.hstack(image_scores) if len(image_scores) > 0:
if len(image_scores) > cfg.TEST.DETECTIONS_PER_IM: image_scores = np.hstack(image_scores)
image_thresh = np.sort(image_scores)[-cfg.TEST.DETECTIONS_PER_IM] if len(image_scores) > cfg.TEST.DETECTIONS_PER_IM:
for j in range(1, num_classes): image_thresh = np.sort(image_scores)[-cfg.TEST.DETECTIONS_PER_IM]
keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0] for j in range(1, num_classes):
all_boxes[j][i] = all_boxes[j][i][keep, :] keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
_t['misc'].toc() all_boxes[j][i] = all_boxes[j][i][keep, :]
_t['misc'].toc()
print('\rim_detect: {:d}/{:d} {:.3f}s {:.3f}s'
.format(i + 1, num_images, _t['im_detect'].average_time, print('\rim_detect: {:d}/{:d} {:.3f}s {:.3f}s'
_t['misc'].average_time), end='') .format(i + 1, num_images,
_t['im_detect'].average_time,
print('\n>>>>>>>>>>>>>>>>>>> Evaluating <<<<<<<<<<<<<<<<<<<<') _t['misc'].average_time),
end='')
print('Evaluating detections')
server.evaluate_detections(all_boxes) print('\n>>>>>>>>>>>>>>>>>>> Evaluating <<<<<<<<<<<<<<<<<<<<')
print('Evaluating detections')
server.evaluate_detections(all_boxes)
...@@ -13,6 +13,6 @@ from __future__ import absolute_import ...@@ -13,6 +13,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from lib.fpn.layers.anchor_target_layer import AnchorTargetLayer from lib.fpn.anchor_target_layer import AnchorTargetLayer
from lib.fpn.layers.proposal_layer import ProposalLayer from lib.fpn.proposal_layer import ProposalLayer
from lib.fpn.layers.proposal_target_layer import ProposalTargetLayer from lib.fpn.proposal_target_layer import ProposalTargetLayer
# --------------------------------------------------------
# Mask R-CNN @ Detectron
# Copyright (c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Import custom modules # Import custom modules
from lib.modeling.base import affine from lib.modeling.base import affine
from lib.modeling.base import bn from lib.modeling.base import bn
from lib.modeling.base import conv1x1 from lib.modeling.base import conv1x1
from lib.modeling.base import conv3x3 from lib.modeling.base import conv3x3
from lib.modeling.fast_rcnn import FastRCNN from lib.modeling.fast_rcnn import FastRCNN
from lib.modeling.fpn import FPN from lib.modeling.fpn import FPN
from lib.modeling.retinanet import RetinaNet from lib.modeling.retinanet import RetinaNet
from lib.modeling.rpn import RPN from lib.modeling.rpn import RPN
from lib.modeling.ssd import SSD from lib.modeling.ssd import SSD
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Define some basic structures.""" """Define some basic structures."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import dragon.vm.torch as torch import dragon.vm.torch as torch
def affine(dim_in, inplace=True): def affine(dim_in, inplace=True):
"""AffineBN, weight and bias are fixed.""" """AffineBN, weight and bias are fixed."""
return torch.nn.Affine( return torch.nn.Affine(
dim_in, dim_in,
fix_weight=True, fix_weight=True,
fix_bias=True, fix_bias=True,
inplace=inplace, inplace=inplace,
) )
def bn(dim_in, eps=1e-5): def bn(dim_in, eps=1e-5):
"""The BatchNorm.""" """The BatchNorm."""
return torch.nn.BatchNorm2d(dim_in, eps=eps) return torch.nn.BatchNorm2d(dim_in, eps=eps)
def conv1x1(dim_in, dim_out, stride=1, bias=False): def conv1x1(dim_in, dim_out, stride=1, bias=False):
"""1x1 convolution.""" """1x1 convolution."""
return torch.nn.Conv2d( return torch.nn.Conv2d(
dim_in, dim_in,
dim_out, dim_out,
kernel_size=1, kernel_size=1,
stride=stride, stride=stride,
bias=bias, bias=bias,
) )
def conv3x3(dim_in, dim_out, stride=1, bias=False): def conv3x3(dim_in, dim_out, stride=1, bias=False):
"""3x3 convolution with padding.""" """3x3 convolution with padding."""
return torch.nn.Conv2d( return torch.nn.Conv2d(
dim_in, dim_in,
dim_out, dim_out,
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
padding=1, padding=1,
bias=bias, bias=bias,
) )
...@@ -35,11 +35,13 @@ class Detector(torch.nn.Module): ...@@ -35,11 +35,13 @@ class Detector(torch.nn.Module):
``lib.core.config`` for their hyper-parameters. ``lib.core.config`` for their hyper-parameters.
""" """
def __init__(self): def __init__(self):
super(Detector, self).__init__() super(Detector, self).__init__()
model = cfg.MODEL.TYPE model = cfg.MODEL.TYPE
backbone = cfg.MODEL.BACKBONE.lower().split('.') backbone = cfg.MODEL.BACKBONE.lower().split('.')
body, modules = backbone[0], backbone[1:] body, modules = backbone[0], backbone[1:]
self.recorder = None
# + Data Loader # + Data Loader
self.data_layer = importlib.import_module( self.data_layer = importlib.import_module(
...@@ -92,9 +94,14 @@ class Detector(torch.nn.Module): ...@@ -92,9 +94,14 @@ class Detector(torch.nn.Module):
Parameters Parameters
---------- ----------
inputs : dict or None inputs : dict, optional
The inputs. The inputs.
Returns
-------
dict
The outputs.
""" """
# 0. Get the inputs # 0. Get the inputs
if inputs is None: if inputs is None:
...@@ -161,7 +168,6 @@ class Detector(torch.nn.Module): ...@@ -161,7 +168,6 @@ class Detector(torch.nn.Module):
"""Optimize the graph for the inference. """Optimize the graph for the inference.
It usually involves the removing of BN or Affine. It usually involves the removing of BN or Affine.
""" """
################################## ##################################
# Merge Affine into Convolution # # Merge Affine into Convolution #
......
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import importlib import importlib
_STORE = collections.defaultdict(dict) _STORE = collections.defaultdict(dict)
########################################### ###########################################
# # # #
# Body # # Body #
# # # #
########################################### ###########################################
# ResNet # ResNet
for D in [18, 34, 50, 101, 152, 200, 269]: for D in [18, 34, 50, 101, 152, 200, 269]:
_STORE['BODY']['resnet{}'.format(D)] = \ _STORE['BODY']['resnet{}'.format(D)] = \
'lib.modeling.resnet.make_resnet_{}'.format(D) 'lib.modeling.resnet.make_resnet_{}'.format(D)
# VGG # VGG
for D in [16, 19]: for D in [16, 19]:
for T in ['', '_reduced_300', '_reduced_512']: for T in ['', '_reduced_300', '_reduced_512']:
_STORE['BODY']['vgg{}{}'.format(D, T)] = \ _STORE['BODY']['vgg{}{}'.format(D, T)] = \
'lib.modeling.vgg.make_vgg_{}{}'.format(D, T) 'lib.modeling.vgg.make_vgg_{}{}'.format(D, T)
# AirNet # AirNet
for D in ['', '3b', '4b', '5b']: for D in ['', '3b', '4b', '5b']:
_STORE['BODY']['airnet{}'.format(D)] = \ _STORE['BODY']['airnet{}'.format(D)] = \
'lib.modeling.airnet.make_airnet_{}'.format(D) 'lib.modeling.airnet.make_airnet_{}'.format(D)
def get_template_func(name, sets, desc): def get_template_func(name, sets, desc):
name = name.lower() name = name.lower()
if name not in sets: if name not in sets:
raise ValueError( raise ValueError(
'The {} for {} was not registered.\n' 'The {} for {} was not registered.\n'
'Registered modules: [{}]'.format( 'Registered modules: [{}]'.format(
name, desc, ', '.join(sets.keys()))) name, desc, ', '.join(sets.keys())))
module_name = '.'.join(sets[name].split('.')[0:-1]) module_name = '.'.join(sets[name].split('.')[0:-1])
func_name = sets[name].split('.')[-1] func_name = sets[name].split('.')[-1]
try: try:
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
return getattr(module, func_name) return getattr(module, func_name)
except ImportError as e: except ImportError as e:
raise ValueError('Can not import module from: ' + module_name) raise ValueError('Can not import module from: ' + module_name)
def get_body_func(name): def get_body_func(name):
return get_template_func( return get_template_func(
name, _STORE['BODY'], 'Body') name, _STORE['BODY'], 'Body')
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import dragon.vm.torch as torch import dragon.vm.torch as torch
from lib.core.config import cfg from lib.core.config import cfg
from lib.modeling import conv1x1 from lib.modeling import conv1x1
from lib.modeling import conv3x3 from lib.modeling import conv3x3
HIGHEST_BACKBONE_LVL = 5 # E.g., "conv5"-like level HIGHEST_BACKBONE_LVL = 5 # E.g., "conv5"-like level
class FPN(torch.nn.Module): class FPN(torch.nn.Module):
"""Feature Pyramid Networks for R-CNN and RetinaNet.""" """Feature Pyramid Networks for R-CNN and RetinaNet."""
def __init__(self, feature_dims): def __init__(self, feature_dims):
super(FPN, self).__init__() super(FPN, self).__init__()
self.C = torch.nn.ModuleList() self.C = torch.nn.ModuleList()
self.P = torch.nn.ModuleList() self.P = torch.nn.ModuleList()
self.apply_func = self.apply_on_rcnn self.apply_func = self.apply_on_rcnn
for lvl in range(cfg.FPN.RPN_MIN_LEVEL, HIGHEST_BACKBONE_LVL + 1): for lvl in range(cfg.FPN.RPN_MIN_LEVEL, HIGHEST_BACKBONE_LVL + 1):
self.C.append(conv1x1(feature_dims[lvl - 1], cfg.FPN.DIM, bias=True)) self.C.append(conv1x1(feature_dims[lvl - 1], cfg.FPN.DIM, bias=True))
self.P.append(conv3x3(cfg.FPN.DIM, cfg.FPN.DIM, bias=True)) self.P.append(conv3x3(cfg.FPN.DIM, cfg.FPN.DIM, bias=True))
if 'retinanet' in cfg.MODEL.TYPE: if 'retinanet' in cfg.MODEL.TYPE:
for lvl in range(HIGHEST_BACKBONE_LVL + 1, cfg.FPN.RPN_MAX_LEVEL + 1): for lvl in range(HIGHEST_BACKBONE_LVL + 1, cfg.FPN.RPN_MAX_LEVEL + 1):
dim_in = feature_dims[-1] if lvl == HIGHEST_BACKBONE_LVL + 1 else cfg.FPN.DIM dim_in = feature_dims[-1] if lvl == HIGHEST_BACKBONE_LVL + 1 else cfg.FPN.DIM
self.P.append(conv3x3(dim_in, cfg.FPN.DIM, stride=2, bias=True)) self.P.append(conv3x3(dim_in, cfg.FPN.DIM, stride=2, bias=True))
self.apply_func = self.apply_on_retinanet self.apply_func = self.apply_on_retinanet
self.relu = torch.nn.ReLU(inplace=False) self.relu = torch.nn.ReLU(inplace=False)
self.maxpool = torch.nn.MaxPool2d(1, 2, ceil_mode=True) self.maxpool = torch.nn.MaxPool2d(1, 2, ceil_mode=True)
self.reset_parameters() self.reset_parameters()
self.feature_dims = [cfg.FPN.DIM] self.feature_dims = [cfg.FPN.DIM]
def reset_parameters(self): def reset_parameters(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, torch.nn.Conv2d): if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_uniform_( torch.nn.init.kaiming_uniform_(
m.weight, m.weight,
a=1, # Fix the gain for [-127, 127] a=1, # Fix the gain for [-127, 127]
) # Xavier Initialization ) # Xavier Initialization
torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.bias, 0)
def apply_on_rcnn(self, features): def apply_on_rcnn(self, features):
fpn_input = self.C[-1](features[-1]) fpn_input = self.C[-1](features[-1])
min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL
outputs = [self.P[HIGHEST_BACKBONE_LVL - min_lvl](fpn_input)] outputs = [self.P[HIGHEST_BACKBONE_LVL - min_lvl](fpn_input)]
# Apply MaxPool for higher features # Apply MaxPool for higher features
for i in range(HIGHEST_BACKBONE_LVL + 1, max_lvl + 1): for i in range(HIGHEST_BACKBONE_LVL + 1, max_lvl + 1):
outputs.append(self.maxpool(outputs[-1])) outputs.append(self.maxpool(outputs[-1]))
# Build Pyramids between [MIN_LEVEL, HIGHEST_LEVEL] # Build Pyramids between [MIN_LEVEL, HIGHEST_LEVEL]
for i in range(HIGHEST_BACKBONE_LVL - 1, min_lvl - 1, -1): for i in range(HIGHEST_BACKBONE_LVL - 1, min_lvl - 1, -1):
lateral_output = self.C[i - min_lvl](features[i - 1]) lateral_output = self.C[i - min_lvl](features[i - 1])
upscale_output = torch.vision.ops.nn_resize( upscale_output = torch.vision.ops.nn_resize(
fpn_input, dsize=lateral_output.shape[-2:]) fpn_input, dsize=lateral_output.shape[-2:])
fpn_input = lateral_output.__iadd__(upscale_output) fpn_input = lateral_output.__iadd__(upscale_output)
outputs.insert(0, self.P[i - min_lvl](fpn_input)) outputs.insert(0, self.P[i - min_lvl](fpn_input))
return outputs return outputs
def apply_on_retinanet(self, features): def apply_on_retinanet(self, features):
fpn_input = self.C[-1](features[-1]) fpn_input = self.C[-1](features[-1])
min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL
outputs = [self.P[HIGHEST_BACKBONE_LVL- min_lvl](fpn_input)] outputs = [self.P[HIGHEST_BACKBONE_LVL - min_lvl](fpn_input)]
# Add extra convolutions for higher features # Add extra convolutions for higher features
extra_input = features[-1] extra_input = features[-1]
for i in range(HIGHEST_BACKBONE_LVL + 1, max_lvl + 1): for i in range(HIGHEST_BACKBONE_LVL + 1, max_lvl + 1):
outputs.append(self.P[i - min_lvl](extra_input)) outputs.append(self.P[i - min_lvl](extra_input))
if i != max_lvl: if i != max_lvl:
extra_input = self.relu(outputs[-1]) extra_input = self.relu(outputs[-1])
# Build Pyramids between [MIN_LEVEL, HIGHEST_LEVEL] # Build Pyramids between [MIN_LEVEL, HIGHEST_LEVEL]
for i in range(HIGHEST_BACKBONE_LVL - 1, min_lvl - 1, -1): for i in range(HIGHEST_BACKBONE_LVL - 1, min_lvl - 1, -1):
lateral_output = self.C[i - min_lvl](features[i - 1]) lateral_output = self.C[i - min_lvl](features[i - 1])
upscale_output = torch.vision.ops.nn_resize( upscale_output = torch.vision.ops.nn_resize(
fpn_input, dsize=lateral_output.shape[-2:]) fpn_input, dsize=lateral_output.shape[-2:])
fpn_input = lateral_output.__iadd__(upscale_output) fpn_input = lateral_output.__iadd__(upscale_output)
outputs.insert(0, self.P[i - min_lvl](fpn_input)) outputs.insert(0, self.P[i - min_lvl](fpn_input))
return outputs return outputs
def forward(self, features): def forward(self, features):
return self.apply_func(features) return self.apply_func(features)
...@@ -59,8 +59,7 @@ class RetinaNet(torch.nn.Module): ...@@ -59,8 +59,7 @@ class RetinaNet(torch.nn.Module):
gamma=cfg.MODEL.FOCAL_LOSS_GAMMA, gamma=cfg.MODEL.FOCAL_LOSS_GAMMA,
) )
self.bbox_loss = torch.nn.SmoothL1Loss( self.bbox_loss = torch.nn.SmoothL1Loss(
beta=1. / 9., reduction='batch_size', beta=.11, reduction='batch_size')
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -133,26 +132,22 @@ class RetinaNet(torch.nn.Module): ...@@ -133,26 +132,22 @@ class RetinaNet(torch.nn.Module):
gt_boxes=gt_boxes, gt_boxes=gt_boxes,
ims_info=ims_info, ims_info=ims_info,
) )
return collections.OrderedDict({ return collections.OrderedDict([
'cls_loss': ('cls_loss', self.cls_loss(
self.cls_loss( cls_score, self.retinanet_data['labels'])),
cls_score, ('bbox_loss', self.bbox_loss(
self.retinanet_data['labels'], bbox_pred,
), self.retinanet_data['bbox_targets'],
'bbox_loss': self.retinanet_data['bbox_inside_weights'],
self.bbox_loss( self.retinanet_data['bbox_outside_weights'],
bbox_pred, )),
self.retinanet_data['bbox_targets'], ])
self.retinanet_data['bbox_inside_weights'],
self.retinanet_data['bbox_outside_weights'],
)
})
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
cls_score, bbox_pred = self.compute_outputs(kwargs['features']) cls_score, bbox_pred = self.compute_outputs(kwargs['features'])
cls_score, bbox_pred = cls_score.float(), bbox_pred.float() cls_score, bbox_pred = cls_score.float(), bbox_pred.float()
outputs = collections.OrderedDict({'bbox_pred': bbox_pred}) outputs = collections.OrderedDict([('bbox_pred', bbox_pred)])
if self.training: if self.training:
outputs.update( outputs.update(
......
...@@ -136,32 +136,29 @@ class SSD(torch.nn.Module): ...@@ -136,32 +136,29 @@ class SSD(torch.nn.Module):
gt_boxes=gt_boxes, gt_boxes=gt_boxes,
) )
) )
return collections.OrderedDict({ return collections.OrderedDict([
# A compensating factor of 4.0 is used # A compensating factor of 4.0 is used
# As we normalize both the pos and neg samples # As we normalize both the pos and neg samples
'cls_loss': ('cls_loss', self.cls_loss(
self.cls_loss( cls_score.view(-1, cfg.MODEL.NUM_CLASSES),
cls_score.view(-1, cfg.MODEL.NUM_CLASSES), self.ssd_data['labels']) * 4.),
self.ssd_data['labels'] ('bbox_loss', self.bbox_loss(
) * 4., bbox_pred,
'bbox_loss': self.ssd_data['bbox_targets'],
self.bbox_loss( self.ssd_data['bbox_inside_weights'],
bbox_pred, self.ssd_data['bbox_outside_weights'],
self.ssd_data['bbox_targets'], )),
self.ssd_data['bbox_inside_weights'], ])
self.ssd_data['bbox_outside_weights'],
)
})
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
prior_boxes = self.prior_box_layer(kwargs['features']) prior_boxes = self.prior_box_layer(kwargs['features'])
cls_score, bbox_pred = self.compute_outputs(kwargs['features']) cls_score, bbox_pred = self.compute_outputs(kwargs['features'])
cls_score, bbox_pred = cls_score.float(), bbox_pred.float() cls_score, bbox_pred = cls_score.float(), bbox_pred.float()
outputs = collections.OrderedDict({ outputs = collections.OrderedDict([
'prior_boxes': prior_boxes, ('bbox_pred', bbox_pred),
'bbox_pred': bbox_pred, ('prior_boxes', prior_boxes),
}) ])
if self.training: if self.training:
outputs.update( outputs.update(
......
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on: # Codes are based on:
# #
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/fast_rcnn/nms_wrapper.py> # <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/fast_rcnn/nms_wrapper.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from lib.core.config import cfg from lib.core.config import cfg
from lib.utils import logger from lib.utils import logger
try: try:
from lib.nms.cpu_nms import cpu_nms, cpu_soft_nms from lib.nms.cpu_nms import cpu_nms, cpu_soft_nms
except ImportError as e: except ImportError as e:
print('Failed to import cpu nms. Error: {0}'.format(str(e))) print('Failed to import cpu nms. Error: {0}'.format(str(e)))
try: try:
from lib.nms.gpu_nms import gpu_nms from lib.nms.gpu_nms import gpu_nms
except ImportError as e: except ImportError as e:
print('Failed to import gpu nms. Error: {0}'.format(str(e))) print('Failed to import gpu nms. Error: {0}'.format(str(e)))
def nms(detections, thresh, force_cpu=False): def nms(detections, thresh, force_cpu=False):
"""Perform either CPU or GPU Hard-NMS.""" """Perform either CPU or GPU Hard-NMS."""
if detections.shape[0] == 0: if detections.shape[0] == 0:
return [] return []
if cfg.USE_GPU_NMS and not force_cpu: if cfg.USE_GPU_NMS and not force_cpu:
return gpu_nms(detections, thresh, device_id=cfg.GPU_ID) return gpu_nms(detections, thresh, device_id=cfg.GPU_ID)
else: else:
return cpu_nms(detections, thresh) return cpu_nms(detections, thresh)
def soft_nms( def soft_nms(
detections, detections,
thresh, thresh,
method='linear', method='linear',
sigma=0.5, sigma=0.5,
score_thresh=0.001, score_thresh=0.001,
): ):
"""Perform CPU Soft-NMS.""" """Perform CPU Soft-NMS."""
if detections.shape[0] == 0: if detections.shape[0] == 0:
return [] return []
methods = {'hard': 0, 'linear': 1, 'gaussian': 2} methods = {'hard': 0, 'linear': 1, 'gaussian': 2}
if method not in methods: if method not in methods:
logger.fatal('Unknown soft nms method: {}'.format(method)) logger.fatal('Unknown soft nms method: {}'.format(method))
return cpu_soft_nms( return cpu_soft_nms(
detections, detections,
thresh, thresh,
methods[method], methods[method],
sigma, sigma,
score_thresh, score_thresh,
) )
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
\ No newline at end of file
syntax = "proto2";
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;
optional int32 label = 5;
repeated float float_data = 6;
optional bool encoded = 7 [default = false];
repeated int32 labels = 8;
}
message Annotation {
optional float x1 = 1;
optional float y1 = 2;
optional float x2 = 3;
optional float y2 = 4;
optional string name = 5;
optional bool difficult = 6 [default = false];
optional string mask = 7;
}
message AnnotatedDatum {
optional Datum datum = 1;
optional string filename = 2;
repeated Annotation annotation = 3;
}
...@@ -13,5 +13,5 @@ from __future__ import absolute_import ...@@ -13,5 +13,5 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from lib.faster_rcnn.layers.data_layer import DataLayer from lib.faster_rcnn.data_layer import DataLayer
from lib.retinanet.layers.anchor_target_layer import AnchorTargetLayer from lib.retinanet.anchor_target_layer import AnchorTargetLayer
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
...@@ -13,8 +13,8 @@ from __future__ import absolute_import ...@@ -13,8 +13,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from lib.ssd.layers.data_layer import DataLayer from lib.ssd.data_layer import DataLayer
from lib.ssd.layers.hard_mining_layer import HardMiningLayer from lib.ssd.hard_mining_layer import HardMiningLayer
from lib.ssd.layers.multibox_layer import MultiBoxMatchLayer from lib.ssd.multibox_layer import MultiBoxMatchLayer
from lib.ssd.layers.multibox_layer import MultiBoxTargetLayer from lib.ssd.multibox_layer import MultiBoxTargetLayer
from lib.ssd.layers.priorbox_layer import PriorBoxLayer from lib.ssd.priorbox_layer import PriorBoxLayer
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import numpy as np
from lib.core.config import cfg
class BlobFetcher(multiprocessing.Process):
def __init__(self, **kwargs):
super(BlobFetcher, self).__init__()
self._img_blob_size = (
cfg.TRAIN.IMS_PER_BATCH,
cfg.SSD.RESIZE.HEIGHT,
cfg.SSD.RESIZE.WIDTH, 3,
)
self.q_in = self.q_out = None
self.daemon = True
def get(self):
img_blob, boxes_blob = np.zeros(self._img_blob_size, 'uint8'), []
for i in range(cfg.TRAIN.IMS_PER_BATCH):
img_blob[i], gt_boxes = self.q_in.get()
# Pack the boxes by adding the index of images
boxes = np.zeros((gt_boxes.shape[0], gt_boxes.shape[1] + 1), np.float32)
boxes[:, :gt_boxes.shape[1]] = gt_boxes
boxes[:, -1] = i
boxes_blob.append(boxes)
return {
'data': img_blob,
'gt_boxes': np.concatenate(boxes_blob, 0),
}
def run(self):
while True:
self.q_out.put(self.get())
...@@ -13,54 +13,69 @@ from __future__ import absolute_import ...@@ -13,54 +13,69 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from multiprocessing import Queue import multiprocessing as mp
import time import time
import dragon import dragon
import pprint import dragon.vm.torch as torch
import numpy as np
from lib.faster_rcnn.data.data_reader import DataReader from lib.core.config import cfg
from lib.ssd.data.data_transformer import DataTransformer from lib.datasets.factory import get_imdb
from lib.ssd.data.blob_fetcher import BlobFetcher from lib.ssd.data_transformer import DataTransformer
from lib.utils import logger from lib.utils import logger
class DataBatch(object): class DataLayer(torch.nn.Module):
"""DataBatch aims to prefetch data by ``Triple-Buffering``. """Generate a mini-batch of data."""
It takes full advantages of the Process/Thread of Python, def __init__(self):
super(DataLayer, self).__init__()
database = get_imdb(cfg.TRAIN.DATABASE)
self.data_batch = DataBatch(**{
'dataset': lambda: dragon.io.SeetaRecordDataset(database.source),
'classes': database.classes,
'shuffle': cfg.TRAIN.USE_SHUFFLE,
'num_chunks': cfg.TRAIN.NUM_SHUFFLE_CHUNKS,
'batch_size': cfg.TRAIN.IMS_PER_BATCH * 2,
})
which provides remarkable I/O speed up for scalable distributed training. def forward(self):
# Get an array blob from the Queue
outputs = self.data_batch.get()
# Zero-Copy the array to tensor
outputs['data'] = torch.from_numpy(outputs['data'])
return outputs
class DataBatch(mp.Process):
"""Prefetch the batch of data."""
"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Construct a ``DataBatch``. """Construct a ``DataBatch``.
Parameters Parameters
---------- ----------
source : str dataset : lambda
The path of database. The creator of a dataset.
shuffle : bool, optional, default=False shuffle : bool, optional, default=False
Whether to shuffle the data. Whether to shuffle the data.
num_chunks : int, optional, default=2048 num_chunks : int, optional, default=0
The number of chunks to split. The number of chunks to split.
batch_size : int, optional, default=128 batch_size : int, optional, default=32
The size of a mini-batch. The size of a mini-batch.
prefetch : int, optional, default=5 prefetch : int, optional, default=5
The prefetch count. The prefetch count.
""" """
super(DataBatch, self).__init__() super(DataBatch, self).__init__()
# Init mpi # Distributed settings
global_rank, local_rank, group_size = 0, 0, 1 rank, group_size = 0, 1
if dragon.mpi.is_init(): process_group = dragon.distributed.get_default_process_group()
group = dragon.mpi.is_parallel() if process_group is not None and kwargs.get(
if group is not None: # DataParallel 'phase', 'TRAIN') == 'TRAIN':
global_rank = dragon.mpi.rank() group_size = process_group.size
group_size = len(group) rank = dragon.distributed.get_rank(process_group)
for i, node in enumerate(group):
if global_rank == node:
local_rank = i
kwargs['group_size'] = group_size kwargs['group_size'] = group_size
# Configuration # Configuration
...@@ -77,63 +92,50 @@ class DataBatch(object): ...@@ -77,63 +92,50 @@ class DataBatch(object):
self._num_transformers = min( self._num_transformers = min(
self._num_transformers, self._max_transformers) self._num_transformers, self._max_transformers)
# Init queues # Initialize queues
self.Q1 = Queue(self._prefetch * self._num_readers * self._batch_size) num_batches = self._prefetch * self._num_readers
self.Q2 = Queue(self._prefetch * self._num_readers * self._batch_size) self.Q1 = mp.Queue(num_batches * self._batch_size)
self.Q3 = Queue(self._prefetch * self._num_readers) self.Q2 = mp.Queue(num_batches * self._batch_size)
self.Q3 = mp.Queue(num_batches)
# Init readers # Initialize readers
self._readers = [] self._readers = []
for i in range(self._num_readers): for i in range(self._num_readers):
self._readers.append(DataReader(**kwargs))
self._readers[-1].q_out = self.Q1
for i in range(self._num_readers):
part_idx, num_parts = i, self._num_readers part_idx, num_parts = i, self._num_readers
num_parts *= group_size num_parts *= group_size
part_idx += local_rank * self._num_readers part_idx += rank * self._num_readers
self._readers[i]._num_parts = num_parts self._readers.append(dragon.io.DataReader(
self._readers[i]._part_idx = part_idx num_parts=num_parts, part_idx=part_idx, **kwargs))
self._readers[i]._rng_seed += part_idx self._readers[i]._seed += part_idx
self._readers[i].q_out = self.Q1
self._readers[i].start() self._readers[i].start()
time.sleep(0.1) time.sleep(0.1)
# Init transformers # Initialize transformers
self._transformers = [] self._transformers = []
for i in range(self._num_transformers): for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs) transformer = DataTransformer(**kwargs)
transformer._rng_seed += (i + local_rank * self._num_transformers) transformer._rng_seed += (i + rank * self._num_transformers)
transformer.q_in = self.Q1 transformer.q_in, transformer.q_out = self.Q1, self.Q2
transformer.q_out = self.Q2
transformer.start() transformer.start()
self._transformers.append(transformer) self._transformers.append(transformer)
time.sleep(0.1) time.sleep(0.1)
# Init blob fetchers # Initialize batch-producer
self._fetchers = [] self.start()
for i in range(self._num_fetchers):
fetcher = BlobFetcher(**kwargs)
fetcher.q_in = self.Q2
fetcher.q_out = self.Q3
fetcher.start()
self._fetchers.append(fetcher)
time.sleep(0.1)
# Prevent to echo multiple nodes
if local_rank == 0:
self.echo()
# Register cleanup callbacks
def cleanup(): def cleanup():
def terminate(processes): def terminate(processes):
for process in processes: for process in processes:
process.terminate() process.terminate()
process.join() process.join()
terminate(self._fetchers) terminate([self])
logger.info('Terminating BlobFetcher ......') logger.info('Terminate DataBatch.')
terminate(self._transformers) terminate(self._transformers)
logger.info('Terminating DataTransformer ......') logger.info('Terminate DataTransformer.')
terminate(self._readers) terminate(self._readers)
logger.info('Terminating DataReader......') logger.info('Terminate DataReader.')
import atexit import atexit
atexit.register(cleanup) atexit.register(cleanup)
...@@ -149,14 +151,24 @@ class DataBatch(object): ...@@ -149,14 +151,24 @@ class DataBatch(object):
""" """
return self.Q3.get() return self.Q3.get()
def echo(self): def run(self):
"""Print I/O Information.""" """Start the process to produce batches."""
print('---------------------------------------------------------') image_batch_shape = (
print('BatchFetcher({} Threads), Using config:'.format( cfg.TRAIN.IMS_PER_BATCH,
self._num_readers + self._num_transformers + self._num_fetchers)) cfg.SSD.RESIZE.HEIGHT,
params = {'queue_size': self._prefetch, cfg.SSD.RESIZE.WIDTH, 3,
'n_readers': self._num_readers, )
'n_transformers': self._num_transformers,
'n_fetchers': self._num_fetchers} while True:
pprint.pprint(params) boxes_to_pack = []
print('---------------------------------------------------------') image_batch = np.zeros(image_batch_shape, 'uint8')
for image_index in range(cfg.TRAIN.IMS_PER_BATCH):
image_batch[image_index], gt_boxes = self.Q2.get()
boxes = np.zeros((gt_boxes.shape[0], gt_boxes.shape[1] + 1), 'float32')
boxes[:, :gt_boxes.shape[1]], boxes[:, -1] = gt_boxes, image_index
boxes_to_pack.append(boxes)
self.Q3.put({
'data': image_batch,
'gt_boxes': np.concatenate(boxes_to_pack),
})
...@@ -13,14 +13,14 @@ from __future__ import absolute_import ...@@ -13,14 +13,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import cv2
import multiprocessing import multiprocessing
import cv2
import numpy as np import numpy as np
from lib.core.config import cfg from lib.core.config import cfg
from lib.proto import anno_pb2 as pb from lib.ssd import transforms
from lib.ssd.data import transforms from lib.utils.boxes import flip_boxes
from lib.utils import logger
class DataTransformer(multiprocessing.Process): class DataTransformer(multiprocessing.Process):
...@@ -41,38 +41,41 @@ class DataTransformer(multiprocessing.Process): ...@@ -41,38 +41,41 @@ class DataTransformer(multiprocessing.Process):
self.q_in = self.q_out = None self.q_in = self.q_out = None
self.daemon = True self.daemon = True
def make_roi_dict(self, ann_datum, flip=False): def make_roi_dict(self, example, flip=False):
annotations = ann_datum.annotation
n_objects = 0 n_objects = 0
if not self._use_diff: if not self._use_diff:
for ann in annotations: for obj in example['object']:
if not ann.difficult: n_objects += 1 if obj.get('difficult', 0) == 0:
else: n_objects = len(annotations) n_objects += 1
else:
n_objects = len(example['object'])
roi_dict = { roi_dict = {
'width': ann_datum.datum.width, 'width': example['width'],
'height': ann_datum.datum.height, 'height': example['height'],
'gt_classes': np.zeros((n_objects,), dtype=np.int32), 'gt_classes': np.zeros((n_objects,), 'int32'),
'boxes': np.zeros((n_objects, 4), dtype=np.float32), 'boxes': np.zeros((n_objects, 4), 'float32'),
'normalized_boxes': np.zeros((n_objects, 4), dtype=np.float32), 'normalized_boxes': np.zeros((n_objects, 4), 'float32'),
} }
rec_idx = 0 # Filter the difficult instances
for ann in annotations: object_idx = 0
if not self._use_diff and ann.difficult: for obj in example['object']:
if not self._use_diff and \
obj.get('difficult', 0) > 0:
continue continue
roi_dict['boxes'][rec_idx, :] = [ roi_dict['boxes'][object_idx, :] = [
max(0, ann.x1), max(0, obj['xmin']),
max(0, ann.y1), max(0, obj['ymin']),
min(ann.x2, ann_datum.datum.width - 1), min(obj['xmax'], example['width'] - 1),
min(ann.y2, ann_datum.datum.height - 1), min(obj['ymax'], example['height'] - 1),
] ]
roi_dict['gt_classes'][rec_idx] = \ roi_dict['gt_classes'][object_idx] = \
self._class_to_ind[ann.name] self._class_to_ind[obj['name']]
rec_idx += 1 object_idx += 1
if flip: if flip:
roi_dict['boxes'] = _flip_boxes( roi_dict['boxes'] = flip_boxes(
roi_dict['boxes'], roi_dict['width']) roi_dict['boxes'], roi_dict['width'])
roi_dict['boxes'][:, 0::2] /= roi_dict['width'] roi_dict['boxes'][:, 0::2] /= roi_dict['width']
...@@ -80,26 +83,19 @@ class DataTransformer(multiprocessing.Process): ...@@ -80,26 +83,19 @@ class DataTransformer(multiprocessing.Process):
return roi_dict return roi_dict
def get(self, serialized): def get(self, example):
ann_datum = pb.AnnotatedDatum() img = np.frombuffer(example['content'], np.uint8)
ann_datum.ParseFromString(serialized) img = cv2.imdecode(img, -1)
img_datum = ann_datum.datum
img = np.fromstring(img_datum.data, np.uint8)
if img_datum.encoded is True:
img = cv2.imdecode(img, -1)
else:
h, w = img_datum.height, img_datum.width
img = img.reshape((h, w, img_datum.channels))
# Flip # Flip
flip = False flip = False
if self._mirror: if self._mirror:
if np.random.randint(0, 2) > 0: if np.random.randint(2) > 0:
img = img[:, ::-1, :] img = img[:, ::-1, :]
flip = True flip = True
# Datum -> RoIDB # Example -> RoIDict
roi_dict = self.make_roi_dict(ann_datum, flip) roi_dict = self.make_roi_dict(example, flip)
# Post-Process for gt boxes # Post-Process for gt boxes
# Shape like: [num_objects, {x1, y1, x2, y2, cls}] # Shape like: [num_objects, {x1, y1, x2, y2, cls}]
...@@ -120,19 +116,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -120,19 +116,7 @@ class DataTransformer(multiprocessing.Process):
def run(self): def run(self):
np.random.seed(self._rng_seed) np.random.seed(self._rng_seed)
while True: while True:
serialized = self.q_in.get() outputs = self.get(self.q_in.get())
im, gt_boxes = self.get(serialized) if len(outputs[1]) < 1:
if len(gt_boxes) < 1: continue # Ignore the non-object image
continue self.q_out.put(outputs)
self.q_out.put((im, gt_boxes))
def _flip_boxes(boxes, width):
flip_boxes = boxes.copy()
old_x1 = boxes[:, 0].copy()
old_x2 = boxes[:, 2].copy()
flip_boxes[:, 0] = width - old_x2 - 1
flip_boxes[:, 2] = width - old_x1 - 1
if not (flip_boxes[:, 2] >= flip_boxes[:, 0]).all():
logger.fatal('Encounter invalid coordinates after flipping boxes.')
return flip_boxes
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
def generate_anchors(min_sizes, max_sizes, ratios): def generate_anchors(min_sizes, max_sizes, ratios):
""" """
Generate anchor (reference) windows by enumerating Generate anchor (reference) windows by enumerating
aspect ratios, min_sizes, max_sizes wrt a reference ctr (x, y, w, h). aspect ratios, min_sizes, max_sizes wrt a reference ctr (x, y, w, h).
""" """
total_anchors = [] total_anchors = []
for idx, min_size in enumerate(min_sizes): for idx, min_size in enumerate(min_sizes):
# Note that SSD assume it is a ctr-anchor # Note that SSD assume it is a ctr-anchor
base_anchor = np.array([0, 0, min_size, min_size]) base_anchor = np.array([0, 0, min_size, min_size])
anchors = _ratio_enum(base_anchor, ratios) anchors = _ratio_enum(base_anchor, ratios)
if len(max_sizes) > 0: if len(max_sizes) > 0:
max_size = max_sizes[idx] max_size = max_sizes[idx]
_anchors = anchors[0].reshape((1, 4)) _anchors = anchors[0].reshape((1, 4))
_anchors = np.vstack([_anchors, _max_size_enum( _anchors = np.vstack([_anchors, _max_size_enum(
base_anchor, min_size, max_size)]) base_anchor, min_size, max_size)])
anchors = np.vstack([_anchors, anchors[1:]]) anchors = np.vstack([_anchors, anchors[1:]])
total_anchors.append(anchors) total_anchors.append(anchors)
return np.vstack(total_anchors) return np.vstack(total_anchors)
def _whctrs(anchor): def _whctrs(anchor):
"""Return width, height, x center, and y center for an anchor (window).""" """Return width, height, x center, and y center for an anchor (window)."""
w, h = anchor[2], anchor[3] w, h = anchor[2], anchor[3]
x_ctr, y_ctr = anchor[0], anchor[1] x_ctr, y_ctr = anchor[0], anchor[1]
return w, h, x_ctr, y_ctr return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr): def _mkanchors(ws, hs, x_ctr, y_ctr):
""" """
Given a vector of widths (ws) and heights (hs) around a center Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows). (x_ctr, y_ctr), output a set of anchors (windows).
""" """
ws = ws[:, np.newaxis] ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis] hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * ws, anchors = np.hstack((x_ctr - 0.5 * ws,
y_ctr - 0.5 * hs, y_ctr - 0.5 * hs,
x_ctr + 0.5 * ws, x_ctr + 0.5 * ws,
y_ctr + 0.5 * hs)) y_ctr + 0.5 * hs))
return anchors return anchors
def _ratio_enum(anchor, ratios): def _ratio_enum(anchor, ratios):
"""Enumerate a set of anchors for each aspect ratio wrt an anchor.""" """Enumerate a set of anchors for each aspect ratio wrt an anchor."""
w, h, x_ctr, y_ctr = _whctrs(anchor) w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h size = w * h
size_ratios = size / ratios size_ratios = size / ratios
hs = np.round(np.sqrt(size_ratios)) hs = np.round(np.sqrt(size_ratios))
ws = np.round(hs * ratios) ws = np.round(hs * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors return anchors
def _max_size_enum(base_anchor, min_size, max_size): def _max_size_enum(base_anchor, min_size, max_size):
"""Enumerate a anchor for max_size wrt base_anchor.""" """Enumerate a anchor for max_size wrt base_anchor."""
w, h, x_ctr, y_ctr = _whctrs(base_anchor) w, h, x_ctr, y_ctr = _whctrs(base_anchor)
ws = hs = np.sqrt([min_size * max_size]) ws = hs = np.sqrt([min_size * max_size])
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors return anchors
if __name__ == '__main__': if __name__ == '__main__':
print(generate_anchors(min_sizes=[30], max_sizes=[60], ratios=[1, 0.5, 2])) print(generate_anchors(min_sizes=[30], max_sizes=[60], ratios=[1, 0.5, 2]))
# ------------------------------------------------------------ # ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# #
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License # You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See, # along with the software. If not, See,
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import dragon.vm.torch as torch import dragon.vm.torch as torch
import numpy as np import numpy as np
from lib.core.config import cfg from lib.core.config import cfg
from lib.utils.blob import blob_to_tensor from lib.utils.blob import blob_to_tensor
class HardMiningLayer(torch.nn.Module): class HardMiningLayer(torch.nn.Module):
def __init__(self): def __init__(self):
super(HardMiningLayer, self).__init__() super(HardMiningLayer, self).__init__()
def forward(self, conf_prob, match_labels, max_overlaps): def forward(self, conf_prob, match_labels, max_overlaps):
# Confidence of each matched box # Confidence of each matched box
conf_prob_wide = conf_prob.numpy(True) conf_prob_wide = conf_prob.numpy(True)
# Label of each matched box # Label of each matched box
match_labels_wide = match_labels match_labels_wide = match_labels
# Max overlaps between default boxes and gt boxes # Max overlaps between default boxes and gt boxes
max_overlaps_wide = max_overlaps max_overlaps_wide = max_overlaps
# label ``-1`` will be ignored # label ``-1`` will be ignored
labels_wide = -np.ones(match_labels_wide.shape, dtype=np.int64) labels_wide = -np.ones(match_labels_wide.shape, dtype=np.int64)
for ix in range(match_labels_wide.shape[0]): for ix in range(match_labels_wide.shape[0]):
match_labels = match_labels_wide[ix] match_labels = match_labels_wide[ix]
max_overlaps = max_overlaps_wide[ix] max_overlaps = max_overlaps_wide[ix]
conf_prob = conf_prob_wide[ix] conf_prob = conf_prob_wide[ix]
conf_loss = np.zeros(match_labels.shape, dtype=np.float32) conf_loss = np.zeros(match_labels.shape, dtype=np.float32)
inds = np.where(match_labels >= 0)[0] inds = np.where(match_labels >= 0)[0]
flt_min = np.finfo(float).eps flt_min = np.finfo(float).eps
# Softmax cross-entropy # Softmax cross-entropy
conf_loss[inds] = -np.log(np.maximum( conf_loss[inds] = -np.log(np.maximum(
conf_prob[inds, match_labels[inds]], flt_min)) conf_prob[inds, match_labels[inds]], flt_min))
# Filter negatives # Filter negatives
fg_inds = np.where(match_labels > 0)[0] fg_inds = np.where(match_labels > 0)[0]
neg_inds = np.where(match_labels == 0)[0] neg_inds = np.where(match_labels == 0)[0]
neg_overlaps = max_overlaps[neg_inds] neg_overlaps = max_overlaps[neg_inds]
eligible_neg_inds = np.where(neg_overlaps < cfg.SSD.OHEM.NEG_OVERLAP)[0] eligible_neg_inds = np.where(neg_overlaps < cfg.SSD.OHEM.NEG_OVERLAP)[0]
sel_inds = neg_inds[eligible_neg_inds] sel_inds = neg_inds[eligible_neg_inds]
# Do Mining # Do Mining
sel_loss = conf_loss[sel_inds] sel_loss = conf_loss[sel_inds]
num_pos = len(fg_inds) num_pos = len(fg_inds)
num_sel = min(int(num_pos * cfg.SSD.OHEM.NEG_POS_RATIO), len(sel_inds)) num_sel = min(int(num_pos * cfg.SSD.OHEM.NEG_POS_RATIO), len(sel_inds))
sorted_sel_inds = sel_inds[np.argsort(-sel_loss)] sorted_sel_inds = sel_inds[np.argsort(-sel_loss)]
bg_inds = sorted_sel_inds[:num_sel] bg_inds = sorted_sel_inds[:num_sel]
labels_wide[ix][fg_inds] = match_labels[fg_inds] # Keep fg indices labels_wide[ix][fg_inds] = match_labels[fg_inds] # Keep fg indices
labels_wide[ix][bg_inds] = 0 # Use hard negatives as bg indices labels_wide[ix][bg_inds] = 0 # Use hard negatives as bg indices
# Feed labels to compute cls loss # Feed labels to compute cls loss
return {'labels': blob_to_tensor(labels_wide)} return {'labels': blob_to_tensor(labels_wide)}
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dragon.vm.torch as torch
from lib.core.config import cfg
from lib.datasets.factory import get_imdb
from lib.ssd.data.data_batch import DataBatch
class DataLayer(torch.nn.Module):
def __init__(self):
super(DataLayer, self).__init__()
database = get_imdb(cfg.TRAIN.DATABASE)
self.data_batch = DataBatch(**{
'source': database.source,
'classes': database.classes,
'shuffle': cfg.TRAIN.USE_SHUFFLE,
'num_chunks': 2048, # Chunk-Wise Shuffle
'batch_size': cfg.TRAIN.IMS_PER_BATCH * 2,
})
def forward(self):
# Get an array blob from the Queue
outputs = self.data_batch.get()
# Zero-Copy the array to tensor
outputs['data'] = torch.from_numpy(outputs['data'])
return outputs
...@@ -19,7 +19,7 @@ sys.path.append('../../') ...@@ -19,7 +19,7 @@ sys.path.append('../../')
import cv2 import cv2
import numpy as np import numpy as np
from lib.ssd.data import transforms from lib.ssd import transforms
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -201,6 +201,16 @@ def expand_boxes(boxes, scale): ...@@ -201,6 +201,16 @@ def expand_boxes(boxes, scale):
return boxes_exp return boxes_exp
def flip_boxes(boxes, width):
"""Flip the boxes horizontally."""
flip_boxes = boxes.copy()
old_x1 = boxes[:, 0].copy()
old_x2 = boxes[:, 2].copy()
flip_boxes[:, 0] = width - old_x2 - 1
flip_boxes[:, 2] = width - old_x1 - 1
return flip_boxes
def filter_boxes(boxes, min_size): def filter_boxes(boxes, min_size):
"""Remove all boxes with any side smaller than min size.""" """Remove all boxes with any side smaller than min size."""
ws = boxes[:, 2] - boxes[:, 0] + 1 ws = boxes[:, 2] - boxes[:, 0] + 1
......
...@@ -62,22 +62,20 @@ if __name__ == '__main__': ...@@ -62,22 +62,20 @@ if __name__ == '__main__':
if checkpoint is not None: if checkpoint is not None:
cfg.TRAIN.WEIGHTS = checkpoint cfg.TRAIN.WEIGHTS = checkpoint
# Setup MPI # Setup the distributed environment
if cfg.NUM_GPUS != dragon.mpi.size(): world_rank = dragon.distributed.get_rank()
world_size = dragon.distributed.get_world_size()
if cfg.NUM_GPUS != world_size:
raise ValueError( raise ValueError(
'Excepted {} mpi nodes, but got {}.' 'Excepted staring of {} processes, got {}.'
.format(len(args.gpus), dragon.mpi.size()) .format(cfg.NUM_GPUS, world_size)
) )
GPUs = [i for i in range(cfg.NUM_GPUS)] logger.set_root_logger(world_rank == 0)
cfg.GPU_ID = GPUs[dragon.mpi.rank()]
dragon.mpi.add_parallel_group([i for i in range(cfg.NUM_GPUS)])
dragon.mpi.set_parallel_mode('NCCL' if cfg.USE_NCCL else 'MPI')
# Setup logger # Select the GPU depending on the rank of process
if dragon.mpi.rank() != 0: cfg.GPU_ID = [i for i in range(cfg.NUM_GPUS)][world_rank]
logger.set_root_logger(False)
# Fix the random seeds (numpy and dragon) for reproducibility # Fix the random seed for reproducibility
numpy.random.seed(cfg.RNG_SEED) numpy.random.seed(cfg.RNG_SEED)
dragon.config.set_random_seed(cfg.RNG_SEED) dragon.config.set_random_seed(cfg.RNG_SEED)
...@@ -89,7 +87,8 @@ if __name__ == '__main__': ...@@ -89,7 +87,8 @@ if __name__ == '__main__':
# Ready to train the network # Ready to train the network
logger.info('Output will be saved to `{:s}`' logger.info('Output will be saved to `{:s}`'
.format(coordinator.checkpoints_dir())) .format(coordinator.checkpoints_dir()))
train_net(coordinator, start_iter) with dragon.distributed.new_group(
ranks=[i for i in range(cfg.NUM_GPUS)],
# Finalize mpi backend='NCCL' if cfg.USE_NCCL else 'MPI',
dragon.mpi.finalize() verbose=True).as_default():
train_net(coordinator, start_iter)
...@@ -82,7 +82,7 @@ if __name__ == '__main__': ...@@ -82,7 +82,7 @@ if __name__ == '__main__':
if checkpoint is not None: if checkpoint is not None:
cfg.TRAIN.WEIGHTS = checkpoint cfg.TRAIN.WEIGHTS = checkpoint
# Fix the random seeds (numpy and dragon) for reproducibility # Fix the random seed for reproducibility
numpy.random.seed(cfg.RNG_SEED) numpy.random.seed(cfg.RNG_SEED)
dragon.config.set_random_seed(cfg.RNG_SEED) dragon.config.set_random_seed(cfg.RNG_SEED)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!