Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
SeetaDet
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit 4bcab266
authored
Apr 21, 2020
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Plan the queueing of testing images
1 parent
f4ecc7c7
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
81 additions
and
81 deletions
CHANGES
README.md
configs/retinanet/coco_retinanet_400_R-50-FPN.yml → configs/retinanet/coco_retinanet_416_R-50-FPN.yml
configs/retinanet/voc_retinanet_320_AirNet-FPN.yml
configs/retinanet/voc_retinanet_320_R-50-FPN.yml
seetadet/core/test_engine.py
seetadet/datasets/kpl_record.py
tools/export.py
tools/mpi_train.py
tools/test.py
tools/test_all.py
tools/train.py
CHANGES
View file @
4bcab26
------------------------------------------------------------------------
------------------------------------------------------------------------
The list of most significant changes made over time in SeetaDet.
The list of most significant changes made over time in SeetaDet.
SeetaDet 0.4.1 (20200421)
Dragon Minimum Required (Version 0.3.0.dev20200421)
Changes:
- Plan the queueing of testing images instead of reading them all.
Preview Features:
- None
Bugs fixed:
- None
------------------------------------------------------------------------
SeetaDet 0.4.0 (20200408)
SeetaDet 0.4.0 (20200408)
Dragon Minimum Required (Version 0.3.0.dev20200408)
Dragon Minimum Required (Version 0.3.0.dev20200408)
...
...
README.md
View file @
4bcab26
...
@@ -14,7 +14,7 @@ The torch-style codes help us to simplify the hierarchical pipeline of modern de
...
@@ -14,7 +14,7 @@ The torch-style codes help us to simplify the hierarchical pipeline of modern de
## Requirements
## Requirements
seeta-dragon >= 0.3.0.dev202004
08
seeta-dragon >= 0.3.0.dev202004
21
## Installation
## Installation
...
...
configs/retinanet/coco_retinanet_4
00
_R-50-FPN.yml
→
configs/retinanet/coco_retinanet_4
16
_R-50-FPN.yml
View file @
4bcab26
...
@@ -22,10 +22,10 @@ MODEL:
...
@@ -22,10 +22,10 @@ MODEL:
NUM_CLASSES
:
81
NUM_CLASSES
:
81
SOLVER
:
SOLVER
:
BASE_LR
:
0.01
BASE_LR
:
0.01
DECAY_STEPS
:
[
60000
,
8
0000
]
DECAY_STEPS
:
[
120000
,
16
0000
]
MAX_STEPS
:
9
0000
MAX_STEPS
:
18
0000
SNAPSHOT_EVERY
:
5000
SNAPSHOT_EVERY
:
5000
SNAPSHOT_PREFIX
:
coco_retinanet_4
00
SNAPSHOT_PREFIX
:
coco_retinanet_4
16
FPN
:
FPN
:
RPN_MIN_LEVEL
:
3
RPN_MIN_LEVEL
:
3
RPN_MAX_LEVEL
:
7
RPN_MAX_LEVEL
:
7
...
@@ -34,15 +34,13 @@ TRAIN:
...
@@ -34,15 +34,13 @@ TRAIN:
DATASET
:
'
/data/coco_2014_trainval35k'
DATASET
:
'
/data/coco_2014_trainval35k'
USE_DIFF
:
False
# Do not use crowd objects
USE_DIFF
:
False
# Do not use crowd objects
USE_COLOR_JITTER
:
True
USE_COLOR_JITTER
:
True
IMS_PER_BATCH
:
8
IMS_PER_BATCH
:
16
SCALES
:
[
400
]
SCALES
:
[
416
]
MAX_SIZE
:
666
RANDOM_SCALES
:
[
0.25
,
1.0
]
RANDOM_SCALES
:
[
0.75
,
1.0
]
TEST
:
TEST
:
DATASET
:
'
/data/coco_2014_minival'
DATASET
:
'
/data/coco_2014_minival'
JSON_FILE
:
'
/data/instances_minival2014.json'
JSON_FILE
:
'
/data/instances_minival2014.json'
PROTOCOL
:
'
coco'
PROTOCOL
:
'
coco'
IMS_PER_BATCH
:
1
IMS_PER_BATCH
:
1
SCALES
:
[
400
]
SCALES
:
[
416
]
MAX_SIZE
:
666
NMS
:
0.5
NMS
:
0.5
\ No newline at end of file
configs/retinanet/voc_retinanet_320_AirNet-FPN.yml
View file @
4bcab26
...
@@ -26,7 +26,7 @@ TRAIN:
...
@@ -26,7 +26,7 @@ TRAIN:
USE_COLOR_JITTER
:
True
USE_COLOR_JITTER
:
True
IMS_PER_BATCH
:
32
IMS_PER_BATCH
:
32
SCALES
:
[
320
]
SCALES
:
[
320
]
RANDOM_SCALES
:
[
0.5
,
1.0
]
RANDOM_SCALES
:
[
0.
2
5
,
1.0
]
TEST
:
TEST
:
DATASET
:
'
/data/voc_2007_test'
DATASET
:
'
/data/voc_2007_test'
PROTOCOL
:
'
voc2007'
# 'voc2007', 'voc2010', 'coco'
PROTOCOL
:
'
voc2007'
# 'voc2007', 'voc2010', 'coco'
...
...
configs/retinanet/voc_retinanet_320_R-50-FPN.yml
View file @
4bcab26
...
@@ -27,7 +27,7 @@ TRAIN:
...
@@ -27,7 +27,7 @@ TRAIN:
USE_COLOR_JITTER
:
True
USE_COLOR_JITTER
:
True
IMS_PER_BATCH
:
32
IMS_PER_BATCH
:
32
SCALES
:
[
320
]
SCALES
:
[
320
]
RANDOM_SCALES
:
[
0.5
,
2.0
]
RANDOM_SCALES
:
[
0.
2
5
,
2.0
]
TEST
:
TEST
:
DATASET
:
'
/data/voc_2007_test'
DATASET
:
'
/data/voc_2007_test'
PROTOCOL
:
'
voc2007'
# 'voc2007', 'voc2010', 'coco'
PROTOCOL
:
'
voc2007'
# 'voc2007', 'voc2010', 'coco'
...
...
seetadet/core/test_engine.py
View file @
4bcab26
...
@@ -23,7 +23,7 @@ from seetadet.utils import time_util
...
@@ -23,7 +23,7 @@ from seetadet.utils import time_util
from
seetadet.utils.vis
import
vis_one_image
from
seetadet.utils.vis
import
vis_one_image
def
run_test_net
(
checkpoint
,
server
,
devices
):
def
run_test_net
(
checkpoint
,
server
,
devices
,
read_every
=
1000
):
classes
=
server
.
classes
classes
=
server
.
classes
num_images
=
server
.
num_images
num_images
=
server
.
num_images
num_classes
=
server
.
num_classes
num_classes
=
server
.
num_classes
...
@@ -60,17 +60,21 @@ def run_test_net(checkpoint, server, devices):
...
@@ -60,17 +60,21 @@ def run_test_net(checkpoint, server, devices):
for
process
in
workers
:
for
process
in
workers
:
process
.
start
()
process
.
start
()
for
i
in
range
(
num_images
):
num_sends
=
0
image_id
,
raw_image
=
server
.
get_image
()
queues
[
i
%
num_workers
]
.
put
((
i
,
raw_image
))
# Hold the image until the visualization
if
cfg
.
VIS
or
cfg
.
VIS_ON_FILE
:
vis_image_dict
[
i
]
=
(
image_id
,
raw_image
)
for
i
in
range
(
num_workers
):
queues
[
i
]
.
put
((
-
1
,
None
))
for
count
in
range
(
num_images
):
for
count
in
range
(
num_images
):
if
count
>=
num_sends
:
num_to_send
=
min
(
read_every
,
num_images
-
num_sends
)
for
i
in
range
(
count
,
count
+
num_to_send
):
image_id
,
raw_image
=
server
.
get_image
()
queues
[
i
%
num_workers
]
.
put
((
i
,
raw_image
))
if
cfg
.
VIS
or
cfg
.
VIS_ON_FILE
:
vis_image_dict
[
i
]
=
(
image_id
,
raw_image
)
num_sends
+=
num_to_send
if
num_sends
==
num_images
:
for
i
in
range
(
num_workers
):
queues
[
i
]
.
put
((
-
1
,
None
))
i
,
time_diffs
,
results
=
queues
[
-
1
]
.
get
()
i
,
time_diffs
,
results
=
queues
[
-
1
]
.
get
()
# Unpack the diverse results
# Unpack the diverse results
...
...
seetadet/datasets/kpl_record.py
View file @
4bcab26
...
@@ -36,7 +36,7 @@ class KPLRecordDataset(Dataset):
...
@@ -36,7 +36,7 @@ class KPLRecordDataset(Dataset):
def
dump_detections
(
self
,
all_boxes
,
output_dir
):
def
dump_detections
(
self
,
all_boxes
,
output_dir
):
dataset
=
self
.
cls
(
self
.
source
)
dataset
=
self
.
cls
(
self
.
source
)
for
file
in
(
'
data.data'
,
'data.index'
,
'data
.meta'
):
for
file
in
(
'
root.data'
,
'root.index'
,
'root
.meta'
):
file
=
os
.
path
.
join
(
output_dir
,
file
)
file
=
os
.
path
.
join
(
output_dir
,
file
)
if
os
.
path
.
exists
(
file
):
if
os
.
path
.
exists
(
file
):
os
.
remove
(
file
)
os
.
remove
(
file
)
...
...
tools/export.py
View file @
4bcab26
...
@@ -13,14 +13,13 @@ from __future__ import absolute_import
...
@@ -13,14 +13,13 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
sys
import
sys
sys
.
path
.
insert
(
0
,
'..'
)
import
argparse
import
argparse
import
dragon.vm.torch
as
torch
import
dragon.vm.torch
as
torch
import
pprint
import
pprint
sys
.
path
.
insert
(
0
,
'..'
)
from
seetadet
import
onnx
as
_
from
seetadet
import
onnx
as
_
from
seetadet.core.config
import
cfg
from
seetadet.core.config
import
cfg
from
seetadet.core.coordinator
import
Coordinator
from
seetadet.core.coordinator
import
Coordinator
...
@@ -30,12 +29,14 @@ from seetadet.utils import logger
...
@@ -30,12 +29,14 @@ from seetadet.utils import logger
def
parse_args
():
def
parse_args
():
"""Parse input arguments"""
"""Parse input arguments"""
parser
=
argparse
.
ArgumentParser
(
description
=
'Export a Detection Network'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Export a detection network into the onnx model'
)
parser
.
add_argument
(
'--cfg'
,
dest
=
'cfg_file'
,
parser
.
add_argument
(
'--cfg'
,
dest
=
'cfg_file'
,
help
=
'optional config file'
,
default
=
None
,
type
=
str
)
help
=
'optional config file'
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--exp_dir'
,
dest
=
'exp_dir'
,
parser
.
add_argument
(
'--exp_dir'
,
dest
=
'exp_dir'
,
help
=
'experiment dir'
,
help
=
'experiment dir'
,
default
=
None
,
type
=
str
)
default
=
''
,
type
=
str
)
parser
.
add_argument
(
'--input_shape'
,
dest
=
'input_shape'
,
parser
.
add_argument
(
'--input_shape'
,
dest
=
'input_shape'
,
help
=
'The shape of dummy input'
,
help
=
'The shape of dummy input'
,
default
=
(
1
,
224
,
224
,
3
),
type
=
tuple
)
default
=
(
1
,
224
,
224
,
3
),
type
=
tuple
)
...
@@ -50,16 +51,7 @@ def parse_args():
...
@@ -50,16 +51,7 @@ def parse_args():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
logger
.
info
(
'Called with args:
\n
'
+
str
(
args
))
if
args
.
exp_dir
is
None
or
\
not
os
.
path
.
exists
(
args
.
exp_dir
):
raise
ValueError
(
'Excepted a existing experiment dir.
\n
Got {}.'
.
format
(
os
.
path
.
abspath
(
args
.
exp_dir
))
)
logger
.
info
(
'Called with args:'
)
logger
.
info
(
args
)
coordinator
=
Coordinator
(
args
.
cfg_file
,
exp_dir
=
args
.
exp_dir
)
coordinator
=
Coordinator
(
args
.
cfg_file
,
exp_dir
=
args
.
exp_dir
)
logger
.
info
(
'Using config:
\n
'
+
pprint
.
pformat
(
cfg
))
logger
.
info
(
'Using config:
\n
'
+
pprint
.
pformat
(
cfg
))
...
...
tools/mpi_train.py
View file @
4bcab26
...
@@ -13,14 +13,13 @@ from __future__ import absolute_import
...
@@ -13,14 +13,13 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
sys
import
sys
sys
.
path
.
insert
(
0
,
'..'
)
import
argparse
import
argparse
import
dragon
import
dragon
import
numpy
import
numpy
sys
.
path
.
insert
(
0
,
'..'
)
from
seetadet.core.config
import
cfg
from
seetadet.core.config
import
cfg
from
seetadet.core.coordinator
import
Coordinator
from
seetadet.core.coordinator
import
Coordinator
from
seetadet.core.train
import
train_net
from
seetadet.core.train
import
train_net
...
@@ -30,13 +29,14 @@ from seetadet.utils import logger
...
@@ -30,13 +29,14 @@ from seetadet.utils import logger
def
parse_args
():
def
parse_args
():
"""Parse input arguments."""
"""Parse input arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a Fast R-CNN network'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detection network with mpi utilities'
)
parser
.
add_argument
(
'--cfg'
,
dest
=
'cfg_file'
,
parser
.
add_argument
(
'--cfg'
,
dest
=
'cfg_file'
,
help
=
'config file'
,
help
=
'config file'
,
default
=
None
,
type
=
str
)
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--exp_dir'
,
dest
=
'exp_dir'
,
parser
.
add_argument
(
'--exp_dir'
,
dest
=
'exp_dir'
,
help
=
'experiment dir'
,
help
=
'experiment dir'
,
default
=
None
,
type
=
str
)
default
=
''
,
type
=
str
)
if
len
(
sys
.
argv
)
==
1
:
if
len
(
sys
.
argv
)
==
1
:
parser
.
print_help
()
parser
.
print_help
()
...
@@ -49,13 +49,6 @@ def parse_args():
...
@@ -49,13 +49,6 @@ def parse_args():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
if
args
.
exp_dir
is
None
or
\
not
os
.
path
.
exists
(
args
.
exp_dir
):
raise
ValueError
(
'Excepted a existing experiment dir.
\n
Got {}.'
.
format
(
os
.
path
.
abspath
(
args
.
exp_dir
))
)
coordinator
=
Coordinator
(
args
.
cfg_file
,
exp_dir
=
args
.
exp_dir
)
coordinator
=
Coordinator
(
args
.
cfg_file
,
exp_dir
=
args
.
exp_dir
)
checkpoint
,
start_iter
=
coordinator
.
checkpoint
(
wait
=
False
)
checkpoint
,
start_iter
=
coordinator
.
checkpoint
(
wait
=
False
)
...
...
tools/test.py
View file @
4bcab26
...
@@ -13,13 +13,12 @@ from __future__ import absolute_import
...
@@ -13,13 +13,12 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
sys
import
sys
sys
.
path
.
insert
(
0
,
'..'
)
import
argparse
import
argparse
import
pprint
import
pprint
sys
.
path
.
insert
(
0
,
'..'
)
from
seetadet.core
import
test_engine
from
seetadet.core
import
test_engine
from
seetadet.core.config
import
cfg
from
seetadet.core.config
import
cfg
from
seetadet.core.coordinator
import
Coordinator
from
seetadet.core.coordinator
import
Coordinator
...
@@ -30,7 +29,8 @@ from seetadet.utils import logger
...
@@ -30,7 +29,8 @@ from seetadet.utils import logger
def
parse_args
():
def
parse_args
():
"""Parse input arguments"""
"""Parse input arguments"""
parser
=
argparse
.
ArgumentParser
(
description
=
'Test a Detection Network'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Test a detection network with a specified checkpoint'
)
parser
.
add_argument
(
'--gpus'
,
dest
=
'gpus'
,
parser
.
add_argument
(
'--gpus'
,
dest
=
'gpus'
,
help
=
'index of GPUs to use'
,
help
=
'index of GPUs to use'
,
default
=
None
,
nargs
=
'+'
,
type
=
int
)
default
=
None
,
nargs
=
'+'
,
type
=
int
)
...
@@ -39,12 +39,15 @@ def parse_args():
...
@@ -39,12 +39,15 @@ def parse_args():
default
=
None
,
type
=
str
)
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--exp_dir'
,
dest
=
'exp_dir'
,
parser
.
add_argument
(
'--exp_dir'
,
dest
=
'exp_dir'
,
help
=
'experiment dir'
,
help
=
'experiment dir'
,
default
=
None
,
type
=
str
)
default
=
''
,
type
=
str
)
parser
.
add_argument
(
'--output_dir'
,
dest
=
'output_dir'
,
parser
.
add_argument
(
'--output_dir'
,
dest
=
'output_dir'
,
help
=
'output dir'
,
help
=
'output dir'
,
default
=
None
,
type
=
str
)
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--iter'
,
dest
=
'iter'
,
help
=
'global step'
,
parser
.
add_argument
(
'--iter'
,
dest
=
'iter'
,
help
=
'global step'
,
default
=
None
,
type
=
int
)
default
=
None
,
type
=
int
)
parser
.
add_argument
(
'--read_every'
,
dest
=
'read_every'
,
help
=
'read every n images for testing'
,
default
=
1000
,
type
=
int
)
parser
.
add_argument
(
'--dump'
,
dest
=
'dump'
,
parser
.
add_argument
(
'--dump'
,
dest
=
'dump'
,
help
=
'dump the result back to record?'
,
help
=
'dump the result back to record?'
,
action
=
'store_true'
)
action
=
'store_true'
)
...
@@ -62,16 +65,7 @@ def parse_args():
...
@@ -62,16 +65,7 @@ def parse_args():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
logger
.
info
(
'Called with args:
\n
'
+
str
(
args
))
if
args
.
exp_dir
is
None
or
\
not
os
.
path
.
exists
(
args
.
exp_dir
):
raise
ValueError
(
'Excepted a existing experiment dir.
\n
Got {}.'
.
format
(
os
.
path
.
abspath
(
args
.
exp_dir
))
)
logger
.
info
(
'Called with args:'
)
logger
.
info
(
args
)
coordinator
=
Coordinator
(
args
.
cfg_file
,
args
.
exp_dir
)
coordinator
=
Coordinator
(
args
.
cfg_file
,
args
.
exp_dir
)
logger
.
info
(
'Using config:
\n
'
+
pprint
.
pformat
(
cfg
))
logger
.
info
(
'Using config:
\n
'
+
pprint
.
pformat
(
cfg
))
...
@@ -79,7 +73,10 @@ if __name__ == '__main__':
...
@@ -79,7 +73,10 @@ if __name__ == '__main__':
# Load the checkpoint and test engine
# Load the checkpoint and test engine
checkpoint
,
_
=
coordinator
.
checkpoint
(
args
.
iter
,
wait
=
args
.
wait
)
checkpoint
,
_
=
coordinator
.
checkpoint
(
args
.
iter
,
wait
=
args
.
wait
)
if
checkpoint
is
None
:
if
checkpoint
is
None
:
raise
RuntimeError
(
'The checkpoint of global step {} does not exist.'
.
format
(
args
.
iter
))
raise
RuntimeError
(
'The checkpoint of step {} does not exist.'
.
format
(
args
.
iter
)
)
# Inspect the dataset
# Inspect the dataset
dataset
=
get_dataset
(
cfg
.
TEST
.
DATASET
)
dataset
=
get_dataset
(
cfg
.
TEST
.
DATASET
)
...
@@ -93,4 +90,4 @@ if __name__ == '__main__':
...
@@ -93,4 +90,4 @@ if __name__ == '__main__':
# Bind the server and run the test
# Bind the server and run the test
server
=
TestServer
(
coordinator
.
results_dir
(
checkpoint
))
server
=
TestServer
(
coordinator
.
results_dir
(
checkpoint
))
test_engine
.
run_test_net
(
checkpoint
,
server
,
args
.
gpus
)
test_engine
.
run_test_net
(
checkpoint
,
server
,
args
.
gpus
,
args
.
read_every
)
tools/test_all.py
View file @
4bcab26
...
@@ -15,23 +15,25 @@ from __future__ import print_function
...
@@ -15,23 +15,25 @@ from __future__ import print_function
import
os
import
os
import
sys
import
sys
sys
.
path
.
insert
(
0
,
'..'
)
import
argparse
import
argparse
import
numpy
import
numpy
sys
.
path
.
insert
(
0
,
'..'
)
from
seetadet.core.coordinator
import
Coordinator
from
seetadet.core.coordinator
import
Coordinator
from
seetadet.utils
import
logger
from
seetadet.utils
import
logger
def
parse_args
():
def
parse_args
():
"""Parse input arguments"""
"""Parse input arguments"""
parser
=
argparse
.
ArgumentParser
(
description
=
'Test a Detection Network'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Test a detection network with all checkpoints'
)
parser
.
add_argument
(
'--cfg'
,
dest
=
'cfg_file'
,
parser
.
add_argument
(
'--cfg'
,
dest
=
'cfg_file'
,
help
=
'optional config file'
,
default
=
None
,
type
=
str
)
help
=
'optional config file'
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--exp_dir'
,
dest
=
'exp_dir'
,
parser
.
add_argument
(
'--exp_dir'
,
dest
=
'exp_dir'
,
help
=
'experiment dir'
,
help
=
'experiment dir'
,
default
=
None
,
type
=
str
)
default
=
''
,
type
=
str
)
if
len
(
sys
.
argv
)
==
1
:
if
len
(
sys
.
argv
)
==
1
:
parser
.
print_help
()
parser
.
print_help
()
...
@@ -52,13 +54,7 @@ def test(cfg_file, exp_dir, global_step):
...
@@ -52,13 +54,7 @@ def test(cfg_file, exp_dir, global_step):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
logger
.
info
(
'Called with args:
\n
'
+
str
(
args
))
if
args
.
exp_dir
is
None
or
\
not
os
.
path
.
exists
(
args
.
exp_dir
):
raise
ValueError
(
'Excepted a existing experiment dir.
\n
Got {}.'
.
format
(
os
.
path
.
abspath
(
args
.
exp_dir
))
)
coordinator
=
Coordinator
(
args
.
cfg_file
,
exp_dir
=
args
.
exp_dir
)
coordinator
=
Coordinator
(
args
.
cfg_file
,
exp_dir
=
args
.
exp_dir
)
...
@@ -66,7 +62,7 @@ if __name__ == '__main__':
...
@@ -66,7 +62,7 @@ if __name__ == '__main__':
files
=
os
.
listdir
(
coordinator
.
checkpoints_dir
())
files
=
os
.
listdir
(
coordinator
.
checkpoints_dir
())
for
file
in
files
:
for
file
in
files
:
step
=
int
(
file
.
split
(
'_iter_'
)[
-
1
]
.
split
(
b
'.'
)[
0
])
step
=
int
(
file
.
split
(
'_iter_'
)[
-
1
]
.
split
(
'.'
)[
0
])
global_steps
.
append
(
step
)
global_steps
.
append
(
step
)
order
=
numpy
.
argsort
(
-
numpy
.
array
(
global_steps
))
order
=
numpy
.
argsort
(
-
numpy
.
array
(
global_steps
))
...
...
tools/train.py
View file @
4bcab26
...
@@ -13,15 +13,15 @@ from __future__ import absolute_import
...
@@ -13,15 +13,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
sys
import
sys
sys
.
path
.
insert
(
0
,
'..'
)
import
os.path
as
osp
import
argparse
import
argparse
import
dragon
import
dragon
import
numpy
import
numpy
import
pprint
import
pprint
sys
.
path
.
insert
(
0
,
'..'
)
from
seetadet.core.config
import
cfg
from
seetadet.core.config
import
cfg
from
seetadet.core.coordinator
import
Coordinator
from
seetadet.core.coordinator
import
Coordinator
from
seetadet.core.train
import
train_net
from
seetadet.core.train
import
train_net
...
@@ -31,7 +31,8 @@ from seetadet.utils import logger
...
@@ -31,7 +31,8 @@ from seetadet.utils import logger
def
parse_args
():
def
parse_args
():
"""Parse input arguments."""
"""Parse input arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a Detection Network'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detection network'
)
parser
.
add_argument
(
'--cfg'
,
dest
=
'cfg_file'
,
parser
.
add_argument
(
'--cfg'
,
dest
=
'cfg_file'
,
help
=
'optional config file'
,
help
=
'optional config file'
,
default
=
None
,
type
=
str
)
default
=
None
,
type
=
str
)
...
@@ -61,7 +62,7 @@ def mpi_train(cfg_file, exp_dir):
...
@@ -61,7 +62,7 @@ def mpi_train(cfg_file, exp_dir):
import
subprocess
import
subprocess
args
=
'mpirun --allow-run-as-root -n {} --bind-to none '
.
format
(
cfg
.
NUM_GPUS
)
args
=
'mpirun --allow-run-as-root -n {} --bind-to none '
.
format
(
cfg
.
NUM_GPUS
)
args
+=
'{} {} '
.
format
(
sys
.
executable
,
'mpi_train.py'
)
args
+=
'{} {} '
.
format
(
sys
.
executable
,
'mpi_train.py'
)
args
+=
'--cfg {} --exp_dir {} '
.
format
(
os
p
.
abspath
(
cfg_file
),
exp_dir
)
args
+=
'--cfg {} --exp_dir {} '
.
format
(
os
.
path
.
abspath
(
cfg_file
),
exp_dir
)
return
subprocess
.
call
(
args
,
shell
=
True
)
return
subprocess
.
call
(
args
,
shell
=
True
)
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment