Commit c8535116 by Ting PAN

Compile proto files

1 parent 89f1ee28
------------------------------------------------------------------------
The list of most significant changes made over time in SeetaDet.
SeetaDet 0.1.0 (20190311)
SeetaDet 0.1.0 (20190314)
Dragon Minimum Required (Version 0.3.0.0)
......
# delete cache
rm -r build install *.c *.cpp
# compile proto files
protoc -I ../lib/proto --python_out=../lib/proto ../lib/proto/anno.proto
# compile cython modules
python setup.py build_ext --inplace
......
......@@ -23,6 +23,7 @@ MODEL:
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
WARM_UP_ITERS: 2000 # default: 500
LR_POLICY: steps_with_decay
STEPS: [120000, 160000]
MAX_ITERS: 180000
......
syntax = "proto2";
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;
optional int32 label = 5;
repeated float float_data = 6;
optional bool encoded = 7 [default = false];
}
message Annotation {
optional float x1 = 1;
optional float y1 = 2;
optional float x2 = 3;
optional float y2 = 4;
optional string name = 5;
optional bool difficult = 6 [default = false];
optional bool crowd = 7 [default = false];
optional string mask = 8;
}
message AnnotatedDatum {
optional Datum datum = 1;
optional string filename = 2;
repeated Annotation annotation = 3;
}
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: anno.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='anno.proto',
package='',
serialized_pb=_b('\n\nanno.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\"\x88\x01\n\nAnnotation\x12\n\n\x02x1\x18\x01 \x01(\x02\x12\n\n\x02y1\x18\x02 \x01(\x02\x12\n\n\x02x2\x18\x03 \x01(\x02\x12\n\n\x02y2\x18\x04 \x01(\x02\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x18\n\tdifficult\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x14\n\x05\x63rowd\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x0c\n\x04mask\x18\x08 \x01(\t\"Z\n\x0e\x41nnotatedDatum\x12\x15\n\x05\x64\x61tum\x18\x01 \x01(\x0b\x32\x06.Datum\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\x12\x1f\n\nannotation\x18\x03 \x03(\x0b\x32\x0b.Annotation')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATUM = _descriptor.Descriptor(
name='Datum',
full_name='Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='channels', full_name='Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='height', full_name='Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='width', full_name='Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data', full_name='Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='label', full_name='Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='encoded', full_name='Datum.encoded', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=15,
serialized_end=144,
)
_ANNOTATION = _descriptor.Descriptor(
name='Annotation',
full_name='Annotation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='x1', full_name='Annotation.x1', index=0,
number=1, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y1', full_name='Annotation.y1', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='x2', full_name='Annotation.x2', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y2', full_name='Annotation.y2', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='Annotation.name', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='difficult', full_name='Annotation.difficult', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='crowd', full_name='Annotation.crowd', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mask', full_name='Annotation.mask', index=7,
number=8, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=147,
serialized_end=283,
)
_ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum',
full_name='AnnotatedDatum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='datum', full_name='AnnotatedDatum.datum', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='filename', full_name='AnnotatedDatum.filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='annotation', full_name='AnnotatedDatum.annotation', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=285,
serialized_end=375,
)
_ANNOTATEDDATUM.fields_by_name['datum'].message_type = _DATUM
_ANNOTATEDDATUM.fields_by_name['annotation'].message_type = _ANNOTATION
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION
DESCRIPTOR.message_types_by_name['AnnotatedDatum'] = _ANNOTATEDDATUM
Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
DESCRIPTOR = _DATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Datum)
))
_sym_db.RegisterMessage(Datum)
Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATION,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Annotation)
))
_sym_db.RegisterMessage(Annotation)
AnnotatedDatum = _reflection.GeneratedProtocolMessageType('AnnotatedDatum', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATEDDATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:AnnotatedDatum)
))
_sym_db.RegisterMessage(AnnotatedDatum)
# @@protoc_insertion_point(module_scope)
......@@ -14,11 +14,14 @@ from __future__ import division
from __future__ import print_function
import os
import sys
import time
import cv2
from . import anno_pb2 as pb
from dragon.tools.db import LMDB
sys.path.insert(0, '../../..')
from lib.proto import anno_pb2 as pb
ZFILL = 8
ENCODE_QUALITY = 95
......@@ -88,9 +91,6 @@ def make_db(database_file, images_path, gt_recs, ext='.png'):
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
......
......@@ -14,12 +14,14 @@ from __future__ import division
from __future__ import print_function
import os
import sys
import time
import cv2
import xml.etree.ElementTree as ET
from dragon.tools.db import LMDB
from . import anno_pb2 as pb
sys.path.insert(0, '../../..')
from lib.proto import anno_pb2 as pb
ZFILL = 8
ENCODE_QUALITY = 95
......@@ -124,9 +126,6 @@ def make_db(database_file,
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
......
# --------------------------------------------------------
# Detectron @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make LMDB for cityscape dataset."""
import os
import sys
import shutil
import numpy as np
np.random.seed(1337)
try:
import cPickle
except:
import pickle as cPickle
sys.path.insert(0, '../../../')
from database.mrcnn.utils.make import make_db
from database.mrcnn.cityscape.make_mask import make_mask
if __name__ == '__main__':
CITYSCAPE_ROOT = '/data/cityscape'
# make RLE masks
if not os.path.exists('build'): os.makedirs('build')
cs_train = make_mask(
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest'),
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists/train.lst'))
cs_val = make_mask(
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest'),
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists/val.lst'))
with open('build/cs_train_mask.pkl', 'wb') as f:
cPickle.dump(cs_train, f, cPickle.HIGHEST_PROTOCOL)
with open('build/cs_val_mask.pkl', 'wb') as f:
cPickle.dump(cs_val, f, cPickle.HIGHEST_PROTOCOL)
# make image splits
for split in ['train', 'val', 'test']:
with open(os.path.join(CITYSCAPE_ROOT,
'gtFine_trainvaltest/imglists', split + '.lst'), 'r') as f:
entries = [line.strip().split('\t') for line in f.readlines()]
if split == 'train': np.random.shuffle(entries)
with open(os.path.join(CITYSCAPE_ROOT,
'gtFine_trainvaltest/imglists', split + '.txt'), 'w') as w:
for entry in entries: w.write(entry[1].split('.')[0] + '\n')
# make database
make_db(database_file=os.path.join(CITYSCAPE_ROOT, 'cache/cs_train_lmdb'),
images_path=os.path.join(CITYSCAPE_ROOT, 'leftImg8bit_trainvaltest'),
mask_file='build/cs_train_mask.pkl',
splits_path=os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists'),
splits=['train'], ext='.png')
make_db(database_file=os.path.join(CITYSCAPE_ROOT, 'cache/cs_val_lmdb'),
images_path=os.path.join(CITYSCAPE_ROOT, 'leftImg8bit_trainvaltest'),
mask_file='build/cs_val_mask.pkl',
splits_path=os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists'),
splits=['val'], ext='.png')
# clean!
shutil.rmtree('build')
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make masks for cityscape dataset."""
import os
import sys
import cv2
from collections import OrderedDict
import PIL.Image as Image
import numpy as np
np.random.seed(1337)
sys.path.insert(0, '../../..')
from lib.pycocotools.mask_utils import mask_bin2rle
from database.mrcnn.utils.process_pool import ProcessPool
class_id = [0,
24, 25, 26, 27,
28, 31, 32, 33]
classes = ['__background__',
'person', 'rider', 'car', 'truck',
'bus', 'train', 'motorcycle', 'bicycle']
ind_to_class = dict(zip(range(len(classes)), classes))
def parse_gt(gt_file, im_scale=1.0):
im = Image.open(gt_file)
pixel = list(im.getdata())
pixel = np.array(pixel).reshape([im.size[1], im.size[0]])
objects = []
for c in range(1, len(class_id)):
px = np.where((pixel >= class_id[c] * 1000) & (pixel < (class_id[c] + 1) * 1000))
if len(px[0]) == 0: continue
uids = np.unique(pixel[px])
for idx, uid in enumerate(uids):
px = np.where(pixel == uid)
x1 = np.min(px[1])
y1 = np.min(px[0])
x2 = np.max(px[1])
y2 = np.max(px[0])
if x2 - x1 <= 1 or y2 - y1 <= 1: continue
mask = np.zeros([im.size[1], im.size[0]], dtype=np.uint8)
mask[px] = 1
if im_scale != 1:
mask = cv2.resize(mask, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_NEAREST)
x1 = min(int(x1 * im_scale), mask.shape[1])
y1 = min(int(y1 * im_scale), mask.shape[0])
x2 = min(int(x2 * im_scale), mask.shape[1])
y2 = min(int(y2 * im_scale), mask.shape[0])
objects.append({'bbox': [x1, y1, x2, y2],
'mask': mask_bin2rle([mask])[0],
'name': ind_to_class[c],
'difficult': False})
return objects
def map_func(gts, Q):
for image_id, gt_file in gts:
objects = parse_gt(gt_file)
Q.put((image_id, objects))
def make_mask(gt_root, split_file):
# Create tasks
gt_tasks, gt_recs = [], OrderedDict()
with open(split_file, 'r') as f:
for line in f:
_, image_path, gt_path = line.strip().split('\t')
image_id = image_path.split('.')[0]
gt_file = os.path.join(gt_root, gt_path.replace('labelTrainIds', 'instanceIds'))
gt_tasks.append((image_id, gt_file))
num_tasks = len(gt_tasks)
# Run!
with ProcessPool(16) as pool:
pool.run(gt_tasks, func=map_func)
for idx in range(num_tasks):
image_id, objects = pool.get()
gt_recs[image_id] = objects
print('\rProcess: {} / {}'.format(idx + 1, num_tasks), end='')
return gt_recs
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import cv2
from collections import defaultdict
from lib.pycocotools.mask_utils import mask_rle2im
CITYSCAPE_ROOT = '/data/cityscape'
def write_results(json_file, img_list):
with open(json_file, 'r') as f:
json_results = json.load(f)
class_id = [0, 24, 25, 26, 27, 28, 31, 32, 33]
category_id_to_class_id = dict(zip(range(9), class_id))
result_path = os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest', 'results', 'pred')
if not os.path.exists(result_path): os.makedirs(result_path)
counts = defaultdict(int)
txt_results = defaultdict(list)
for idx, rec in enumerate(json_results):
class_id = category_id_to_class_id[rec['category_id']]
if class_id == 0: continue
im_h, im_w = rec['segmentation']['size']
mask_rle = rec['segmentation']['counts']
mask_image = mask_rle2im([mask_rle], im_h, im_w)[0] * 200
image_name = rec['image_id'].split('_leftImg8bit')[0]
mask_name = image_name + '_' + str(counts[image_name]) + '.png'
counts[image_name] += 1
mask_path = os.path.join(result_path, mask_name)
cv2.imwrite(mask_path, mask_image)
txt_results[image_name].append((mask_name, class_id, rec['score']))
print('\rWriting masks ({} / {})'.format(idx + 1, len(json_results)), end='')
with open(img_list, 'r') as F:
for line in F.readlines():
image_name = line.strip().split('/')[-1].split('_leftImg8bit')[0]
txt_path = os.path.join(result_path, image_name + '.txt')
with open(txt_path, 'w') as f:
for rec in txt_results[image_name]:
f.write('{} {} {:.8f}\n'.format(rec[0], rec[1], rec[2]))
if __name__ == '__main__':
write_results(
'/results/segmentations.json',
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest', 'imglists', 'val.txt')
)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make LMDB for COCO dataset."""
import os
import sys
import shutil
sys.path.insert(0, '../../../')
from database.mrcnn.utils.make import make_db
from database.mrcnn.coco.make_mask import make_mask, merge_mask
if __name__ == '__main__':
COCO_ROOT = '/data/coco'
# make RLE masks
if not os.path.exists('build'): os.makedirs('build')
make_mask('train', '2014', COCO_ROOT)
make_mask('valminusminival', '2014', COCO_ROOT)
make_mask('minival', '2014', COCO_ROOT)
merge_mask('trainval35k', '2014', [
'build/coco_2014_train_mask.pkl',
'build/coco_2014_valminusminival_mask.pkl'])
# train database: coco_2014_trainval35k
make_db(database_file=os.path.join(COCO_ROOT, 'cache/coco_2014_trainval35k_lmdb'),
images_path=[os.path.join(COCO_ROOT, 'images/train2014'),
os.path.join(COCO_ROOT, 'images/val2014')],
splits_path=[os.path.join(COCO_ROOT, 'ImageSets'),
os.path.join(COCO_ROOT, 'ImageSets')],
mask_file='build/coco_2014_trainval35k_mask.pkl',
splits=['train', 'valminusminival'])
# val database: coco_2014_minival
make_db(database_file=os.path.join(COCO_ROOT, 'cache/coco_2014_minival_lmdb'),
images_path=os.path.join(COCO_ROOT, 'images/val2014'),
mask_file='build/coco_2014_minival_mask.pkl',
splits_path=os.path.join(COCO_ROOT, 'ImageSets'),
splits=['minival'])
# clean!
shutil.rmtree('build')
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import sys
import os.path as osp
from collections import OrderedDict
try:
import cPickle
except:
import pickle as cPickle
sys.path.insert(0, '../../..')
from lib.pycocotools.coco import COCO
from lib.pycocotools.mask_utils import mask_poly2rle
class imdb(object):
def __init__(self, image_set, year, data_dir):
self._year = year
self._image_set = image_set
self._data_path = osp.join(data_dir)
self.invalid_cnt = 0
self.ignore_cnt = 0
#################
# CLASSES #
#################
# load COCO API, classes, class <-> id mappings
self._COCO = COCO(self._get_ann_file())
cats = self._COCO.loadCats(self._COCO.getCatIds())
self._classes = tuple(['__background__'] + [c['name'] for c in cats])
self._class_to_ind = dict(zip(self._classes, range(self.num_classes)))
self._ind_to_class = dict(zip(range(self.num_classes), self._classes))
self._class_to_coco_cat_id = dict(zip([c['name'] for c in cats],
self._COCO.getCatIds()))
self._coco_cat_id_to_class_ind = dict([(self._class_to_coco_cat_id[cls],
self._class_to_ind[cls]) for cls in self._classes[1:]])
#################
# SET #
#################
self._view_map = {
'minival2014': 'val2014', # 5k val2014 subset
'valminusminival2014': 'val2014', # val2014 \setminus minival2014
}
coco_name = image_set + year # e.g., "val2014"
self._data_name = (self._view_map[coco_name]
if coco_name in self._view_map else coco_name)
#################
# IMAGES #
#################
self._image_index = self._load_image_set_index()
self._annotations = self._load_annotations()
def _get_ann_file(self):
prefix = 'instances' if self._image_set.find('test') == -1 \
else 'image_info'
return osp.join(self._data_path, 'annotations',
prefix + '_' + self._image_set + self._year + '.json')
def _load_image_set_index(self):
"""
Load image ids.
"""
image_ids = self._COCO.getImgIds()
return image_ids
def _load_annotations(self):
"""
Load annotations.
"""
annotations = [self._load_coco_annotation(index)
for index in self._image_index]
return annotations
def image_path_from_index(self, index):
"""
Construct an image path from the image's "index" identifier.
"""
# Example image path for index=119993:
# images/train2014/COCO_train2014_000000119993.jpg
file_name = ('COCO_' + self._data_name + '_' +
str(index).zfill(12) + '.jpg')
image_path = osp.join(self._data_path, 'images',
self._data_name, file_name)
assert osp.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
def image_path_at(self, i):
"""
Return the absolute path to image i in the image sequence.
"""
return self.image_path_from_index(self._image_index[i])
def annotation_at(self, i):
"""
Return the absolute path to image i in the image sequence.
"""
return self._annotations[i]
def _load_coco_annotation(self, index):
"""
Loads COCO bounding-box instance annotations. Crowd instances are
handled by marking their overlaps (with all categories) to -1. This
overlap value means that crowd "instances" are excluded from training.
"""
im_ann = self._COCO.loadImgs(index)[0]
width = im_ann['width']
height = im_ann['height']
annIds = self._COCO.getAnnIds(imgIds=index, iscrowd=None)
objs = self._COCO.loadAnns(annIds)
# Sanitize boxes -- some are invalid
valid_objs = []
for obj in objs:
x1 = int(max(0, obj['bbox'][0]))
y1 = int(max(0, obj['bbox'][1]))
x2 = int(min(width - 1, x1 + max(0, obj['bbox'][2] - 1)))
y2 = int(min(height - 1, y1 + max(0, obj['bbox'][3] - 1)))
if type(obj['segmentation']) is list:
for p in obj['segmentation']:
if len(p) < 6: print('Remove invalid segm.')
# Valid polygons have >= 3 points, so require >= 6 coordinates
obj['segmentation'] = [p for p in obj['segmentation'] if len(p) >= 6]
rle_masks = mask_poly2rle([obj['segmentation']], height, width)
else:
# crowd masks
rle_masks = [obj['segmentation']]
if obj['area'] > 0 and x2 > x1 and y2 > y1:
obj['clean_bbox'] = [x1, y1, x2, y2]
# Exclude the crowd masks
# TODO(PhyscalX): You may encounter crashes when decoding crowd masks.
mask = rle_masks[0] if not obj['iscrowd'] else ''
valid_objs.append(
{'bbox': [x1, y1, x2, y2],
'mask': mask,
'category_id': obj['category_id'],
'class_id': self._coco_cat_id_to_class_ind[obj['category_id']],
'crowd': obj['iscrowd']})
valid_objs[-1]['name'] = self._ind_to_class[valid_objs[-1]['class_id']]
return height, width, valid_objs
@property
def num_images(self):
return len(self._image_index)
@property
def num_classes(self):
return len(self._classes)
def make_mask(split, year, data_dir):
coco = imdb(split, year, data_dir)
print('Preparing to make split: {}, total {} images'.format(split, coco.num_images))
if not osp.exists(osp.join(coco._data_path, 'ImageSets')):
os.makedirs(osp.join(coco._data_path, 'ImageSets'))
gt_recs = OrderedDict()
for i in range(coco.num_images):
filename = (coco.image_path_at(i).split('/')[-1]).split('.')[0]
h, w, objects = coco.annotation_at(i)
gt_recs[filename] = objects
with open(osp.join('build',
'coco_' + year + '_' + split + '_mask.pkl'), 'wb') as f:
cPickle.dump(gt_recs, f, cPickle.HIGHEST_PROTOCOL)
with open(osp.join(coco._data_path, 'ImageSets', split + '.txt'), 'w') as f:
for i in range(coco.num_images):
filename = (coco.image_path_at(i).split('/')[-1]).split('.')[0]
if i != coco.num_images - 1: filename += '\n'
f.write(filename)
def merge_mask(split, year, mask_files):
gt_recs = OrderedDict()
data_path = os.path.dirname(mask_files[0])
for mask_file in mask_files:
with open(mask_file, 'rb') as f:
recs = cPickle.load(f)
gt_recs.update(recs)
with open(osp.join(data_path,
'coco_' + year + '_' + split + '_mask.pkl'), 'wb') as f:
cPickle.dump(gt_recs, f, cPickle.HIGHEST_PROTOCOL)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import sys
import time
import json
import cv2
from dragon.tools.db import LMDB, wrapper_str
sys.path.insert(0, '../../../')
import database.mrcnn.utils.anno_pb2 as pb
IMAGE_INFO = '/data/image_info_test-dev2017.json'
def load_image_list(image_info):
num_images = len(image_info['images'])
image_list = []
print('The split has {} images.'.format(num_images))
for image in image_info['images']:
image_list.append(image['file_name'])
return image_list
def make_datum(image_file):
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = True
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = os.path.split(image_file)[-1]
return anno_datum
def make_db(database_file, images_path, image_list):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
start_time = time.time()
zfill_flag = '{0:0%d}' % (8)
for image_file in image_list:
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, len(image_list), now_time - start_time))
db.commit()
datum = make_datum(os.path.join(images_path, image_file))
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, len(image_list), now_time - start_time))
db.commit()
db.close()
end_time = time.time()
print('{0} images have been stored in the database.'.format(len(image_list)))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
if __name__ == '__main__':
image_info = json.load(open(IMAGE_INFO, 'r'))
image_list = load_image_list(image_info)
make_db('/data/coco_2017_test-dev_lmdb',
'/data/test2017', image_list)
\ No newline at end of file
# --------------------------------------------------------
# FPN @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
from .make import set_zfill, set_quality, make_db
\ No newline at end of file
syntax = "proto2";
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;
optional int32 label = 5;
repeated float float_data = 6;
optional bool encoded = 7 [default = false];
repeated int32 labels = 8;
}
message Annotation {
optional float x1 = 1;
optional float y1 = 2;
optional float x2 = 3;
optional float y2 = 4;
optional string name = 5;
optional bool difficult = 6 [default = false];
optional string mask = 7;
}
message AnnotatedDatum {
optional Datum datum = 1;
optional string filename = 2;
repeated Annotation annotation = 3;
}
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: anno.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='anno.proto',
package='',
serialized_pb=_b('\n\nanno.proto\"\x91\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x0e\n\x06labels\x18\x08 \x03(\x05\"r\n\nAnnotation\x12\n\n\x02x1\x18\x01 \x01(\x02\x12\n\n\x02y1\x18\x02 \x01(\x02\x12\n\n\x02x2\x18\x03 \x01(\x02\x12\n\n\x02y2\x18\x04 \x01(\x02\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x18\n\tdifficult\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x0c\n\x04mask\x18\x07 \x01(\t\"Z\n\x0e\x41nnotatedDatum\x12\x15\n\x05\x64\x61tum\x18\x01 \x01(\x0b\x32\x06.Datum\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\x12\x1f\n\nannotation\x18\x03 \x03(\x0b\x32\x0b.Annotation')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATUM = _descriptor.Descriptor(
name='Datum',
full_name='Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='channels', full_name='Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='height', full_name='Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='width', full_name='Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data', full_name='Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='label', full_name='Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='encoded', full_name='Datum.encoded', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='labels', full_name='Datum.labels', index=7,
number=8, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=15,
serialized_end=160,
)
_ANNOTATION = _descriptor.Descriptor(
name='Annotation',
full_name='Annotation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='x1', full_name='Annotation.x1', index=0,
number=1, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y1', full_name='Annotation.y1', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='x2', full_name='Annotation.x2', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y2', full_name='Annotation.y2', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='Annotation.name', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='difficult', full_name='Annotation.difficult', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mask', full_name='Annotation.mask', index=6,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=162,
serialized_end=276,
)
_ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum',
full_name='AnnotatedDatum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='datum', full_name='AnnotatedDatum.datum', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='filename', full_name='AnnotatedDatum.filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='annotation', full_name='AnnotatedDatum.annotation', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=278,
serialized_end=368,
)
_ANNOTATEDDATUM.fields_by_name['datum'].message_type = _DATUM
_ANNOTATEDDATUM.fields_by_name['annotation'].message_type = _ANNOTATION
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION
DESCRIPTOR.message_types_by_name['AnnotatedDatum'] = _ANNOTATEDDATUM
Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
DESCRIPTOR = _DATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Datum)
))
_sym_db.RegisterMessage(Datum)
Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATION,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Annotation)
))
_sym_db.RegisterMessage(Annotation)
AnnotatedDatum = _reflection.GeneratedProtocolMessageType('AnnotatedDatum', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATEDDATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:AnnotatedDatum)
))
_sym_db.RegisterMessage(AnnotatedDatum)
# @@protoc_insertion_point(module_scope)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import time
import cv2
try:
import cPickle
except:
import pickle as cPickle
from dragon.tools.db import LMDB
from . import anno_pb2 as pb
ZFILL = 8
ENCODE_QUALITY = 95
def set_zfill(value):
global ZFILL
ZFILL = value
def set_quality(value):
global ENCODE_QUALITY
ENCODE_QUALITY = value
def make_datum(image_file, mask_objects, im_scale=None):
filename = os.path.split(image_file)[-1]
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
if im_scale: im = cv2.resize(im, None,
fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = ENCODE_QUALITY != 100
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), ENCODE_QUALITY])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = filename.split('.')[0]
for ix, obj in enumerate(mask_objects):
anno = pb.Annotation()
x1, y1, x2, y2 = obj['bbox']
anno.name = obj['name']
anno.x1, anno.y1, anno.x2, anno.y2 = x1, y1, x2, y2
if 'difficult' in obj: anno.difficult = obj['difficult']
if 'crowd' in obj: anno.difficult = obj['crowd']
anno.mask = obj['mask']
anno_datum.annotation.add().CopyFrom(anno)
return anno_datum
def make_db(database_file, images_path, mask_file,
splits_path, splits, ext='.jpg', im_scale=None):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
if not isinstance(images_path, list):
images_path = [images_path]
if not isinstance(splits_path, list):
splits_path = [splits_path]
assert len(splits) == len(splits_path)
assert len(splits) == len(images_path)
if mask_file is not None:
with open(mask_file, 'rb') as f:
all_masks = cPickle.load(f)
else:
all_masks = {}
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
total_line = 0
start_time = time.time()
zfill_flag = '{0:0%d}' % (ZFILL)
for db_idx, split in enumerate(splits):
split_file = os.path.join(splits_path[db_idx], split + '.txt')
assert os.path.exists(split_file)
with open(split_file, 'r') as f:
lines = f.readlines()
total_line += len(lines)
for line in lines:
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, total_line, now_time - start_time))
db.commit()
filename = line.strip()
image_file = os.path.join(images_path[db_idx], filename + ext)
mask_objects = all_masks[filename] if filename in all_masks else None
if mask_objects is None:
raise ValueError('The image({}) takes invalid mask settings.'.format(filename))
datum = make_datum(image_file, mask_objects, im_scale)
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""A simple process pool to map tasks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
class ProcessPool(object):
def __init__(self, num_processes=8, max_qsize=100):
self.num_tasks = self.fetch_tasks = 0
self.num_processes = num_processes
self.Q = multiprocessing.Queue(max_qsize)
def __enter__(self):
return self
def __exit__(self, *excinfo):
pass
def map(self, tasks, func):
n_tasks_each = int(len(tasks) / self.num_processes)
remain_tasks = len(tasks) - n_tasks_each * self.num_processes
pos = 0
for i in range(self.num_processes):
if i != self.num_processes - 1:
work_set = tasks[pos: pos + n_tasks_each]
pos += n_tasks_each
else:
work_set = tasks[pos: pos + n_tasks_each + remain_tasks]
print('[Main]: Process #{} Got {} tasks.'.format(i, len(work_set)))
p = multiprocessing.Process(target=func, args=(work_set, self.Q))
p.start()
def wait(self):
displays = {}
while True:
qsize = self.Q.qsize()
if qsize == self.num_tasks: break
if qsize > 0 and qsize % 100 == 0:
if qsize not in displays:
displays[qsize] = True
print('[Queue]: Cached {} tasks.'.format(qsize))
outputs = []
while self.Q.qsize() > 0:
outputs.append(self.Q.get())
assert len(outputs) == self.num_tasks
print('[Main]: Got {} outputs.'.format(len(outputs)))
return outputs
def get(self):
self.fetch_tasks += 1
if self.fetch_tasks > self.num_tasks:
return None
return self.Q.get()
def run_all(self, tasks, func):
self.num_tasks = len(tasks)
self.map(tasks, func)
self.wait()
def run(self, tasks, func):
self.num_tasks = len(tasks)
self.map(tasks, func)
\ No newline at end of file
......@@ -292,6 +292,9 @@ __C.RETINANET.SOFTMAX = False
__C.FPN = edict()
# Channel dimension of the FPN feature levels
__C.FPN.DIM = 256
# Coarsest level of the FPN pyramid
__C.FPN.RPN_MAX_LEVEL = 6
# Finest level of the FPN pyramid
......
......@@ -63,16 +63,16 @@ class Coordinator(object):
steps = []
for ix, file in enumerate(files):
step = int(file.split('_iter_')[-1].split('.')[0])
if global_step == step: return os.path.join(self.checkpoints_dir(), files[ix])
if global_step == step:
return os.path.join(self.checkpoints_dir(), files[ix]), step
steps.append(step)
if global_step is None:
if len(files) == 0:
raise ValueError('Dir({}) is empty.'.format(self.checkpoints_dir()))
if len(files) == 0: return None, 0
last_idx = int(np.argmax(steps)); last_step = steps[last_idx]
return os.path.join(self.checkpoints_dir(), files[last_idx])
return None
return os.path.join(self.checkpoints_dir(), files[last_idx]), last_step
return None, 0
result = locate()
while not result and wait:
while result[0] is None and wait:
print('\rWaiting for step_{}.checkpoint to exist...'.format(global_step), end='')
time.sleep(10)
result = locate()
......
......@@ -24,7 +24,7 @@ from lib.utils import logger
class Solver(object):
def __init__(self):
# Define the generic detector
self.detector = Detector().cuda(cfg.GPU_ID)
self.detector = Detector()
# Define the optimizer and its arguments
self.optimizer = None
self.opt_arguments = {
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
import os
import datetime
from collections import OrderedDict
import dragon.vm.torch as torch
from lib.core.config import cfg
......@@ -45,7 +46,9 @@ class SolverWrapper(object):
# Mixed precision training?
if cfg.MODEL.DATA_TYPE.lower() == 'float16':
self.solver.detector.half() # Powerful FP16 Support
self.solver.detector.half() # Powerful FP16 Support
self.solver.detector.cuda(cfg.GPU_ID)
# Plan the metrics
self.metrics = OrderedDict()
......
......@@ -13,7 +13,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import importlib
import dragon.vm.torch as torch
from collections import OrderedDict
......@@ -122,7 +121,7 @@ class Detector(torch.nn.Module):
outputs.update(
self.rpn(
features=features,
**inputs,
**inputs
)
)
outputs.update(
......@@ -130,7 +129,7 @@ class Detector(torch.nn.Module):
features=features,
rpn_cls_score=outputs['rpn_cls_score'],
rpn_bbox_pred=outputs['rpn_bbox_pred'],
**inputs,
**inputs
)
)
......@@ -139,7 +138,7 @@ class Detector(torch.nn.Module):
outputs.update(
self.retinanet(
features=features,
**inputs,
**inputs
)
)
......@@ -148,7 +147,7 @@ class Detector(torch.nn.Module):
outputs.update(
self.ssd(
features=features,
**inputs,
**inputs
)
)
......@@ -187,5 +186,11 @@ class Detector(torch.nn.Module):
term = torch.sqrt(e.running_var.data + e.eps)
term = e.weight.data / term
last_module.bias = e.bias.data - term * e.running_mean.data
last_module.weight.data.mul_(term)
if last_module.weight.dtype == 'float16':
last_module.bias.half_()
weight = last_module.weight.data.float()
weight.mul_(term)
last_module.weight.copy_(weight)
else:
last_module.weight.data.mul_(term)
last_module = e
\ No newline at end of file
......@@ -29,17 +29,19 @@ class FPN(torch.nn.Module):
super(FPN, self).__init__()
self.C = torch.nn.ModuleList()
self.P = torch.nn.ModuleList()
self.apply_func = self.apply_on_rcnn
for lvl in range(cfg.FPN.RPN_MIN_LEVEL, HIGHEST_BACKBONE_LVL + 1):
self.C.append(conv1x1(feature_dims[lvl - 1], 256, bias=True))
self.P.append(conv3x3(256, 256, bias=True))
self.C.append(conv1x1(feature_dims[lvl - 1], cfg.FPN.DIM, bias=True))
self.P.append(conv3x3(cfg.FPN.DIM, cfg.FPN.DIM, bias=True))
if 'retinanet' in cfg.MODEL.TYPE:
for lvl in range(HIGHEST_BACKBONE_LVL + 1, cfg.FPN.RPN_MAX_LEVEL + 1):
dim_in = feature_dims[-1] if lvl == HIGHEST_BACKBONE_LVL + 1 else 256
self.P.append(conv3x3(dim_in, 256, stride=2, bias=True))
dim_in = feature_dims[-1] if lvl == HIGHEST_BACKBONE_LVL + 1 else cfg.FPN.DIM
self.P.append(conv3x3(dim_in, cfg.FPN.DIM, stride=2, bias=True))
self.apply_func = self.apply_on_retinanet
self.relu = torch.nn.ReLU(inplace=False)
self.maxpool = torch.nn.MaxPool2d(1, 2, ceil_mode=True)
self.reset_parameters()
self.feature_dims = [256]
self.feature_dims = [cfg.FPN.DIM]
def reset_parameters(self):
for m in self.modules():
......@@ -51,7 +53,7 @@ class FPN(torch.nn.Module):
) # Xavier Initialization
torch.nn.init.constant_(m.bias, 0)
def apply_with_rcnn(self, features):
def apply_on_rcnn(self, features):
fpn_input = self.C[-1](features[-1])
min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL
outputs = [self.P[HIGHEST_BACKBONE_LVL- min_lvl](fpn_input)]
......@@ -70,7 +72,7 @@ class FPN(torch.nn.Module):
return outputs
def apply_with_retinanet(self, features):
def apply_on_retinanet(self, features):
fpn_input = self.C[-1](features[-1])
min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL
outputs = [self.P[HIGHEST_BACKBONE_LVL- min_lvl](fpn_input)]
......@@ -92,9 +94,4 @@ class FPN(torch.nn.Module):
return outputs
def forward(self, features):
if 'rcnn' in cfg.MODEL.TYPE:
return self.apply_with_rcnn(features)
elif 'retinanet' in cfg.MODEL.TYPE:
return self.apply_with_retinanet(features)
else:
raise NotImplementedError()
\ No newline at end of file
return self.apply_func(features)
\ No newline at end of file
......@@ -46,7 +46,7 @@ class RetinaNet(torch.nn.Module):
self.bbox_pred = conv3x3(dim_in, 4 * A, bias=True)
self.cls_prob = torch.nn.Softmax(dim=1, inplace=True) \
if cfg.RETINANET.SOFTMAX else torch.nn.Sigmoid(inplace=True)
self.relu = torch.nn.ReLU(inplace=True)
self.relu = torch.nn.ELU(inplace=True)
self.proposal_layer = ProposalLayer()
########################################
......
......@@ -60,11 +60,6 @@ class RPN(torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal_(m.weight, std=0.01)
torch.nn.init.constant_(m.bias, 0)
if cfg.MODEL.DATA_TYPE.lower() == 'float16':
# Zero the weights of linear layers for FP16
# Numerical stability is guaranteed
self.cls_score.weight.zero_()
self.bbox_pred.weight.zero_()
def compute_outputs(self, features):
"""Compute the RPN logits.
......
......@@ -78,10 +78,10 @@ class SSD(torch.nn.Module):
for i, feature in enumerate(features):
cls_score_wide.append(
self.cls_score[i](feature)
.permute((0, 2, 3, 1)).view(0, -1))
.permute(0, 2, 3, 1).view(0, -1))
bbox_pred_wide.append(
self.bbox_pred[i](feature)
.permute((0, 2, 3, 1)).view(0, -1))
.permute(0, 2, 3, 1).view(0, -1))
# Concat them if necessary
return torch.cat(cls_score_wide, dim=1).view(
......
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: anno.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='anno.proto',
package='',
serialized_pb=_b('\n\nanno.proto\"\x91\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x0e\n\x06labels\x18\x08 \x03(\x05\"r\n\nAnnotation\x12\n\n\x02x1\x18\x01 \x01(\x02\x12\n\n\x02y1\x18\x02 \x01(\x02\x12\n\n\x02x2\x18\x03 \x01(\x02\x12\n\n\x02y2\x18\x04 \x01(\x02\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x18\n\tdifficult\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x0c\n\x04mask\x18\x07 \x01(\t\"Z\n\x0e\x41nnotatedDatum\x12\x15\n\x05\x64\x61tum\x18\x01 \x01(\x0b\x32\x06.Datum\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\x12\x1f\n\nannotation\x18\x03 \x03(\x0b\x32\x0b.Annotation')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATUM = _descriptor.Descriptor(
name='Datum',
full_name='Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='channels', full_name='Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='height', full_name='Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='width', full_name='Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data', full_name='Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='label', full_name='Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='encoded', full_name='Datum.encoded', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='labels', full_name='Datum.labels', index=7,
number=8, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=15,
serialized_end=160,
)
_ANNOTATION = _descriptor.Descriptor(
name='Annotation',
full_name='Annotation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='x1', full_name='Annotation.x1', index=0,
number=1, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y1', full_name='Annotation.y1', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='x2', full_name='Annotation.x2', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y2', full_name='Annotation.y2', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='Annotation.name', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='difficult', full_name='Annotation.difficult', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mask', full_name='Annotation.mask', index=6,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=162,
serialized_end=276,
)
_ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum',
full_name='AnnotatedDatum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='datum', full_name='AnnotatedDatum.datum', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='filename', full_name='AnnotatedDatum.filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='annotation', full_name='AnnotatedDatum.annotation', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=278,
serialized_end=368,
)
_ANNOTATEDDATUM.fields_by_name['datum'].message_type = _DATUM
_ANNOTATEDDATUM.fields_by_name['annotation'].message_type = _ANNOTATION
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION
DESCRIPTOR.message_types_by_name['AnnotatedDatum'] = _ANNOTATEDDATUM
Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
DESCRIPTOR = _DATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Datum)
))
_sym_db.RegisterMessage(Datum)
Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATION,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Annotation)
))
_sym_db.RegisterMessage(Annotation)
AnnotatedDatum = _reflection.GeneratedProtocolMessageType('AnnotatedDatum', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATEDDATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:AnnotatedDatum)
))
_sym_db.RegisterMessage(AnnotatedDatum)
# @@protoc_insertion_point(module_scope)
......@@ -61,7 +61,7 @@ if __name__ == '__main__':
logger.info('Using config:\n' + pprint.pformat(cfg))
# Load the checkpoint and test engine
checkpoint = coordinator.checkpoint(global_step=None, wait=True)
checkpoint, _ = coordinator.checkpoint(global_step=None, wait=True)
# Ready to export the network
logger.info('Exporting model will be saved to `{:s}`'
......
......@@ -34,9 +34,6 @@ def parse_args():
parser.add_argument('--exp_dir', dest='exp_dir',
help='experiment dir',
default=None, type=str)
parser.add_argument('--resume', dest='resume',
help='resume training?',
action='store_true')
if len(sys.argv) == 1:
parser.print_help()
......@@ -55,11 +52,8 @@ if __name__ == '__main__':
.format(os.path.abspath(args.exp_dir)) if args.exp_dir else 'None')
coordinator = Coordinator(args.cfg_file, exp_dir=args.exp_dir)
start_iter = 0
if args.resume:
cfg.TRAIN.WEIGHTS, start_iter = \
coordinator.checkpoint(global_step=None)
checkpoint, start_iter = coordinator.checkpoint(wait=False)
if checkpoint is not None: cfg.TRAIN.WEIGHTS = checkpoint
# Setup MPI
if cfg.NUM_GPUS != mpi.Size():
......@@ -86,7 +80,7 @@ if __name__ == '__main__':
# Ready to train the network
logger.info('Output will be saved to `{:s}`'
.format(coordinator.checkpoints_dir()))
train_net(coordinator)
train_net(coordinator, start_iter)
# Finalize mpi
mpi.Finalize()
\ No newline at end of file
......@@ -57,7 +57,7 @@ if __name__ == '__main__':
logger.info('Using config:\n' + pprint.pformat(cfg))
# Load the checkpoint and test engine
checkpoint = coordinator.checkpoint(global_step=args.iter, wait=args.wait)
checkpoint, _ = coordinator.checkpoint(global_step=args.iter, wait=args.wait)
if checkpoint is None:
raise RuntimeError('The checkpoint of global step {} does not exist.'.format(args.iter))
test_engine = importlib.import_module('lib.{}.test'.format(cfg.MODEL.TYPE))
......
......@@ -68,7 +68,7 @@ def mpi_train(cfg_file, exp_dir):
for i, host in enumerate(cfg.HOSTS):
mpi_args += (host + ':{},'.format(cfg.NUM_GPUS // len(cfg.HOSTS)))
if i > 0: subprocess.call('scp -r {} {}:{}'.format(
osp.abspath(exp_dir), host, osp.abspath(exp_dir)), shell=True)
osp.abspath(exp_dir), host, osp.dirname(exp_dir)), shell=True)
return subprocess.call('{} {} {} {}'.format(
mpi_args, sys.executable, 'mpi_train.py', args), shell=True)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!