Commit c8535116 by Ting PAN

Compile proto files

1 parent 89f1ee28
------------------------------------------------------------------------ ------------------------------------------------------------------------
The list of most significant changes made over time in SeetaDet. The list of most significant changes made over time in SeetaDet.
SeetaDet 0.1.0 (20190311) SeetaDet 0.1.0 (20190314)
Dragon Minimum Required (Version 0.3.0.0) Dragon Minimum Required (Version 0.3.0.0)
......
# delete cache # delete cache
rm -r build install *.c *.cpp rm -r build install *.c *.cpp
# compile proto files
protoc -I ../lib/proto --python_out=../lib/proto ../lib/proto/anno.proto
# compile cython modules # compile cython modules
python setup.py build_ext --inplace python setup.py build_ext --inplace
......
...@@ -23,6 +23,7 @@ MODEL: ...@@ -23,6 +23,7 @@ MODEL:
SOLVER: SOLVER:
BASE_LR: 0.02 BASE_LR: 0.02
WEIGHT_DECAY: 0.0001 WEIGHT_DECAY: 0.0001
WARM_UP_ITERS: 2000 # default: 500
LR_POLICY: steps_with_decay LR_POLICY: steps_with_decay
STEPS: [120000, 160000] STEPS: [120000, 160000]
MAX_ITERS: 180000 MAX_ITERS: 180000
......
syntax = "proto2";
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;
optional int32 label = 5;
repeated float float_data = 6;
optional bool encoded = 7 [default = false];
}
message Annotation {
optional float x1 = 1;
optional float y1 = 2;
optional float x2 = 3;
optional float y2 = 4;
optional string name = 5;
optional bool difficult = 6 [default = false];
optional bool crowd = 7 [default = false];
optional string mask = 8;
}
message AnnotatedDatum {
optional Datum datum = 1;
optional string filename = 2;
repeated Annotation annotation = 3;
}
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: anno.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='anno.proto',
package='',
serialized_pb=_b('\n\nanno.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\"\x88\x01\n\nAnnotation\x12\n\n\x02x1\x18\x01 \x01(\x02\x12\n\n\x02y1\x18\x02 \x01(\x02\x12\n\n\x02x2\x18\x03 \x01(\x02\x12\n\n\x02y2\x18\x04 \x01(\x02\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x18\n\tdifficult\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x14\n\x05\x63rowd\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x0c\n\x04mask\x18\x08 \x01(\t\"Z\n\x0e\x41nnotatedDatum\x12\x15\n\x05\x64\x61tum\x18\x01 \x01(\x0b\x32\x06.Datum\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\x12\x1f\n\nannotation\x18\x03 \x03(\x0b\x32\x0b.Annotation')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATUM = _descriptor.Descriptor(
name='Datum',
full_name='Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='channels', full_name='Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='height', full_name='Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='width', full_name='Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data', full_name='Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='label', full_name='Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='encoded', full_name='Datum.encoded', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=15,
serialized_end=144,
)
_ANNOTATION = _descriptor.Descriptor(
name='Annotation',
full_name='Annotation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='x1', full_name='Annotation.x1', index=0,
number=1, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y1', full_name='Annotation.y1', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='x2', full_name='Annotation.x2', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y2', full_name='Annotation.y2', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='Annotation.name', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='difficult', full_name='Annotation.difficult', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='crowd', full_name='Annotation.crowd', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mask', full_name='Annotation.mask', index=7,
number=8, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=147,
serialized_end=283,
)
_ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum',
full_name='AnnotatedDatum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='datum', full_name='AnnotatedDatum.datum', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='filename', full_name='AnnotatedDatum.filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='annotation', full_name='AnnotatedDatum.annotation', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=285,
serialized_end=375,
)
_ANNOTATEDDATUM.fields_by_name['datum'].message_type = _DATUM
_ANNOTATEDDATUM.fields_by_name['annotation'].message_type = _ANNOTATION
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION
DESCRIPTOR.message_types_by_name['AnnotatedDatum'] = _ANNOTATEDDATUM
Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
DESCRIPTOR = _DATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Datum)
))
_sym_db.RegisterMessage(Datum)
Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATION,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Annotation)
))
_sym_db.RegisterMessage(Annotation)
AnnotatedDatum = _reflection.GeneratedProtocolMessageType('AnnotatedDatum', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATEDDATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:AnnotatedDatum)
))
_sym_db.RegisterMessage(AnnotatedDatum)
# @@protoc_insertion_point(module_scope)
...@@ -14,11 +14,14 @@ from __future__ import division ...@@ -14,11 +14,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import time import time
import cv2 import cv2
from . import anno_pb2 as pb
from dragon.tools.db import LMDB from dragon.tools.db import LMDB
sys.path.insert(0, '../../..')
from lib.proto import anno_pb2 as pb
ZFILL = 8 ZFILL = 8
ENCODE_QUALITY = 95 ENCODE_QUALITY = 95
...@@ -88,9 +91,6 @@ def make_db(database_file, images_path, gt_recs, ext='.png'): ...@@ -88,9 +91,6 @@ def make_db(database_file, images_path, gt_recs, ext='.png'):
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time)) print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit() db.commit()
db.close() db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time() end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line)) print('{0} images have been stored in the database.'.format(total_line))
......
...@@ -14,12 +14,14 @@ from __future__ import division ...@@ -14,12 +14,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import time import time
import cv2 import cv2
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from dragon.tools.db import LMDB from dragon.tools.db import LMDB
from . import anno_pb2 as pb sys.path.insert(0, '../../..')
from lib.proto import anno_pb2 as pb
ZFILL = 8 ZFILL = 8
ENCODE_QUALITY = 95 ENCODE_QUALITY = 95
...@@ -124,9 +126,6 @@ def make_db(database_file, ...@@ -124,9 +126,6 @@ def make_db(database_file,
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time)) print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit() db.commit()
db.close() db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time() end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line)) print('{0} images have been stored in the database.'.format(total_line))
......
# --------------------------------------------------------
# Detectron @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make LMDB for cityscape dataset."""
import os
import sys
import shutil
import numpy as np
np.random.seed(1337)
try:
import cPickle
except:
import pickle as cPickle
sys.path.insert(0, '../../../')
from database.mrcnn.utils.make import make_db
from database.mrcnn.cityscape.make_mask import make_mask
if __name__ == '__main__':
CITYSCAPE_ROOT = '/data/cityscape'
# make RLE masks
if not os.path.exists('build'): os.makedirs('build')
cs_train = make_mask(
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest'),
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists/train.lst'))
cs_val = make_mask(
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest'),
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists/val.lst'))
with open('build/cs_train_mask.pkl', 'wb') as f:
cPickle.dump(cs_train, f, cPickle.HIGHEST_PROTOCOL)
with open('build/cs_val_mask.pkl', 'wb') as f:
cPickle.dump(cs_val, f, cPickle.HIGHEST_PROTOCOL)
# make image splits
for split in ['train', 'val', 'test']:
with open(os.path.join(CITYSCAPE_ROOT,
'gtFine_trainvaltest/imglists', split + '.lst'), 'r') as f:
entries = [line.strip().split('\t') for line in f.readlines()]
if split == 'train': np.random.shuffle(entries)
with open(os.path.join(CITYSCAPE_ROOT,
'gtFine_trainvaltest/imglists', split + '.txt'), 'w') as w:
for entry in entries: w.write(entry[1].split('.')[0] + '\n')
# make database
make_db(database_file=os.path.join(CITYSCAPE_ROOT, 'cache/cs_train_lmdb'),
images_path=os.path.join(CITYSCAPE_ROOT, 'leftImg8bit_trainvaltest'),
mask_file='build/cs_train_mask.pkl',
splits_path=os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists'),
splits=['train'], ext='.png')
make_db(database_file=os.path.join(CITYSCAPE_ROOT, 'cache/cs_val_lmdb'),
images_path=os.path.join(CITYSCAPE_ROOT, 'leftImg8bit_trainvaltest'),
mask_file='build/cs_val_mask.pkl',
splits_path=os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists'),
splits=['val'], ext='.png')
# clean!
shutil.rmtree('build')
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make masks for cityscape dataset."""
import os
import sys
import cv2
from collections import OrderedDict
import PIL.Image as Image
import numpy as np
np.random.seed(1337)
sys.path.insert(0, '../../..')
from lib.pycocotools.mask_utils import mask_bin2rle
from database.mrcnn.utils.process_pool import ProcessPool
class_id = [0,
24, 25, 26, 27,
28, 31, 32, 33]
classes = ['__background__',
'person', 'rider', 'car', 'truck',
'bus', 'train', 'motorcycle', 'bicycle']
ind_to_class = dict(zip(range(len(classes)), classes))
def parse_gt(gt_file, im_scale=1.0):
im = Image.open(gt_file)
pixel = list(im.getdata())
pixel = np.array(pixel).reshape([im.size[1], im.size[0]])
objects = []
for c in range(1, len(class_id)):
px = np.where((pixel >= class_id[c] * 1000) & (pixel < (class_id[c] + 1) * 1000))
if len(px[0]) == 0: continue
uids = np.unique(pixel[px])
for idx, uid in enumerate(uids):
px = np.where(pixel == uid)
x1 = np.min(px[1])
y1 = np.min(px[0])
x2 = np.max(px[1])
y2 = np.max(px[0])
if x2 - x1 <= 1 or y2 - y1 <= 1: continue
mask = np.zeros([im.size[1], im.size[0]], dtype=np.uint8)
mask[px] = 1
if im_scale != 1:
mask = cv2.resize(mask, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_NEAREST)
x1 = min(int(x1 * im_scale), mask.shape[1])
y1 = min(int(y1 * im_scale), mask.shape[0])
x2 = min(int(x2 * im_scale), mask.shape[1])
y2 = min(int(y2 * im_scale), mask.shape[0])
objects.append({'bbox': [x1, y1, x2, y2],
'mask': mask_bin2rle([mask])[0],
'name': ind_to_class[c],
'difficult': False})
return objects
def map_func(gts, Q):
for image_id, gt_file in gts:
objects = parse_gt(gt_file)
Q.put((image_id, objects))
def make_mask(gt_root, split_file):
# Create tasks
gt_tasks, gt_recs = [], OrderedDict()
with open(split_file, 'r') as f:
for line in f:
_, image_path, gt_path = line.strip().split('\t')
image_id = image_path.split('.')[0]
gt_file = os.path.join(gt_root, gt_path.replace('labelTrainIds', 'instanceIds'))
gt_tasks.append((image_id, gt_file))
num_tasks = len(gt_tasks)
# Run!
with ProcessPool(16) as pool:
pool.run(gt_tasks, func=map_func)
for idx in range(num_tasks):
image_id, objects = pool.get()
gt_recs[image_id] = objects
print('\rProcess: {} / {}'.format(idx + 1, num_tasks), end='')
return gt_recs
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import cv2
from collections import defaultdict
from lib.pycocotools.mask_utils import mask_rle2im
CITYSCAPE_ROOT = '/data/cityscape'
def write_results(json_file, img_list):
with open(json_file, 'r') as f:
json_results = json.load(f)
class_id = [0, 24, 25, 26, 27, 28, 31, 32, 33]
category_id_to_class_id = dict(zip(range(9), class_id))
result_path = os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest', 'results', 'pred')
if not os.path.exists(result_path): os.makedirs(result_path)
counts = defaultdict(int)
txt_results = defaultdict(list)
for idx, rec in enumerate(json_results):
class_id = category_id_to_class_id[rec['category_id']]
if class_id == 0: continue
im_h, im_w = rec['segmentation']['size']
mask_rle = rec['segmentation']['counts']
mask_image = mask_rle2im([mask_rle], im_h, im_w)[0] * 200
image_name = rec['image_id'].split('_leftImg8bit')[0]
mask_name = image_name + '_' + str(counts[image_name]) + '.png'
counts[image_name] += 1
mask_path = os.path.join(result_path, mask_name)
cv2.imwrite(mask_path, mask_image)
txt_results[image_name].append((mask_name, class_id, rec['score']))
print('\rWriting masks ({} / {})'.format(idx + 1, len(json_results)), end='')
with open(img_list, 'r') as F:
for line in F.readlines():
image_name = line.strip().split('/')[-1].split('_leftImg8bit')[0]
txt_path = os.path.join(result_path, image_name + '.txt')
with open(txt_path, 'w') as f:
for rec in txt_results[image_name]:
f.write('{} {} {:.8f}\n'.format(rec[0], rec[1], rec[2]))
if __name__ == '__main__':
write_results(
'/results/segmentations.json',
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest', 'imglists', 'val.txt')
)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make LMDB for COCO dataset."""
import os
import sys
import shutil
sys.path.insert(0, '../../../')
from database.mrcnn.utils.make import make_db
from database.mrcnn.coco.make_mask import make_mask, merge_mask
if __name__ == '__main__':
COCO_ROOT = '/data/coco'
# make RLE masks
if not os.path.exists('build'): os.makedirs('build')
make_mask('train', '2014', COCO_ROOT)
make_mask('valminusminival', '2014', COCO_ROOT)
make_mask('minival', '2014', COCO_ROOT)
merge_mask('trainval35k', '2014', [
'build/coco_2014_train_mask.pkl',
'build/coco_2014_valminusminival_mask.pkl'])
# train database: coco_2014_trainval35k
make_db(database_file=os.path.join(COCO_ROOT, 'cache/coco_2014_trainval35k_lmdb'),
images_path=[os.path.join(COCO_ROOT, 'images/train2014'),
os.path.join(COCO_ROOT, 'images/val2014')],
splits_path=[os.path.join(COCO_ROOT, 'ImageSets'),
os.path.join(COCO_ROOT, 'ImageSets')],
mask_file='build/coco_2014_trainval35k_mask.pkl',
splits=['train', 'valminusminival'])
# val database: coco_2014_minival
make_db(database_file=os.path.join(COCO_ROOT, 'cache/coco_2014_minival_lmdb'),
images_path=os.path.join(COCO_ROOT, 'images/val2014'),
mask_file='build/coco_2014_minival_mask.pkl',
splits_path=os.path.join(COCO_ROOT, 'ImageSets'),
splits=['minival'])
# clean!
shutil.rmtree('build')
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import sys
import os.path as osp
from collections import OrderedDict
try:
import cPickle
except:
import pickle as cPickle
sys.path.insert(0, '../../..')
from lib.pycocotools.coco import COCO
from lib.pycocotools.mask_utils import mask_poly2rle
class imdb(object):
def __init__(self, image_set, year, data_dir):
self._year = year
self._image_set = image_set
self._data_path = osp.join(data_dir)
self.invalid_cnt = 0
self.ignore_cnt = 0
#################
# CLASSES #
#################
# load COCO API, classes, class <-> id mappings
self._COCO = COCO(self._get_ann_file())
cats = self._COCO.loadCats(self._COCO.getCatIds())
self._classes = tuple(['__background__'] + [c['name'] for c in cats])
self._class_to_ind = dict(zip(self._classes, range(self.num_classes)))
self._ind_to_class = dict(zip(range(self.num_classes), self._classes))
self._class_to_coco_cat_id = dict(zip([c['name'] for c in cats],
self._COCO.getCatIds()))
self._coco_cat_id_to_class_ind = dict([(self._class_to_coco_cat_id[cls],
self._class_to_ind[cls]) for cls in self._classes[1:]])
#################
# SET #
#################
self._view_map = {
'minival2014': 'val2014', # 5k val2014 subset
'valminusminival2014': 'val2014', # val2014 \setminus minival2014
}
coco_name = image_set + year # e.g., "val2014"
self._data_name = (self._view_map[coco_name]
if coco_name in self._view_map else coco_name)
#################
# IMAGES #
#################
self._image_index = self._load_image_set_index()
self._annotations = self._load_annotations()
def _get_ann_file(self):
prefix = 'instances' if self._image_set.find('test') == -1 \
else 'image_info'
return osp.join(self._data_path, 'annotations',
prefix + '_' + self._image_set + self._year + '.json')
def _load_image_set_index(self):
"""
Load image ids.
"""
image_ids = self._COCO.getImgIds()
return image_ids
def _load_annotations(self):
"""
Load annotations.
"""
annotations = [self._load_coco_annotation(index)
for index in self._image_index]
return annotations
def image_path_from_index(self, index):
"""
Construct an image path from the image's "index" identifier.
"""
# Example image path for index=119993:
# images/train2014/COCO_train2014_000000119993.jpg
file_name = ('COCO_' + self._data_name + '_' +
str(index).zfill(12) + '.jpg')
image_path = osp.join(self._data_path, 'images',
self._data_name, file_name)
assert osp.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
def image_path_at(self, i):
"""
Return the absolute path to image i in the image sequence.
"""
return self.image_path_from_index(self._image_index[i])
def annotation_at(self, i):
"""
Return the absolute path to image i in the image sequence.
"""
return self._annotations[i]
def _load_coco_annotation(self, index):
"""
Loads COCO bounding-box instance annotations. Crowd instances are
handled by marking their overlaps (with all categories) to -1. This
overlap value means that crowd "instances" are excluded from training.
"""
im_ann = self._COCO.loadImgs(index)[0]
width = im_ann['width']
height = im_ann['height']
annIds = self._COCO.getAnnIds(imgIds=index, iscrowd=None)
objs = self._COCO.loadAnns(annIds)
# Sanitize boxes -- some are invalid
valid_objs = []
for obj in objs:
x1 = int(max(0, obj['bbox'][0]))
y1 = int(max(0, obj['bbox'][1]))
x2 = int(min(width - 1, x1 + max(0, obj['bbox'][2] - 1)))
y2 = int(min(height - 1, y1 + max(0, obj['bbox'][3] - 1)))
if type(obj['segmentation']) is list:
for p in obj['segmentation']:
if len(p) < 6: print('Remove invalid segm.')
# Valid polygons have >= 3 points, so require >= 6 coordinates
obj['segmentation'] = [p for p in obj['segmentation'] if len(p) >= 6]
rle_masks = mask_poly2rle([obj['segmentation']], height, width)
else:
# crowd masks
rle_masks = [obj['segmentation']]
if obj['area'] > 0 and x2 > x1 and y2 > y1:
obj['clean_bbox'] = [x1, y1, x2, y2]
# Exclude the crowd masks
# TODO(PhyscalX): You may encounter crashes when decoding crowd masks.
mask = rle_masks[0] if not obj['iscrowd'] else ''
valid_objs.append(
{'bbox': [x1, y1, x2, y2],
'mask': mask,
'category_id': obj['category_id'],
'class_id': self._coco_cat_id_to_class_ind[obj['category_id']],
'crowd': obj['iscrowd']})
valid_objs[-1]['name'] = self._ind_to_class[valid_objs[-1]['class_id']]
return height, width, valid_objs
@property
def num_images(self):
return len(self._image_index)
@property
def num_classes(self):
return len(self._classes)
def make_mask(split, year, data_dir):
coco = imdb(split, year, data_dir)
print('Preparing to make split: {}, total {} images'.format(split, coco.num_images))
if not osp.exists(osp.join(coco._data_path, 'ImageSets')):
os.makedirs(osp.join(coco._data_path, 'ImageSets'))
gt_recs = OrderedDict()
for i in range(coco.num_images):
filename = (coco.image_path_at(i).split('/')[-1]).split('.')[0]
h, w, objects = coco.annotation_at(i)
gt_recs[filename] = objects
with open(osp.join('build',
'coco_' + year + '_' + split + '_mask.pkl'), 'wb') as f:
cPickle.dump(gt_recs, f, cPickle.HIGHEST_PROTOCOL)
with open(osp.join(coco._data_path, 'ImageSets', split + '.txt'), 'w') as f:
for i in range(coco.num_images):
filename = (coco.image_path_at(i).split('/')[-1]).split('.')[0]
if i != coco.num_images - 1: filename += '\n'
f.write(filename)
def merge_mask(split, year, mask_files):
gt_recs = OrderedDict()
data_path = os.path.dirname(mask_files[0])
for mask_file in mask_files:
with open(mask_file, 'rb') as f:
recs = cPickle.load(f)
gt_recs.update(recs)
with open(osp.join(data_path,
'coco_' + year + '_' + split + '_mask.pkl'), 'wb') as f:
cPickle.dump(gt_recs, f, cPickle.HIGHEST_PROTOCOL)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import sys
import time
import json
import cv2
from dragon.tools.db import LMDB, wrapper_str
sys.path.insert(0, '../../../')
import database.mrcnn.utils.anno_pb2 as pb
IMAGE_INFO = '/data/image_info_test-dev2017.json'
def load_image_list(image_info):
num_images = len(image_info['images'])
image_list = []
print('The split has {} images.'.format(num_images))
for image in image_info['images']:
image_list.append(image['file_name'])
return image_list
def make_datum(image_file):
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = True
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = os.path.split(image_file)[-1]
return anno_datum
def make_db(database_file, images_path, image_list):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
start_time = time.time()
zfill_flag = '{0:0%d}' % (8)
for image_file in image_list:
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, len(image_list), now_time - start_time))
db.commit()
datum = make_datum(os.path.join(images_path, image_file))
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, len(image_list), now_time - start_time))
db.commit()
db.close()
end_time = time.time()
print('{0} images have been stored in the database.'.format(len(image_list)))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
if __name__ == '__main__':
image_info = json.load(open(IMAGE_INFO, 'r'))
image_list = load_image_list(image_info)
make_db('/data/coco_2017_test-dev_lmdb',
'/data/test2017', image_list)
\ No newline at end of file
# --------------------------------------------------------
# FPN @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
from .make import set_zfill, set_quality, make_db
\ No newline at end of file
syntax = "proto2";
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;
optional int32 label = 5;
repeated float float_data = 6;
optional bool encoded = 7 [default = false];
repeated int32 labels = 8;
}
message Annotation {
optional float x1 = 1;
optional float y1 = 2;
optional float x2 = 3;
optional float y2 = 4;
optional string name = 5;
optional bool difficult = 6 [default = false];
optional string mask = 7;
}
message AnnotatedDatum {
optional Datum datum = 1;
optional string filename = 2;
repeated Annotation annotation = 3;
}
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: anno.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='anno.proto',
package='',
serialized_pb=_b('\n\nanno.proto\"\x91\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x0e\n\x06labels\x18\x08 \x03(\x05\"r\n\nAnnotation\x12\n\n\x02x1\x18\x01 \x01(\x02\x12\n\n\x02y1\x18\x02 \x01(\x02\x12\n\n\x02x2\x18\x03 \x01(\x02\x12\n\n\x02y2\x18\x04 \x01(\x02\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x18\n\tdifficult\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x0c\n\x04mask\x18\x07 \x01(\t\"Z\n\x0e\x41nnotatedDatum\x12\x15\n\x05\x64\x61tum\x18\x01 \x01(\x0b\x32\x06.Datum\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\x12\x1f\n\nannotation\x18\x03 \x03(\x0b\x32\x0b.Annotation')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATUM = _descriptor.Descriptor(
name='Datum',
full_name='Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='channels', full_name='Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='height', full_name='Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='width', full_name='Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data', full_name='Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='label', full_name='Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='encoded', full_name='Datum.encoded', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='labels', full_name='Datum.labels', index=7,
number=8, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=15,
serialized_end=160,
)
_ANNOTATION = _descriptor.Descriptor(
name='Annotation',
full_name='Annotation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='x1', full_name='Annotation.x1', index=0,
number=1, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y1', full_name='Annotation.y1', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='x2', full_name='Annotation.x2', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y2', full_name='Annotation.y2', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='Annotation.name', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='difficult', full_name='Annotation.difficult', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mask', full_name='Annotation.mask', index=6,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=162,
serialized_end=276,
)
_ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum',
full_name='AnnotatedDatum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='datum', full_name='AnnotatedDatum.datum', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='filename', full_name='AnnotatedDatum.filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='annotation', full_name='AnnotatedDatum.annotation', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=278,
serialized_end=368,
)
_ANNOTATEDDATUM.fields_by_name['datum'].message_type = _DATUM
_ANNOTATEDDATUM.fields_by_name['annotation'].message_type = _ANNOTATION
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION
DESCRIPTOR.message_types_by_name['AnnotatedDatum'] = _ANNOTATEDDATUM
Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
DESCRIPTOR = _DATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Datum)
))
_sym_db.RegisterMessage(Datum)
Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATION,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Annotation)
))
_sym_db.RegisterMessage(Annotation)
AnnotatedDatum = _reflection.GeneratedProtocolMessageType('AnnotatedDatum', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATEDDATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:AnnotatedDatum)
))
_sym_db.RegisterMessage(AnnotatedDatum)
# @@protoc_insertion_point(module_scope)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import time
import cv2
try:
import cPickle
except:
import pickle as cPickle
from dragon.tools.db import LMDB
from . import anno_pb2 as pb
ZFILL = 8
ENCODE_QUALITY = 95
def set_zfill(value):
global ZFILL
ZFILL = value
def set_quality(value):
global ENCODE_QUALITY
ENCODE_QUALITY = value
def make_datum(image_file, mask_objects, im_scale=None):
filename = os.path.split(image_file)[-1]
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
if im_scale: im = cv2.resize(im, None,
fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = ENCODE_QUALITY != 100
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), ENCODE_QUALITY])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = filename.split('.')[0]
for ix, obj in enumerate(mask_objects):
anno = pb.Annotation()
x1, y1, x2, y2 = obj['bbox']
anno.name = obj['name']
anno.x1, anno.y1, anno.x2, anno.y2 = x1, y1, x2, y2
if 'difficult' in obj: anno.difficult = obj['difficult']
if 'crowd' in obj: anno.difficult = obj['crowd']
anno.mask = obj['mask']
anno_datum.annotation.add().CopyFrom(anno)
return anno_datum
def make_db(database_file, images_path, mask_file,
splits_path, splits, ext='.jpg', im_scale=None):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
if not isinstance(images_path, list):
images_path = [images_path]
if not isinstance(splits_path, list):
splits_path = [splits_path]
assert len(splits) == len(splits_path)
assert len(splits) == len(images_path)
if mask_file is not None:
with open(mask_file, 'rb') as f:
all_masks = cPickle.load(f)
else:
all_masks = {}
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
total_line = 0
start_time = time.time()
zfill_flag = '{0:0%d}' % (ZFILL)
for db_idx, split in enumerate(splits):
split_file = os.path.join(splits_path[db_idx], split + '.txt')
assert os.path.exists(split_file)
with open(split_file, 'r') as f:
lines = f.readlines()
total_line += len(lines)
for line in lines:
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, total_line, now_time - start_time))
db.commit()
filename = line.strip()
image_file = os.path.join(images_path[db_idx], filename + ext)
mask_objects = all_masks[filename] if filename in all_masks else None
if mask_objects is None:
raise ValueError('The image({}) takes invalid mask settings.'.format(filename))
datum = make_datum(image_file, mask_objects, im_scale)
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""A simple process pool to map tasks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
class ProcessPool(object):
def __init__(self, num_processes=8, max_qsize=100):
self.num_tasks = self.fetch_tasks = 0
self.num_processes = num_processes
self.Q = multiprocessing.Queue(max_qsize)
def __enter__(self):
return self
def __exit__(self, *excinfo):
pass
def map(self, tasks, func):
n_tasks_each = int(len(tasks) / self.num_processes)
remain_tasks = len(tasks) - n_tasks_each * self.num_processes
pos = 0
for i in range(self.num_processes):
if i != self.num_processes - 1:
work_set = tasks[pos: pos + n_tasks_each]
pos += n_tasks_each
else:
work_set = tasks[pos: pos + n_tasks_each + remain_tasks]
print('[Main]: Process #{} Got {} tasks.'.format(i, len(work_set)))
p = multiprocessing.Process(target=func, args=(work_set, self.Q))
p.start()
def wait(self):
displays = {}
while True:
qsize = self.Q.qsize()
if qsize == self.num_tasks: break
if qsize > 0 and qsize % 100 == 0:
if qsize not in displays:
displays[qsize] = True
print('[Queue]: Cached {} tasks.'.format(qsize))
outputs = []
while self.Q.qsize() > 0:
outputs.append(self.Q.get())
assert len(outputs) == self.num_tasks
print('[Main]: Got {} outputs.'.format(len(outputs)))
return outputs
def get(self):
self.fetch_tasks += 1
if self.fetch_tasks > self.num_tasks:
return None
return self.Q.get()
def run_all(self, tasks, func):
self.num_tasks = len(tasks)
self.map(tasks, func)
self.wait()
def run(self, tasks, func):
self.num_tasks = len(tasks)
self.map(tasks, func)
\ No newline at end of file
...@@ -292,6 +292,9 @@ __C.RETINANET.SOFTMAX = False ...@@ -292,6 +292,9 @@ __C.RETINANET.SOFTMAX = False
__C.FPN = edict() __C.FPN = edict()
# Channel dimension of the FPN feature levels
__C.FPN.DIM = 256
# Coarsest level of the FPN pyramid # Coarsest level of the FPN pyramid
__C.FPN.RPN_MAX_LEVEL = 6 __C.FPN.RPN_MAX_LEVEL = 6
# Finest level of the FPN pyramid # Finest level of the FPN pyramid
......
...@@ -63,16 +63,16 @@ class Coordinator(object): ...@@ -63,16 +63,16 @@ class Coordinator(object):
steps = [] steps = []
for ix, file in enumerate(files): for ix, file in enumerate(files):
step = int(file.split('_iter_')[-1].split('.')[0]) step = int(file.split('_iter_')[-1].split('.')[0])
if global_step == step: return os.path.join(self.checkpoints_dir(), files[ix]) if global_step == step:
return os.path.join(self.checkpoints_dir(), files[ix]), step
steps.append(step) steps.append(step)
if global_step is None: if global_step is None:
if len(files) == 0: if len(files) == 0: return None, 0
raise ValueError('Dir({}) is empty.'.format(self.checkpoints_dir()))
last_idx = int(np.argmax(steps)); last_step = steps[last_idx] last_idx = int(np.argmax(steps)); last_step = steps[last_idx]
return os.path.join(self.checkpoints_dir(), files[last_idx]) return os.path.join(self.checkpoints_dir(), files[last_idx]), last_step
return None return None, 0
result = locate() result = locate()
while not result and wait: while result[0] is None and wait:
print('\rWaiting for step_{}.checkpoint to exist...'.format(global_step), end='') print('\rWaiting for step_{}.checkpoint to exist...'.format(global_step), end='')
time.sleep(10) time.sleep(10)
result = locate() result = locate()
......
...@@ -24,7 +24,7 @@ from lib.utils import logger ...@@ -24,7 +24,7 @@ from lib.utils import logger
class Solver(object): class Solver(object):
def __init__(self): def __init__(self):
# Define the generic detector # Define the generic detector
self.detector = Detector().cuda(cfg.GPU_ID) self.detector = Detector()
# Define the optimizer and its arguments # Define the optimizer and its arguments
self.optimizer = None self.optimizer = None
self.opt_arguments = { self.opt_arguments = {
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
import datetime import datetime
from collections import OrderedDict from collections import OrderedDict
import dragon.vm.torch as torch import dragon.vm.torch as torch
from lib.core.config import cfg from lib.core.config import cfg
...@@ -47,6 +48,8 @@ class SolverWrapper(object): ...@@ -47,6 +48,8 @@ class SolverWrapper(object):
if cfg.MODEL.DATA_TYPE.lower() == 'float16': if cfg.MODEL.DATA_TYPE.lower() == 'float16':
self.solver.detector.half() # Powerful FP16 Support self.solver.detector.half() # Powerful FP16 Support
self.solver.detector.cuda(cfg.GPU_ID)
# Plan the metrics # Plan the metrics
self.metrics = OrderedDict() self.metrics = OrderedDict()
if cfg.ENABLE_TENSOR_BOARD: if cfg.ENABLE_TENSOR_BOARD:
......
...@@ -13,7 +13,6 @@ from __future__ import absolute_import ...@@ -13,7 +13,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import importlib import importlib
import dragon.vm.torch as torch import dragon.vm.torch as torch
from collections import OrderedDict from collections import OrderedDict
...@@ -122,7 +121,7 @@ class Detector(torch.nn.Module): ...@@ -122,7 +121,7 @@ class Detector(torch.nn.Module):
outputs.update( outputs.update(
self.rpn( self.rpn(
features=features, features=features,
**inputs, **inputs
) )
) )
outputs.update( outputs.update(
...@@ -130,7 +129,7 @@ class Detector(torch.nn.Module): ...@@ -130,7 +129,7 @@ class Detector(torch.nn.Module):
features=features, features=features,
rpn_cls_score=outputs['rpn_cls_score'], rpn_cls_score=outputs['rpn_cls_score'],
rpn_bbox_pred=outputs['rpn_bbox_pred'], rpn_bbox_pred=outputs['rpn_bbox_pred'],
**inputs, **inputs
) )
) )
...@@ -139,7 +138,7 @@ class Detector(torch.nn.Module): ...@@ -139,7 +138,7 @@ class Detector(torch.nn.Module):
outputs.update( outputs.update(
self.retinanet( self.retinanet(
features=features, features=features,
**inputs, **inputs
) )
) )
...@@ -148,7 +147,7 @@ class Detector(torch.nn.Module): ...@@ -148,7 +147,7 @@ class Detector(torch.nn.Module):
outputs.update( outputs.update(
self.ssd( self.ssd(
features=features, features=features,
**inputs, **inputs
) )
) )
...@@ -187,5 +186,11 @@ class Detector(torch.nn.Module): ...@@ -187,5 +186,11 @@ class Detector(torch.nn.Module):
term = torch.sqrt(e.running_var.data + e.eps) term = torch.sqrt(e.running_var.data + e.eps)
term = e.weight.data / term term = e.weight.data / term
last_module.bias = e.bias.data - term * e.running_mean.data last_module.bias = e.bias.data - term * e.running_mean.data
if last_module.weight.dtype == 'float16':
last_module.bias.half_()
weight = last_module.weight.data.float()
weight.mul_(term)
last_module.weight.copy_(weight)
else:
last_module.weight.data.mul_(term) last_module.weight.data.mul_(term)
last_module = e last_module = e
\ No newline at end of file
...@@ -29,17 +29,19 @@ class FPN(torch.nn.Module): ...@@ -29,17 +29,19 @@ class FPN(torch.nn.Module):
super(FPN, self).__init__() super(FPN, self).__init__()
self.C = torch.nn.ModuleList() self.C = torch.nn.ModuleList()
self.P = torch.nn.ModuleList() self.P = torch.nn.ModuleList()
self.apply_func = self.apply_on_rcnn
for lvl in range(cfg.FPN.RPN_MIN_LEVEL, HIGHEST_BACKBONE_LVL + 1): for lvl in range(cfg.FPN.RPN_MIN_LEVEL, HIGHEST_BACKBONE_LVL + 1):
self.C.append(conv1x1(feature_dims[lvl - 1], 256, bias=True)) self.C.append(conv1x1(feature_dims[lvl - 1], cfg.FPN.DIM, bias=True))
self.P.append(conv3x3(256, 256, bias=True)) self.P.append(conv3x3(cfg.FPN.DIM, cfg.FPN.DIM, bias=True))
if 'retinanet' in cfg.MODEL.TYPE: if 'retinanet' in cfg.MODEL.TYPE:
for lvl in range(HIGHEST_BACKBONE_LVL + 1, cfg.FPN.RPN_MAX_LEVEL + 1): for lvl in range(HIGHEST_BACKBONE_LVL + 1, cfg.FPN.RPN_MAX_LEVEL + 1):
dim_in = feature_dims[-1] if lvl == HIGHEST_BACKBONE_LVL + 1 else 256 dim_in = feature_dims[-1] if lvl == HIGHEST_BACKBONE_LVL + 1 else cfg.FPN.DIM
self.P.append(conv3x3(dim_in, 256, stride=2, bias=True)) self.P.append(conv3x3(dim_in, cfg.FPN.DIM, stride=2, bias=True))
self.apply_func = self.apply_on_retinanet
self.relu = torch.nn.ReLU(inplace=False) self.relu = torch.nn.ReLU(inplace=False)
self.maxpool = torch.nn.MaxPool2d(1, 2, ceil_mode=True) self.maxpool = torch.nn.MaxPool2d(1, 2, ceil_mode=True)
self.reset_parameters() self.reset_parameters()
self.feature_dims = [256] self.feature_dims = [cfg.FPN.DIM]
def reset_parameters(self): def reset_parameters(self):
for m in self.modules(): for m in self.modules():
...@@ -51,7 +53,7 @@ class FPN(torch.nn.Module): ...@@ -51,7 +53,7 @@ class FPN(torch.nn.Module):
) # Xavier Initialization ) # Xavier Initialization
torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.bias, 0)
def apply_with_rcnn(self, features): def apply_on_rcnn(self, features):
fpn_input = self.C[-1](features[-1]) fpn_input = self.C[-1](features[-1])
min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL
outputs = [self.P[HIGHEST_BACKBONE_LVL- min_lvl](fpn_input)] outputs = [self.P[HIGHEST_BACKBONE_LVL- min_lvl](fpn_input)]
...@@ -70,7 +72,7 @@ class FPN(torch.nn.Module): ...@@ -70,7 +72,7 @@ class FPN(torch.nn.Module):
return outputs return outputs
def apply_with_retinanet(self, features): def apply_on_retinanet(self, features):
fpn_input = self.C[-1](features[-1]) fpn_input = self.C[-1](features[-1])
min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL min_lvl, max_lvl = cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL
outputs = [self.P[HIGHEST_BACKBONE_LVL- min_lvl](fpn_input)] outputs = [self.P[HIGHEST_BACKBONE_LVL- min_lvl](fpn_input)]
...@@ -92,9 +94,4 @@ class FPN(torch.nn.Module): ...@@ -92,9 +94,4 @@ class FPN(torch.nn.Module):
return outputs return outputs
def forward(self, features): def forward(self, features):
if 'rcnn' in cfg.MODEL.TYPE: return self.apply_func(features)
return self.apply_with_rcnn(features) \ No newline at end of file
elif 'retinanet' in cfg.MODEL.TYPE:
return self.apply_with_retinanet(features)
else:
raise NotImplementedError()
\ No newline at end of file
...@@ -46,7 +46,7 @@ class RetinaNet(torch.nn.Module): ...@@ -46,7 +46,7 @@ class RetinaNet(torch.nn.Module):
self.bbox_pred = conv3x3(dim_in, 4 * A, bias=True) self.bbox_pred = conv3x3(dim_in, 4 * A, bias=True)
self.cls_prob = torch.nn.Softmax(dim=1, inplace=True) \ self.cls_prob = torch.nn.Softmax(dim=1, inplace=True) \
if cfg.RETINANET.SOFTMAX else torch.nn.Sigmoid(inplace=True) if cfg.RETINANET.SOFTMAX else torch.nn.Sigmoid(inplace=True)
self.relu = torch.nn.ReLU(inplace=True) self.relu = torch.nn.ELU(inplace=True)
self.proposal_layer = ProposalLayer() self.proposal_layer = ProposalLayer()
######################################## ########################################
......
...@@ -60,11 +60,6 @@ class RPN(torch.nn.Module): ...@@ -60,11 +60,6 @@ class RPN(torch.nn.Module):
if isinstance(m, torch.nn.Conv2d): if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal_(m.weight, std=0.01) torch.nn.init.normal_(m.weight, std=0.01)
torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.bias, 0)
if cfg.MODEL.DATA_TYPE.lower() == 'float16':
# Zero the weights of linear layers for FP16
# Numerical stability is guaranteed
self.cls_score.weight.zero_()
self.bbox_pred.weight.zero_()
def compute_outputs(self, features): def compute_outputs(self, features):
"""Compute the RPN logits. """Compute the RPN logits.
......
...@@ -78,10 +78,10 @@ class SSD(torch.nn.Module): ...@@ -78,10 +78,10 @@ class SSD(torch.nn.Module):
for i, feature in enumerate(features): for i, feature in enumerate(features):
cls_score_wide.append( cls_score_wide.append(
self.cls_score[i](feature) self.cls_score[i](feature)
.permute((0, 2, 3, 1)).view(0, -1)) .permute(0, 2, 3, 1).view(0, -1))
bbox_pred_wide.append( bbox_pred_wide.append(
self.bbox_pred[i](feature) self.bbox_pred[i](feature)
.permute((0, 2, 3, 1)).view(0, -1)) .permute(0, 2, 3, 1).view(0, -1))
# Concat them if necessary # Concat them if necessary
return torch.cat(cls_score_wide, dim=1).view( return torch.cat(cls_score_wide, dim=1).view(
......
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: anno.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='anno.proto',
package='',
serialized_pb=_b('\n\nanno.proto\"\x91\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x0e\n\x06labels\x18\x08 \x03(\x05\"r\n\nAnnotation\x12\n\n\x02x1\x18\x01 \x01(\x02\x12\n\n\x02y1\x18\x02 \x01(\x02\x12\n\n\x02x2\x18\x03 \x01(\x02\x12\n\n\x02y2\x18\x04 \x01(\x02\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x18\n\tdifficult\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x0c\n\x04mask\x18\x07 \x01(\t\"Z\n\x0e\x41nnotatedDatum\x12\x15\n\x05\x64\x61tum\x18\x01 \x01(\x0b\x32\x06.Datum\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\x12\x1f\n\nannotation\x18\x03 \x03(\x0b\x32\x0b.Annotation')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATUM = _descriptor.Descriptor(
name='Datum',
full_name='Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='channels', full_name='Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='height', full_name='Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='width', full_name='Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data', full_name='Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='label', full_name='Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='encoded', full_name='Datum.encoded', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='labels', full_name='Datum.labels', index=7,
number=8, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=15,
serialized_end=160,
)
_ANNOTATION = _descriptor.Descriptor(
name='Annotation',
full_name='Annotation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='x1', full_name='Annotation.x1', index=0,
number=1, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y1', full_name='Annotation.y1', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='x2', full_name='Annotation.x2', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y2', full_name='Annotation.y2', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='Annotation.name', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='difficult', full_name='Annotation.difficult', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mask', full_name='Annotation.mask', index=6,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=162,
serialized_end=276,
)
_ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum',
full_name='AnnotatedDatum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='datum', full_name='AnnotatedDatum.datum', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='filename', full_name='AnnotatedDatum.filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='annotation', full_name='AnnotatedDatum.annotation', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=278,
serialized_end=368,
)
_ANNOTATEDDATUM.fields_by_name['datum'].message_type = _DATUM
_ANNOTATEDDATUM.fields_by_name['annotation'].message_type = _ANNOTATION
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION
DESCRIPTOR.message_types_by_name['AnnotatedDatum'] = _ANNOTATEDDATUM
Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
DESCRIPTOR = _DATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Datum)
))
_sym_db.RegisterMessage(Datum)
Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATION,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Annotation)
))
_sym_db.RegisterMessage(Annotation)
AnnotatedDatum = _reflection.GeneratedProtocolMessageType('AnnotatedDatum', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATEDDATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:AnnotatedDatum)
))
_sym_db.RegisterMessage(AnnotatedDatum)
# @@protoc_insertion_point(module_scope)
...@@ -61,7 +61,7 @@ if __name__ == '__main__': ...@@ -61,7 +61,7 @@ if __name__ == '__main__':
logger.info('Using config:\n' + pprint.pformat(cfg)) logger.info('Using config:\n' + pprint.pformat(cfg))
# Load the checkpoint and test engine # Load the checkpoint and test engine
checkpoint = coordinator.checkpoint(global_step=None, wait=True) checkpoint, _ = coordinator.checkpoint(global_step=None, wait=True)
# Ready to export the network # Ready to export the network
logger.info('Exporting model will be saved to `{:s}`' logger.info('Exporting model will be saved to `{:s}`'
......
...@@ -34,9 +34,6 @@ def parse_args(): ...@@ -34,9 +34,6 @@ def parse_args():
parser.add_argument('--exp_dir', dest='exp_dir', parser.add_argument('--exp_dir', dest='exp_dir',
help='experiment dir', help='experiment dir',
default=None, type=str) default=None, type=str)
parser.add_argument('--resume', dest='resume',
help='resume training?',
action='store_true')
if len(sys.argv) == 1: if len(sys.argv) == 1:
parser.print_help() parser.print_help()
...@@ -55,11 +52,8 @@ if __name__ == '__main__': ...@@ -55,11 +52,8 @@ if __name__ == '__main__':
.format(os.path.abspath(args.exp_dir)) if args.exp_dir else 'None') .format(os.path.abspath(args.exp_dir)) if args.exp_dir else 'None')
coordinator = Coordinator(args.cfg_file, exp_dir=args.exp_dir) coordinator = Coordinator(args.cfg_file, exp_dir=args.exp_dir)
checkpoint, start_iter = coordinator.checkpoint(wait=False)
start_iter = 0 if checkpoint is not None: cfg.TRAIN.WEIGHTS = checkpoint
if args.resume:
cfg.TRAIN.WEIGHTS, start_iter = \
coordinator.checkpoint(global_step=None)
# Setup MPI # Setup MPI
if cfg.NUM_GPUS != mpi.Size(): if cfg.NUM_GPUS != mpi.Size():
...@@ -86,7 +80,7 @@ if __name__ == '__main__': ...@@ -86,7 +80,7 @@ if __name__ == '__main__':
# Ready to train the network # Ready to train the network
logger.info('Output will be saved to `{:s}`' logger.info('Output will be saved to `{:s}`'
.format(coordinator.checkpoints_dir())) .format(coordinator.checkpoints_dir()))
train_net(coordinator) train_net(coordinator, start_iter)
# Finalize mpi # Finalize mpi
mpi.Finalize() mpi.Finalize()
\ No newline at end of file
...@@ -57,7 +57,7 @@ if __name__ == '__main__': ...@@ -57,7 +57,7 @@ if __name__ == '__main__':
logger.info('Using config:\n' + pprint.pformat(cfg)) logger.info('Using config:\n' + pprint.pformat(cfg))
# Load the checkpoint and test engine # Load the checkpoint and test engine
checkpoint = coordinator.checkpoint(global_step=args.iter, wait=args.wait) checkpoint, _ = coordinator.checkpoint(global_step=args.iter, wait=args.wait)
if checkpoint is None: if checkpoint is None:
raise RuntimeError('The checkpoint of global step {} does not exist.'.format(args.iter)) raise RuntimeError('The checkpoint of global step {} does not exist.'.format(args.iter))
test_engine = importlib.import_module('lib.{}.test'.format(cfg.MODEL.TYPE)) test_engine = importlib.import_module('lib.{}.test'.format(cfg.MODEL.TYPE))
......
...@@ -68,7 +68,7 @@ def mpi_train(cfg_file, exp_dir): ...@@ -68,7 +68,7 @@ def mpi_train(cfg_file, exp_dir):
for i, host in enumerate(cfg.HOSTS): for i, host in enumerate(cfg.HOSTS):
mpi_args += (host + ':{},'.format(cfg.NUM_GPUS // len(cfg.HOSTS))) mpi_args += (host + ':{},'.format(cfg.NUM_GPUS // len(cfg.HOSTS)))
if i > 0: subprocess.call('scp -r {} {}:{}'.format( if i > 0: subprocess.call('scp -r {} {}:{}'.format(
osp.abspath(exp_dir), host, osp.abspath(exp_dir)), shell=True) osp.abspath(exp_dir), host, osp.dirname(exp_dir)), shell=True)
return subprocess.call('{} {} {} {}'.format( return subprocess.call('{} {} {} {}'.format(
mpi_args, sys.executable, 'mpi_train.py', args), shell=True) mpi_args, sys.executable, 'mpi_train.py', args), shell=True)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!