Commit 19c489b6 by Ting PAN

Init repository

0 parents
Showing with 4854 additions and 0 deletions
## General
# Compiled Object files
*.slo
*.lo
*.o
*.cuo
# Compiled Dynamic libraries
# *.so
*.dylib
# Compiled Static libraries
*.lai
*.la
#*.a
# Compiled python
*.pyc
__pycache__
# Compiled MATLAB
*.mex*
# IPython notebook checkpoints
.ipynb_checkpoints
# Editor temporaries
*.swp
*~
# Sublime Text settings
*.sublime-workspace
*.sublime-project
# Eclipse Project settings
*.*project
.settings
# QtCreator files
*.user
# PyCharm files
.idea
# OSX dir files
.DS_Store
\ No newline at end of file
------------------------------------------------------------------------
The list of most significant changes made over time in SeetaDet.
SeetaDet 0.1.0 (20190311)
Recommended docker for Dragon:
seetaresearch/dragon:0.3.0.0-rc4-cuda9.1-ubuntu16.04
Changes:
Preview Features:
- Init repository.
Bugs fixed:
- None
## SeetaDet
## WHAT's SeetaDet?
SeetaDet contains many useful object detectors, including R-CNN series, SSD,
and the recent RetinaNet. We have achieved the same or higher performance than
the baseline reported by the original paper.
This repository is based on our [Dragon](https://github.com/seetaresearch/Dragon),
while the style of codes is PyTorch. The torch-style codes help us to simplify the
hierarchical pipeline of modern detection.
## Installation
#### 1. Install the required python packages
```bash
pip install cython pyyaml matplotlib
pip install opencv-python Pillow
```
#### 2. Compile the C Extensions
```bash
cd SeeTADet/compile
bash ./make.sh
```
## Resources
#### Pre-trained ImageNet models
| Model | Usage |
| :------: | :------: |
| [VGG16.SSD](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/VGG16.SSD.pth)| SSD |
| [VGG16.RCNN](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/VGG16.RCNN.pth)| R-CNN |
| [R-50.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-50.Affine.pth)| R-CNN, RetinaNet |
| [R-101.Affine](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/R-101.Affine.pth)| R-CNN, RetinaNet |
| [AirNet.SSD](http://dragon.seetatech.com/download/models/SeetaDet/imagenet/AirNet.SSD.pth)| SSD |
## References
[1] [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497). Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun. NIPS, 2015.
[2] [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. CVPR, 2016.
[3] [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325). Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, and Alexander C. Berg. ECCV, 2016.
[4] [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144). Tsung-Yi Lin, Piotr Dollár, Ross Girshick, Kaiming He, Bharath Hariharan, and Serge Belongie. CVPR, 2017.
[5] [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002). Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, and Piotr Dollár. ICCV, 2017.
[6] [Mask R-CNN](https://arxiv.org/abs/1703.06870). Kaiming He, Georgia Gkioxari, Piotr Dollár and Ross Girshick. ICCV, 2017.
[7] [Detectron](https://github.com/facebookresearch/Detectron). Ross Girshick, Ilija Radosavovic, Georgia Gkioxari, Piotr Dollar and Kaiming He. 2018.
\ No newline at end of file
# - Find the NumPy libraries
# This module finds if NumPy is installed, and sets the following variables
# indicating where it is.
#
# TODO: Update to provide the libraries and paths for linking npymath lib.
#
# NUMPY_FOUND - was NumPy found
# NUMPY_VERSION - the version of NumPy found as a string
# NUMPY_VERSION_MAJOR - the major version number of NumPy
# NUMPY_VERSION_MINOR - the minor version number of NumPy
# NUMPY_VERSION_PATCH - the patch version number of NumPy
# NUMPY_VERSION_DECIMAL - e.g. version 1.6.1 is 10601
# NUMPY_INCLUDE_DIR - path to the NumPy include files
unset(NUMPY_VERSION)
unset(NUMPY_INCLUDE_DIR)
if(PYTHONINTERP_FOUND)
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
"import numpy as n; print(n.__version__); print(n.get_include());"
RESULT_VARIABLE __result
OUTPUT_VARIABLE __output
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(__result MATCHES 0)
string(REGEX REPLACE ";" "\\\\;" __values ${__output})
string(REGEX REPLACE "\r?\n" ";" __values ${__values})
list(GET __values 0 NUMPY_VERSION)
list(GET __values 1 NUMPY_INCLUDE_DIR)
string(REGEX MATCH "^([0-9])+\\.([0-9])+\\.([0-9])+" __ver_check "${NUMPY_VERSION}")
if(NOT "${__ver_check}" STREQUAL "")
set(NUMPY_VERSION_MAJOR ${CMAKE_MATCH_1})
set(NUMPY_VERSION_MINOR ${CMAKE_MATCH_2})
set(NUMPY_VERSION_PATCH ${CMAKE_MATCH_3})
math(EXPR NUMPY_VERSION_DECIMAL
"(${NUMPY_VERSION_MAJOR} * 10000) + (${NUMPY_VERSION_MINOR} * 100) + ${NUMPY_VERSION_PATCH}")
string(REGEX REPLACE "\\\\" "/" NUMPY_INCLUDE_DIR ${NUMPY_INCLUDE_DIR})
else()
unset(NUMPY_VERSION)
unset(NUMPY_INCLUDE_DIR)
message(STATUS "Requested NumPy version and include path, but got instead:\n${__output}\n")
endif()
endif()
else()
message("Can not find Python interpretator.")
message(FATAL_ERROR "Do you set PYTHON_EXECUTABLE correctly?")
endif()
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NumPy REQUIRED_VARS NUMPY_INCLUDE_DIR NUMPY_VERSION
VERSION_VAR NUMPY_VERSION)
if(NUMPY_FOUND)
message(STATUS "NumPy ver. ${NUMPY_VERSION} found (include: ${NUMPY_INCLUDE_DIR})")
endif()
\ No newline at end of file
# - Find python libraries
# This module finds the libraries corresponding to the Python interpeter
# FindPythonInterp provides.
# This code sets the following variables:
#
# PYTHONLIBS_FOUND - have the Python libs been found
# PYTHON_PREFIX - path to the Python installation
# PYTHON_LIBRARIES - path to the python library
# PYTHON_INCLUDE_DIRS - path to where Python.h is found
# PYTHON_MODULE_EXTENSION - lib extension, e.g. '.so' or '.pyd'
# PYTHON_MODULE_PREFIX - lib name prefix: usually an empty string
# PYTHON_SITE_PACKAGES - path to installation site-packages
# PYTHON_IS_DEBUG - whether the Python interpreter is a debug build
#
# Thanks to talljimbo for the patch adding the 'LDVERSION' config
# variable usage.
#=============================================================================
# Copyright 2001-2009 Kitware, Inc.
# Copyright 2012 Continuum Analytics, Inc.
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the names of Kitware, Inc., the Insight Software Consortium,
# nor the names of their contributors may be used to endorse or promote
# products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#=============================================================================
# Checking for the extension makes sure that `LibsNew` was found and not just `Libs`.
if(PYTHONLIBS_FOUND AND PYTHON_MODULE_EXTENSION)
return()
endif()
# Use the Python interpreter to find the libs.
if(PythonLibsNew_FIND_REQUIRED)
find_package(PythonInterp ${PythonLibsNew_FIND_VERSION} REQUIRED)
else()
find_package(PythonInterp ${PythonLibsNew_FIND_VERSION})
endif()
if(NOT PYTHONINTERP_FOUND)
set(PYTHONLIBS_FOUND FALSE)
return()
endif()
# According to http://stackoverflow.com/questions/646518/python-how-to-detect-debug-interpreter
# testing whether sys has the gettotalrefcount function is a reliable, cross-platform
# way to detect a CPython debug interpreter.
#
# The library suffix is from the config var LDVERSION sometimes, otherwise
# VERSION. VERSION will typically be like "2.7" on unix, and "27" on windows.
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
"from distutils import sysconfig as s;import sys;import struct;
print('.'.join(str(v) for v in sys.version_info));
print(sys.prefix);
print(s.get_python_inc(plat_specific=True));
print(s.get_python_lib(plat_specific=True));
print(s.get_config_var('SO'));
print(hasattr(sys, 'gettotalrefcount')+0);
print(struct.calcsize('@P'));
print(s.get_config_var('LDVERSION') or s.get_config_var('VERSION'));
print(s.get_config_var('LIBDIR') or '');
print(s.get_config_var('MULTIARCH') or '');
"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE _PYTHON_VALUES
ERROR_VARIABLE _PYTHON_ERROR_VALUE)
if(NOT _PYTHON_SUCCESS MATCHES 0)
if(PythonLibsNew_FIND_REQUIRED)
message(FATAL_ERROR
"Python config failure:\n${_PYTHON_ERROR_VALUE}")
endif()
set(PYTHONLIBS_FOUND FALSE)
return()
endif()
# Convert the process output into a list
string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
list(GET _PYTHON_VALUES 0 _PYTHON_VERSION_LIST)
list(GET _PYTHON_VALUES 1 PYTHON_PREFIX)
list(GET _PYTHON_VALUES 2 PYTHON_INCLUDE_DIR)
list(GET _PYTHON_VALUES 3 PYTHON_SITE_PACKAGES)
list(GET _PYTHON_VALUES 4 PYTHON_MODULE_EXTENSION)
list(GET _PYTHON_VALUES 5 PYTHON_IS_DEBUG)
list(GET _PYTHON_VALUES 6 PYTHON_SIZEOF_VOID_P)
list(GET _PYTHON_VALUES 7 PYTHON_LIBRARY_SUFFIX)
list(GET _PYTHON_VALUES 8 PYTHON_LIBDIR)
list(GET _PYTHON_VALUES 9 PYTHON_MULTIARCH)
# Make sure the Python has the same pointer-size as the chosen compiler
# Skip if CMAKE_SIZEOF_VOID_P is not defined
if(CMAKE_SIZEOF_VOID_P AND (NOT "${PYTHON_SIZEOF_VOID_P}" STREQUAL "${CMAKE_SIZEOF_VOID_P}"))
if(PythonLibsNew_FIND_REQUIRED)
math(EXPR _PYTHON_BITS "${PYTHON_SIZEOF_VOID_P} * 8")
math(EXPR _CMAKE_BITS "${CMAKE_SIZEOF_VOID_P} * 8")
message(FATAL_ERROR
"Python config failure: Python is ${_PYTHON_BITS}-bit, "
"chosen compiler is ${_CMAKE_BITS}-bit")
endif()
set(PYTHONLIBS_FOUND FALSE)
return()
endif()
# The built-in FindPython didn't always give the version numbers
string(REGEX REPLACE "\\." ";" _PYTHON_VERSION_LIST ${_PYTHON_VERSION_LIST})
list(GET _PYTHON_VERSION_LIST 0 PYTHON_VERSION_MAJOR)
list(GET _PYTHON_VERSION_LIST 1 PYTHON_VERSION_MINOR)
list(GET _PYTHON_VERSION_LIST 2 PYTHON_VERSION_PATCH)
# Make sure all directory separators are '/'
string(REGEX REPLACE "\\\\" "/" PYTHON_PREFIX ${PYTHON_PREFIX})
string(REGEX REPLACE "\\\\" "/" PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE_DIR})
string(REGEX REPLACE "\\\\" "/" PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES})
if(CMAKE_HOST_WIN32)
set(PYTHON_LIBRARY
"${PYTHON_PREFIX}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")
# when run in a venv, PYTHON_PREFIX points to it. But the libraries remain in the
# original python installation. They may be found relative to PYTHON_INCLUDE_DIR.
if(NOT EXISTS "${PYTHON_LIBRARY}")
get_filename_component(_PYTHON_ROOT ${PYTHON_INCLUDE_DIR} DIRECTORY)
set(PYTHON_LIBRARY
"${_PYTHON_ROOT}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")
endif()
# raise an error if the python libs are still not found.
if(NOT EXISTS "${PYTHON_LIBRARY}")
message(FATAL_ERROR "Python libraries not found")
endif()
else()
if(PYTHON_MULTIARCH)
set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}/${PYTHON_MULTIARCH}" "${PYTHON_LIBDIR}")
else()
set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}")
endif()
#message(STATUS "Searching for Python libs in ${_PYTHON_LIBS_SEARCH}")
# Probably this needs to be more involved. It would be nice if the config
# information the python interpreter itself gave us were more complete.
find_library(PYTHON_LIBRARY
NAMES "python${PYTHON_LIBRARY_SUFFIX}"
PATHS ${_PYTHON_LIBS_SEARCH}
NO_DEFAULT_PATH)
# If all else fails, just set the name/version and let the linker figure out the path.
if(NOT PYTHON_LIBRARY)
set(PYTHON_LIBRARY python${PYTHON_LIBRARY_SUFFIX})
endif()
endif()
MARK_AS_ADVANCED(
PYTHON_LIBRARY
PYTHON_INCLUDE_DIR
)
# We use PYTHON_INCLUDE_DIR, PYTHON_LIBRARY and PYTHON_DEBUG_LIBRARY for the
# cache entries because they are meant to specify the location of a single
# library. We now set the variables listed by the documentation for this
# module.
SET(PYTHON_INCLUDE_DIRS "${PYTHON_INCLUDE_DIR}")
SET(PYTHON_LIBRARIES "${PYTHON_LIBRARY}")
SET(PYTHON_DEBUG_LIBRARIES "${PYTHON_DEBUG_LIBRARY}")
find_package_message(PYTHON
"Found PythonLibs: ${PYTHON_LIBRARY}"
"${PYTHON_EXECUTABLE}${PYTHON_VERSION}")
set(PYTHONLIBS_FOUND TRUE)
PROJECT(gpu_nms)
CMAKE_MINIMUM_REQUIRED(VERSION 3.0.2)
# ---------------- User Config ----------------
# Set your python "interpreter" if necessary
# if not, a default interpreter will be used
# here, provide several examples:
# set(PYTHON_EXECUTABLE /usr/bin/python) # Linux & OSX, Builtin Python
# set(PYTHON_EXECUTABLE /X/anaconda/bin/python) # Linux & OSX, Anaconda
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda
# Set CUDA compiling architecture
# Remove "compute_70/sm_70" if using CUDA 8.0
set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
-gencode arch=compute_35,code=sm_35
-gencode arch=compute_50,code=sm_50
-gencode arch=compute_60,code=sm_60
-gencode arch=compute_70,code=sm_70)
# ---------------- User Config ----------------
# ---[ Dependencies
include(${PROJECT_SOURCE_DIR}/CMake/FindPythonLibs.cmake)
include(${PROJECT_SOURCE_DIR}/CMake/FindNumPy.cmake)
FIND_PACKAGE(CUDA REQUIRED)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
message(STATUS "C++11 support has been enabled by default.")
# ---[ Config types
set(CMAKE_BUILD_TYPE Release CACHE STRING "set build type to release")
set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release" FORCE)
# ---[ Includes
set(INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include)
include_directories(${INCLUDE_DIR})
include_directories(${PROJECT_SOURCE_DIR}/src)
include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${NUMPY_INCLUDE_DIR})
include_directories(${CUDA_INCLUDE_DIRS})
# ---[ libs
link_directories(${PYTHON_LIBRARIES})
# ---[ Install
set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" FORCE)
set(CMAKE_SHARED_LIBRARY_PREFIX "")
# ---[ Flags
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /O2 /Oi /GL /Ot /Gy")
endif()
if(UNIX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -s -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s -w -fPIC -O3 -m64 -std=c++11")
endif()
# ---[ Files
set(HEADER_FILES gpu_nms.h)
set(SRC_FILES gpu_nms.cpp nms_kernel.cu)
# ---[ Add Target
CUDA_ADD_LIBRARY(${PROJECT_NAME} SHARED ${HEADER_FILES} ${SRC_FILES})
# ---[ Link Libs
TARGET_LINK_LIBRARIES(${PROJECT_NAME} ${CUDA_LIBRARIES} ${CUDA_cublas_LIBRARY} ${CUDA_curand_LIBRARY})
if(WIN32)
TARGET_LINK_LIBRARIES(${PROJECT_NAME} ${PYTHON_LIBRARIES})
endif()
# ---[ Install Target
set_target_properties(${PROJECT_NAME} PROPERTIES OUTPUT_NAME "gpu_nms")
install (TARGETS ${PROJECT_NAME} DESTINATION ${PROJECT_BINARY_DIR}/../install/lib/nms)
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Sergey Karayev
# --------------------------------------------------------
cimport cython
import numpy as np
cimport numpy as np
DTYPE = np.float
ctypedef np.float_t DTYPE_t
@cython.boundscheck(False)
def bbox_overlaps(
np.ndarray[DTYPE_t, ndim=2] boxes,
np.ndarray[DTYPE_t, ndim=2] query_boxes):
"""
Parameters
----------
boxes: (N, 4) ndarray of float
query_boxes: (K, 4) ndarray of float
Returns
-------
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
cdef unsigned int N = boxes.shape[0]
cdef unsigned int K = query_boxes.shape[0]
cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE)
cdef DTYPE_t iw, ih, box_area
cdef DTYPE_t ua
cdef unsigned int k, n
with nogil:
for k in range(K):
box_area = (
(query_boxes[k, 2] - query_boxes[k, 0] + 1) *
(query_boxes[k, 3] - query_boxes[k, 1] + 1)
)
for n in range(N):
iw = (
min(boxes[n, 2], query_boxes[k, 2]) -
max(boxes[n, 0], query_boxes[k, 0]) + 1
)
if iw > 0:
ih = (
min(boxes[n, 3], query_boxes[k, 3]) -
max(boxes[n, 1], query_boxes[k, 1]) + 1
)
if ih > 0:
ua = float(
(boxes[n, 2] - boxes[n, 0] + 1) *
(boxes[n, 3] - boxes[n, 1] + 1) +
box_area - iw * ih
)
overlaps[n, k] = iw * ih / ua
return overlaps
\ No newline at end of file
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
cimport cython
import numpy as np
cimport numpy as np
cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
return a if a >= b else b
cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
return a if a <= b else b
@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
cdef np.ndarray[np.intp_t, ndim=1] order = scores.argsort()[::-1]
cdef int ndets = dets.shape[0]
cdef np.ndarray[np.int_t, ndim=1] suppressed = \
np.zeros((ndets), dtype=np.int)
# nominal indices
cdef int _i, _j
# sorted indices
cdef int i, j
# temp variables for box i's (the box currently under consideration)
cdef np.float32_t ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
cdef np.float32_t xx1, yy1, xx2, yy2
cdef np.float32_t w, h
cdef np.float32_t inter, ovr
keep = []
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
keep.append(i)
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
return keep
@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float thresh,
unsigned int method=0, float sigma=0.5, float score_thresh=0.001):
cdef unsigned int N = boxes.shape[0]
cdef float iw, ih, box_area
cdef float ua
cdef int pos = 0
cdef float maxscore = 0
cdef int maxpos = 0
cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
for i in range(N):
maxscore = boxes[i, 4]
maxpos = i
tx1 = boxes[i,0]
ty1 = boxes[i,1]
tx2 = boxes[i,2]
ty2 = boxes[i,3]
ts = boxes[i,4]
pos = i + 1
# get max box
while pos < N:
if maxscore < boxes[pos, 4]:
maxscore = boxes[pos, 4]
maxpos = pos
pos = pos + 1
# add max box as a detection
boxes[i,0] = boxes[maxpos,0]
boxes[i,1] = boxes[maxpos,1]
boxes[i,2] = boxes[maxpos,2]
boxes[i,3] = boxes[maxpos,3]
boxes[i,4] = boxes[maxpos,4]
# swap ith box with position of max box
boxes[maxpos,0] = tx1
boxes[maxpos,1] = ty1
boxes[maxpos,2] = tx2
boxes[maxpos,3] = ty2
boxes[maxpos,4] = ts
tx1 = boxes[i,0]
ty1 = boxes[i,1]
tx2 = boxes[i,2]
ty2 = boxes[i,3]
ts = boxes[i,4]
pos = i + 1
# NMS iterations, note that N changes if detection boxes fall below threshold
while pos < N:
x1 = boxes[pos, 0]
y1 = boxes[pos, 1]
x2 = boxes[pos, 2]
y2 = boxes[pos, 3]
s = boxes[pos, 4]
area = (x2 - x1 + 1) * (y2 - y1 + 1)
iw = (min(tx2, x2) - max(tx1, x1) + 1)
if iw > 0:
ih = (min(ty2, y2) - max(ty1, y1) + 1)
if ih > 0:
ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
ov = iw * ih / ua #iou between max box and detection box
if method == 1: # linear
if ov > thresh:
weight = 1 - ov
else:
weight = 1
elif method == 2: # gaussian
weight = np.exp(-(ov * ov) / sigma)
else: # original NMS
if ov > thresh:
weight = 0
else:
weight = 1
boxes[pos, 4] = weight * boxes[pos, 4]
# if box score falls below threshold, discard the box by swapping with last box
# update N
if boxes[pos, 4] < score_thresh:
boxes[pos,0] = boxes[N-1, 0]
boxes[pos,1] = boxes[N-1, 1]
boxes[pos,2] = boxes[N-1, 2]
boxes[pos,3] = boxes[N-1, 3]
boxes[pos,4] = boxes[N-1, 4]
N = N - 1
pos = pos - 1
pos = pos + 1
keep = [i for i in range(N)]
return keep
\ No newline at end of file
void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
int boxes_dim, float nms_overlap_thresh, int device_id);
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
cimport numpy as np
assert sizeof(int) == sizeof(np.int32_t)
cdef extern from "gpu_nms.h":
void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int)
def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, float thresh, int device_id=0):
cdef int boxes_num = dets.shape[0]
cdef int boxes_dim = dets.shape[1]
cdef int num_out
cdef np.ndarray[np.int32_t, ndim=1] \
keep = np.zeros(boxes_num, dtype=np.int32)
cdef np.ndarray[np.float32_t, ndim=1] \
scores = dets[:, 4]
cdef np.ndarray[np.intp_t, ndim=1] \
order = scores.argsort()[::-1]
cdef np.ndarray[np.float32_t, ndim=2] \
sorted_dets = dets[order, :]
_nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id)
keep = keep[:num_out]
return list(order[keep])
# delete cache
rm -r build install *.c *.cpp
# compile cython modules
python setup.py build_ext --inplace
# compile cuda modules
cd build
cmake .. && make install && cd ..
# setup
cp -r install/lib ../
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#include <vector>
#include "gpu_nms.h"
#define CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
cudaError_t error = condition; \
if (error != cudaSuccess) { \
\
} \
} while (0)
void SetDevice(int device_id) {
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
if (current_device == device_id) return;
CUDA_CHECK(cudaSetDevice(device_id));
}
#define DIV_UP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define NMS_BLOCK_SIZE 64
template <typename T>
__device__ T iou(const T* A, const T* B) {
const T x1 = max(A[0], B[0]);
const T y1 = max(A[1], B[1]);
const T x2 = min(A[2], B[2]);
const T y2 = min(A[3], B[3]);
const T width = max((T)0, x2 - x1 + 1);
const T height = max((T)0, y2 - y1 + 1);
const T area = width * height;
const T A_area = (A[2] - A[0] + 1) * (A[3] - A[1] + 1);
const T B_area = (B[2] - B[0] + 1) * (B[3] - B[1] + 1);
return area / (A_area + B_area - area);
}
template <typename T>
__global__ void nms_mask(const int num_boxes, const T nms_thresh,
const T* boxes, unsigned long long* mask) {
const int i_start = blockIdx.x * NMS_BLOCK_SIZE;
const int di_end = min(num_boxes - i_start, NMS_BLOCK_SIZE);
const int j_start = blockIdx.y * NMS_BLOCK_SIZE;
const int dj_end = min(num_boxes - j_start, NMS_BLOCK_SIZE);
const int num_blocks = DIV_UP(num_boxes, NMS_BLOCK_SIZE);
const int bid = blockIdx.x;
const int tid = threadIdx.x;
__shared__ T boxes_i[NMS_BLOCK_SIZE * 4];
if (tid < di_end) {
boxes_i[tid * 4 + 0] = boxes[(i_start + tid) * 5 + 0];
boxes_i[tid * 4 + 1] = boxes[(i_start + tid) * 5 + 1];
boxes_i[tid * 4 + 2] = boxes[(i_start + tid) * 5 + 2];
boxes_i[tid * 4 + 3] = boxes[(i_start + tid) * 5 + 3];
}
__syncthreads();
if (tid < dj_end) {
const T* const box_j = boxes + (j_start + tid) * 5;
unsigned long long mask_j = 0;
const int di_start = (i_start == j_start) ? (tid + 1) : 0;
for (int di = di_start; di < di_end; ++di)
if (iou(box_j, boxes_i + di * 4) > nms_thresh)
mask_j |= 1ULL << di;
mask[(j_start + tid) * num_blocks + bid] = mask_j;
}
}
template <typename T>
void ApplyNMS(const int num_boxes, const int max_keeps, const float thresh,
const T* boxes, int* keep_indices, int& num_keep) {
const int num_blocks = DIV_UP(num_boxes, NMS_BLOCK_SIZE);
const dim3 blocks(num_blocks, num_blocks);
size_t mask_nbytes = num_boxes * num_blocks * sizeof(unsigned long long);
size_t boxes_nbytes = num_boxes * 5 * sizeof(T);
void* boxes_dev, *mask_dev;
CUDA_CHECK(cudaMalloc(&boxes_dev, boxes_nbytes));
CUDA_CHECK(cudaMalloc(&mask_dev, mask_nbytes));
CUDA_CHECK(cudaMemcpy(boxes_dev, boxes, boxes_nbytes, cudaMemcpyHostToDevice));
nms_mask<T> << <blocks, NMS_BLOCK_SIZE >> > (num_boxes, thresh,
(T*)boxes_dev,
(unsigned long long*)mask_dev);
CUDA_CHECK(cudaPeekAtLastError());
std::vector<unsigned long long> mask_host(num_boxes * num_blocks);
CUDA_CHECK(cudaMemcpy(&mask_host[0], mask_dev, mask_nbytes, cudaMemcpyDeviceToHost));
std::vector<unsigned long long> dead_bit(num_blocks);
memset(&dead_bit[0], 0, sizeof(unsigned long long) * num_blocks);
int num_selected = 0;
for (int i = 0; i < num_boxes; ++i) {
const int nblock = i / NMS_BLOCK_SIZE;
const int inblock = i % NMS_BLOCK_SIZE;
if (!(dead_bit[nblock] & (1ULL << inblock))) {
keep_indices[num_selected++] = i;
unsigned long long* mask_i = &mask_host[0] + i * num_blocks;
for (int j = nblock; j < num_blocks; ++j) dead_bit[j] |= mask_i[j];
if (num_selected == max_keeps) break;
}
}
num_keep = num_selected;
CUDA_CHECK(cudaFree(mask_dev));
CUDA_CHECK(cudaFree(boxes_dev));
}
void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
int boxes_dim, float nms_overlap_thresh, int device_id) {
// set the device to use
SetDevice(device_id);
// apply gpu nms
ApplyNMS<float>(boxes_num, boxes_num, nms_overlap_thresh,
boxes_host, keep_out, *num_out);
}
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from distutils.extension import Extension
from distutils.core import setup
from Cython.Distutils import build_ext
import numpy as np
numpy_include = np.get_include()
ext_modules = [
Extension(
"install.lib.utils.cython_bbox",
["bbox.pyx"],
extra_compile_args=["-Wno-cpp", "-Wno-unused-function"],
include_dirs = [numpy_include]),
Extension(
"install.lib.nms.cpu_nms",
["cpu_nms.pyx"],
extra_compile_args=["-Wno-cpp", "-Wno-unused-function"],
include_dirs = [numpy_include]),
Extension(
"install.deprecated.gpu_nms",
["gpu_nms.pyx"],
extra_compile_args=["-Wno-cpp", "-Wno-unused-function"],
language='c++',
include_dirs = [numpy_include]),
Extension(
'install.lib.pycocotools._mask',
['../lib/pycocotools/maskApi.c', '../lib/pycocotools/_mask.pyx'],
include_dirs=[numpy_include, 'pycocotools'],
extra_compile_args=['-Wno-cpp', '-Wno-unused-function', '-std=c99']),
]
setup(name='Detectron',ext_modules=ext_modules,cmdclass = {'build_ext': build_ext})
NUM_GPUS: 8
VIS: False
ENABLE_TENSOR_BOARD: False
MODEL:
TYPE: faster_rcnn
BACKBONE: resnet101.fpn
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
NUM_CLASSES: 81
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
STEPS: [60000, 80000]
MAX_ITERS: 90000
SNAPSHOT_ITERS: 5000
SNAPSHOT_PREFIX: coco_faster_rcnn
FRCNN:
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
TRAIN:
WEIGHTS: '../data/imagenet_models/R-101.Affine.pth'
DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb'
IMS_PER_BATCH: 2
USE_DIFF: False # Do not use crowd objects
BATCH_SIZE: 512
SCALES: [800]
MAX_SIZE: 1333
TEST:
DATABASE: 'taas:/data/coco_2014_minival_lmdb'
JSON_FILE: '/data/instances_minival2014.json'
PROTOCOL: 'coco'
RPN_POST_NMS_TOP_N: 1000
SCALES: [800]
MAX_SIZE: 1333
NMS: 0.5
NUM_GPUS: 8
VIS: False
ENABLE_TENSOR_BOARD: False
MODEL:
TYPE: faster_rcnn
BACKBONE: resnet101.fpn
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
NUM_CLASSES: 81
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
STEPS: [120000, 160000]
MAX_ITERS: 180000
SNAPSHOT_ITERS: 5000
SNAPSHOT_PREFIX: coco_faster_rcnn
FRCNN:
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
TRAIN:
WEIGHTS: '../data/imagenet_models/R-101.Affine.pth'
DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb'
IMS_PER_BATCH: 2
USE_DIFF: False # Do not use crowd objects
BATCH_SIZE: 512
SCALES: [800]
MAX_SIZE: 1333
TEST:
DATABASE: 'taas:/data/coco_2014_minival_lmdb'
JSON_FILE: '/data/instances_minival2014.json'
PROTOCOL: 'coco'
RPN_POST_NMS_TOP_N: 1000
SCALES: [800]
MAX_SIZE: 1333
NMS: 0.5
NUM_GPUS: 1
VIS: False
ENABLE_TENSOR_BOARD: False
MODEL:
TYPE: faster_rcnn
BACKBONE: resnet50.fpn
CLASSES: ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
NUM_CLASSES: 21
SOLVER:
BASE_LR: 0.002
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
STEPS: [100000, 140000]
MAX_ITERS: 140000
SNAPSHOT_ITERS: 5000
SNAPSHOT_PREFIX: voc_faster_rcnn
FRCNN:
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
TRAIN:
WEIGHTS: '../data/imagenet_models/R-50.Affine.pth'
DATABASE: 'taas:/data/voc_0712_trainval_lmdb'
IMS_PER_BATCH: 2
BATCH_SIZE: 128
SCALES: [600]
MAX_SIZE: 1000
TEST:
DATABASE: 'taas:/data/voc_2007_test_lmdb'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
RPN_POST_NMS_TOP_N: 1000
SCALES: [600]
MAX_SIZE: 1000
NMS: 0.45
\ No newline at end of file
NUM_GPUS: 1
VIS: False
ENABLE_TENSOR_BOARD: False
MODEL:
TYPE: faster_rcnn
BACKBONE: vgg16.c4
CLASSES: ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
NUM_CLASSES: 21
SOLVER:
BASE_LR: 0.001
WEIGHT_DECAY: 0.0005
LR_POLICY: steps_with_decay
STEPS: [100000, 140000]
MAX_ITERS: 140000
SNAPSHOT_ITERS: 5000
SNAPSHOT_PREFIX: voc_faster_rcnn
RPN:
STRIDES: [16]
SCALES: [8, 16, 32] # RField: [128, 256, 512]
ASPECT_RATIOS: [0.5, 1.0, 2.0]
FRCNN:
ROI_XFORM_METHOD: RoIPool
ROI_XFORM_RESOLUTION: 7
MLP_HEAD_DIM: 4096
TRAIN:
WEIGHTS: '../data/imagenet_models/VGG16.RCNN.pth'
DATABASE: 'taas:/data/voc_0712_trainval_lmdb'
RPN_MIN_SIZE: 16
IMS_PER_BATCH: 2
BATCH_SIZE: 128
SCALES: [600]
MAX_SIZE: 1000
TEST:
DATABASE: 'taas:/data/voc_2007_test_lmdb'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
RPN_MIN_SIZE: 16
RPN_POST_NMS_TOP_N: 300
SCALES: [600]
MAX_SIZE: 1000
NMS: 0.45
\ No newline at end of file
NUM_GPUS: 4
VIS: False
ENABLE_TENSOR_BOARD: False
MODEL:
TYPE: retinanet
BACKBONE: resnet50.fpn
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
NUM_CLASSES: 81
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
STEPS: [30000, 40000]
MAX_ITERS: 45000
SNAPSHOT_ITERS: 5000
SNAPSHOT_PREFIX: coco_retinanet_400
FPN:
RPN_MIN_LEVEL: 3
RPN_MAX_LEVEL: 7
TRAIN:
WEIGHTS: '../data/imagenet_models/R-50.Affine.pth'
DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb'
IMS_PER_BATCH: 8
SCALES: [400]
MAX_SIZE: 666
TEST:
DATABASE: 'taas:/data/coco_2014_minival_lmdb'
JSON_FILE: '/data/instances_minival2014.json'
PROTOCOL: 'coco'
IMS_PER_BATCH: 1
SCALES: [400]
MAX_SIZE: 666
NMS: 0.5
\ No newline at end of file
NUM_GPUS: 4
VIS: False
ENABLE_TENSOR_BOARD: False
MODEL:
TYPE: retinanet
BACKBONE: resnet50.fpn
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
NUM_CLASSES: 81
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
STEPS: [120000, 160000]
MAX_ITERS: 180000
SNAPSHOT_ITERS: 5000
SNAPSHOT_PREFIX: coco_retinanet_400
FPN:
RPN_MIN_LEVEL: 3
RPN_MAX_LEVEL: 7
DROPBLOCK:
DROP_ON: True
DECREMENT: 0.000005 # * 20000 = 0.1
TRAIN:
WEIGHTS: '../data/imagenet_models/R-50.Affine.pth'
DATABASE: 'taas:/data/coco_2014_trainval35k_lmdb'
IMS_PER_BATCH: 8
SCALES: [400]
MAX_SIZE: 666
SCALE_JITTERING: True
COLOR_JITTERING: True
SCALE_RANGE: [0.8, 1.2]
TEST:
DATABASE: 'taas:/data/coco_2014_minival_lmdb'
JSON_FILE: '/data/instances_minival2014.json'
PROTOCOL: 'coco'
IMS_PER_BATCH: 1
SCALES: [400]
MAX_SIZE: 666
NMS: 0.5
\ No newline at end of file
NUM_GPUS: 1
VIS: False
ENABLE_TENSOR_BOARD: False
MODEL:
TYPE: ssd
BACKBONE: airnet5b.mbox
CLASSES: ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
NUM_CLASSES: 21
SOLVER:
BASE_LR: 0.001
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
STEPS: [80000, 100000, 120000]
MAX_ITERS: 120000
SNAPSHOT_ITERS: 5000
SNAPSHOT_PREFIX: voc_ssd_300
SSD:
RESIZE:
HEIGHT: 300
WIDTH: 300
MULTIBOX:
MIN_SIZES: [30, 90, 150]
MAX_SIZES: [90, 150, 210]
STRIDES: [8, 16, 32]
ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5], [1, 2, 0.5]]
TRAIN:
WEIGHTS: '../data/imagenet_models/AirNet.SSD.pth'
DATABASE: 'taas:/data/voc_0712_trainval_lmdb'
IMS_PER_BATCH: 32
TEST:
DATABASE: 'taas:/data/voc_2007_test_lmdb'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
IMS_PER_BATCH: 8
NMS_TOP_K: 400
NMS: 0.45
SCORE_THRESH: 0.01
DETECTIONS_PER_IM: 200
\ No newline at end of file
NUM_GPUS: 1
VIS: False
ENABLE_TENSOR_BOARD: False
MODEL:
TYPE: ssd
BACKBONE: vgg16_reduced_300.mbox
FREEZE_AT: 0
CLASSES: ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
NUM_CLASSES: 21
SOLVER:
BASE_LR: 0.002
WARM_UP_FACTOR: 0.
WEIGHT_DECAY: 0.0005
LR_POLICY: steps_with_decay
STEPS: [80000, 100000, 120000]
MAX_ITERS: 120000
SNAPSHOT_ITERS: 5000
SNAPSHOT_PREFIX: voc_ssd_300
SSD:
RESIZE:
HEIGHT: 300
WIDTH: 300
MULTIBOX:
STRIDES: [8, 16, 32, 64, 100, 300]
MIN_SIZES: [30, 60, 110, 162, 213, 264]
MAX_SIZES: [60, 110, 162, 213, 264, 315]
ASPECT_RATIOS: [[1, 2, 0.5], [1, 2, 0.5, 3, 0.33], [1, 2, 0.5, 3, 0.33],
[1, 2, 0.5, 3, 0.33], [1, 2, 0.5], [1, 2, 0.5]]
TRAIN:
WEIGHTS: '../data/imagenet_models/VGG16.SSD.pth'
DATABASE: 'taas:/data/voc_0712_trainval_lmdb'
IMS_PER_BATCH: 32
TEST:
DATABASE: 'taas:/data/voc_2007_test_lmdb'
PROTOCOL: 'voc2007' # 'voc2007', 'voc2010', 'coco'
IMS_PER_BATCH: 8
NMS_TOP_K: 400
NMS: 0.45
SCORE_THRESH: 0.01
DETECTIONS_PER_IM: 200
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
\ No newline at end of file
# --------------------------------------------------------
# Detectron @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os.path as osp
sys.path.insert(0, '../../../')
from database.frcnn.utils.make_from_xml import make_db
if __name__ == '__main__':
VOC_ROOT_DIR = '/home/workspace/datasets/VOC'
# train database: voc_2007_trainval + voc_2012_trainval
make_db(database_file=osp.join(VOC_ROOT_DIR, 'cache/voc_0712_trainval_lmdb'),
images_path=[osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/JPEGImages'),
osp.join(VOC_ROOT_DIR, 'VOCdevkit2012/VOC2012/JPEGImages')],
annotations_path=[osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/Annotations'),
osp.join(VOC_ROOT_DIR, 'VOCdevkit2012/VOC2012/Annotations')],
imagesets_path=[osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/ImageSets/Main'),
osp.join(VOC_ROOT_DIR, 'VOCdevkit2012/VOC2012/ImageSets/Main')],
splits=['trainval', 'trainval'])
# test database: voc_2007_test
make_db(database_file=osp.join(VOC_ROOT_DIR, 'cache/voc_2007_test_lmdb'),
images_path=osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/JPEGImages'),
annotations_path=osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/Annotations'),
imagesets_path=osp.join(VOC_ROOT_DIR, 'VOCdevkit2007/VOC2007/ImageSets/Main'),
splits=['test'])
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
\ No newline at end of file
syntax = "proto2";
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;
optional int32 label = 5;
repeated float float_data = 6;
optional bool encoded = 7 [default = false];
}
message Annotation {
optional float x1 = 1;
optional float y1 = 2;
optional float x2 = 3;
optional float y2 = 4;
optional string name = 5;
optional bool difficult = 6 [default = false];
optional bool crowd = 7 [default = false];
optional string mask = 8;
}
message AnnotatedDatum {
optional Datum datum = 1;
optional string filename = 2;
repeated Annotation annotation = 3;
}
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: anno.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='anno.proto',
package='',
serialized_pb=_b('\n\nanno.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\"\x88\x01\n\nAnnotation\x12\n\n\x02x1\x18\x01 \x01(\x02\x12\n\n\x02y1\x18\x02 \x01(\x02\x12\n\n\x02x2\x18\x03 \x01(\x02\x12\n\n\x02y2\x18\x04 \x01(\x02\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x18\n\tdifficult\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x14\n\x05\x63rowd\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x0c\n\x04mask\x18\x08 \x01(\t\"Z\n\x0e\x41nnotatedDatum\x12\x15\n\x05\x64\x61tum\x18\x01 \x01(\x0b\x32\x06.Datum\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\x12\x1f\n\nannotation\x18\x03 \x03(\x0b\x32\x0b.Annotation')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATUM = _descriptor.Descriptor(
name='Datum',
full_name='Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='channels', full_name='Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='height', full_name='Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='width', full_name='Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data', full_name='Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='label', full_name='Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='encoded', full_name='Datum.encoded', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=15,
serialized_end=144,
)
_ANNOTATION = _descriptor.Descriptor(
name='Annotation',
full_name='Annotation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='x1', full_name='Annotation.x1', index=0,
number=1, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y1', full_name='Annotation.y1', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='x2', full_name='Annotation.x2', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y2', full_name='Annotation.y2', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='Annotation.name', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='difficult', full_name='Annotation.difficult', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='crowd', full_name='Annotation.crowd', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mask', full_name='Annotation.mask', index=7,
number=8, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=147,
serialized_end=283,
)
_ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum',
full_name='AnnotatedDatum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='datum', full_name='AnnotatedDatum.datum', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='filename', full_name='AnnotatedDatum.filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='annotation', full_name='AnnotatedDatum.annotation', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=285,
serialized_end=375,
)
_ANNOTATEDDATUM.fields_by_name['datum'].message_type = _DATUM
_ANNOTATEDDATUM.fields_by_name['annotation'].message_type = _ANNOTATION
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION
DESCRIPTOR.message_types_by_name['AnnotatedDatum'] = _ANNOTATEDDATUM
Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
DESCRIPTOR = _DATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Datum)
))
_sym_db.RegisterMessage(Datum)
Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATION,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Annotation)
))
_sym_db.RegisterMessage(Annotation)
AnnotatedDatum = _reflection.GeneratedProtocolMessageType('AnnotatedDatum', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATEDDATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:AnnotatedDatum)
))
_sym_db.RegisterMessage(AnnotatedDatum)
# @@protoc_insertion_point(module_scope)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import cv2
from . import anno_pb2 as pb
from dragon.tools.db import LMDB
ZFILL = 8
ENCODE_QUALITY = 95
def set_zfill(value):
global ZFILL
ZFILL = value
def set_quality(value):
global ENCODE_QUALITY
ENCODE_QUALITY = value
def make_datum(image_id, image_file, objects):
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = ENCODE_QUALITY != 100
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), ENCODE_QUALITY])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = image_id
for ix, obj in enumerate(objects):
anno = pb.Annotation()
anno.x1, anno.y1, anno.x2, anno.y2 = obj['bbox']
anno.name = obj['name']
anno.difficult = obj['difficult']
anno_datum.annotation.add().CopyFrom(anno)
return anno_datum
def make_db(database_file, images_path, gt_recs, ext='.png'):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
total_line = len(gt_recs)
start_time = time.time()
zfill_flag = '{0:0%d}' % (ZFILL)
for image_id, objects in gt_recs.items():
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, total_line, now_time - start_time))
db.commit()
image_file = os.path.join(images_path, image_id + ext)
datum = make_datum(image_id, image_file, objects)
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import cv2
import xml.etree.ElementTree as ET
from dragon.tools.db import LMDB
from . import anno_pb2 as pb
ZFILL = 8
ENCODE_QUALITY = 95
def set_zfill(value):
global ZFILL
ZFILL = value
def set_quality(value):
global ENCODE_QUALITY
ENCODE_QUALITY = value
def make_datum(image_file, xml_file):
tree = ET.parse(xml_file)
filename = os.path.split(xml_file)[-1]
objs = tree.findall('object')
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = ENCODE_QUALITY != 100
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), ENCODE_QUALITY])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = filename.split('.')[0]
for ix, obj in enumerate(objs):
anno = pb.Annotation()
bbox = obj.find('bndbox')
x1 = float(bbox.find('xmin').text)
y1 = float(bbox.find('ymin').text)
x2 = float(bbox.find('xmax').text)
y2 = float(bbox.find('ymax').text)
cls = obj.find('name').text.strip()
anno.x1, anno.y1, anno.x2, anno.y2 = (x1, y1, x2, y2)
anno.name = cls
anno.difficult = False
if obj.find('difficult') is not None:
anno.difficult = int(obj.find('difficult').text) == 1
anno_datum.annotation.add().CopyFrom(anno)
return anno_datum
def make_db(database_file,
images_path,
annotations_path,
imagesets_path,
splits):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
if not isinstance(images_path, list):
images_path = [images_path]
if not isinstance(annotations_path, list):
annotations_path = [annotations_path]
if not isinstance(imagesets_path, list):
imagesets_path = [imagesets_path]
assert len(splits) == len(imagesets_path)
assert len(splits) == len(images_path)
assert len(splits) == len(annotations_path)
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
total_line = 0
start_time = time.time()
zfill_flag = '{0:0%d}' % (ZFILL)
for db_idx, split in enumerate(splits):
split_file = os.path.join(imagesets_path[db_idx], split + '.txt')
assert os.path.exists(split_file)
with open(split_file, 'r') as f:
lines = f.readlines()
total_line += len(lines)
for line in lines:
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, total_line, now_time - start_time))
db.commit()
filename = line.strip()
image_file = os.path.join(images_path[db_idx], filename + '.jpg')
xml_file = os.path.join(annotations_path[db_idx], filename + '.xml')
datum = make_datum(image_file, xml_file)
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
\ No newline at end of file
# --------------------------------------------------------
# Detectron @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make LMDB for cityscape dataset."""
import os
import sys
import shutil
import numpy as np
np.random.seed(1337)
try:
import cPickle
except:
import pickle as cPickle
sys.path.insert(0, '../../../')
from database.mrcnn.utils.make import make_db
from database.mrcnn.cityscape.make_mask import make_mask
if __name__ == '__main__':
CITYSCAPE_ROOT = '/data/cityscape'
# make RLE masks
if not os.path.exists('build'): os.makedirs('build')
cs_train = make_mask(
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest'),
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists/train.lst'))
cs_val = make_mask(
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest'),
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists/val.lst'))
with open('build/cs_train_mask.pkl', 'wb') as f:
cPickle.dump(cs_train, f, cPickle.HIGHEST_PROTOCOL)
with open('build/cs_val_mask.pkl', 'wb') as f:
cPickle.dump(cs_val, f, cPickle.HIGHEST_PROTOCOL)
# make image splits
for split in ['train', 'val', 'test']:
with open(os.path.join(CITYSCAPE_ROOT,
'gtFine_trainvaltest/imglists', split + '.lst'), 'r') as f:
entries = [line.strip().split('\t') for line in f.readlines()]
if split == 'train': np.random.shuffle(entries)
with open(os.path.join(CITYSCAPE_ROOT,
'gtFine_trainvaltest/imglists', split + '.txt'), 'w') as w:
for entry in entries: w.write(entry[1].split('.')[0] + '\n')
# make database
make_db(database_file=os.path.join(CITYSCAPE_ROOT, 'cache/cs_train_lmdb'),
images_path=os.path.join(CITYSCAPE_ROOT, 'leftImg8bit_trainvaltest'),
mask_file='build/cs_train_mask.pkl',
splits_path=os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists'),
splits=['train'], ext='.png')
make_db(database_file=os.path.join(CITYSCAPE_ROOT, 'cache/cs_val_lmdb'),
images_path=os.path.join(CITYSCAPE_ROOT, 'leftImg8bit_trainvaltest'),
mask_file='build/cs_val_mask.pkl',
splits_path=os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest/imglists'),
splits=['val'], ext='.png')
# clean!
shutil.rmtree('build')
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make masks for cityscape dataset."""
import os
import sys
import cv2
from collections import OrderedDict
import PIL.Image as Image
import numpy as np
np.random.seed(1337)
sys.path.insert(0, '../../..')
from lib.pycocotools.mask_utils import mask_bin2rle
from database.mrcnn.utils.process_pool import ProcessPool
class_id = [0,
24, 25, 26, 27,
28, 31, 32, 33]
classes = ['__background__',
'person', 'rider', 'car', 'truck',
'bus', 'train', 'motorcycle', 'bicycle']
ind_to_class = dict(zip(range(len(classes)), classes))
def parse_gt(gt_file, im_scale=1.0):
im = Image.open(gt_file)
pixel = list(im.getdata())
pixel = np.array(pixel).reshape([im.size[1], im.size[0]])
objects = []
for c in range(1, len(class_id)):
px = np.where((pixel >= class_id[c] * 1000) & (pixel < (class_id[c] + 1) * 1000))
if len(px[0]) == 0: continue
uids = np.unique(pixel[px])
for idx, uid in enumerate(uids):
px = np.where(pixel == uid)
x1 = np.min(px[1])
y1 = np.min(px[0])
x2 = np.max(px[1])
y2 = np.max(px[0])
if x2 - x1 <= 1 or y2 - y1 <= 1: continue
mask = np.zeros([im.size[1], im.size[0]], dtype=np.uint8)
mask[px] = 1
if im_scale != 1:
mask = cv2.resize(mask, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_NEAREST)
x1 = min(int(x1 * im_scale), mask.shape[1])
y1 = min(int(y1 * im_scale), mask.shape[0])
x2 = min(int(x2 * im_scale), mask.shape[1])
y2 = min(int(y2 * im_scale), mask.shape[0])
objects.append({'bbox': [x1, y1, x2, y2],
'mask': mask_bin2rle([mask])[0],
'name': ind_to_class[c],
'difficult': False})
return objects
def map_func(gts, Q):
for image_id, gt_file in gts:
objects = parse_gt(gt_file)
Q.put((image_id, objects))
def make_mask(gt_root, split_file):
# Create tasks
gt_tasks, gt_recs = [], OrderedDict()
with open(split_file, 'r') as f:
for line in f:
_, image_path, gt_path = line.strip().split('\t')
image_id = image_path.split('.')[0]
gt_file = os.path.join(gt_root, gt_path.replace('labelTrainIds', 'instanceIds'))
gt_tasks.append((image_id, gt_file))
num_tasks = len(gt_tasks)
# Run!
with ProcessPool(16) as pool:
pool.run(gt_tasks, func=map_func)
for idx in range(num_tasks):
image_id, objects = pool.get()
gt_recs[image_id] = objects
print('\rProcess: {} / {}'.format(idx + 1, num_tasks), end='')
return gt_recs
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import cv2
from collections import defaultdict
from lib.pycocotools.mask_utils import mask_rle2im
CITYSCAPE_ROOT = '/data/cityscape'
def write_results(json_file, img_list):
with open(json_file, 'r') as f:
json_results = json.load(f)
class_id = [0, 24, 25, 26, 27, 28, 31, 32, 33]
category_id_to_class_id = dict(zip(range(9), class_id))
result_path = os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest', 'results', 'pred')
if not os.path.exists(result_path): os.makedirs(result_path)
counts = defaultdict(int)
txt_results = defaultdict(list)
for idx, rec in enumerate(json_results):
class_id = category_id_to_class_id[rec['category_id']]
if class_id == 0: continue
im_h, im_w = rec['segmentation']['size']
mask_rle = rec['segmentation']['counts']
mask_image = mask_rle2im([mask_rle], im_h, im_w)[0] * 200
image_name = rec['image_id'].split('_leftImg8bit')[0]
mask_name = image_name + '_' + str(counts[image_name]) + '.png'
counts[image_name] += 1
mask_path = os.path.join(result_path, mask_name)
cv2.imwrite(mask_path, mask_image)
txt_results[image_name].append((mask_name, class_id, rec['score']))
print('\rWriting masks ({} / {})'.format(idx + 1, len(json_results)), end='')
with open(img_list, 'r') as F:
for line in F.readlines():
image_name = line.strip().split('/')[-1].split('_leftImg8bit')[0]
txt_path = os.path.join(result_path, image_name + '.txt')
with open(txt_path, 'w') as f:
for rec in txt_results[image_name]:
f.write('{} {} {:.8f}\n'.format(rec[0], rec[1], rec[2]))
if __name__ == '__main__':
write_results(
'/results/segmentations.json',
os.path.join(CITYSCAPE_ROOT, 'gtFine_trainvaltest', 'imglists', 'val.txt')
)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Make LMDB for COCO dataset."""
import os
import sys
import shutil
sys.path.insert(0, '../../../')
from database.mrcnn.utils.make import make_db
from database.mrcnn.coco.make_mask import make_mask, merge_mask
if __name__ == '__main__':
COCO_ROOT = '/data/coco'
# make RLE masks
if not os.path.exists('build'): os.makedirs('build')
make_mask('train', '2014', COCO_ROOT)
make_mask('valminusminival', '2014', COCO_ROOT)
make_mask('minival', '2014', COCO_ROOT)
merge_mask('trainval35k', '2014', [
'build/coco_2014_train_mask.pkl',
'build/coco_2014_valminusminival_mask.pkl'])
# train database: coco_2014_trainval35k
make_db(database_file=os.path.join(COCO_ROOT, 'cache/coco_2014_trainval35k_lmdb'),
images_path=[os.path.join(COCO_ROOT, 'images/train2014'),
os.path.join(COCO_ROOT, 'images/val2014')],
splits_path=[os.path.join(COCO_ROOT, 'ImageSets'),
os.path.join(COCO_ROOT, 'ImageSets')],
mask_file='build/coco_2014_trainval35k_mask.pkl',
splits=['train', 'valminusminival'])
# val database: coco_2014_minival
make_db(database_file=os.path.join(COCO_ROOT, 'cache/coco_2014_minival_lmdb'),
images_path=os.path.join(COCO_ROOT, 'images/val2014'),
mask_file='build/coco_2014_minival_mask.pkl',
splits_path=os.path.join(COCO_ROOT, 'ImageSets'),
splits=['minival'])
# clean!
shutil.rmtree('build')
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import sys
import os.path as osp
from collections import OrderedDict
try:
import cPickle
except:
import pickle as cPickle
sys.path.insert(0, '../../..')
from lib.pycocotools.coco import COCO
from lib.pycocotools.mask_utils import mask_poly2rle
class imdb(object):
def __init__(self, image_set, year, data_dir):
self._year = year
self._image_set = image_set
self._data_path = osp.join(data_dir)
self.invalid_cnt = 0
self.ignore_cnt = 0
#################
# CLASSES #
#################
# load COCO API, classes, class <-> id mappings
self._COCO = COCO(self._get_ann_file())
cats = self._COCO.loadCats(self._COCO.getCatIds())
self._classes = tuple(['__background__'] + [c['name'] for c in cats])
self._class_to_ind = dict(zip(self._classes, range(self.num_classes)))
self._ind_to_class = dict(zip(range(self.num_classes), self._classes))
self._class_to_coco_cat_id = dict(zip([c['name'] for c in cats],
self._COCO.getCatIds()))
self._coco_cat_id_to_class_ind = dict([(self._class_to_coco_cat_id[cls],
self._class_to_ind[cls]) for cls in self._classes[1:]])
#################
# SET #
#################
self._view_map = {
'minival2014': 'val2014', # 5k val2014 subset
'valminusminival2014': 'val2014', # val2014 \setminus minival2014
}
coco_name = image_set + year # e.g., "val2014"
self._data_name = (self._view_map[coco_name]
if coco_name in self._view_map else coco_name)
#################
# IMAGES #
#################
self._image_index = self._load_image_set_index()
self._annotations = self._load_annotations()
def _get_ann_file(self):
prefix = 'instances' if self._image_set.find('test') == -1 \
else 'image_info'
return osp.join(self._data_path, 'annotations',
prefix + '_' + self._image_set + self._year + '.json')
def _load_image_set_index(self):
"""
Load image ids.
"""
image_ids = self._COCO.getImgIds()
return image_ids
def _load_annotations(self):
"""
Load annotations.
"""
annotations = [self._load_coco_annotation(index)
for index in self._image_index]
return annotations
def image_path_from_index(self, index):
"""
Construct an image path from the image's "index" identifier.
"""
# Example image path for index=119993:
# images/train2014/COCO_train2014_000000119993.jpg
file_name = ('COCO_' + self._data_name + '_' +
str(index).zfill(12) + '.jpg')
image_path = osp.join(self._data_path, 'images',
self._data_name, file_name)
assert osp.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
def image_path_at(self, i):
"""
Return the absolute path to image i in the image sequence.
"""
return self.image_path_from_index(self._image_index[i])
def annotation_at(self, i):
"""
Return the absolute path to image i in the image sequence.
"""
return self._annotations[i]
def _load_coco_annotation(self, index):
"""
Loads COCO bounding-box instance annotations. Crowd instances are
handled by marking their overlaps (with all categories) to -1. This
overlap value means that crowd "instances" are excluded from training.
"""
im_ann = self._COCO.loadImgs(index)[0]
width = im_ann['width']
height = im_ann['height']
annIds = self._COCO.getAnnIds(imgIds=index, iscrowd=None)
objs = self._COCO.loadAnns(annIds)
# Sanitize boxes -- some are invalid
valid_objs = []
for obj in objs:
x1 = int(max(0, obj['bbox'][0]))
y1 = int(max(0, obj['bbox'][1]))
x2 = int(min(width - 1, x1 + max(0, obj['bbox'][2] - 1)))
y2 = int(min(height - 1, y1 + max(0, obj['bbox'][3] - 1)))
if type(obj['segmentation']) is list:
for p in obj['segmentation']:
if len(p) < 6: print('Remove invalid segm.')
# Valid polygons have >= 3 points, so require >= 6 coordinates
obj['segmentation'] = [p for p in obj['segmentation'] if len(p) >= 6]
rle_masks = mask_poly2rle([obj['segmentation']], height, width)
else:
# crowd masks
rle_masks = [obj['segmentation']]
if obj['area'] > 0 and x2 > x1 and y2 > y1:
obj['clean_bbox'] = [x1, y1, x2, y2]
# Exclude the crowd masks
# TODO(PhyscalX): You may encounter crashes when decoding crowd masks.
mask = rle_masks[0] if not obj['iscrowd'] else ''
valid_objs.append(
{'bbox': [x1, y1, x2, y2],
'mask': mask,
'category_id': obj['category_id'],
'class_id': self._coco_cat_id_to_class_ind[obj['category_id']],
'crowd': obj['iscrowd']})
valid_objs[-1]['name'] = self._ind_to_class[valid_objs[-1]['class_id']]
return height, width, valid_objs
@property
def num_images(self):
return len(self._image_index)
@property
def num_classes(self):
return len(self._classes)
def make_mask(split, year, data_dir):
coco = imdb(split, year, data_dir)
print('Preparing to make split: {}, total {} images'.format(split, coco.num_images))
if not osp.exists(osp.join(coco._data_path, 'ImageSets')):
os.makedirs(osp.join(coco._data_path, 'ImageSets'))
gt_recs = OrderedDict()
for i in range(coco.num_images):
filename = (coco.image_path_at(i).split('/')[-1]).split('.')[0]
h, w, objects = coco.annotation_at(i)
gt_recs[filename] = objects
with open(osp.join('build',
'coco_' + year + '_' + split + '_mask.pkl'), 'wb') as f:
cPickle.dump(gt_recs, f, cPickle.HIGHEST_PROTOCOL)
with open(osp.join(coco._data_path, 'ImageSets', split + '.txt'), 'w') as f:
for i in range(coco.num_images):
filename = (coco.image_path_at(i).split('/')[-1]).split('.')[0]
if i != coco.num_images - 1: filename += '\n'
f.write(filename)
def merge_mask(split, year, mask_files):
gt_recs = OrderedDict()
data_path = os.path.dirname(mask_files[0])
for mask_file in mask_files:
with open(mask_file, 'rb') as f:
recs = cPickle.load(f)
gt_recs.update(recs)
with open(osp.join(data_path,
'coco_' + year + '_' + split + '_mask.pkl'), 'wb') as f:
cPickle.dump(gt_recs, f, cPickle.HIGHEST_PROTOCOL)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import sys
import time
import json
import cv2
from dragon.tools.db import LMDB, wrapper_str
sys.path.insert(0, '../../../')
import database.mrcnn.utils.anno_pb2 as pb
IMAGE_INFO = '/data/image_info_test-dev2017.json'
def load_image_list(image_info):
num_images = len(image_info['images'])
image_list = []
print('The split has {} images.'.format(num_images))
for image in image_info['images']:
image_list.append(image['file_name'])
return image_list
def make_datum(image_file):
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = True
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = os.path.split(image_file)[-1]
return anno_datum
def make_db(database_file, images_path, image_list):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
start_time = time.time()
zfill_flag = '{0:0%d}' % (8)
for image_file in image_list:
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, len(image_list), now_time - start_time))
db.commit()
datum = make_datum(os.path.join(images_path, image_file))
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, len(image_list), now_time - start_time))
db.commit()
db.close()
end_time = time.time()
print('{0} images have been stored in the database.'.format(len(image_list)))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
if __name__ == '__main__':
image_info = json.load(open(IMAGE_INFO, 'r'))
image_list = load_image_list(image_info)
make_db('/data/coco_2017_test-dev_lmdb',
'/data/test2017', image_list)
\ No newline at end of file
# --------------------------------------------------------
# FPN @ Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
from .make import set_zfill, set_quality, make_db
\ No newline at end of file
syntax = "proto2";
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;
optional int32 label = 5;
repeated float float_data = 6;
optional bool encoded = 7 [default = false];
repeated int32 labels = 8;
}
message Annotation {
optional float x1 = 1;
optional float y1 = 2;
optional float x2 = 3;
optional float y2 = 4;
optional string name = 5;
optional bool difficult = 6 [default = false];
optional string mask = 7;
}
message AnnotatedDatum {
optional Datum datum = 1;
optional string filename = 2;
repeated Annotation annotation = 3;
}
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: anno.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='anno.proto',
package='',
serialized_pb=_b('\n\nanno.proto\"\x91\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x0e\n\x06labels\x18\x08 \x03(\x05\"r\n\nAnnotation\x12\n\n\x02x1\x18\x01 \x01(\x02\x12\n\n\x02y1\x18\x02 \x01(\x02\x12\n\n\x02x2\x18\x03 \x01(\x02\x12\n\n\x02y2\x18\x04 \x01(\x02\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x18\n\tdifficult\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x0c\n\x04mask\x18\x07 \x01(\t\"Z\n\x0e\x41nnotatedDatum\x12\x15\n\x05\x64\x61tum\x18\x01 \x01(\x0b\x32\x06.Datum\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\x12\x1f\n\nannotation\x18\x03 \x03(\x0b\x32\x0b.Annotation')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATUM = _descriptor.Descriptor(
name='Datum',
full_name='Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='channels', full_name='Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='height', full_name='Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='width', full_name='Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data', full_name='Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='label', full_name='Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='encoded', full_name='Datum.encoded', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='labels', full_name='Datum.labels', index=7,
number=8, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=15,
serialized_end=160,
)
_ANNOTATION = _descriptor.Descriptor(
name='Annotation',
full_name='Annotation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='x1', full_name='Annotation.x1', index=0,
number=1, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y1', full_name='Annotation.y1', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='x2', full_name='Annotation.x2', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='y2', full_name='Annotation.y2', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='Annotation.name', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='difficult', full_name='Annotation.difficult', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mask', full_name='Annotation.mask', index=6,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=162,
serialized_end=276,
)
_ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum',
full_name='AnnotatedDatum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='datum', full_name='AnnotatedDatum.datum', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='filename', full_name='AnnotatedDatum.filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='annotation', full_name='AnnotatedDatum.annotation', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=278,
serialized_end=368,
)
_ANNOTATEDDATUM.fields_by_name['datum'].message_type = _DATUM
_ANNOTATEDDATUM.fields_by_name['annotation'].message_type = _ANNOTATION
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION
DESCRIPTOR.message_types_by_name['AnnotatedDatum'] = _ANNOTATEDDATUM
Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
DESCRIPTOR = _DATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Datum)
))
_sym_db.RegisterMessage(Datum)
Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATION,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:Annotation)
))
_sym_db.RegisterMessage(Annotation)
AnnotatedDatum = _reflection.GeneratedProtocolMessageType('AnnotatedDatum', (_message.Message,), dict(
DESCRIPTOR = _ANNOTATEDDATUM,
__module__ = 'anno_pb2'
# @@protoc_insertion_point(class_scope:AnnotatedDatum)
))
_sym_db.RegisterMessage(AnnotatedDatum)
# @@protoc_insertion_point(module_scope)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
import os
import time
import cv2
try:
import cPickle
except:
import pickle as cPickle
from dragon.tools.db import LMDB
from . import anno_pb2 as pb
ZFILL = 8
ENCODE_QUALITY = 95
def set_zfill(value):
global ZFILL
ZFILL = value
def set_quality(value):
global ENCODE_QUALITY
ENCODE_QUALITY = value
def make_datum(image_file, mask_objects, im_scale=None):
filename = os.path.split(image_file)[-1]
anno_datum = pb.AnnotatedDatum()
datum = pb.Datum()
im = cv2.imread(image_file)
if im_scale: im = cv2.resize(im, None,
fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
datum.height, datum.width, datum.channels = im.shape
datum.encoded = ENCODE_QUALITY != 100
if datum.encoded:
result, im = cv2.imencode('.jpg', im, [int(cv2.IMWRITE_JPEG_QUALITY), ENCODE_QUALITY])
datum.data = im.tostring()
anno_datum.datum.CopyFrom(datum)
anno_datum.filename = filename.split('.')[0]
for ix, obj in enumerate(mask_objects):
anno = pb.Annotation()
x1, y1, x2, y2 = obj['bbox']
anno.name = obj['name']
anno.x1, anno.y1, anno.x2, anno.y2 = x1, y1, x2, y2
if 'difficult' in obj: anno.difficult = obj['difficult']
if 'crowd' in obj: anno.difficult = obj['crowd']
anno.mask = obj['mask']
anno_datum.annotation.add().CopyFrom(anno)
return anno_datum
def make_db(database_file, images_path, mask_file,
splits_path, splits, ext='.jpg', im_scale=None):
if os.path.isdir(database_file) is True:
raise ValueError('The database path is already exist.')
else:
root_dir = database_file[:database_file.rfind('/')]
if not os.path.exists(root_dir):
os.makedirs(root_dir)
if not isinstance(images_path, list):
images_path = [images_path]
if not isinstance(splits_path, list):
splits_path = [splits_path]
assert len(splits) == len(splits_path)
assert len(splits) == len(images_path)
if mask_file is not None:
with open(mask_file, 'rb') as f:
all_masks = cPickle.load(f)
else:
all_masks = {}
print('Start Time: ', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
db = LMDB(max_commit=10000)
db.open(database_file, mode='w')
count = 0
total_line = 0
start_time = time.time()
zfill_flag = '{0:0%d}' % (ZFILL)
for db_idx, split in enumerate(splits):
split_file = os.path.join(splits_path[db_idx], split + '.txt')
assert os.path.exists(split_file)
with open(split_file, 'r') as f:
lines = f.readlines()
total_line += len(lines)
for line in lines:
count += 1
if count % 10000 == 0:
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(
count, total_line, now_time - start_time))
db.commit()
filename = line.strip()
image_file = os.path.join(images_path[db_idx], filename + ext)
mask_objects = all_masks[filename] if filename in all_masks else None
if mask_objects is None:
raise ValueError('The image({}) takes invalid mask settings.'.format(filename))
datum = make_datum(image_file, mask_objects, im_scale)
db.put(zfill_flag.format(count - 1), datum.SerializeToString())
now_time = time.time()
print('{0} / {1} in {2:.2f} sec'.format(count, total_line, now_time - start_time))
db.commit()
db.close()
# Compress the empty space
db.open(database_file, mode='w')
db.commit()
end_time = time.time()
print('{0} images have been stored in the database.'.format(total_line))
print('This task finishes within {0:.2f} seconds.'.format(end_time - start_time))
print('The size of database is {0} MB.'.format(
float(os.path.getsize(database_file + '/data.mdb') / 1000 / 1000)))
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""A simple process pool to map tasks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
class ProcessPool(object):
def __init__(self, num_processes=8, max_qsize=100):
self.num_tasks = self.fetch_tasks = 0
self.num_processes = num_processes
self.Q = multiprocessing.Queue(max_qsize)
def __enter__(self):
return self
def __exit__(self, *excinfo):
pass
def map(self, tasks, func):
n_tasks_each = int(len(tasks) / self.num_processes)
remain_tasks = len(tasks) - n_tasks_each * self.num_processes
pos = 0
for i in range(self.num_processes):
if i != self.num_processes - 1:
work_set = tasks[pos: pos + n_tasks_each]
pos += n_tasks_each
else:
work_set = tasks[pos: pos + n_tasks_each + remain_tasks]
print('[Main]: Process #{} Got {} tasks.'.format(i, len(work_set)))
p = multiprocessing.Process(target=func, args=(work_set, self.Q))
p.start()
def wait(self):
displays = {}
while True:
qsize = self.Q.qsize()
if qsize == self.num_tasks: break
if qsize > 0 and qsize % 100 == 0:
if qsize not in displays:
displays[qsize] = True
print('[Queue]: Cached {} tasks.'.format(qsize))
outputs = []
while self.Q.qsize() > 0:
outputs.append(self.Q.get())
assert len(outputs) == self.num_tasks
print('[Main]: Got {} outputs.'.format(len(outputs)))
return outputs
def get(self):
self.fetch_tasks += 1
if self.fetch_tasks > self.num_tasks:
return None
return self.Q.get()
def run_all(self, tasks, func):
self.num_tasks = len(tasks)
self.map(tasks, func)
self.wait()
def run(self, tasks, func):
self.num_tasks = len(tasks)
self.map(tasks, func)
\ No newline at end of file
# --------------------------------------------------------
# Detectron
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import time
import numpy as np
from lib.core.config import cfg, cfg_from_file
class Coordinator(object):
"""Coordinator is a simple tool to manage the
unique experiments from the YAML configurations.
"""
def __init__(self, cfg_file, exp_dir=None):
# Override the default configs
cfg_from_file(cfg_file)
if cfg.EXP_DIR != '':
exp_dir = cfg.EXP_DIR
if exp_dir is None:
model_id = time.strftime(
'%Y%m%d_%H%M%S', time.localtime(time.time()))
self.experiment_dir = '../experiments/{}'.format(model_id)
if not os.path.exists(self.experiment_dir):
os.makedirs(self.experiment_dir)
else:
if not os.path.exists(exp_dir):
raise ValueError('ExperimentDir({}) does not exist.'.format(exp_dir))
self.experiment_dir = exp_dir
def _path_at(self, file, auto_create=True):
path = os.path.abspath(os.path.join(self.experiment_dir, file))
if auto_create and not os.path.exists(path): os.makedirs(path)
return path
def checkpoints_dir(self):
return self._path_at('checkpoints')
def exports_dir(self):
return self._path_at('exports')
def results_dir(self, checkpoint=None):
sub_dir = os.path.splitext(os.path.basename(checkpoint))[0] if checkpoint else ''
return self._path_at(os.path.join('results', sub_dir))
def checkpoint(self, global_step=None, wait=True):
def locate():
files = os.listdir(self.checkpoints_dir())
steps = []
for ix, file in enumerate(files):
step = int(file.split('_iter_')[-1].split('.')[0])
if global_step == step: return os.path.join(self.checkpoints_dir(), files[ix])
steps.append(step)
if global_step is None:
if len(files) == 0:
raise ValueError('Dir({}) is empty.'.format(self.checkpoints_dir()))
last_idx = int(np.argmax(steps)); last_step = steps[last_idx]
return os.path.join(self.checkpoints_dir(), files[last_idx])
return None
result = locate()
while not result and wait:
print('\rWaiting for step_{}.checkpoint to exist...'.format(global_step), end='')
time.sleep(10)
result = locate()
return result
def delete_experiment(self):
if os.path.exists(self.experiment_dir):
shutil.rmtree(self.experiment_dir)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import dragon.vm.torch as torch
from lib.core.config import cfg
from lib.modeling.detector import Detector
from lib.utils import logger
class Solver(object):
def __init__(self):
# Define the generic detector
self.detector = Detector().cuda(cfg.GPU_ID)
# Define the optimizer and its arguments
self.optimizer = None
self.opt_arguments = {
'scale_gradient': 1. / (
cfg.SOLVER.LOSS_SCALING *
cfg.SOLVER.ITER_SIZE),
'clip_gradient': float(cfg.SOLVER.CLIP_NORM),
'weight_decay': cfg.SOLVER.WEIGHT_DECAY,
}
# Define the global step
self.iter = 0
# Define the decay step
self._current_step = 0
def _get_param_groups(self):
param_groups = [
{
'params': [],
'lr_mult': 1.,
'decay_mult': 1.,
},
# Special treatment for biases (mainly to match historical impl.
# details):
# (1) Do not apply weight decay
# (2) Use a 2x higher learning rate
{
'params': [],
'lr_mult': 2.,
'decay_mult': 0.,
}
]
for name, param in self.detector.named_parameters():
if 'bias' in name: param_groups[1]['params'].append(param)
else: param_groups[0]['params'].append(param)
return param_groups
def set_learning_rate(self):
policy = cfg.SOLVER.LR_POLICY
if policy == 'steps_with_decay':
if self._current_step < len(cfg.SOLVER.STEPS) \
and self.iter >= cfg.SOLVER.STEPS[self._current_step]:
self._current_step = self._current_step + 1
logger.info('MultiStep Status: Iteration {}, step = {}' \
.format(self.iter, self._current_step))
new_lr = cfg.SOLVER.BASE_LR * (
cfg.SOLVER.GAMMA ** self._current_step)
self.optimizer.param_groups[0]['lr'] = \
self.optimizer.param_groups[1]['lr'] = new_lr
else:
raise ValueError('Unknown lr policy: ' + policy)
def one_step(self):
# Forward & Backward & Compute_loss
iter_size = cfg.SOLVER.ITER_SIZE
loss_scaling = cfg.SOLVER.LOSS_SCALING
run_time = 0.; stats = {'loss': {'total': 0.}, 'iter': self.iter}
add_loss = lambda x, y: y if x is None else x + y
tic = time.time()
if iter_size > 1:
# Dragon is designed for manual gradients accumulating
# ``zero_grad`` is only required if calling ``accumulate_grad``
self.optimizer.zero_grad()
for i in range(iter_size):
outputs, total_loss = self.detector(), None
# Sum the partial losses
for k, v in outputs.items():
if 'loss' in k:
if k not in stats['loss']:
stats['loss'][k] = 0.
total_loss = add_loss(total_loss, v)
stats['loss'][k] += float(v) * loss_scaling
if loss_scaling != 1.: total_loss *= loss_scaling
stats['loss']['total'] += float(total_loss)
total_loss.backward()
if iter_size > 1: self.optimizer.accumulate_grad()
run_time += (time.time() - tic)
# Apply Update
self.set_learning_rate()
tic = time.time()
self.optimizer.step()
run_time += (time.time() - tic)
self.iter += 1
# Average loss by the iter size
for k in stats['loss'].keys():
stats['loss'][k] /= cfg.SOLVER.ITER_SIZE
# Misc stats
stats['lr'] = self.base_lr
stats['time'] = run_time
return stats
@property
def base_lr(self):
return self.optimizer.param_groups[0]['lr']
@base_lr.setter
def base_lr(self, value):
self.optimizer.param_groups[0]['lr'] = \
self.optimizer.param_groups[1]['lr'] = value
class SGDSolver(Solver):
def __init__(self):
super(SGDSolver, self).__init__()
self.opt_arguments.update(**{
'lr': cfg.SOLVER.BASE_LR,
'momentum': cfg.SOLVER.MOMENTUM,
})
self.optimizer = torch.optim.SGD(
self._get_param_groups(), **self.opt_arguments)
class NesterovSolver(Solver):
def __init__(self):
super(NesterovSolver, self).__init__()
self.opt_arguments.update(**{
'lr': cfg.SOLVER.BASE_LR,
'momentum': cfg.SOLVER.MOMENTUM,
'nesterov': True,
})
self.optimizer = torch.optim.SGD(
self._get_param_groups(), **self.opt_arguments)
class RMSPropSolver(Solver):
def __init__(self):
super(RMSPropSolver, self).__init__()
self.opt_arguments.update(**{
'lr': cfg.SOLVER.BASE_LR,
'alpha': 0.9,
'eps': 1e-5,
})
self.optimizer = torch.optim.RMSprop(
self._get_param_groups(), **self.opt_arguments)
class AdamSolver(Solver):
def __init__(self):
super(AdamSolver, self).__init__()
self.opt_arguments.update(**{
'lr': cfg.SOLVER.BASE_LR,
'beta1': 0.9,
'beta2': 0.999,
'eps': 1e-5,
})
self.optimizer = torch.optim.RMSprop(
self._get_param_groups(), **self.opt_arguments)
def get_solver_func(type):
if type == 'MomentumSGD':
return SGDSolver
elif type == 'Nesterov':
return NesterovSolver
elif type == 'RMSProp':
return RMSPropSolver
elif type == 'Adam':
return AdamSolver
else:
raise ValueError('Unsupported solver type: {}.\n'
'Excepted in (MomentumSGD, Nesterov, RMSProp, Adam)'.format(type))
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import cv2
from multiprocessing import Queue
from collections import OrderedDict
from lib.core.config import cfg
from lib.datasets.factory import get_imdb
# All detectors share the same reader/transformer during testing
from lib.faster_rcnn.data.data_reader import DataReader
from lib.faster_rcnn.data.data_transformer import DataTransformer
class TestServer(object):
def __init__(self, output_dir):
self.imdb = get_imdb(cfg.TEST.DATABASE)
self.imdb.competition_mode(cfg.TEST.COMPETITION_MODE)
self.num_images, self.num_classes, self.classes = \
self.imdb.num_images, self.imdb.num_classes, self.imdb.classes
self.data_reader = DataReader(**{'source': self.imdb.source})
self.data_transformer = DataTransformer()
self.data_reader.Q_out = Queue(cfg.TEST.IMS_PER_BATCH)
self.data_reader.start()
self.gt_recs = OrderedDict()
self.output_dir = output_dir
if cfg.VIS_ON_FILE:
self.vis_dir = os.path.join(self.output_dir, 'vis')
if not os.path.exists(self.vis_dir): os.makedirs(self.vis_dir)
def set_transformer(self, transformer_cls):
self.data_transformer = transformer_cls()
def get_image(self):
serialized = self.data_reader.Q_out.get()
image = self.data_transformer.get_image(serialized)
image_id, objects = self.data_transformer.get_annotations(serialized)
self.gt_recs[image_id] = {
'objects': objects,
'width': image.shape[1],
'height': image.shape[0]}
return image_id, image
def get_save_filename(self, image_id, ext='.jpg'):
return os.path.join(self.vis_dir, image_id + ext) \
if cfg.VIS_ON_FILE else None
def get_records(self):
if len(self.gt_recs) != self.num_images:
raise RuntimeError('Loading {} records, '
'while the specific database required {}'.format(
len(self.gt_recs), self.num_images))
return self.gt_recs
def evaluate_detections(self, all_boxes):
self.imdb.evaluate_detections(
all_boxes, self.get_records(), self.output_dir)
def evaluate_segmentations(self, all_boxes, all_masks):
self.imdb.evaluate_segmentations(
all_boxes, all_masks, self.get_records(), self.output_dir)
class InferServer(object):
def __init__(self, output_dir):
self.images_dir = cfg.TEST.DATABASE
self.imdb = get_imdb('taas:/empty')
self.images = os.listdir(self.images_dir)
self.num_images, self.num_classes, self.classes = \
len(self.images), cfg.MODEL.NUM_CLASSES, cfg.MODEL.CLASSES
self.data_transformer = DataTransformer()
self.gt_recs = OrderedDict()
self.output_dir = output_dir
self.image_idx = 0
if cfg.VIS_ON_FILE:
self.vis_dir = os.path.join(self.output_dir, 'vis')
if not os.path.exists(self.vis_dir): os.makedirs(self.vis_dir)
def set_transformer(self, transformer_cls):
self.data_transformer = transformer_cls()
def get_image(self):
image_name = self.images[self.image_idx]
image_id = image_name.split('.')[0]
image = cv2.imread(os.path.join(self.images_dir, image_name))
self.image_idx = (self.image_idx + 1) % self.num_images
self.gt_recs[image_id] = {
'width': image.shape[1],
'height': image.shape[0]}
return image_id, image
def get_save_filename(self, image_id, ext='.jpg'):
return os.path.join(self.vis_dir, image_id + ext) \
if cfg.VIS_ON_FILE else None
def get_records(self):
if len(self.gt_recs) != self.num_images:
raise RuntimeError('Loading {} records, '
'while the specific database required {}'.format(
len(self.gt_recs), self.num_images))
return self.gt_recs
def evaluate_detections(self, all_boxes):
self.imdb.evaluate_detections(
all_boxes, self.get_records(), self.output_dir)
def evaluate_segmentations(self, all_boxes, all_masks):
self.imdb.evaluate_segmentations(
all_boxes, all_masks, self.get_records(), self.output_dir)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/fast_rcnn/train.py>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import datetime
from collections import OrderedDict
import dragon.vm.torch as torch
from lib.core.config import cfg
from lib.core.solver import get_solver_func
from lib.utils.timer import Timer
from lib.utils.stats import SmoothedValue
from lib.utils import logger
class SolverWrapper(object):
def __init__(self, coordinator):
self.output_dir = coordinator.checkpoints_dir()
self.solver = get_solver_func('MomentumSGD')()
# Load the pre-trained weights
init_weights = cfg.TRAIN.WEIGHTS
if init_weights != '':
if os.path.exists(init_weights):
logger.info('Loading weights from {}.'.format(init_weights))
self.solver.detector.load_weights(init_weights)
else:
raise ValueError('Invalid path of weights: {}'.format(init_weights))
# Mixed precision training?
if cfg.MODEL.DATA_TYPE.lower() == 'float16':
self.solver.detector.half() # Powerful FP16 Support
# Plan the metrics
self.metrics = OrderedDict()
if cfg.ENABLE_TENSOR_BOARD:
from dragon.tools.tensorboard import TensorBoard
self.board = TensorBoard(log_dir=coordinator.experiment_dir + '/logs')
def snapshot(self):
if not logger.is_root(): return None
filename = (cfg.SOLVER.SNAPSHOT_PREFIX + '_iter_{:d}'
.format(self.solver.iter) + '.pth')
filename = os.path.join(self.output_dir, filename)
torch.save(self.solver.detector.state_dict(), filename)
logger.info('Wrote snapshot to: {:s}'.format(filename))
return filename
def add_metrics(self, stats):
for k, v in stats['loss'].items():
if k not in self.metrics:
self.metrics[k] = SmoothedValue(20)
self.metrics[k].AddValue(v)
def send_metrics(self, stats):
if hasattr(self, 'board'):
self.board.scalar_summary('lr', stats['lr'], stats['iter'])
self.board.scalar_summary('time', stats['time'], stats['iter'])
for k, v in self.metrics.items():
if k == 'total':
self.board.scalar_summary('total_loss', v.GetMedianValue(), stats['iter'])
else: self.board.scalar_summary(k, v.GetMedianValue(), stats['iter'])
def step(self, display=False):
stats = self.solver.one_step()
self.add_metrics(stats)
self.send_metrics(stats)
if display:
logger.info('Iteration %d, lr = %.8f, loss = %f, time = %.2fs' % (stats['iter'],
stats['lr'], self.metrics['total'].GetMedianValue(), stats['time']))
for k, v in self.metrics.items():
if k == 'total': continue
logger.info(' Train net output({}): {}'.format(k, v.GetMedianValue()))
def train_model(self):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
model_paths = []
start_lr = self.solver.base_lr
while self.solver.iter < cfg.SOLVER.MAX_ITERS:
if self.solver.iter < cfg.SOLVER.WARM_UP_ITERS:
alpha = (self.solver.iter + 1.0) / cfg.SOLVER.WARM_UP_ITERS
self.solver.base_lr = \
start_lr * (cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha)
# Apply 1-step SGD update
timer.tic()
self.step(display=self.solver.iter % cfg.SOLVER.DISPLAY == 0)
timer.toc()
if self.solver.iter % (10 * cfg.SOLVER.DISPLAY) == 0:
average_time = timer.average_time
eta_seconds = average_time * (
cfg.SOLVER.MAX_ITERS - self.solver.iter)
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
progress = float(self.solver.iter + 1) / cfg.SOLVER.MAX_ITERS
logger.info('< PROGRESS: {:.2%} | SPEED: {:.3f}s / iter | ETA: {} >'
.format(progress, timer.average_time, eta))
if self.solver.iter % cfg.SOLVER.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
model_paths.append(self.snapshot())
if last_snapshot_iter != self.solver.iter:
model_paths.append(self.snapshot())
return model_paths
def train_net(coordinator, start_iter=0):
sw = SolverWrapper(coordinator)
sw.solver.iter = start_iter
logger.info('Solving...')
model_paths = sw.train_model()
return model_paths
\ No newline at end of file
File mode changed
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/factory.py>
#
# ------------------------------------------------------------
from lib.datasets.pascal_voc import pascal_voc
from lib.datasets.coco import coco
from lib.datasets.taas import taas
__sets = {}
# pascal voc
for year in ['2007', '2012', '0712']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
# coco 2014
for year in ['2014']:
for split in ['train', 'val', 'trainval35k', 'minival', 'valminusminival']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# coco 2015 & 2017
for year in ['2015', '2017']:
for split in ['test', 'test-dev']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# taas
__sets['taas'] = (lambda source: taas(source))
def get_imdb(name):
"""Get an imdb (image database) by name."""
keys = name.split(':')
if len(keys) == 2:
cls, source = keys
if cls not in __sets:
raise KeyError('Unknown dataset: {}'.format(cls))
return __sets[cls](source)
elif len(keys) == 1:
return __sets[name]()
else:
raise ValueError('Illegal format of image database: {}'.format(name))
def list_imdbs():
"""List all registered imdbs."""
return __sets.keys()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/imdb.py>
#
# ------------------------------------------------------------
import os
from dragon.tools.db import LMDB
from lib.core.config import cfg
class imdb(object):
def __init__(self, name):
self._name = name
self._num_classes = 0
self._classes = []
@property
def name(self):
return self._name
@property
def num_classes(self):
return len(self._classes)
@property
def classes(self):
return self._classes
@property
def cache_path(self):
cache_path = os.path.abspath(os.path.join(cfg.DATA_DIR, 'cache'))
if not os.path.exists(cache_path):
os.makedirs(cache_path)
return cache_path
@property
def source(self):
excepted_source = os.path.join(self.cache_path, self.name + '_lmdb')
if not os.path.exists(excepted_source):
raise RuntimeError('Excepted LMDB source from: {}, '
'but it is not existed.'.format(excepted_source))
return excepted_source
@property
def num_images(self):
self._db = LMDB()
self._db.open(self.source)
num_entries = self._db.num_entries()
self._db.close()
return num_entries
def evaluate_detections(self, all_boxes, gt_recs, output_dir):
raise NotImplementedError
def evaluate_masks(self, all_boxes, all_masks, output_dir):
raise NotImplementedError
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import numpy as np
import uuid
try:
import cPickle
except:
import pickle as cPickle
from .imdb import imdb
from .voc_eval import voc_bbox_eval, voc_segm_eval
class pascal_voc(imdb):
def __init__(self, image_set, year, name='voc'):
imdb.__init__(self, name + '_' + year + '_' + image_set)
self._year = year
self._image_set = image_set
self._classes = ('__background__', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
self._salt = str(uuid.uuid4())
self.config = {'cleanup': True, 'use_salt': True}
def _get_comp_id(self):
return '_' + self._salt if self.config['use_salt'] else ''
def _get_prefix(self, type='bbox'):
if type == 'bbox': return 'detections_'
elif type == 'segm': return 'segmentations_'
elif type == 'kpt': return 'keypoints_'
return ''
def _get_voc_results_T(self, results_folder, type='bbox'):
# experiments/model_id/results/detections_voc_2007_test_<comp_id>_aeroplane.txt
filename = self._get_prefix(type) + self._name + self._get_comp_id() + '_{:s}.txt'
if not os.path.exists(results_folder): os.makedirs(results_folder)
return os.path.join(results_folder, filename)
def _write_voc_bbox_results(self, all_boxes, gt_recs, output_dir):
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__': continue
print('Writing {} VOC format bbox results'.format(cls))
filename = self._get_voc_results_T(output_dir).format(cls)
with open(filename, 'wt') as f:
ix = 0
for image_id, rec in gt_recs.items():
dets = all_boxes[cls_ind][ix]; ix += 1
if dets == []: continue
for k in range(dets.shape[0]):
f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
format(image_id, dets[k, -1],
dets[k, 0] + 1, dets[k, 1] + 1,
dets[k, 2] + 1, dets[k, 3] + 1))
def _write_seg_results_file(self, all_boxes, all_masks):
for cls_inds, cls in enumerate(self.classes):
if cls == '__background__': continue
print('Writing {} VOC results file'.format(cls))
results_folder = os.path.join(self._devkit_path, 'results', 'seg')
if not os.path.exists(results_folder): os.makedirs(results_folder)
det_filename = os.path.join(results_folder, cls + '_det.pkl')
seg_filename = os.path.join(results_folder, cls + '_seg.pkl')
with open(det_filename, 'wb') as f:
cPickle.dump(all_boxes[cls_inds], f, cPickle.HIGHEST_PROTOCOL)
with open(seg_filename, 'wb') as f:
cPickle.dump(all_masks[cls_inds], f, cPickle.HIGHEST_PROTOCOL)
def _do_voc_bbox_eval(self, gt_recs, output_dir):
aps = []
# The PASCAL VOC metric changed in 2010
use_07_metric = True if int(self._year) < 2010 else False
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No') + '\n')
for i, cls in enumerate(self._classes):
if cls == '__background__':
continue
det_file = self._get_voc_results_T(output_dir).format(cls)
rec, prec, ap = voc_bbox_eval(det_file, gt_recs, cls,
IoU=0.5, use_07_metric=use_07_metric)
aps += [ap]
print('AP for {} = {:.4f}'.format(cls, ap))
print('Mean AP = {:.4f}\n'.format(np.mean(aps)))
def _do_voc_segm_eval(self, imagenames, output_dir):
aps = []
# define this as true according to SDS's evaluation protocol
use_07_metric = True
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
print('~~~~~~ Evaluation use min overlap = 0.5 ~~~~~~')
for i, cls in enumerate(self.classes):
if cls == '__background__':
continue
det_file = os.path.join(output_dir, 'bbox_' + cls + '.pkl')
seg_file = os.path.join(output_dir, 'segm_' + cls + '.pkl')
mask_file = os.path.join(self.cache_path, self.name + '.pkl')
ap = seg_eval_v2(det_file, seg_file, mask_file, imagenames, cls,
ovthresh=0.5, use_07_metric=use_07_metric)
aps += [ap]
print('AP for {} = {:.2f}'.format(cls, ap))
print('Mean AP@0.5 = {:.2f}'.format(np.mean(aps)))
print('~~~~~~ Evaluation use min overlap = 0.7 ~~~~~~')
aps = []
for i, cls in enumerate(self.classes):
if cls == '__background__':
continue
det_file = os.path.join(output_dir, 'bbox_' + cls + '.pkl')
seg_file = os.path.join(output_dir, 'segm_' + cls + '.pkl')
mask_file = os.path.join(self.cache_path, self.name + '.pkl')
ap = seg_eval_v2(det_file, seg_file, mask_file, imagenames, cls,
ovthresh=0.7, use_07_metric=use_07_metric)
aps += [ap]
print('AP for {} = {:.2f}'.format(cls, ap))
print('Mean AP@0.7 = {:.2f}'.format(np.mean(aps)))
def evaluate_detections(self, all_boxes, gt_recs, output_dir):
self._write_voc_bbox_results(all_boxes, gt_recs, output_dir)
self._do_voc_bbox_eval(gt_recs, output_dir)
if self.config['cleanup']:
for cls in self._classes:
if cls == '__background__': continue
filename = self._get_voc_results_T(output_dir).format(cls)
os.remove(filename)
def competition_mode(self, on):
if on:
self.config['use_salt'] = False
self.config['cleanup'] = False
else:
self.config['use_salt'] = True
self.config['cleanup'] = True
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import numpy as np
try:
import cPickle
except:
import pickle as cPickle
from lib.core.config import cfg
from lib.utils.mask_transform import mask_overlap
from lib.utils.boxes import expand_boxes
from lib.pycocotools.mask_utils import mask_rle2im
def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def voc_bbox_eval(det_file, gt_recs, cls_name,
IoU=0.5, use_07_metric=False):
class_recs = {}
n_pos = 0
for image_name, rec in gt_recs.items():
R = [obj for obj in rec['objects'] if obj['name'] == cls_name]
bbox = np.array([x['bbox'] for x in R])
difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
det = [False] * len(R)
n_pos = n_pos + sum(~difficult)
class_recs[image_name] = {
'bbox': bbox,
'difficult': difficult,
'det': det
}
# read detections
with open(det_file, 'r') as f: lines = f.readlines()
splitlines = [x.strip().split(' ') for x in lines]
image_ids = [x[0] for x in splitlines]
confidence = np.array([float(x[1]) for x in splitlines])
BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
# avoid IndexError if detecting nothing
if len(BB) == 0: return 0, 0, -1
# sort by confidence
sorted_ind = np.argsort(-confidence)
BB = BB[sorted_ind, :]
image_ids = [image_ids[x] for x in sorted_ind]
# go down dets and mark TPs and FPs
nd = len(image_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
for d in range(nd):
R = class_recs[image_ids[d]]
bb = BB[d, :].astype(float)
ovmax = -np.inf
BBGT = R['bbox'].astype(float)
if BBGT.size > 0:
# compute overlaps
# intersection
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1., 0.)
ih = np.maximum(iymax - iymin + 1., 0.)
inters = iw * ih
# union
uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
(BBGT[:, 2] - BBGT[:, 0] + 1.) *
(BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
overlaps = inters / uni
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
if ovmax > IoU:
if not R['difficult'][jmax]:
if not R['det'][jmax]:
tp[d] = 1.
R['det'][jmax] = 1
else:
fp[d] = 1.
else:
fp[d] = 1.
# compute precision recall
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / float(n_pos)
# avoid divide by zero in case the first detection matches a difficult
# ground truth
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = voc_ap(rec, prec, use_07_metric)
return rec, prec, ap
def voc_segm_eval(det_file, seg_file, gt_recs, cls_name,
IoU=0.5, use_07_metric=False):
# 0. Constants
M = cfg.MRCNN.RESOLUTION
binary_thresh = cfg.TEST.BINARY_THRESH
scale = (M + 2.0) / M
padded_mask = np.zeros((M + 2, M + 2), dtype=np.float32)
# 1. Get bbox & mask ground truths
image_names, class_recs, n_pos = [], {}, 0
for image_name, rec in gt_recs.items():
R = [obj for obj in rec['objects'] if obj['name'] == cls_name]
bbox = np.array([x['bbox'] for x in R])
mask = np.array([mask_rle2im([x['mask']], rec['height'], rec['width'])[0] for x in R])
difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
det = [False] * len(R)
n_pos = n_pos + sum(~difficult)
class_recs[image_name] = {
'bbox': bbox,
'mask': mask,
'difficult': difficult,
'det': det
}
image_names.append(image_name)
# 2. Get predict pickle file for this class
with open(det_file, 'rb') as f: boxes_pkl = cPickle.load(f)
with open(seg_file, 'rb') as f: masks_pkl = cPickle.load(f)
# 3. Pre-compute number of total instances to allocate memory
num_images = len(gt_recs)
box_num = 0
for im_i in range(num_images):
box_num += len(boxes_pkl[im_i])
# avoid IndexError if detecting nothing
if box_num == 0: return 0, 0, -1
# 4. Re-organize all the predicted boxes
new_boxes = np.zeros((box_num, 5))
new_masks = np.zeros((box_num, M, M))
new_images = []
cnt = 0
for image_ind in range(num_images):
boxes = boxes_pkl[image_ind]
masks = masks_pkl[image_ind]
num_instance = len(boxes)
for box_ind in range(num_instance):
new_boxes[cnt] = boxes[box_ind]
new_masks[cnt] = masks[box_ind]
new_images.append(image_names[image_ind])
cnt += 1
# 5. Rearrange boxes according to their scores
seg_scores = new_boxes[:, -1]
keep_inds = np.argsort(-seg_scores)
new_boxes = new_boxes[keep_inds, :]
new_masks = new_masks[keep_inds, :, :]
num_pred = new_boxes.shape[0]
# 6. Calculate t/f positive
fp = np.zeros((num_pred, 1))
tp = np.zeros((num_pred, 1))
ref_boxes = expand_boxes(new_boxes, scale)
ref_boxes = ref_boxes.astype(np.int32)
for i in range(num_pred):
image_name = new_images[keep_inds[i]]
if image_name not in class_recs:
print('Warning: {} does not exist in the ground-truths.'.format(image_name))
fp[i] = 1
continue
R = class_recs[image_name]
im_h, im_w = \
gt_recs[image_name]['height'], \
gt_recs[image_name]['width']
# decode mask
ref_box = ref_boxes[i, :4]
mask = new_masks[i]
padded_mask[1:-1, 1:-1] = mask[:, :]
w = ref_box[2] - ref_box[0] + 1
h = ref_box[3] - ref_box[1] + 1
w = np.maximum(w, 1)
h = np.maximum(h, 1)
mask = cv2.resize(padded_mask, (w, h))
mask = np.array(mask > binary_thresh, dtype=np.uint8)
x1 = max(ref_box[0], 0)
y1 = max(ref_box[1], 0)
x2 = min(ref_box[2] + 1, im_w)
y2 = min(ref_box[3] + 1, im_h)
pred_mask = mask[(y1 - ref_box[1]): (y2 - ref_box[1]),
(x1 - ref_box[0]): (x2 - ref_box[0])]
# calculate max region overlap
ovmax = -1; jmax = -1
for j in range(len(R['det'])):
gt_mask_bound = R['bbox'][j].astype(int)
pred_mask_bound = new_boxes[i, :4].astype(int)
crop_mask = R['mask'][j][gt_mask_bound[1] : gt_mask_bound[3] + 1,
gt_mask_bound[0] : gt_mask_bound[2] + 1]
ov = mask_overlap(gt_mask_bound, pred_mask_bound, crop_mask, pred_mask)
if ov > ovmax:
ovmax = ov
jmax = j
if ovmax > IoU:
if not R['difficult'][jmax]:
if not R['det'][jmax]:
tp[i] = 1.
R['det'][jmax] = 1
else:
fp[i] = 1.
else:
fp[i] = 1
# 7. Calculate precision
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / float(n_pos)
# avoid divide by zero in case the first matches a difficult gt
prec = tp / np.maximum(fp + tp, np.finfo(np.float64).eps)
ap = voc_ap(rec, prec, use_07_metric=use_07_metric)
return ap
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from lib.faster_rcnn.layers.data_layer import DataLayer
from lib.faster_rcnn.layers.anchor_target_layer import AnchorTargetLayer
from lib.faster_rcnn.layers.proposal_layer import ProposalLayer
from lib.faster_rcnn.layers.proposal_target_layer import ProposalTargetLayer
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from multiprocessing import Process
from lib.core.config import cfg
from lib.utils.blob import im_list_to_blob
class BlobFetcher(Process):
def __init__(self, **kwargs):
super(BlobFetcher, self).__init__()
self.Q1_in = self.Q2_in = self.Q_out = None
self.daemon = True
def get(self, Q_in):
processed_ims = []; ims_info = []; all_boxes = []
for ix in range(cfg.TRAIN.IMS_PER_BATCH):
im, im_scale, gt_boxes = Q_in.get()
processed_ims.append(im)
ims_info.append(list(im.shape[0:2]) + [im_scale])
# Encode boxes by adding the idx of images
im_boxes = np.zeros((gt_boxes.shape[0], gt_boxes.shape[1] + 1), dtype=np.float32)
im_boxes[:, 0:gt_boxes.shape[1]] = gt_boxes
im_boxes[:, -1] = ix
all_boxes.append(im_boxes)
return {
'data': im_list_to_blob(processed_ims),
'ims_info': np.array(ims_info, dtype=np.float32),
'gt_boxes': np.concatenate(all_boxes, axis=0),
}
def run(self):
while True:
if self.Q1_in.qsize() >= cfg.TRAIN.IMS_PER_BATCH:
self.Q_out.put(self.get(self.Q1_in))
elif self.Q2_in.qsize() >= cfg.TRAIN.IMS_PER_BATCH:
self.Q_out.put(self.get(self.Q2_in))
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import pprint
from multiprocessing import Queue
import dragon.core.mpi as mpi
from lib.core.config import cfg
import lib.utils.logger as logger
from lib.faster_rcnn.data.data_reader import DataReader
from lib.faster_rcnn.data.data_transformer import DataTransformer
from lib.faster_rcnn.data.blob_fetcher import BlobFetcher
class DataBatch(object):
"""DataBatch aims to prefetch data by ``Triple-Buffering``.
It takes full advantages of the Process/Thread of Python,
which provides remarkable I/O speed up for scalable distributed training.
"""
def __init__(self, **kwargs):
"""Construct a ``DataBatch``.
Parameters
----------
source : str
The path of database.
multiple_nodes: boolean
Whether to split data for multiple parallel nodes. Default is ``False``.
shuffle : boolean
Whether to shuffle the data. Default is ``False``.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
batch_size : int
The size of a training batch.
partition : boolean
Whether to partition batch. Default is ``False``.
prefetch : int
The prefetch count. Default is ``5``.
"""
super(DataBatch, self).__init__()
# Init mpi
global_rank = 0; local_rank = 0; group_size = 1
if mpi.Is_Init():
idx, group = mpi.AllowParallel()
if idx != -1: # DataParallel
global_rank = mpi.Rank()
group_size = len(group)
for i, node in enumerate(group):
if global_rank == node: local_rank = i
kwargs['group_size'] = group_size
# Configuration
self._prefetch = kwargs.get('prefetch', 5)
self._num_readers = kwargs.get('num_readers', 1)
self._num_transformers = kwargs.get('num_transformers', -1)
self._max_transformers = kwargs.get('max_transformers', 3)
self._num_fetchers = kwargs.get('num_fetchers', 1)
# Io-Aware Policy
if self._num_transformers == -1:
self._num_transformers = 2
# Add 1 transformer for color augmentation
if cfg.TRAIN.COLOR_JITTERING:
self._num_transformers += 1
self._num_transformers = min(self._num_transformers, self._max_transformers)
self._batch_size = kwargs.get('batch_size', 100)
self._partition = kwargs.get('partition', False)
if self._partition:
self._batch_size = int(self._batch_size / kwargs['group_size'])
# Init queues
self.Q_level_1 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q1_level_2 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q2_level_2 = Queue(self._prefetch * self._num_readers * self._batch_size)
self.Q_level_3 = Queue(self._prefetch * self._num_readers)
# Init readers
self._readers = []
for i in range(self._num_readers):
self._readers.append(DataReader(**kwargs))
self._readers[-1].Q_out = self.Q_level_1
for i in range(self._num_readers):
num_parts = self._num_readers
part_idx = i
if self._readers[i]._multiple_nodes or \
self._readers[i]._use_shuffle:
num_parts *= group_size
part_idx += local_rank * self._num_readers
self._readers[i]._num_parts = num_parts
self._readers[i]._part_idx = part_idx
self._readers[i]._random_seed += part_idx
self._readers[i].start()
time.sleep(0.1)
# Init transformers
self._transformers = []
for i in range(self._num_transformers):
transformer = DataTransformer(**kwargs)
transformer._random_seed += (i + local_rank * self._num_transformers)
transformer.Q_in = self.Q_level_1
transformer.Q1_out = self.Q1_level_2
transformer.Q2_out = self.Q2_level_2
transformer.start()
self._transformers.append(transformer)
time.sleep(0.1)
# Init blob fetchers
self._fetchers = []
for i in range(self._num_fetchers):
fetcher = BlobFetcher(**kwargs)
fetcher.Q1_in = self.Q1_level_2
fetcher.Q2_in = self.Q2_level_2
fetcher.Q_out = self.Q_level_3
fetcher.start()
self._fetchers.append(fetcher)
time.sleep(0.1)
# Prevent to echo multiple nodes
if local_rank == 0: self.echo()
def cleanup():
def terminate(processes):
for process in processes:
process.terminate()
process.join()
terminate(self._fetchers)
logger.info('Terminating BlobFetcher ......')
terminate(self._transformers)
logger.info('Terminating DataTransformer ......')
terminate(self._readers)
logger.info('Terminating DataReader......')
import atexit
atexit.register(cleanup)
def get(self):
"""Get a batch.
Returns
-------
dict
The batch dict.
"""
return self.Q_level_3.get()
def echo(self):
"""Print I/O Information.
Returns
-------
None
"""
print('---------------------------------------------------------')
print('BatchFetcher({} Threads), Using config:'.format(
self._num_readers + self._num_transformers + self._num_fetchers))
params = {'queue_size': self._prefetch,
'n_readers': self._num_readers,
'n_transformers': self._num_transformers,
'n_fetchers': self._num_fetchers}
pprint.pprint(params)
print('---------------------------------------------------------')
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import numpy.random as npr
from multiprocessing import Process
import dragon.config as config
from dragon.tools.db import LMDB
class DataReader(Process):
"""DataReader is deployed to queue encoded str from `LMDB`_.
It is supported to adaptively partition and shuffle records over all distributed nodes.
"""
def __init__(self, **kwargs):
"""Construct a ``DataReader``.
Parameters
----------
source : str
The path of database.
multiple_nodes: boolean
Whether to split data for multiple parallel nodes. Default is ``False``.
shuffle : boolean
Whether to shuffle the data. Default is ``False``.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
"""
super(DataReader, self).__init__()
self._source = kwargs.get('source', '')
self._multiple_nodes = kwargs.get('multiple_nodes', False)
self._use_shuffle = kwargs.get('shuffle', False)
self._use_instance_chunk = kwargs.get('instance_chunk', False)
self._num_chunks = kwargs.get('num_chunks', 2048)
self._chunk_size = kwargs.get('chunk_size', -1)
self._part_idx, self._num_parts = 0, 1
self._cur_idx, self._cur_chunk_idx = 0, 0
self._random_seed = config.GetRandomSeed()
self.Q_out = None
self.daemon = True
def element(self):
"""Get the value of current record.
Returns
-------
str
The encoded str.
"""
return self._db.value()
def redirect(self, target_idx):
"""Redirect to the target position.
Parameters
----------
target_idx : int
The key of instance in ``LMDB``.
Returns
-------
None
Notes
-----
The redirection reopens the ``LMDB``.
You can drop caches by ``echo 3 > /proc/sys/vm/drop_caches``.
This will disturb getting stuck when ``Database Size`` >> ``RAM Size``.
"""
self._db.close()
self._db.open(self._source)
self._cur_idx = target_idx
self._db.set(str(self._cur_idx).zfill(self._zfill))
def reset(self):
"""Reset the cursor and environment.
Returns
-------
None
"""
if self._multiple_nodes or self._use_shuffle:
if self._use_shuffle: self._perm = npr.permutation(self._num_shuffle_parts)
self._cur_chunk_idx = 0
self._start_idx = int(self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx])
self._start_idx = int(self._start_idx * self._chunk_size)
if self._start_idx >= self._num_entries: self.next_chunk()
self._end_idx = self._start_idx + self._chunk_size
self._end_idx = min(self._num_entries, self._end_idx)
else:
self._start_idx = 0
self._end_idx = self._num_entries
self.redirect(self._start_idx)
def next_record(self):
"""Step the cursor of records.
Returns
-------
None
"""
self._cur_idx += 1
self._db.next()
def next_chunk(self):
"""Step the cursor of shuffling chunks.
Returns
-------
None
"""
self._cur_chunk_idx += 1
if self._cur_chunk_idx >= self._num_shuffle_parts: self.reset()
else:
self._start_idx = self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx]
self._start_idx = self._start_idx * self._chunk_size
if self._start_idx >= self._num_entries: self.next_chunk()
else:
self._end_idx = self._start_idx + self._chunk_size
self._end_idx = min(self._num_entries, self._end_idx)
self.redirect(self._start_idx)
def run(self):
"""Start the process.
Returns
-------
None
"""
# fix seed
npr.seed(self._random_seed)
# init db
self._db = LMDB()
self._db.open(self._source)
self._zfill = self._db.zfill()
self._num_entries = self._db.num_entries()
self._epoch_size = int(self._num_entries/ self._num_parts + 1)
if self._use_shuffle:
if self._chunk_size == 1:
# Each chunk has at most 1 record [For Fully Shuffle]
self._chunk_size, self._num_shuffle_parts = \
1, int(self._num_entries / self._num_parts) + 1
else:
if self._use_shuffle and self._chunk_size == -1:
# Search a optimal chunk size by chunks [For Chunk Shuffle]
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20)))
min_chunk_size = 1
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2
self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._num_entries / self._num_shuffle_parts / self._num_parts + 1)
limit = (self._num_parts - 0.5) * self._num_shuffle_parts * self._chunk_size
if self._num_entries <= limit:
# Roll back to fully shuffle
self._chunk_size, self._num_shuffle_parts = \
1, int(self._num_entries / self._num_parts) + 1
else:
# Each chunk has at most K records [For Multiple Nodes]
# Note that if ``shuffle`` and ``multiple_nodes`` are all ``False``,
# ``chunk_size`` and ``num_shuffle_parts`` are meaningless
self._chunk_size = int(self._num_entries / self._num_parts) + 1
self._num_shuffle_parts = 1
self._perm = np.arange(self._num_shuffle_parts)
# Init env
self.reset()
# Run!
while True:
self.Q_out.put(self.element())
self.next_record()
if self._cur_idx >= self._end_idx:
if self._multiple_nodes or \
self._use_shuffle: self.next_chunk()
else: self.reset()
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from multiprocessing import Process
import numpy as np
import numpy.random as npr
try:
import cv2
except ImportError as e:
print('Failed to import cv2. Error: {0}'.format(str(e)))
try:
import PIL.Image
except ImportError as e:
print('Failed to import PIL. Error: {0}'.format(str(e)))
from lib.core.config import cfg
from lib.proto import anno_pb2 as pb
from lib.utils.blob import prep_im_for_blob
import lib.utils.logger as logger
class DataTransformer(Process):
def __init__(self, **kwargs):
super(DataTransformer, self).__init__()
self._random_seed = cfg.RNG_SEED
self._use_flipped = cfg.TRAIN.USE_FLIPPED
self._use_diff = cfg.TRAIN.USE_DIFF
self._classes = kwargs.get('classes', ('__background__',))
self._num_classes = len(self._classes)
self._class_to_ind = dict(zip(self._classes, range(self._num_classes)))
self._queues = []
self.Q_in = self.Q1_out = self.Q2_out = None
self.daemon = True
def make_roidb(self, ann_datum, im_scale, flip=False, offsets=None):
annotations = ann_datum.annotation
n_objects = 0
if not self._use_diff:
for ann in annotations:
if not ann.difficult: n_objects += 1
else: n_objects = len(annotations)
roidb = {
'width': ann_datum.datum.width,
'height': ann_datum.datum.height,
'gt_classes': np.zeros((n_objects,), dtype=np.int32),
'boxes': np.zeros((n_objects, 4), dtype=np.float32),
}
ix = 0
for ann in annotations:
if not self._use_diff and ann.difficult: continue
roidb['boxes'][ix, :] = [
max(0, ann.x1), max(0, ann.y1),
min(ann.x2, ann_datum.datum.width - 1),
min(ann.y2, ann_datum.datum.height - 1)]
roidb['gt_classes'][ix] = self._class_to_ind[ann.name]
ix += 1
if flip: roidb['boxes'] = _flip_boxes(roidb['boxes'], roidb['width'])
roidb['boxes'] *= im_scale
if offsets is not None:
roidb['boxes'][:, 0::2] += offsets[0]
roidb['boxes'][:, 1::2] += offsets[1]
roidb['boxes'][:, :] = np.minimum(
np.maximum(roidb['boxes'][:, :], 0),
[offsets[2][1] - 1, offsets[2][0] - 1] * 2)
return roidb
@classmethod
def get_image(cls, serialized):
datum = pb.AnnotatedDatum()
datum.ParseFromString(serialized)
datum = datum.datum
im = np.fromstring(datum.data, np.uint8)
return cv2.imdecode(im, -1) if datum.encoded is True else \
im.reshape((datum.height, datum.width, datum.channels))
@classmethod
def get_annotations(cls, serialized):
datum = pb.AnnotatedDatum()
datum.ParseFromString(serialized)
filename = datum.filename
annotations = datum.annotation
objects = []
for ix, ann in enumerate(annotations):
objects.append({
'name': ann.name,
'difficult': int(ann.difficult),
'bbox': [ann.x1, ann.y1, ann.x2, ann.y2],
'mask': ann.mask,
})
return filename, objects
def get(self, serialized):
datum = pb.AnnotatedDatum()
datum.ParseFromString(serialized)
im_datum = datum.datum
im = np.fromstring(im_datum.data, np.uint8)
if im_datum.encoded is True: im = cv2.imdecode(im, -1)
else: im = im.reshape((im_datum.height, im_datum.width, im_datum.channels))
# Scale
scale_indices = npr.randint(0, high=len(cfg.TRAIN.SCALES))
target_size = cfg.TRAIN.SCALES[scale_indices]
im, im_scale, jitter = prep_im_for_blob(im, target_size, cfg.TRAIN.MAX_SIZE)
# Crop or Pad
offsets = None
if cfg.TRAIN.MAX_SIZE > 0:
if jitter != 1.0:
# To a rectangle (scale, max_size)
target_size = (np.array(im.shape[0:2]) / jitter).astype(np.int)
im, offsets = _get_image_with_target_size(target_size, im)
else:
# To a square (target_size, target_size)
im, offsets = _get_image_with_target_size([target_size] * 2, im)
# Flip
flip = False
if self._use_flipped:
if npr.randint(0, 2) > 0:
im = im[:, ::-1, :]
flip = True
# Datum -> RoIDB
roidb = self.make_roidb(datum, im_scale, flip, offsets)
# Post-Process for gt boxes
# Shape like: [num_objects, {x1, y1, x2, y2, cls}]
gt_boxes = np.empty((len(roidb['gt_classes']), 5), dtype=np.float32)
gt_boxes[:, 0:4], gt_boxes[:, 4] = roidb['boxes'], roidb['gt_classes']
return im, im_scale, gt_boxes
def run(self):
npr.seed(self._random_seed)
while True:
serialized = self.Q_in.get()
data = self.get(serialized)
# Ensure that there should be at least 1 ground-truth
if len(data[2]) < 1: continue
aspect_ratio = float(data[0].shape[0]) / data[0].shape[1]
if aspect_ratio > 1.0: self.Q1_out.put(data)
else: self.Q2_out.put(data)
def _flip_boxes(boxes, width):
flip_boxes = boxes.copy()
oldx1 = boxes[:, 0].copy()
oldx2 = boxes[:, 2].copy()
flip_boxes[:, 0] = width - oldx2 - 1
flip_boxes[:, 2] = width - oldx1 - 1
if not (flip_boxes[:, 2] >= flip_boxes[:, 0]).all():
logger.fatal('Encounter invalid coordinates after flipping boxes.')
return flip_boxes
def _get_image_with_target_size(target_size, im):
im_shape = list(im.shape)
width_diff = target_size[1] - im_shape[1]
offset_crop_width = max(-width_diff // 2, 0)
offset_pad_width = max(width_diff // 2, 0)
height_diff = target_size[0] - im_shape[0]
offset_crop_height = max(-height_diff // 2, 0)
offset_pad_height = max(height_diff // 2, 0)
im_shape[0 : 2] = target_size
new_im = np.empty(im_shape, dtype=im.dtype)
new_im.fill(127)
new_im[offset_pad_height:offset_pad_height + im.shape[0],
offset_pad_width:offset_pad_width + im.shape[1]] = \
im[offset_crop_height:offset_crop_height + target_size[0],
offset_crop_width:offset_crop_width + target_size[1]]
return new_im, (offset_pad_width - offset_crop_width,
offset_pad_height - offset_crop_height, target_size)
\ No newline at end of file
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py>
#
# ------------------------------------------------------------
import numpy as np
# Verify that we compute the same anchors as Shaoqing's matlab implementation:
#
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat
# >> anchors
#
# anchors =
#
# -83 -39 100 56
# -175 -87 192 104
# -359 -183 376 200
# -55 -55 72 72
# -119 -119 136 136
# -247 -247 264 264
# -35 -79 52 96
# -79 -167 96 184
# -167 -343 184 360
#array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.],
# [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.],
# [-119., -119., 136., 136.],
# [-247., -247., 264., 264.],
# [ -35., -79., 52., 96.],
# [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]])
def generate_anchors(base_size=16, ratios=(0.5, 1, 2),
scales=2**np.arange(3, 6)):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
base_anchor = np.array([1, 1, base_size, base_size]) - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])])
return anchors
def generate_anchors_v2(stride=16, ratios=(0.5, 1, 2),
sizes=(32, 64, 128, 256, 512)):
"""
Generates a matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
are centered on stride / 2, have (approximate) sqrt areas of the specified
sizes, and aspect ratios as given.
"""
return generate_anchors(stride, ratios,
np.array(sizes, dtype=np.float) / stride)
def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1)))
return anchors
def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
if __name__ == '__main__':
print(generate_anchors())
\ No newline at end of file
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!