Commit d8f612c8 by Ting PAN

Init sphinx documentation for C++ API

Summary:
This commit uses sphinx to generate C++ API documentation
whose style and theme are consistent with the Python API.
1 parent 8dbb73a7
Showing with 2128 additions and 1030 deletions
FROM ubuntu:16.04 FROM ubuntu:16.04
RUN \ RUN \
apt-get update && apt-get install -y --no-install-recommends \ apt-get update && apt-get install -y \
--no-install-recommends \
--allow-change-held-packages \
build-essential \ build-essential \
cmake \ cmake \
git \ git \
...@@ -17,7 +19,7 @@ RUN \ ...@@ -17,7 +19,7 @@ RUN \
python3-pip \ python3-pip \
python3-dev \ python3-dev \
python3-pyqt4 \ python3-pyqt4 \
python3-tk \ python3-tk \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN \ RUN \
......
...@@ -2,7 +2,9 @@ FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 ...@@ -2,7 +2,9 @@ FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
RUN \ RUN \
rm /etc/apt/sources.list.d/cuda.list && \ rm /etc/apt/sources.list.d/cuda.list && \
apt-get update && apt-get install -y --no-install-recommends \ apt-get update && apt-get install -y \
--no-install-recommends \
--allow-change-held-packages \
build-essential \ build-essential \
cmake \ cmake \
git \ git \
......
...@@ -3,27 +3,43 @@ Building Dragon Documentation ...@@ -3,27 +3,43 @@ Building Dragon Documentation
This page will help you to build the following documentations: This page will help you to build the following documentations:
Dragon C++ API: https://dragon.seetatech.com/api/cc Python API: https://dragon.seetatech.com/api/python
Dragon Python API: https://dragon.seetatech.com/api/python C++ API: https://dragon.seetatech.com/api/cc
Build Documentation of C++ API Requirements
------------------------------ ------------
- sphinx >= 3.0.2
```bash ```bash
cd dragon/docs/api/cc pip install sphinx
doxygen Doxyfile
``` ```
Then, open the ```docs/api/cc/html/index.html``` in your browser. - sphinx_seeta_theme
```bash
pip install sphinx_seeta_theme
```
- doxygen (C++ API only)
See: http://www.doxygen.org/download.html
Build Documentation of Python API Build Documentation of Python API
--------------------------------- ---------------------------------
```bash ```bash
pip install sphinx_seeta_theme cd dragon/docs/api/python && make html
cd dragon/docs/api/python ```
make html
Then, open the ``docs/_build/api/python/index.html`` in your browser.
Build Documentation of C++ API
------------------------------
```bash
cd dragon/docs/api/cc && make doxygen && make html
``` ```
Then, open the ```docs/api/python/index.html``` in your browser. Then, open the ``docs/_build/api/cc/index.html`` in your browser.
...@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8 ...@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
# title of most generated pages and in a few other places. # title of most generated pages and in a few other places.
# The default value is: My Project. # The default value is: My Project.
PROJECT_NAME = "Dragon - C++ API" PROJECT_NAME =
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This # The PROJECT_NUMBER tag can be used to enter a project or revision number. This
# could be handy for archiving the generated documentation or if some version # could be handy for archiving the generated documentation or if some version
...@@ -44,21 +44,21 @@ PROJECT_NUMBER = ...@@ -44,21 +44,21 @@ PROJECT_NUMBER =
# for a project that appears at the top of each page and should give viewer a # for a project that appears at the top of each page and should give viewer a
# quick idea about the purpose of the project. Keep the description short. # quick idea about the purpose of the project. Keep the description short.
PROJECT_BRIEF = "A Computation Graph Virtual Machine Based Deep Learning Framework" PROJECT_BRIEF =
# With the PROJECT_LOGO tag one can specify a logo or an icon that is included # With the PROJECT_LOGO tag one can specify a logo or an icon that is included
# in the documentation. The maximum height of the logo should not exceed 55 # in the documentation. The maximum height of the logo should not exceed 55
# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy # pixels and the maximum width should not exceed 200 pixels. Doxygen will copy
# the logo to the output directory. # the logo to the output directory.
PROJECT_LOGO = images/logo.png PROJECT_LOGO =
# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path # The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path
# into which the generated documentation will be written. If a relative path is # into which the generated documentation will be written. If a relative path is
# entered, it will be relative to the location where doxygen was started. If # entered, it will be relative to the location where doxygen was started. If
# left blank the current directory will be used. # left blank the current directory will be used.
OUTPUT_DIRECTORY = "" OUTPUT_DIRECTORY = "../../_build/api/cc_doxygen"
# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- # If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub-
# directories (in 2 levels) under the output directory of each output format and # directories (in 2 levels) under the output directory of each output format and
...@@ -143,7 +143,7 @@ ALWAYS_DETAILED_SEC = NO ...@@ -143,7 +143,7 @@ ALWAYS_DETAILED_SEC = NO
# operators of the base classes will not be shown. # operators of the base classes will not be shown.
# The default value is: NO. # The default value is: NO.
INLINE_INHERITED_MEMB = NO INLINE_INHERITED_MEMB = YES
# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path # If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path
# before files name in the file list and in the header files. If set to NO the # before files name in the file list and in the header files. If set to NO the
...@@ -1044,7 +1044,7 @@ VERBATIM_HEADERS = YES ...@@ -1044,7 +1044,7 @@ VERBATIM_HEADERS = YES
# generated with the -Duse-libclang=ON option for CMake. # generated with the -Duse-libclang=ON option for CMake.
# The default value is: NO. # The default value is: NO.
CLANG_ASSISTED_PARSING = NO # CLANG_ASSISTED_PARSING = NO
# If clang assisted parsing is enabled you can provide the compiler with command # If clang assisted parsing is enabled you can provide the compiler with command
# line options that you would normally use when invoking the compiler. Note that # line options that you would normally use when invoking the compiler. Note that
...@@ -1052,7 +1052,7 @@ CLANG_ASSISTED_PARSING = NO ...@@ -1052,7 +1052,7 @@ CLANG_ASSISTED_PARSING = NO
# specified with INPUT and INCLUDE_PATH. # specified with INPUT and INCLUDE_PATH.
# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. # This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES.
CLANG_OPTIONS = # CLANG_OPTIONS =
# If clang assisted parsing is enabled you can provide the clang parser with the # If clang assisted parsing is enabled you can provide the clang parser with the
# path to the compilation database (see: # path to the compilation database (see:
...@@ -1063,7 +1063,7 @@ CLANG_OPTIONS = ...@@ -1063,7 +1063,7 @@ CLANG_OPTIONS =
# generated with the -Duse-libclang=ON option for CMake. # generated with the -Duse-libclang=ON option for CMake.
# The default value is: 0. # The default value is: 0.
CLANG_COMPILATION_DATABASE_PATH = 0 # CLANG_COMPILATION_DATABASE_PATH = 0
#--------------------------------------------------------------------------- #---------------------------------------------------------------------------
# Configuration options related to the alphabetical class index # Configuration options related to the alphabetical class index
...@@ -1098,7 +1098,7 @@ IGNORE_PREFIX = ...@@ -1098,7 +1098,7 @@ IGNORE_PREFIX =
# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output # If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output
# The default value is: YES. # The default value is: YES.
GENERATE_HTML = YES GENERATE_HTML = NO
# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a # The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a
# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of # relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
...@@ -1930,7 +1930,7 @@ MAN_LINKS = NO ...@@ -1930,7 +1930,7 @@ MAN_LINKS = NO
# captures the structure of the code including all documentation. # captures the structure of the code including all documentation.
# The default value is: NO. # The default value is: NO.
GENERATE_XML = NO GENERATE_XML = YES
# The XML_OUTPUT tag is used to specify where the XML pages will be put. If a # The XML_OUTPUT tag is used to specify where the XML pages will be put. If a
# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of # relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
...@@ -2083,9 +2083,7 @@ INCLUDE_FILE_PATTERNS = ...@@ -2083,9 +2083,7 @@ INCLUDE_FILE_PATTERNS =
# recursively expanded use the := operator instead of the = operator. # recursively expanded use the := operator instead of the = operator.
# This tag requires that the tag ENABLE_PREPROCESSING is set to YES. # This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
PREDEFINED = WITH_MPI \ PREDEFINED = DRAGON_API= USE_MPI USE_CUDA USE_CUDNN USE_NCCL
WITH_CUDA \
WITH_CUDNN \
# If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this # If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this
# tag can be used to specify a list of macro names that should be expanded. The # tag can be used to specify a list of macro names that should be expanded. The
......
# Makefile for Sphinx documentation
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
PAPER =
BUILDDIR = ../../_build/api/cc
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed.)
endif
# Internal variables.
PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter
ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
NUMBER_OF_PROCESSORS:=$(shell getconf _NPROCESSORS_ONLN)
.PHONY: help clean html latex latexpdf
help:
@echo "Please use \`make <target>' where <target> is one of"
@echo " doxygen to make Doxygen XML files"
@echo " html to make standalone HTML files"
@echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
@echo " latexpdf to make LaTeX files and run them through pdflatex"
clean:
rm -rf $(BUILDDIR)/*
doxygen:
mkdir -p $(BUILDDIR)_doxygen && doxygen
@echo
@echo "Build finished. The Doxygen XML files are in $(BUILDDIR)_doxygen/xml."
html:
$(SPHINXBUILD) -b html -j ${NUMBER_OF_PROCESSORS} $(ALLSPHINXOPTS) $(BUILDDIR)
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)."
latex:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)-latex
@echo
@echo "Build finished; the LaTeX files are in $(BUILDDIR)-latex."
@echo "Run \`make' in that directory to run these through (pdf)latex" \
"(use \`make latexpdf' here to do that automatically)."
latexpdf:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)-latex
@echo "Running LaTeX files through pdflatex..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)-latex."
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Sphinx configuration for C++ API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from sphinx_seeta_theme import HTMLTranslator
from sphinx_seeta_theme import HTMLTranslatorV2
from sphinx_seeta_theme import setup as setup_v1
def path_to(href, index=False):
if index:
if len(href) == 0:
return 'index.html'
return href + '/index.html'
else:
return href + '.html'
# Basic
html_static_path = ['../_static']
exclude_patterns = ['../_build']
master_doc = 'index'
source_suffix = '.rst'
# Extension
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
'sphinxcontrib.katex',
'breathe',
]
napoleon_use_rtype = False
# Project
project = 'dragon'
copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd'
author = 'SeetaTech'
with open('../../../dragon/version.txt', 'r') as f:
version = f.read().strip()
# Sphinx
c_id_attributes = ['DRAGON_API']
cpp_id_attributes = ['DRAGON_API']
# Breathe
breathe_projects = {'dragon': '../../_build/api/cc_doxygen/xml/'}
breathe_default_project = 'dragon'
# HTML
html_theme = 'seeta'
html_title = ''
html_short_title = ''
html_logo = '../_static/images/dragon.png'
html_favicon = '../_static/favicon.ico'
html_copy_source = False
html_show_sourcelink = False
html_show_sphinx = False
html_show_copyright = False
html_scaled_image_link = False
html_theme_options = {
'navbar_links': {
'Install': path_to('../../install', 1),
'API': [
('master', path_to('../../api/python', 1)),
('versions...', path_to('../../versions', 1)),
],
'Github': 'https://github.com/seetaresearch/dragon',
},
'navbar_logo_link': path_to('../..', 1),
'sidebar_title': 'C++ v{}'.format(version),
'sidebar_title_link': path_to('../../versions', 1),
'breadcrumb_links': [
('Dragon', path_to('../..', 1)),
('API', path_to('../../versions', 1)),
('Dragon v{}'.format(version.replace('a0', '-a0')), path_to('../../api', 1)),
('C++', path_to('', 1)),
],
}
html_sidebars = {
'index': ['localtoc.html'],
'dragon': ['localtoc.html'],
'dragon/**': ['localtoc.html'],
'_modules/**': ['localtoc.html'],
'search': ['localtoc.html'],
}
# LaTex
latex_documents = [(
master_doc,
'dragon.tex',
'Dragon - C++ API',
author,
'manual',
)]
latex_elements = {
'utf8extra': '',
'inputenc': '',
'babel': r'''\usepackage[english]{babel}''',
'preamble': r'''
\usepackage{enumitem}
\usepackage{tocloft}
\renewcommand{\cfttoctitlefont}{\huge\bfseries}
\usepackage{fontspec}
\setmainfont{Source Serif Pro}
\setsansfont{Source Serif Pro}
\setmonofont{Source Serif Pro}
\setcounter{tocdepth}{2}
\usepackage[draft]{minted}
\fvset{breaklines=true, breakanywhere=true}
\setlength{\headheight}{13.6pt}
\setlength{\itemindent}{-1pt}
\makeatletter
\renewcommand*\l@subsection{\@dottedtocline{2}{3.8em}{3.8em}}
\fancypagestyle{normal}{
\fancyhf{}
\fancyfoot[LE,RO]{{\py@HeaderFamily\thepage}}
\fancyfoot[LO]{{\py@HeaderFamily\nouppercase{\rightmark}}}
\fancyfoot[RE]{{\py@HeaderFamily\nouppercase{\leftmark}}}
\fancyhead[LE,RO]{{\py@HeaderFamily}}
}
\makeatother
''',
'maketitle': r'''
\pagenumbering{Roman} %% % to avoid page 1 conflict with actual page 1
\makeatletter
\begin{titlepage}
\noindent\rule[0.25\baselineskip]{\textwidth}{1pt}
\vspace*{5mm}
\begin{figure}[!h]
\raggedleft
\includegraphics[scale=0.3]{logo.png}
\end{figure}
\raggedleft
\vspace*{5mm}
\textbf{\Huge \@title}
\vspace*{40mm}
\LARGE \@author
\end{titlepage}
\makeatother
\pagenumbering{arabic}
''',
'pointsize': '10pt',
'figure_align': 'H',
'printindex': '',
'sphinxsetup': ' \
hmargin={0.75in,0.75in}, \
vmargin={0.5in,1in}, \
verbatimhintsturnover=false, \
verbatimsep=0.75em, \
verbatimhintsturnover=false, \
verbatimwithframe=false, \
VerbatimColor={rgb}{0.949,0.949,0.949}, \
HeaderFamily=\\rmfamily\\bfseries',
}
latex_domain_indices = False
latex_engine = 'xelatex'
latex_logo = '../_static/images/logo.png'
# Application API
class HTMLTranslatorV3(HTMLTranslatorV2):
"""Custom html translator."""
def depart_desc_content(self, node):
"""Remove the sub classees."""
HTMLTranslatorV2.depart_desc_content(self, node)
para_start, para_end = -1, -1
for i, text in enumerate(self.body):
if para_start > 0 and text.startswith('</p>'):
para_end = i
break
if text.startswith('<p>') and \
self.body[i + 1].startswith('Subclassed by'):
para_start = i
if para_start > 0 and para_end > 0:
self.body = self.body[:para_start] + self.body[para_end + 1:]
def depart_desc_parameterlist(self, node):
"""Remove the trailing newline to match the google c++ style."""
HTMLTranslator.depart_desc_parameterlist(self, node)
def setup(app):
"""Custom application setup."""
return setup_v1(app, HTMLTranslatorV3)
dragon/core
===========
.. only:: html
Classes
-------
`class CPUContext <core/CPUContext.html>`_
: The cpu device context.
`class CUDAContext <core/CPUContext.html>`_
: The cuda device context.
`class Graph <core/Graph.html>`_
: Graph to execute operators sequentially.
`class Operator <core/Operator.html>`_
: The base operator class with context.
`class Tensor <core/Tensor.html>`_
: The base tensor class, manage memory or not.
`class TypeMeta <core/TypeMeta.html>`_
: Metaclass for all types.
`class UnifiedMemory <core/UnifiedMemory.html>`_
: Memory to manage both the host and device data.
`class Workspace <core/Workspace.html>`_
: Sandbox to isolate the resources and computations.
.. toctree::
:hidden:
core/CPUContext
core/CUDAContext
core/Graph
core/Operator
core/Tensor
core/TypeMeta
core/UnifiedMemory
core/Workspace
.. raw:: html
<style>
h1:before {
content: "Routine: ";
color: #103d3e;
}
</style>
CPUContext
==========
.. doxygenclass:: dragon::CPUContext
Constructors
------------
.. doxygenfunction:: dragon::CPUContext::CPUContext()
.. doxygenfunction:: dragon::CPUContext::CPUContext(unsigned int random_seed)
.. doxygenfunction:: dragon::CPUContext::CPUContext(const DeviceOption &option)
Public Functions
----------------
Copy
####
.. doxygenfunction:: dragon::CPUContext::Copy
Delete
######
.. doxygenfunction:: dragon::CPUContext::Delete
FinishDeviceComputation
#######################
.. doxygenfunction:: dragon::CPUContext::FinishDeviceComputation
Memset
######
.. doxygenfunction:: dragon::CPUContext::Memset
MemsetAsync
###########
.. doxygenfunction:: dragon::CPUContext::MemsetAsync
Memcpy
######
.. doxygenfunction:: dragon::CPUContext::Memcpy
MemcpyAsync
###########
.. doxygenfunction:: dragon::CPUContext::MemcpyAsync
New
###
.. doxygenfunction:: dragon::CPUContext::New
SwitchToDevice
##############
.. doxygenfunction:: dragon::CPUContext::SwitchToDevice()
SwitchToDevice
##############
.. doxygenfunction:: dragon::CPUContext::SwitchToDevice(int stream)
device
######
.. doxygenfunction:: dragon::CPUContext::device
rand_generator
##############
.. doxygenfunction:: dragon::CPUContext::rand_generator
set_stream
##########
.. doxygenfunction:: dragon::CPUContext::set_stream
stream
######
.. doxygenfunction:: dragon::CPUContext::stream
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
CUDAContext
===========
.. doxygenclass:: dragon::CUDAContext
Constructors
------------
.. doxygenfunction:: dragon::CUDAContext::CUDAContext()
.. doxygenfunction:: dragon::CUDAContext::CUDAContext(int device)
.. doxygenfunction:: dragon::CUDAContext::CUDAContext(const DeviceOption &option)
Public Functions
----------------
Copy
####
.. doxygenfunction:: dragon::CUDAContext::Copy
Delete
######
.. doxygenfunction:: dragon::CUDAContext::Delete
FinishDeviceComputation
#######################
.. doxygenfunction:: dragon::CUDAContext::FinishDeviceComputation
Memset
######
.. doxygenfunction:: dragon::CUDAContext::Memset
MemsetAsync
###########
.. doxygenfunction:: dragon::CUDAContext::MemsetAsync
Memcpy
######
.. doxygenfunction:: dragon::CUDAContext::Memcpy(size_t n, void *dest, const void *src)
Memcpy
######
.. doxygenfunction:: dragon::CUDAContext::Memcpy(size_t n, void *dest, const void *src, int device)
MemcpyAsync
###########
.. doxygenfunction:: dragon::CUDAContext::MemcpyAsync
New
###
.. doxygenfunction:: dragon::CUDAContext::New
SwitchToDevice
##############
.. doxygenfunction:: dragon::CUDAContext::SwitchToDevice()
SwitchToDevice
##############
.. doxygenfunction:: dragon::CUDAContext::SwitchToDevice(int stream)
SynchronizeStream
#################
.. doxygenfunction:: dragon::CUDAContext::SynchronizeStream
cublas_handle
#############
.. doxygenfunction:: dragon::CUDAContext::cublas_handle
cuda_stream
###########
.. doxygenfunction:: dragon::CUDAContext::cuda_stream()
cuda_stream
###########
.. doxygenfunction:: dragon::CUDAContext::cuda_stream(int device, int stream)
cudnn_handle
############
.. doxygenfunction:: dragon::CUDAContext::cudnn_handle
curand_generator
################
.. doxygenfunction:: dragon::CUDAContext::curand_generator
rand_generator
##############
.. doxygenfunction:: dragon::CUDAContext::rand_generator
device
######
.. doxygenfunction:: dragon::CUDAContext::device
set_stream
##########
.. doxygenfunction:: dragon::CUDAContext::set_stream
stream
######
.. doxygenfunction:: dragon::CUDAContext::stream
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
Graph
=====
.. doxygenclass:: dragon::Graph
Constructors
------------
.. doxygenfunction:: dragon::Graph::Graph(const GraphDef& def, Workspace* ws)
Public Functions
----------------
Create
######
.. doxygenfunction:: dragon::Graph::Create
Run
###
.. doxygenfunction:: dragon::Graph::Run
arg
###
.. doxygenfunction:: dragon::Graph::arg
args
####
.. doxygenfunction:: dragon::Graph::args
def
###
.. doxygenfunction:: dragon::Graph::def
optimized_def
#############
.. doxygenfunction:: dragon::Graph::optimized_def
name
####
.. doxygenfunction:: dragon::Graph::name
phase
#####
.. doxygenfunction:: dragon::Graph::phase
ws
##
.. doxygenfunction:: dragon::Graph::ws
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
Operator
========
.. doxygenclass:: dragon::Operator
Constructors
------------
.. doxygenfunction:: dragon::Operator::Operator(const OperatorDef &def, Workspace *ws)
Public Functions
----------------
Arg
###
.. doxygenfunction:: dragon::Operator::Arg
Args
####
.. doxygenfunction:: dragon::Operator::Args
Buffer
######
.. doxygenfunction:: dragon::Operator::Buffer
Fuse
####
.. doxygenfunction:: dragon::Operator::Fuse
Input
#####
.. doxygenfunction:: dragon::Operator::Input
InputSize
#########
.. doxygenfunction:: dragon::Operator::InputSize
Output
######
.. doxygenfunction:: dragon::Operator::Output(int i)
MessageForUnsupported
#####################
.. doxygenfunction:: dragon::Operator::MessageForUnsupported
Output
######
.. doxygenfunction:: dragon::Operator::Output(int i, const vec32_t &inputs)
OutputSize
##########
.. doxygenfunction:: dragon::Operator::OutputSize
Run
###
.. doxygenfunction:: dragon::Operator::Run
UpdateFrom
##########
.. doxygenfunction:: dragon::Operator::UpdateFrom
data_format
###########
.. doxygenfunction:: dragon::Operator::data_format
arg
###
.. doxygenfunction:: dragon::Operator::arg
args
####
.. doxygenfunction:: dragon::Operator::args
def
###
.. doxygenfunction:: dragon::Operator::def
dtype
#####
.. doxygenfunction:: dragon::Operator::dtype
handle
######
.. doxygenfunction:: dragon::Operator::handle
name
####
.. doxygenfunction:: dragon::Operator::name
type
####
.. doxygenfunction:: dragon::Operator::type
phase
#####
.. doxygenfunction:: dragon::Operator::phase
ws
##
.. doxygenfunction:: dragon::Operator::ws
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
Tensor
======
.. doxygenclass:: dragon::Tensor
Constructors
------------
.. doxygenfunction:: dragon::Tensor::Tensor()
.. doxygenfunction:: dragon::Tensor::Tensor(const string &name)
.. doxygenfunction:: dragon::Tensor::Tensor(const vec64_t &dims)
.. doxygenfunction:: dragon::Tensor::Tensor(const vec32_t &dims)
.. doxygenfunction:: dragon::Tensor::Tensor(const TypeMeta &meta)
Public Functions
----------------
CopyFrom
########
.. doxygenfunction:: dragon::Tensor::CopyFrom(const Tensor &other, Context *ctx)
CopyFrom
########
.. doxygenfunction:: dragon::Tensor::CopyFrom(const vector<VectorType> &other)
CopyTo
######
.. doxygenfunction:: dragon::Tensor::CopyTo
DimString
#########
.. doxygenfunction:: dragon::Tensor::DimString() const
DimString
#########
.. doxygenfunction:: dragon::Tensor::DimString(const vector<int64_t> &dims)
IsType
######
.. doxygenfunction:: dragon::Tensor::IsType
Reset
#####
.. doxygenfunction:: dragon::Tensor::Reset
Reshape
#######
.. doxygenfunction:: dragon::Tensor::Reshape
ReshapeLike
###########
.. doxygenfunction:: dragon::Tensor::ReshapeLike
Share
#####
.. doxygenfunction:: dragon::Tensor::Share
SwitchToDevice
##############
.. doxygenfunction:: dragon::Tensor::SwitchToDevice
axis
####
.. doxygenfunction:: dragon::Tensor::axis
capacity
########
.. doxygenfunction:: dragon::Tensor::capacity
conut
#####
.. doxygenfunction:: dragon::Tensor::count() const
conut
#####
.. doxygenfunction:: dragon::Tensor::count(int64_t start) const
conut
#####
.. doxygenfunction:: dragon::Tensor::count(int64_t start, int64_t end) const
data
####
.. doxygenfunction:: dragon::Tensor::data
dim
###
.. doxygenfunction:: dragon::Tensor::dim
dims
####
.. doxygenfunction:: dragon::Tensor::dims
empty
#####
.. doxygenfunction:: dragon::Tensor::empty
has_memory
##########
.. doxygenfunction:: dragon::Tensor::has_memory
has_name
########
.. doxygenfunction:: dragon::Tensor::has_name
meta
####
.. doxygenfunction:: dragon::Tensor::meta
memory
######
.. doxygenfunction:: dragon::Tensor::memory
memory_state
############
.. doxygenfunction:: dragon::Tensor::memory_state
mutable_data
############
.. doxygenfunction:: dragon::Tensor::mutable_data
name
####
.. doxygenfunction:: dragon::Tensor::name
nbytes
######
.. doxygenfunction:: dragon::Tensor::nbytes
ndim
####
.. doxygenfunction:: dragon::Tensor::ndim
raw_data
########
.. doxygenfunction:: dragon::Tensor::raw_data
raw_mutable_data
################
.. doxygenfunction:: dragon::Tensor::raw_mutable_data()
raw_mutable_data
################
.. doxygenfunction:: dragon::Tensor::raw_mutable_data(const TypeMeta &meta)
size
####
.. doxygenfunction:: dragon::Tensor::size
stride
######
.. doxygenfunction:: dragon::Tensor::stride
strides
#######
.. doxygenfunction:: dragon::Tensor::strides
version
#######
.. doxygenfunction:: dragon::Tensor::version
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
TypeMeta
========
.. doxygenclass:: dragon::TypeMeta
Constructors
------------
.. doxygenfunction:: dragon::TypeMeta::TypeMeta()
.. doxygenfunction:: dragon::TypeMeta::TypeMeta(const TypeMeta &src)
Public Functions
----------------
Copy
####
.. doxygenfunction:: dragon::TypeMeta::Copy
Ctor
####
.. doxygenfunction:: dragon::TypeMeta::Ctor
Dtor
####
.. doxygenfunction:: dragon::TypeMeta::Dtor
Id
##
.. doxygenfunction:: dragon::TypeMeta::Id
Itemsize
########
.. doxygenfunction:: dragon::TypeMeta::Itemsize
Make
####
.. doxygenfunction:: dragon::TypeMeta::Make
Match
#####
.. doxygenfunction:: dragon::TypeMeta::Match
copy
####
.. doxygenfunction:: dragon::TypeMeta::copy
ctor
####
.. doxygenfunction:: dragon::TypeMeta::ctor
dtor
####
.. doxygenfunction:: dragon::TypeMeta::dtor
id
##
.. doxygenfunction:: dragon::TypeMeta::id
itemsize
########
.. doxygenfunction:: dragon::TypeMeta::itemsize
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
UnifiedMemory
=============
.. doxygenclass:: dragon::UnifiedMemory
Constructors
------------
.. doxygenfunction:: dragon::UnifiedMemory::UnifiedMemory()
.. doxygenfunction:: dragon::UnifiedMemory::UnifiedMemory(const TypeMeta &meta, size_t size)
Public Types
------------
State
#####
.. doxygenenum:: dragon::UnifiedMemory::State
Public Functions
----------------
SwitchToDevice
##############
.. doxygenfunction:: dragon::UnifiedMemory::SwitchToDevice
SwitchToCUDADevice
##################
.. doxygenfunction:: dragon::UnifiedMemory::SwitchToCUDADevice
ToCPU
#####
.. doxygenfunction:: dragon::UnifiedMemory::ToCPU
ToCUDA
######
.. doxygenfunction:: dragon::UnifiedMemory::ToCUDA
cpu_data
########
.. doxygenfunction:: dragon::UnifiedMemory::cpu_data
cuda_data
#########
.. doxygenfunction:: dragon::UnifiedMemory::cuda_data
device
######
.. doxygenfunction:: dragon::UnifiedMemory::device
info
####
.. doxygenfunction:: dragon::UnifiedMemory::info
mutable_cpu_data
################
.. doxygenfunction:: dragon::UnifiedMemory::mutable_cpu_data
mutable_cuda_data
#################
.. doxygenfunction:: dragon::UnifiedMemory::mutable_cuda_data
set_cpu_data
############
.. doxygenfunction:: dragon::UnifiedMemory::set_cpu_data
set_cuda_data
#############
.. doxygenfunction:: dragon::UnifiedMemory::set_cuda_data
size
####
.. doxygenfunction:: dragon::UnifiedMemory::size
state
#####
.. doxygenfunction:: dragon::UnifiedMemory::state
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
Workspace
=========
.. doxygenclass:: dragon::Workspace
Constructors
------------
.. doxygenfunction:: dragon::Workspace::Workspace(const string &name)
Public Functions
----------------
Clear
#####
.. doxygenfunction:: dragon::Workspace::Clear
CreateGraph
###########
.. doxygenfunction:: dragon::Workspace::CreateGraph
CreateTensor
############
.. doxygenfunction:: dragon::Workspace::CreateTensor
GetFillerInfo
#############
.. doxygenfunction:: dragon::Workspace::GetFillerInfo
GetTensor
#########
.. doxygenfunction:: dragon::Workspace::GetTensor
HasTensor
#########
.. doxygenfunction:: dragon::Workspace::HasTensor
MergeFrom
#########
.. doxygenfunction:: dragon::Workspace::MergeFrom
RegisterAlias
#############
.. doxygenfunction:: dragon::Workspace::RegisterAlias
ResetTensor
###########
.. doxygenfunction:: dragon::Workspace::ResetTensor
RunGraph
########
.. doxygenfunction:: dragon::Workspace::RunGraph
RunOperator
###########
.. doxygenfunction:: dragon::Workspace::RunOperator
TryGetTensor
############
.. doxygenfunction:: dragon::Workspace::TryGetTensor
UniqueName
##########
.. doxygenfunction:: dragon::Workspace::UniqueName
data
####
.. doxygenfunction:: dragon::Workspace::data(const vector<size_t> &segments)
data
####
.. doxygenfunction:: dragon::Workspace::data(const vector<int64_t> &segments)
graphs
######
.. doxygenfunction:: dragon::Workspace::graphs
name
####
.. doxygenfunction:: dragon::Workspace::name
tensors
#######
.. doxygenfunction:: dragon::Workspace::tensors
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
Dragon - C++ API
================
Routines
--------
.. only:: html
`Routine core <dragon/core.html>`_
: Public API for ``dragon/core`` routine.
.. toctree::
:hidden:
dragon/core
:: #########################################################
:: Command file to build on Windows for Sphinx documentation
:: #########################################################
@echo off
:: You can set these variables from the command line
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set BUILDDIR=..\..\_build\api\cc
set ALLSPHINXOPTS=-d %BUILDDIR%\doctrees %SPHINXOPTS% .
if NOT "%PAPER%" == "" (
set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
)
if "%1" == "" goto help
if "%1" == "help" (
:help
echo.Please use `make ^<target^>` where ^<target^> is one of
echo. doxygen to make Doxygen XML files
echo. html to make standalone HTML files
echo. debughtml to make debugging HTML files
echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
echo. latexpdf to make LaTeX files and run them through pdflatex
goto end
)
if "%1" == "clean" (
for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
del /q /s %BUILDDIR%\*
goto end
)
:: Check if sphinx-build is available and fallback to Python version if any
%SPHINXBUILD% 2> nul
if errorlevel 9009 goto sphinx_python
goto sphinx_ok
:sphinx_python
set SPHINXBUILD=python -m sphinx.__init__
%SPHINXBUILD% 2> nul
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
:sphinx_ok
if "%1" == "doxygen" (
(if exist %BUILDDIR%_doxygen rmdir /q /s %BUILDDIR%_doxygen) && mkdir %BUILDDIR%_doxygen && doxygen
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The Doxygen XML files are in %BUILDDIR%_doxygen/xml.
goto end
)
if "%1" == "html" (
%SPHINXBUILD% -b html -j %NUMBER_OF_PROCESSORS% %ALLSPHINXOPTS% %BUILDDIR%
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%.
goto end
)
if "%1" == "latex" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%-latex
if errorlevel 1 exit /b 1
echo.
echo.Build finished; the LaTeX files are in %BUILDDIR%-latex.
goto end
)
if "%1" == "latexpdf" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%-latex
cd %BUILDDIR%-latex
make all-pdf
cd %~dp0
echo.
echo.Build finished; the PDF files are in %BUILDDIR%-latex.
goto end
)
:end
# Makefile for Sphinx documentation # Makefile for Sphinx documentation
#
# You can set these variables from the command line. # You can set these variables from the command line
SPHINXOPTS = SPHINXOPTS =
SPHINXBUILD = sphinx-build SPHINXBUILD = sphinx-build
PAPER = PAPER =
BUILDDIR = ../../_build/api BUILDDIR = ../../_build/api/python
# User-friendly check for sphinx-build # User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed.)
endif endif
# Internal variables. # Internal variables
PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter PAPEROPT_letter = -D latex_paper_size=letter
ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
# the i18n builder cannot share the environment and doctrees with the others NUMBER_OF_PROCESSORS:=$(shell getconf _NPROCESSORS_ONLN)
I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
NPROC:=$(shell getconf _NPROCESSORS_ONLN)
.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext .PHONY: help clean html debughtml latex latexpdf
help: help:
@echo "Please use \`make <target>' where <target> is one of" @echo "Please use \`make <target>' where <target> is one of"
@echo " html to make standalone HTML files" @echo " html to make standalone HTML files"
@echo " deployhtml to make HTML files copyied to website" @echo " debughtml to make debugging HTML files"
@echo " dirhtml to make HTML files named index.html in directories"
@echo " singlehtml to make a single large HTML file"
@echo " pickle to make pickle files"
@echo " json to make JSON files"
@echo " htmlhelp to make HTML files and a HTML help project"
@echo " qthelp to make HTML files and a qthelp project"
@echo " applehelp to make an Apple Help Book"
@echo " devhelp to make HTML files and a Devhelp project"
@echo " epub to make an epub"
@echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
@echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdf to make LaTeX files and run them through pdflatex"
@echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
@echo " text to make text files"
@echo " man to make manual pages"
@echo " texinfo to make Texinfo files"
@echo " info to make Texinfo files and run them through makeinfo"
@echo " gettext to make PO message catalogs"
@echo " changes to make an overview of all changed/added/deprecated items"
@echo " xml to make Docutils-native XML files"
@echo " pseudoxml to make pseudoxml-XML files for display purposes"
@echo " linkcheck to check all external links for integrity"
@echo " doctest to run all doctests embedded in the documentation (if enabled)"
@echo " coverage to run coverage check of the documentation (if enabled)"
clean: clean:
rm -rf $(BUILDDIR)/* rm -rf $(BUILDDIR)/*
html: html:
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/python $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)
@echo @echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/python." @echo "Build finished. The HTML pages are in $(BUILDDIR)."
debughtml: debughtml:
$(SPHINXBUILD) -b html -j ${NPROC} $(ALLSPHINXOPTS) $(BUILDDIR)/python $(SPHINXBUILD) -b html -j ${NUMBER_OF_PROCESSORS} $(ALLSPHINXOPTS) $(BUILDDIR)
@echo @echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/python." @echo "Build finished. The HTML pages are in $(BUILDDIR)."
dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
singlehtml:
$(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
@echo
@echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
pickle:
$(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
@echo
@echo "Build finished; now you can process the pickle files."
json:
$(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
@echo
@echo "Build finished; now you can process the JSON files."
htmlhelp:
$(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
@echo
@echo "Build finished; now you can run HTML Help Workshop with the" \
".hhp project file in $(BUILDDIR)/htmlhelp."
qthelp:
$(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
@echo
@echo "Build finished; now you can run "qcollectiongenerator" with the" \
".qhcp project file in $(BUILDDIR)/qthelp, like this:"
@echo "# qcollectiongenerator $(BUILDDIR)/qthelp/Dragon.qhcp"
@echo "To view the help file:"
@echo "# assistant -collectionFile $(BUILDDIR)/qthelp/Dragon.qhc"
applehelp:
$(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp
@echo
@echo "Build finished. The help book is in $(BUILDDIR)/applehelp."
@echo "N.B. You won't be able to view it unless you put it in" \
"~/Library/Documentation/Help or install it in your application" \
"bundle."
devhelp:
$(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
@echo
@echo "Build finished."
@echo "To view the help file:"
@echo "# mkdir -p $$HOME/.local/share/devhelp/Dragon"
@echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/Dragon"
@echo "# devhelp"
epub:
$(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
@echo
@echo "Build finished. The epub file is in $(BUILDDIR)/epub."
latex: latex:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)-latex
@echo @echo
@echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Build finished; the LaTeX files are in $(BUILDDIR)-latex."
@echo "Run \`make' in that directory to run these through (pdf)latex" \ @echo "Run \`make' in that directory to run these through (pdf)latex" \
"(use \`make latexpdf' here to do that automatically)." "(use \`make latexpdf' here to do that automatically)."
latexpdf: latexpdf:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)-latex
@echo "Running LaTeX files through pdflatex..." @echo "Running LaTeX files through pdflatex..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf $(MAKE) -C $(BUILDDIR)/latex all-pdf
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." @echo "pdflatex finished; the PDF files are in $(BUILDDIR)-latex."
latexpdfja:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo "Running LaTeX files through platex and dvipdfmx..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
text:
$(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
@echo
@echo "Build finished. The text files are in $(BUILDDIR)/text."
man:
$(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
@echo
@echo "Build finished. The manual pages are in $(BUILDDIR)/man."
texinfo:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo
@echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
@echo "Run \`make' in that directory to run these through makeinfo" \
"(use \`make info' here to do that automatically)."
info:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo "Running Texinfo files through makeinfo..."
make -C $(BUILDDIR)/texinfo info
@echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
gettext:
$(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
@echo
@echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
changes:
$(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
@echo
@echo "The overview file is in $(BUILDDIR)/changes."
linkcheck:
$(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
@echo
@echo "Link check complete; look for any errors in the above output " \
"or in $(BUILDDIR)/linkcheck/output.txt."
doctest:
$(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."
coverage:
$(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage
@echo "Testing of coverage in the sources finished, look at the " \
"results in $(BUILDDIR)/coverage/python.txt."
xml:
$(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
@echo
@echo "Build finished. The XML files are in $(BUILDDIR)/xml."
pseudoxml:
$(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
@echo
@echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Sphinx configuration for Python API.""" """Sphinx configuration for Python API."""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -43,7 +42,9 @@ napoleon_use_rtype = False ...@@ -43,7 +42,9 @@ napoleon_use_rtype = False
# Project # Project
project = 'dragon' project = 'dragon'
copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd' copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd'
author = 'Ting Pan\\\\tingpan@seetatech.com' author = 'SeetaTech'
with open('../../../dragon/version.txt', 'r') as f:
version = f.read().strip()
# HTML # HTML
html_theme = 'seeta' html_theme = 'seeta'
...@@ -60,17 +61,18 @@ html_theme_options = { ...@@ -60,17 +61,18 @@ html_theme_options = {
'navbar_links': { 'navbar_links': {
'Install': path_to('../../install', 1), 'Install': path_to('../../install', 1),
'API': [ 'API': [
('C++', path_to('../cc', 1)), ('master', path_to('../../api/python', 1)),
('Python', path_to('', 1)) ('versions...', path_to('../../versions', 1)),
], ],
'Github': 'https://github.com/seetaresearch/dragon', 'Github': 'https://github.com/seetaresearch/dragon',
}, },
'navbar_logo_link': path_to('../..', 1), 'navbar_logo_link': path_to('../..', 1),
'sidebar_title': 'Python v0.3.0', 'sidebar_title': 'Python v{}'.format(version),
'sidebar_title_link': path_to('../../versions', 1), 'sidebar_title_link': path_to('../../versions', 1),
'breadcrumb_links': [ 'breadcrumb_links': [
('Dragon', path_to('../..', 1)), ('Dragon', path_to('../..', 1)),
('API', path_to('../../versions', 1)), ('API', path_to('../../versions', 1)),
('Dragon v{}'.format(version.replace('a0', '-a0')), path_to('../../api', 1)),
('Python', path_to('', 1)), ('Python', path_to('', 1)),
], ],
} }
......
...@@ -24,7 +24,7 @@ name ...@@ -24,7 +24,7 @@ name
ndim ndim
#### ####
.. autoattribute:: dragon.EagerTensor.name .. autoattribute:: dragon.EagerTensor.ndim
shape shape
##### #####
......
@ECHO OFF :: #########################################################
:: Command file to build on Windows for Sphinx documentation
:: #########################################################
REM Command file for Sphinx documentation @echo off
:: You can set these variables from the command line
if "%SPHINXBUILD%" == "" ( if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build set SPHINXBUILD=sphinx-build
) )
set BUILDDIR=..\..\_build\api set BUILDDIR=..\..\_build\api\python
set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . set ALLSPHINXOPTS=-d %BUILDDIR%\doctrees %SPHINXOPTS% .
set I18NSPHINXOPTS=%SPHINXOPTS% .
if NOT "%PAPER%" == "" ( if NOT "%PAPER%" == "" (
set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
) )
if "%1" == "" goto help if "%1" == "" goto help
...@@ -19,25 +20,9 @@ if "%1" == "help" ( ...@@ -19,25 +20,9 @@ if "%1" == "help" (
:help :help
echo.Please use `make ^<target^>` where ^<target^> is one of echo.Please use `make ^<target^>` where ^<target^> is one of
echo. html to make standalone HTML files echo. html to make standalone HTML files
echo. dirhtml to make HTML files named index.html in directories echo. debughtml to make debugging HTML files
echo. singlehtml to make a single large HTML file
echo. pickle to make pickle files
echo. json to make JSON files
echo. htmlhelp to make HTML files and a HTML help project
echo. qthelp to make HTML files and a qthelp project
echo. devhelp to make HTML files and a Devhelp project
echo. epub to make an epub
echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
echo. text to make text files echo. latexpdf to make LaTeX files and run them through pdflatex
echo. man to make manual pages
echo. texinfo to make Texinfo files
echo. gettext to make PO message catalogs
echo. changes to make an overview over all changed/added/deprecated items
echo. xml to make Docutils-native XML files
echo. pseudoxml to make pseudoxml-XML files for display purposes
echo. linkcheck to check all external links for integrity
echo. doctest to run all doctests embedded in the documentation if enabled
echo. coverage to run coverage check of the documentation if enabled
goto end goto end
) )
...@@ -47,13 +32,7 @@ if "%1" == "clean" ( ...@@ -47,13 +32,7 @@ if "%1" == "clean" (
goto end goto end
) )
if "%2" == "f" ( :: Check if sphinx-build is available and fallback to Python version if any
for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
del /q /s %BUILDDIR%\*
)
REM Check if sphinx-build is available and fallback to Python version if any
%SPHINXBUILD% 2> nul %SPHINXBUILD% 2> nul
if errorlevel 9009 goto sphinx_python if errorlevel 9009 goto sphinx_python
goto sphinx_ok goto sphinx_ok
...@@ -76,192 +55,37 @@ if errorlevel 9009 ( ...@@ -76,192 +55,37 @@ if errorlevel 9009 (
:sphinx_ok :sphinx_ok
if "%1" == "html" ( if "%1" == "html" (
%SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/python %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/python.
goto end
)
if "%1" == "dirhtml" (
%SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
goto end
)
if "%1" == "singlehtml" (
%SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
goto end
)
if "%1" == "pickle" (
%SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can process the pickle files.
goto end
)
if "%1" == "json" (
%SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can process the JSON files.
goto end
)
if "%1" == "htmlhelp" (
%SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can run HTML Help Workshop with the ^
.hhp project file in %BUILDDIR%/htmlhelp.
goto end
)
if "%1" == "qthelp" (
%SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
if errorlevel 1 exit /b 1 if errorlevel 1 exit /b 1
echo. echo.
echo.Build finished; now you can run "qcollectiongenerator" with the ^ echo.Build finished. The HTML pages are in %BUILDDIR%.
.qhcp project file in %BUILDDIR%/qthelp, like this:
echo.^> qcollectiongenerator %BUILDDIR%\qthelp\Dragon.qhcp
echo.To view the help file:
echo.^> assistant -collectionFile %BUILDDIR%\qthelp\Dragon.ghc
goto end goto end
) )
if "%1" == "devhelp" ( if "%1" == "debughtml" (
%SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp %SPHINXBUILD% -b html -j %NUMBER_OF_PROCESSORS% %ALLSPHINXOPTS% %BUILDDIR%
if errorlevel 1 exit /b 1 if errorlevel 1 exit /b 1
echo. echo.
echo.Build finished. echo.Build finished. The HTML pages are in %BUILDDIR%.
goto end
)
if "%1" == "epub" (
%SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The epub file is in %BUILDDIR%/epub.
goto end goto end
) )
if "%1" == "latex" ( if "%1" == "latex" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%-latex
if errorlevel 1 exit /b 1 if errorlevel 1 exit /b 1
echo. echo.
echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. echo.Build finished; the LaTeX files are in %BUILDDIR%-latex.
goto end goto end
) )
if "%1" == "latexpdf" ( if "%1" == "latexpdf" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%-latex
cd %BUILDDIR%/latex cd %BUILDDIR%-latex
make all-pdf make all-pdf
cd %~dp0 cd %~dp0
echo. echo.
echo.Build finished; the PDF files are in %BUILDDIR%/latex. echo.Build finished; the PDF files are in %BUILDDIR%-latex.
goto end
)
if "%1" == "latexpdfja" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
cd %BUILDDIR%/latex
make all-pdf-ja
cd %~dp0
echo.
echo.Build finished; the PDF files are in %BUILDDIR%/latex.
goto end
)
if "%1" == "text" (
%SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The text files are in %BUILDDIR%/text.
goto end
)
if "%1" == "man" (
%SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The manual pages are in %BUILDDIR%/man.
goto end
)
if "%1" == "texinfo" (
%SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
goto end
)
if "%1" == "gettext" (
%SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
goto end
)
if "%1" == "changes" (
%SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
if errorlevel 1 exit /b 1
echo.
echo.The overview file is in %BUILDDIR%/changes.
goto end
)
if "%1" == "linkcheck" (
%SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
if errorlevel 1 exit /b 1
echo.
echo.Link check complete; look for any errors in the above output ^
or in %BUILDDIR%/linkcheck/output.txt.
goto end
)
if "%1" == "doctest" (
%SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
if errorlevel 1 exit /b 1
echo.
echo.Testing of doctests in the sources finished, look at the ^
results in %BUILDDIR%/doctest/output.txt.
goto end
)
if "%1" == "coverage" (
%SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage
if errorlevel 1 exit /b 1
echo.
echo.Testing of coverage in the sources finished, look at the ^
results in %BUILDDIR%/coverage/python.txt.
goto end
)
if "%1" == "xml" (
%SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The XML files are in %BUILDDIR%/xml.
goto end
)
if "%1" == "pseudoxml" (
%SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
goto end goto end
) )
......
...@@ -12,7 +12,7 @@ regularizers ...@@ -12,7 +12,7 @@ regularizers
`class L1L2 <regularizers/L1L2.html>`_ `class L1L2 <regularizers/L1L2.html>`_
: The L1L2 regularizer. : The L1L2 regularizer.
`class L2 <regularizers/L1.html>`_ `class L2 <regularizers/L2.html>`_
: The L1 regularizer. : The L1 regularizer.
`class Regularizer <regularizers/Regularizer.html>`_ `class Regularizer <regularizers/Regularizer.html>`_
......
...@@ -17,15 +17,18 @@ ...@@ -17,15 +17,18 @@
namespace dragon { namespace dragon {
/*!
* \brief The cpu device context.
*/
class DRAGON_API CPUContext { class DRAGON_API CPUContext {
public: public:
/*! \brief Default Constructor */ /*! \brief Default Constructor */
explicit CPUContext() : random_seed_(3) {} CPUContext() : random_seed_(3) {}
/*! \brief Constructor with the specified random seed */ /*! \brief Constructor with the random seed */
explicit CPUContext(unsigned int random_seed) : random_seed_(random_seed) {} explicit CPUContext(unsigned int random_seed) : random_seed_(random_seed) {}
/*! \brief Constructor with the specified device option */ /*! \brief Constructor with the device option */
explicit CPUContext(const DeviceOption& option) explicit CPUContext(const DeviceOption& option)
: random_seed_( : random_seed_(
option.has_random_seed() ? option.random_seed() option.has_random_seed() ? option.random_seed()
...@@ -34,74 +37,74 @@ class DRAGON_API CPUContext { ...@@ -34,74 +37,74 @@ class DRAGON_API CPUContext {
/*! \brief Destructor */ /*! \brief Destructor */
virtual ~CPUContext() {} virtual ~CPUContext() {}
/*! \brief Alloc the memory */ /*! \brief Allocate a block of memory */
static void* New(size_t nbytes) { static void* New(size_t size) {
void* data = malloc(nbytes); void* data = malloc(size);
CHECK(data) << "\nAllocate memory with " << nbytes << " bytes failed."; CHECK(data) << "\nAllocate memory with " << size << " bytes failed.";
return data; return data;
} }
/*! \brief Zero-Reset the memory */ /*! \brief Set a memory block to the given value */
static void Memset(size_t nbytes, void* ptr) { static void Memset(size_t n, void* ptr, int value = 0) {
memset(ptr, 0, nbytes); memset(ptr, value, n);
} }
/*! \brief Copy the memory */ /*! \brief Set a memory block to the given value asynchronously */
template <class DestContext, class SrcContext> void MemsetAsync(size_t n, void* ptr, int value) {
static void Memcpy(size_t nbytes, void* dest, const void* src) { memset(ptr, value, n);
memcpy(dest, src, nbytes);
} }
/*! \brief Free the memory */ /*! \brief Copy a memory block to the destination */
static void Delete(void* data) { template <class DestContext, class SrcContext>
free(data); static void Memcpy(size_t n, void* dest, const void* src) {
memcpy(dest, src, n);
} }
/*! \brief Zero-Reset the memory asynchronously */ /*! \brief Copy a memory block to the destination asynchronously */
void MemsetAsync(size_t nbytes, void* ptr) { template <class DestContext, class SrcContext>
memset(ptr, 0, nbytes); void MemcpyAsync(size_t n, void* dest, const void* src) {
memcpy(dest, src, n);
} }
/*! \brief Copy the memory asynchronously */ /*! \brief Deallocate a memory block */
template <class DestContext, class SrcContext> static void Delete(void* ptr) {
void MemcpyAsync(size_t nbytes, void* dest, const void* src) { free(ptr);
memcpy(dest, src, nbytes);
} }
/*! \brief Switch to the device of this context */ /*! \brief Switch to the device in current thread */
void SwitchToDevice() {} void SwitchToDevice() {}
/*! \brief Switch to the device with the given stream */ /*! \brief Switch to the device and select given stream in current thread */
void SwitchToDevice(const int stream_id) {} void SwitchToDevice(int stream) {}
/*! \brief Copy the memory with given type asynchronously */ /*! \brief Copy a typed memory block to the destination */
template <typename T, class DestContext, class SrcContext> template <typename T, class DestContext, class SrcContext>
void Copy(int n, T* dest, const T* src) { static void Copy(int n, T* dest, const T* src) {
if (dest == src) return; if (dest == src) return;
if (std::is_fundamental<T>::value) { if (std::is_fundamental<T>::value) {
Memcpy<DestContext, SrcContext>( Memcpy<DestContext, SrcContext>(
n * sizeof(T), (void*)dest, (const void*)src); n * sizeof(T), (void*)dest, (const void*)src);
} else { } else {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; ++i) {
dest[i] = src[i]; dest[i] = src[i];
} }
} }
} }
/*! \brief Synchronize the dispatched operations */ /*! \brief Wait for the dispatched computation to complete */
void FinishDeviceComputation() {} void FinishDeviceComputation() {}
/*! \brief Return the device index */ /*! \brief Return the device index */
int device_id() const { int device() const {
return 0; return 0;
} }
/*! \brief Return the stream index */ /*! \brief Return the stream index */
int stream_id() const { int stream() const {
return 0; return 0;
} }
/*! \brief Return the internal random generator */ /*! \brief Return the random generator */
std::mt19937* rand_generator() { std::mt19937* rand_generator() {
if (!rand_generator_.get()) { if (!rand_generator_.get()) {
rand_generator_.reset(new std::mt19937(random_seed_)); rand_generator_.reset(new std::mt19937(random_seed_));
...@@ -110,13 +113,13 @@ class DRAGON_API CPUContext { ...@@ -110,13 +113,13 @@ class DRAGON_API CPUContext {
} }
/*! \brief Set the stream index */ /*! \brief Set the stream index */
void set_stream_id(int stream_id) {} void set_stream(int stream) {}
private: private:
/*! \brief Store the random seed */ /*! \brief The random seed */
unsigned int random_seed_; unsigned int random_seed_;
/*! \brief Store the internal random generator */ /*! \brief The random generator */
unique_ptr<std::mt19937> rand_generator_; unique_ptr<std::mt19937> rand_generator_;
}; };
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#ifndef DRAGON_CORE_CONTEXT_CNML_H_ #ifndef DRAGON_CORE_CONTEXT_CNML_H_
#define DRAGON_CORE_CONTEXT_CNML_H_ #define DRAGON_CORE_CONTEXT_CNML_H_
/* CAMBRICON CNRT && CNML Environment */
#include "dragon/core/common.h" #include "dragon/core/common.h"
struct cnrtStream; struct cnrtStream;
...@@ -28,11 +26,19 @@ typedef struct cnmlFusionOp* cnmlFusionOp_t; ...@@ -28,11 +26,19 @@ typedef struct cnmlFusionOp* cnmlFusionOp_t;
namespace dragon { namespace dragon {
class CNRTObject; /*!
* \brief The cnml device context.
*/
class CNMLContext { class CNMLContext {
public: public:
/*! \brief Default Constructor */ /*! \brief Default constructor */
CNMLContext() : device_id_(0), random_seed_(DEFAULT_RNG_SEED) {}
/*! \brief Constructor with the device index */
explicit CNMLContext(int device)
: device_id_(device), random_seed_(DEFAULT_RNG_SEED) {}
/*! \brief Constructor with the device option */
explicit CNMLContext(const DeviceOption& option) explicit CNMLContext(const DeviceOption& option)
: device_id_(option.device_id()), : device_id_(option.device_id()),
random_seed_( random_seed_(
...@@ -41,77 +47,63 @@ class CNMLContext { ...@@ -41,77 +47,63 @@ class CNMLContext {
CHECK_EQ(option.device_type(), PROTO_CNML); CHECK_EQ(option.device_type(), PROTO_CNML);
} }
/*! \brief Constructor with the specified device index */ /*! \brief Allocate a block of memory */
explicit CNMLContext(int device_id = 0) static void* New(size_t size) {
: device_id_(device_id), random_seed_(DEFAULT_RNG_SEED) {} return nullptr;
}
/*! \brief Alloc the memory */ /*! \brief Set a memory block to the given value */
static void* New(size_t nbytes); static void Memset(size_t n, void* ptr, int value) {}
/*! \brief Zero-Reset the memory */ /*! \brief Set a memory block to the given value asynchronously */
static void Memset(size_t nbytes, void* ptr); void MemsetAsync(size_t n, void* ptr, int value) {
Memset(n, ptr, value);
}
/*! \brief Copy the memory */ /*! \brief Copy a memory block to the destination */
template <class DestContext, class SrcContext> template <class DestContext, class SrcContext>
static void Memcpy(size_t nbytes, void* dest, const void* src); static void Memcpy(size_t n, void* dest, const void* src) {}
/*! \brief Free the memory */
static void Delete(void* data);
/*! \brief Zero-Reset the memory asynchronously */
void MemsetAsync(size_t nbytes, void* ptr) {
Memset(nbytes, ptr);
}
/*! \brief Copy the memory asynchronously */ /*! \brief Copy a memory block to the destination asynchronously */
template <class DestContext, class SrcContext> template <class DestContext, class SrcContext>
void MemcpyAsync(size_t nbytes, void* dest, const void* src) { void MemcpyAsync(size_t n, void* dest, const void* src) {
Memcpy<DestContext, SrcContext>(dest, src, nbytes); Memcpy<DestContext, SrcContext>(dest, src, n);
} }
/*! \brief Switch to the device with the given stream */ /*! \brief Deallocate a memory block */
void SwitchToDevice(int stream_id) {} static void Delete(void* ptr) {}
/*! \brief Switch to the device of this context */ /*! \brief Switch to the device in current thread */
void SwitchToDevice() { void SwitchToDevice() {
SwitchToDevice(0); SwitchToDevice(0);
} }
/*! \brief Synchronize the dispatched operations */ /*! \brief Switch to the device and select given stream in current thread */
void FinishDeviceComputation() {} void SwitchToDevice(int stream) {}
/*! \brief Return the specified cnrt stream */ /*! \brief Wait for the dispatched computation to complete */
static cnrtStream_t cnrt_stream(int device_id, int stream_id); void FinishDeviceComputation() {}
/*! \brief Return the internal cnrt stream */ /*! \brief Return the cnrt stream */
cnrtStream_t cnrt_stream() { cnrtStream_t cnrt_stream() {
return cnrt_stream(device_id_, stream_id_); return cnrt_stream(device_id_, stream_id_);
} }
/*! \brief Return the specified cnrt stream */
static cnrtStream_t cnrt_stream(int device_id, int stream_id) {
return (cnrtStream_t) nullptr;
}
/*! \brief Return the device index */ /*! \brief Return the device index */
int device_id() const { int device() const {
return device_id_; return device_id_;
} }
/*! \brief Return the stream index */ /*! \brief Return the stream index */
int stream_id() const { int stream() const {
return stream_id_; return stream_id_;
} }
/*! \brief Return the global context locker */
static std::mutex& mutex() {
static std::mutex m;
return m;
}
/*! \brief Return the thread local cnrt object */
static CNRTObject* obj();
/*! \brief Set the stream index */
void set_stream_id(int stream_id) {
stream_id_ = stream_id;
}
private: private:
int device_id_, stream_id_ = 1, random_seed_; int device_id_, stream_id_ = 1, random_seed_;
unique_ptr<std::mt19937> rand_generator_; unique_ptr<std::mt19937> rand_generator_;
......
#include <regex>
#include "dragon/core/graph.h" #include "dragon/core/graph.h"
#include "dragon/core/graph_gradient.h" #include "dragon/core/graph_gradient.h"
#include "dragon/core/graph_optimizer.h" #include "dragon/core/graph_optimizer.h"
...@@ -46,8 +48,8 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws) ...@@ -46,8 +48,8 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
} }
} }
bool Graph::Create(const GraphDef& def, Workspace* ws) { bool Graph::Create(const GraphDef& def) {
this->opt_def_ = def; // Store for debugging this->optimized_def_ = def; // Store for debugging
bool has_device_option = def.has_device_option(); bool has_device_option = def.has_device_option();
for (int i = 0; i < def.op_size(); i++) { for (int i = 0; i < def.op_size(); i++) {
auto op_def(def.op(i)); auto op_def(def.op(i));
...@@ -63,7 +65,7 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) { ...@@ -63,7 +65,7 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) {
arg.set_i(1); arg.set_i(1);
op_def.add_arg()->CopyFrom(arg); op_def.add_arg()->CopyFrom(arg);
} }
cached_ops_.push_back(NewOperator(op_def, ws)); cached_ops_.push_back(NewOperator(op_def, ws_));
cached_ops_.back()->set_output_aliases(output_aliases_); cached_ops_.back()->set_output_aliases(output_aliases_);
} }
return true; return true;
...@@ -71,25 +73,25 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) { ...@@ -71,25 +73,25 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) {
Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
// Apply the optimizations // Apply the optimizations
GraphDef opt_def = def; GraphDef def_v2(def);
GraphOptimizer graph_optim(ws); GraphOptimizer graph_optimizer(ws);
GraphGradientMaker gradient_maker; GraphGradientMaker gradient_maker;
Map<string, vec32_t> subgraph_indices; Map<string, vec32_t> subgraph_indices;
int opt = 3; // defaults: O3 int opt = 3; // default: O3
if (args().count("optimization")) opt = arg("optimization").i(); if (args().count("optimization")) opt = arg("optimization").i();
if (opt >= 1) opt_def = graph_optim.PruneNodes(def); if (opt >= 1) def_v2 = graph_optimizer.PruneNodes(def);
if (opt >= 2) graph_optim.AddInplace(opt_def, output_aliases_); if (opt >= 2) graph_optimizer.AddInplace(def_v2, output_aliases_);
if (opt >= 3) { if (opt >= 3) {
if (phase() == "TRAIN") { if (phase() == "TRAIN") {
opt_def = graph_optim.MirrorStage(opt_def, subgraph_indices); def_v2 = graph_optimizer.MirrorStage(def_v2, subgraph_indices);
opt_def = gradient_maker.Share(opt_def); def_v2 = gradient_maker.Share(def_v2);
} else { } else {
opt_def = graph_optim.SimulateGC(opt_def); def_v2 = graph_optimizer.SimulateGC(def_v2);
} }
} }
// Create // Create
Create(opt_def, ws); Create(def_v2);
// Recomputation and SubGraph // Recomputation and SubGraph
if (subgraph_indices.size() > 0) { if (subgraph_indices.size() > 0) {
...@@ -105,11 +107,14 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { ...@@ -105,11 +107,14 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
} }
} }
bool Graph::Run(const string& include, const string& exclude, int stream) { bool Graph::Run(int stream, const string& include, const string& exclude) {
unique_ptr<std::regex> regex_incl, regex_excl;
if (!include.empty()) regex_incl.reset(new std::regex(include));
if (!exclude.empty()) regex_excl.reset(new std::regex(exclude));
LOG(DEBUG) << "Run Graph: " << name(); LOG(DEBUG) << "Run Graph: " << name();
for (auto* op : cached_ops_) { for (auto* op : cached_ops_) {
if (!include.empty() && !str::find(op->type(), include)) continue; if (regex_incl && !regex_match(op->type(), *regex_incl)) continue;
if (!exclude.empty() && str::find(op->type(), exclude)) continue; if (regex_excl && regex_match(op->type(), *regex_excl)) continue;
op->SwitchToPhase(phase()); op->SwitchToPhase(phase());
LOG(DEBUG) << "Run Op: " << op->name(); LOG(DEBUG) << "Run Op: " << op->name();
op->Run(stream); op->Run(stream);
......
...@@ -18,26 +18,32 @@ ...@@ -18,26 +18,32 @@
namespace dragon { namespace dragon {
/*!
* \brief The base graph class.
*/
class DRAGON_API GraphBase { class DRAGON_API GraphBase {
public: public:
/*! \brief Default constructor */ /*! \brief Constructor with the def and workspace */
GraphBase(const GraphDef&, Workspace*); GraphBase(const GraphDef& def, Workspace* ws);
/*! \brief Default Destructor */ /*! \brief Destructor */
virtual ~GraphBase() {} virtual ~GraphBase() {}
/*! \brief Create a graph from the optimized def */ /*! \brief Create graph in the workspace */
virtual bool Create(const GraphDef&, Workspace*) = 0; virtual bool Create(const GraphDef& def) = 0;
/*! \brief Run the graph once synchronously */ /*! \brief Run graph on the given stream */
virtual bool Run(const string&, const string&, int = 0) = 0; virtual bool Run(
int stream = 0,
const string& include = "",
const string& exclude = "") = 0;
/*! \brief Return the graph name */ /*! \brief Return the graph name */
const string& name() const { const string& name() const {
return name_; return name_;
} }
/*! \brief Return the defined running phase */ /*! \brief Return the executing phase */
const string& phase() const { const string& phase() const {
return phase_; return phase_;
} }
...@@ -47,19 +53,19 @@ class DRAGON_API GraphBase { ...@@ -47,19 +53,19 @@ class DRAGON_API GraphBase {
return *(args_[name]); return *(args_[name]);
} }
/*! \brief Return the argument map */ /*! \brief Return all the arguments */
const Map<string, const Argument*>& args() { const Map<string, const Argument*>& args() {
return args_; return args_;
} }
/*! \brief Return the stored raw def */ /*! \brief Return the graph def */
const GraphDef& def() const { const GraphDef& def() const {
return def_; return def_;
} }
/*! \brief Return the stored opt def */ /*! \brief Return the optimized graph def */
const GraphDef& opt_def() const { const GraphDef& optimized_def() const {
return opt_def_; return optimized_def_;
} }
/*! \brief Return the parent workspace */ /*! \brief Return the parent workspace */
...@@ -68,42 +74,53 @@ class DRAGON_API GraphBase { ...@@ -68,42 +74,53 @@ class DRAGON_API GraphBase {
} }
protected: protected:
/*! \brief Store the name and running phase */ /*! \brief The name and executing phase */
string name_, phase_; string name_, phase_;
/*! \brief Store the defined arguments */ /*! \brief The defined arguments */
Map<string, const Argument*> args_; Map<string, const Argument*> args_;
/*! \brief Store the parent workspace */ /*! \brief The parent workspace */
Workspace* ws_; Workspace* ws_;
/*! \brief Store the graph definition */ /*! \brief The graph definition */
GraphDef def_, opt_def_; GraphDef def_;
/*! \brief The optimized graph definition */
GraphDef optimized_def_;
DISABLE_COPY_AND_ASSIGN(GraphBase);
}; };
/*!
* \brief Graph to execute operators sequentially.
*/
class Graph : public GraphBase { class Graph : public GraphBase {
public: public:
/*! \brief Default constructor */ /*! \brief Constructor with the def and workspace */
Graph(const GraphDef& def, Workspace* ws); Graph(const GraphDef& def, Workspace* ws);
/*! \brief Default Destructor */ /*! \brief Destructor */
virtual ~Graph() { virtual ~Graph() {
for (auto* cached_op : cached_ops_) { for (auto* cached_op : cached_ops_) {
delete cached_op; delete cached_op;
} }
} }
/*! \brief Create a graph from the optimized def */ /*! \brief Create graph in the workspace */
bool Create(const GraphDef&, Workspace*) override; bool Create(const GraphDef& def) override;
/*! \brief Run the graph once synchronously */ /*! \brief Run graph on the given stream */
bool Run(const string&, const string&, int = 0) override; bool Run(
int stream = 0,
const string& include = "",
const string& exclude = "") override;
protected: protected:
/*! \brief The cached operators */ /*! \brief The cached operators */
vector<OperatorBase*> cached_ops_; vector<OperatorBase*> cached_ops_;
/*! \brief Store the candidate output aliases */ /*! \brief The candidate output aliases */
Map<string, Set<string>> output_aliases_; Map<string, Set<string>> output_aliases_;
}; };
......
...@@ -72,6 +72,9 @@ class GraphOptimizer { ...@@ -72,6 +72,9 @@ class GraphOptimizer {
/* \brief Store the count of references */ /* \brief Store the count of references */
Map<string, int> reference_count_; Map<string, int> reference_count_;
private:
DISABLE_COPY_AND_ASSIGN(GraphOptimizer);
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,44 +20,45 @@ ...@@ -20,44 +20,45 @@
namespace dragon { namespace dragon {
typedef enum { typedef enum {
NCHW, NCHW = 0,
NHWC, NHWC = 1,
} StorageOrder; } StorageOrder;
/*!
* \brief Memory to manage both the host and device data.
*/
class DRAGON_API UnifiedMemory { class DRAGON_API UnifiedMemory {
public: public:
typedef enum { /*!
/*! \brief The initial state */ * \brief The device-aware state for data mutation.
UNINITIALIZED, */
/*! \brief Memory could be modified by CPUContext last time */ enum State {
STATE_AT_CPU, /*! \brief Initial state */
/*! \brief Memory could be modified by CUDAContext last time */ UNINITIALIZED = 0,
STATE_AT_CUDA, /*! \brief Data is mutable to cpu */
/*! \brief Memory could be modified by CNMLContext last time */ STATE_AT_CPU = 1,
STATE_AT_CNML, /*! \brief Data is mutable to cuda */
/*! \brief The synced state */ STATE_AT_CUDA = 2,
SYNCED, /*! \brief Data is mutable to cnml */
} State; STATE_AT_CNML = 3,
/*! \brief Data is synced between host and device */
/*! \brief Default Constructor */ SYNCED = 4,
UnifiedMemory() : cpu_ptr_(nullptr), cuda_ptr_(nullptr), cnml_ptr_(nullptr) {} };
/*! \brief Constructor with the known meta and size */ /*! \brief Default constructor */
UnifiedMemory(const TypeMeta& meta, size_t size) UnifiedMemory() {}
: meta_(meta),
size_(size), /*! \brief Constructor with the type meta and size */
cpu_ptr_(nullptr), UnifiedMemory(const TypeMeta& meta, size_t size) : meta_(meta), size_(size) {}
cuda_ptr_(nullptr),
cnml_ptr_(nullptr) {}
/*! \brief Destructor */ /*! \brief Destructor */
~UnifiedMemory(); ~UnifiedMemory();
/*! \brief Switch to the specified device */ /*! \brief Switch to the given device */
void SwitchToDevice(int device_id); void SwitchToDevice(int device);
/*! \brief Switch to the specified cuda device */ /*! \brief Switch to the given cuda device */
void SwitchToCUDADevice(int device_id); void SwitchToCUDADevice(int device);
/*! \brief Involve the state to CPUContext */ /*! \brief Involve the state to CPUContext */
void ToCPU(size_t size = 0); void ToCPU(size_t size = 0);
...@@ -65,9 +66,9 @@ class DRAGON_API UnifiedMemory { ...@@ -65,9 +66,9 @@ class DRAGON_API UnifiedMemory {
/*! \brief Involve the state to CUDAContext */ /*! \brief Involve the state to CUDAContext */
void ToCUDA(size_t size = 0); void ToCUDA(size_t size = 0);
/*! \brief Return the device index */ /*! \brief Return the memory state */
int device_id() const { State state() const {
return device_id_; return state_;
} }
/*! \brief Return the total number of bytes */ /*! \brief Return the total number of bytes */
...@@ -75,9 +76,9 @@ class DRAGON_API UnifiedMemory { ...@@ -75,9 +76,9 @@ class DRAGON_API UnifiedMemory {
return size_; return size_;
} }
/*! \brief Return the number of chunks */ /*! \brief Return the number of memory chunks */
size_t nchunks() const { size_t num_chunks() const {
return nchunks_; return num_chunks_;
} }
/*! \brief Return the storage order */ /*! \brief Return the storage order */
...@@ -85,30 +86,30 @@ class DRAGON_API UnifiedMemory { ...@@ -85,30 +86,30 @@ class DRAGON_API UnifiedMemory {
return order_; return order_;
} }
/*! \brief Return the memory state */ /*! \brief Return the device index */
State state() const { int device() const {
return state_; return device_id_;
} }
/*! \brief Return a string to describe the internal structure */ /*! \brief Return the data info */
Map<string, string> info() const; Map<string, string> info() const;
/*! \brief Return the const data pointer on CPUContext */ /*! \brief Return the const cpu data */
const void* cpu_data(size_t size = 0); const void* cpu_data(size_t size = 0);
/*! \brief Return the const data pointer on CUDAContext */ /*! \brief Return the const cuda data */
const void* cuda_data(size_t size = 0); const void* cuda_data(size_t size = 0);
/*! \brief Return the const data pointer on CNMLContext */ /*! \brief Return the const cnml data */
const void* cnml_data(); const void* cnml_data();
/*! \brief Return the mutable data pointer on CPUContext */ /*! \brief Return the mutable cpu data */
void* mutable_cpu_data(size_t size = 0); void* mutable_cpu_data(size_t size = 0);
/*! \brief Return the mutable data pointer on CUDAContext */ /*! \brief Return the mutable cuda data */
void* mutable_cuda_data(size_t size = 0); void* mutable_cuda_data(size_t size = 0);
/*! \brief Return the mutable data pointer on CNMLContext */ /*! \brief Return the mutable cnml data */
void* mutable_cnml_data(); void* mutable_cnml_data();
/*! \brief Return the binding cnml cpu tensor */ /*! \brief Return the binding cnml cpu tensor */
...@@ -117,15 +118,15 @@ class DRAGON_API UnifiedMemory { ...@@ -117,15 +118,15 @@ class DRAGON_API UnifiedMemory {
/*! \brief Return the binding cnml mlu tensor */ /*! \brief Return the binding cnml mlu tensor */
cnmlTensor_t& cnml_mlu_tensor(); cnmlTensor_t& cnml_mlu_tensor();
/*! \brief Allocate the mlu device memory */ /*! \brief Allocate the mlu device data */
void* malloc_cnml_data(); void* malloc_cnml_data();
/*! \brief Copy the mlu device memory to the host */ /*! \brief Copy the mlu device data to host */
void fetch_cnml_data(void** data); void fetch_cnml_data(void** data);
/*! \brief Set the chunks of this memory */ /*! \brief Set the number of data chunks */
void set_nchunks(size_t nchunks) { void set_num_chunks(size_t num_chunks) {
nchunks_ = nchunks; num_chunks_ = num_chunks;
} }
/*! \brief Set the storage order */ /*! \brief Set the storage order */
...@@ -133,39 +134,47 @@ class DRAGON_API UnifiedMemory { ...@@ -133,39 +134,47 @@ class DRAGON_API UnifiedMemory {
order_ = order; order_ = order;
} }
/*! \brief Set the cpu data pointer from external context */ /*! \brief Set to use an external block of cpu data */
void set_cpu_data(void* cpu_ptr, size_t size); void set_cpu_data(void* cpu_ptr, size_t size);
/*! \brief Set the cuda data pointer from external context */ /*! \brief Set to use an extenral block of cuda data */
void set_cuda_data(void* cuda_ptr, size_t size, int device_id); void set_cuda_data(void* cuda_ptr, size_t size, int device);
private: private:
/*! \brief The type meta */ /*! \brief The data state */
TypeMeta meta_; State state_ = UNINITIALIZED;
/*! \brief The size and number of chunks */ /*! \brief The size and number of chunks */
size_t size_ = 0, nchunks_ = 1; size_t size_ = 0, num_chunks_ = 1;
/*! \brief The type meta */
TypeMeta meta_;
/*! \brief The storage order */ /*! \brief The storage order */
StorageOrder order_ = NCHW; StorageOrder order_ = NCHW;
/*! \brief The current state */ /*! \brief The device index */
State state_ = UNINITIALIZED; int device_id_ = 0;
/*! \brief The data pointers */ /*! \brief The cpu data pointer */
void *cpu_ptr_, *cuda_ptr_, *cnml_ptr_; void* cpu_ptr_ = nullptr;
/*! \brief The cuda data pointer */
void* cuda_ptr_ = nullptr;
/*! \brief The cnml data pointer */
void* cnml_ptr_ = nullptr;
/*! \brief The ownership of data pointers */ /*! \brief The ownership of data pointers */
int own_cpu_ptr_ = 1, own_cuda_ptr_ = 1; int own_cpu_ptr_ = 1, own_cuda_ptr_ = 1;
/*! \brief The device index */
int device_id_ = 0;
/*! \brief The binding cpu tensor for cnml */ /*! \brief The binding cpu tensor for cnml */
cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr; cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr;
/*! \brief The binding mlu tensor for cnml */ /*! \brief The binding mlu tensor for cnml */
cnmlTensor_t cnml_mlu_tensor_ = nullptr; cnmlTensor_t cnml_mlu_tensor_ = nullptr;
DISABLE_COPY_AND_ASSIGN(UnifiedMemory);
}; };
} // namespace dragon } // namespace dragon
......
...@@ -41,12 +41,6 @@ OperatorBase::OperatorBase(const OperatorDef& def, Workspace* ws) ...@@ -41,12 +41,6 @@ OperatorBase::OperatorBase(const OperatorDef& def, Workspace* ws)
} }
} }
// template <class Context>
// Operator<Context>::Operator(const OperatorDef& def, Workspace* ws)
// : OperatorBase(def, ws),
// ctx_(def.device_option()),
// do_sync_(OpArg<bool>("do_sync", false)) {}
Tensor& OperatorBase::Input(int i) { Tensor& OperatorBase::Input(int i) {
CHECK_LT(i, (int)inputs_.size()); CHECK_LT(i, (int)inputs_.size());
CHECK_GE(i, -(int)inputs_.size()); CHECK_GE(i, -(int)inputs_.size());
...@@ -80,27 +74,17 @@ Tensor* OperatorBase::Buffer(const string& name) { ...@@ -80,27 +74,17 @@ Tensor* OperatorBase::Buffer(const string& name) {
return ws()->CreateTensor("/share/buffer/" + handle_ + "/" + name); return ws()->CreateTensor("/share/buffer/" + handle_ + "/" + name);
} }
string OperatorBase::TypeString(const Tensor& tensor, const Set<string>& types) string OperatorBase::MessageForUnsupported(
const { const string& value,
std::stringstream ss; const vector<string>& support_values,
ss << "Unsupported type of Tensor(" << tensor.name() const string& entry) const {
<< "): " << types::to_string(tensor.meta()) << "\n";
ss << "<" << type() << "Op>"
<< " supports the following types: {\n";
for (auto& type : types)
ss << " * " << type << ",\n";
ss << "}";
return ss.str();
}
string OperatorBase::TypeString(const string& dtype, const Set<string>& types)
const {
std::stringstream ss; std::stringstream ss;
ss << "Unsupported type: " << dtype << "\n"; ss << "Unsupported " << entry << ": " << value << "\n";
ss << "<" << type() << "Op>" ss << "<" << type() << "Op>"
<< " supports the following types: {\n"; << " supports the following " << entry << "(s): {\n";
for (auto& type : types) for (const auto& support_value : support_values) {
ss << " * " << type << ",\n"; ss << " * " << support_value << ",\n";
}
ss << "}"; ss << "}";
return ss.str(); return ss.str();
} }
...@@ -133,7 +117,7 @@ void Operator<Context>::Prepare() { ...@@ -133,7 +117,7 @@ void Operator<Context>::Prepare() {
flag->mutable_data<bool, CPUContext>()[0] = true; flag->mutable_data<bool, CPUContext>()[0] = true;
vector<OperatorBase*>& chain = subgraph()[name]; vector<OperatorBase*>& chain = subgraph()[name];
for (auto* op : chain) { for (auto* op : chain) {
op->Run(ctx()->stream_id()); op->Run(ctx()->stream());
} }
flag->mutable_data<bool, CPUContext>()[0] = false; flag->mutable_data<bool, CPUContext>()[0] = false;
} }
...@@ -156,12 +140,12 @@ template <class Context> ...@@ -156,12 +140,12 @@ template <class Context>
void Operator<Context>::SwitchToDevice() { void Operator<Context>::SwitchToDevice() {
for (auto* tensor : inputs_) { for (auto* tensor : inputs_) {
if (tensor->has_name()) { if (tensor->has_name()) {
tensor->SwitchToDevice(ctx()->device_id()); tensor->SwitchToDevice(ctx()->device());
} }
} }
for (auto* tensor : outputs_) { for (auto* tensor : outputs_) {
if (tensor->has_name()) { if (tensor->has_name()) {
tensor->SwitchToDevice(ctx()->device_id()); tensor->SwitchToDevice(ctx()->device());
} }
} }
} }
......
...@@ -28,30 +28,30 @@ class DRAGON_API OperatorBase { ...@@ -28,30 +28,30 @@ class DRAGON_API OperatorBase {
public: public:
typedef Map<string, vector<OperatorBase*>> SubGraph; typedef Map<string, vector<OperatorBase*>> SubGraph;
/*! \brief Default constructor */ /*! \brief Constructor with the def and workspace */
OperatorBase(const OperatorDef&, Workspace*); OperatorBase(const OperatorDef&, Workspace*);
/*! \brief Default Destructor */ /*! \brief Destructor */
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
/*! \brief Fusion this operator into the specified graph */ /*! \brief Update operator from the given def */
virtual void Fusion(void* graph) { OperatorBase* UpdateFrom(const OperatorDef&);
/*! \brief Fusion operator into the given graph */
virtual void Fuse(void* graph) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
/*! \brief Run operator on the specified stream */ /*! \brief Run operator on the given stream */
virtual void Run(int stream = 0) { virtual void Run(int stream = 0) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
/*! \brief Switch the internal running phase */ /*! \brief Switch to the given executing phase */
void SwitchToPhase(const string& phase) { void SwitchToPhase(const string& phase) {
phase_ = phase; phase_ = phase;
} }
/*! \brief Update operator according to a new def */
OperatorBase* UpdateFrom(const OperatorDef&);
/*! \brief Return the input tensor */ /*! \brief Return the input tensor */
Tensor& Input(int i); Tensor& Input(int i);
...@@ -61,7 +61,7 @@ class DRAGON_API OperatorBase { ...@@ -61,7 +61,7 @@ class DRAGON_API OperatorBase {
/*! \brief Return the output tensor with input aliases */ /*! \brief Return the output tensor with input aliases */
Tensor* Output(int i, const vec32_t& inputs); Tensor* Output(int i, const vec32_t& inputs);
/*! \brief Return the unique named buffer */ /*! \brief Return the buffer tensor */
Tensor* Buffer(const string& name); Tensor* Buffer(const string& name);
/*! \brief Return the number of inputs */ /*! \brief Return the number of inputs */
...@@ -74,31 +74,26 @@ class DRAGON_API OperatorBase { ...@@ -74,31 +74,26 @@ class DRAGON_API OperatorBase {
return (int)outputs_.size(); return (int)outputs_.size();
} }
/*! \brief Return the value of the specified argument */ /*! \brief Return the value of single argument */
template <typename T> template <typename T>
T Arg(const string& name, const T& default_value); T Arg(const string& name, const T& default_value);
/*! \brief Return the values of the specified argument */ /*! \brief Return the value of repeated argument */
template <typename T> template <typename T>
vector<T> Args(const string& name); vector<T> Args(const string& name);
/*! \brief Return the debug string of stored def */ /*! \brief Return the message for supported value */
string DebugString() const { string MessageForUnsupported(
return def_.DebugString(); const string& value,
} const vector<string>& support_values,
const string& entry = "type") const;
/*! \brief Return the debug string of tensor type */
string TypeString(const Tensor&, const Set<string>&) const;
/* \brief Return the debug string of given type */
string TypeString(const string&, const Set<string>&) const;
/*! \brief Return the specified argument */ /*! \brief Return the specified argument */
const Argument& arg(const string& name) { const Argument& arg(const string& name) {
return *(args_[name]); return *(args_[name]);
} }
/*! \brief Return the argument map */ /*! \brief Return all the arguments */
const Map<string, const Argument*>& args() { const Map<string, const Argument*>& args() {
return args_; return args_;
} }
...@@ -113,7 +108,7 @@ class DRAGON_API OperatorBase { ...@@ -113,7 +108,7 @@ class DRAGON_API OperatorBase {
return def_.type(); return def_.type();
} }
/*! \brief Return the current running phase */ /*! \brief Return the running phase */
const string& phase() const { const string& phase() const {
return phase_; return phase_;
} }
...@@ -190,12 +185,17 @@ class DRAGON_API OperatorBase { ...@@ -190,12 +185,17 @@ class DRAGON_API OperatorBase {
/*! \brief Store the defined arguments */ /*! \brief Store the defined arguments */
Map<string, const Argument*> args_; Map<string, const Argument*> args_;
DISABLE_COPY_AND_ASSIGN(OperatorBase);
}; };
/*!
* \brief The base operator class with context.
*/
template <class Context> template <class Context>
class DRAGON_API Operator : public OperatorBase { class DRAGON_API Operator : public OperatorBase {
public: public:
/*! \brief Default constructor */ /*! \brief Constructor with the def and workspace */
Operator(const OperatorDef& def, Workspace* ws) Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws), : OperatorBase(def, ws),
ctx_(def.device_option()), ctx_(def.device_option()),
...@@ -247,22 +247,21 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*); ...@@ -247,22 +247,21 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
name(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {} \ name(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {} \
virtual ~name() {} virtual ~name() {}
#define USE_OPERATOR_BASE_FUNCTIONS \ #define USE_OPERATOR_BASE_FUNCTIONS \
using OperatorBase::SwitchToPhase; \ using OperatorBase::SwitchToPhase; \
using OperatorBase::Input; \ using OperatorBase::Input; \
using OperatorBase::Output; \ using OperatorBase::Output; \
using OperatorBase::Buffer; \ using OperatorBase::Buffer; \
using OperatorBase::InputSize; \ using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \ using OperatorBase::OutputSize; \
using OperatorBase::DebugString; \ using OperatorBase::MessageForUnsupported; \
using OperatorBase::TypeString; \ using OperatorBase::name; \
using OperatorBase::name; \ using OperatorBase::type; \
using OperatorBase::type; \ using OperatorBase::phase; \
using OperatorBase::phase; \ using OperatorBase::dtype; \
using OperatorBase::dtype; \ using OperatorBase::data_format; \
using OperatorBase::data_format; \ using OperatorBase::handle; \
using OperatorBase::handle; \ using OperatorBase::def; \
using OperatorBase::def; \
using OperatorBase::ws using OperatorBase::ws
#define USE_OPERATOR_FUNCTIONS \ #define USE_OPERATOR_FUNCTIONS \
...@@ -298,41 +297,40 @@ using AllTensorTypes = ...@@ -298,41 +297,40 @@ using AllTensorTypes =
template <typename Sizes, typename... Args> template <typename Sizes, typename... Args>
struct DispatchHelper; struct DispatchHelper;
#define DEFINE_TENSOR_TYPES_DISPATCHER(func) \ #define DEFINE_TENSOR_TYPES_DISPATCHER(func) \
template <typename T, typename... Types, typename... Args> \ template <typename T, typename... Types, typename... Args> \
struct DispatchHelper<TensorTypes<T, Types...>, Args...> { \ struct DispatchHelper<TensorTypes<T, Types...>, Args...> { \
template <typename Op> \ template <typename Op> \
static void Call(Op* op, const TypeMeta& meta, string& types) { \ static void Call(Op* op, const TypeMeta& meta, string& types) { \
if (meta.Match<T>()) return op->template func<T, Args...>(); \ if (meta.Match<T>()) return op->template func<T, Args...>(); \
types += " * " + types::to_string<T>() + ",\n"; \ types += " * " + types::to_string<T>() + ",\n"; \
return DispatchHelper<TensorTypes<Types...>, Args...>::Call( \ return DispatchHelper<TensorTypes<Types...>, Args...>::Call( \
op, meta, types); \ op, meta, types); \
} \ } \
template <typename Op> \ template <typename Op> \
static void Call(Op* op) { \ static void Call(Op* op) { \
string types; \ string types; \
return Call(op, types::to_meta(op->dtype()), types); \ return Call(op, types::to_meta(op->dtype()), types); \
} \ } \
template <typename Op> \ template <typename Op> \
static void Call(Op* op, const Tensor& tensor) { \ static void Call(Op* op, const Tensor& tensor) { \
string types; \ string types; \
return Call(op, tensor.meta(), types); \ return Call(op, tensor.meta(), types); \
} \ } \
}; \ }; \
template <typename... Args> \ template <typename... Args> \
struct DispatchHelper<TensorTypes<>, Args...> { \ struct DispatchHelper<TensorTypes<>, Args...> { \
template <typename Op> \ template <typename Op> \
static void Call(Op* op, const TypeMeta& meta, string& types) { \ static void Call(Op* op, const TypeMeta& meta, string& types) { \
LOG(FATAL) << "Unsupported tensor type: " << types::to_string(meta) \ LOG(FATAL) << "Unsupported type: " << types::to_string(meta) << "\n" \
<< "\n" \ << "<" << op->type() << "Op>" \
<< "<" << op->type() << "Op>" \ << " supports the following type(s): {\n" \
<< " supports the following types: {\n" \ << types << "}"; \
<< types << "}"; \ } \
} \ template <typename Op> \
template <typename Op> \ static void Call(Op* op, const Tensor& tensor) { \
static void Call(Op* op, const Tensor& tensor) { \ return Call(op, tensor.meta(), ""); \
return Call(op, tensor.meta(), ""); \ } \
} \
}; };
DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType); DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
......
...@@ -17,78 +17,93 @@ ...@@ -17,78 +17,93 @@
namespace dragon { namespace dragon {
template <class SrcType, class ObjType, class... Args> /*!
* \brief Registry to create class instances.
*/
template <class KeyType, class ObjectType, class... Args>
class Registry { class Registry {
public: public:
typedef std::function<ObjType*(Args...)> Creator; typedef std::function<ObjectType*(Args...)> Creator;
ObjType* Create(const SrcType& key, Args... args) { /*! \brief Create an instance of specified class */
ObjectType* Create(const KeyType& key, Args... args) {
CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered."; CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered.";
return registry_[key](args...); return registry_[key](args...);
} }
bool Has(const SrcType& key) { /*! \brief Return whether the specified class is registered */
bool Has(const KeyType& key) {
return (registry_.count(key)) != 0; return (registry_.count(key)) != 0;
} }
void Register(const SrcType& key, Creator creator) { /*! \brief Register a class with the creator */
void Register(const KeyType& key, Creator creator) {
CHECK(!registry_.count(key)) CHECK(!registry_.count(key))
<< "\nKey(" << key << ") has already registered."; << "\nKey(" << key << ") has already registered.";
registry_[key] = creator; registry_[key] = creator;
} }
vector<SrcType> keys() { /*! \brief Return the key of registered classes */
vector<SrcType> ret; vector<KeyType> keys() {
for (const auto& it : registry_) vector<KeyType> ret;
for (const auto& it : registry_) {
ret.push_back(it.first); ret.push_back(it.first);
}
return ret; return ret;
} }
private: private:
Map<SrcType, Creator> registry_; /*! \brief The registry map */
Map<KeyType, Creator> registry_;
}; };
template <class SrcType, class ObjType, class... Args> /*!
* \brief Register creator into the registry.
*/
template <class KeyType, class ObjectType, class... Args>
class Registerer { class Registerer {
public: public:
/*! \brief Constructor with key and creator */
Registerer( Registerer(
const SrcType& key, const KeyType& key,
Registry<SrcType, ObjType, Args...>* registry, Registry<KeyType, ObjectType, Args...>* registry,
typename Registry<SrcType, ObjType, Args...>::Creator creator, typename Registry<KeyType, ObjectType, Args...>::Creator creator,
const string& help_msg = "") { const string& help_msg = "") {
registry->Register(key, creator); registry->Register(key, creator);
} }
/*! \brief Return the default creator */
template <class DerivedType> template <class DerivedType>
static ObjType* defaultCreator(Args... args) { static ObjectType* DefaultCreator(Args... args) {
return new DerivedType(args...); return new DerivedType(args...);
} }
}; };
// Used in *.h files // Used in *.h files
#define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType, ...) \ #define DECLARE_TYPED_REGISTRY(RegistryName, KeyType, ObjectType, ...) \
DRAGON_API Registry<SrcType, ObjType, ##__VA_ARGS__>* RegistryName(); \ DRAGON_API Registry<KeyType, ObjectType, ##__VA_ARGS__>* RegistryName(); \
typedef Registerer<SrcType, ObjType, ##__VA_ARGS__> Registerer##RegistryName; typedef Registerer<KeyType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName;
// Used in *.cc files // Used in *.cc files
#define DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjType, ...) \ #define DEFINE_TYPED_REGISTRY(RegistryName, KeyType, ObjectType, ...) \
Registry<SrcType, ObjType, ##__VA_ARGS__>* RegistryName() { \ Registry<KeyType, ObjectType, ##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType, ObjType, ##__VA_ARGS__>* registry = \ static Registry<KeyType, ObjectType, ##__VA_ARGS__>* registry = \
new Registry<SrcType, ObjType, ##__VA_ARGS__>(); \ new Registry<KeyType, ObjectType, ##__VA_ARGS__>(); \
return registry; \ return registry; \
} }
#define DECLARE_REGISTRY(RegistryName, ObjType, ...) \ #define DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
DECLARE_TYPED_REGISTRY(RegistryName, string, ObjType, ##__VA_ARGS__) DECLARE_TYPED_REGISTRY(RegistryName, string, ObjectType, ##__VA_ARGS__)
#define DEFINE_REGISTRY(RegistryName, ObjType, ...) \ #define DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
DEFINE_TYPED_REGISTRY(RegistryName, string, ObjType, ##__VA_ARGS__) DEFINE_TYPED_REGISTRY(RegistryName, string, ObjectType, ##__VA_ARGS__)
#define REGISTER_TYPED_CLASS(RegistryName, key, ...) \ #define REGISTER_TYPED_CLASS(RegistryName, key, ...) \
static Registerer##RegistryName ANONYMOUS_VARIABLE(g_##RegistryName)( \ static Registerer##RegistryName ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, \ key, \
RegistryName(), \ RegistryName(), \
Registerer##RegistryName::defaultCreator<__VA_ARGS__>) Registerer##RegistryName::DefaultCreator<__VA_ARGS__>)
#define REGISTER_CLASS(RegistryName, key, ...) \ #define REGISTER_CLASS(RegistryName, key, ...) \
REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
......
...@@ -18,28 +18,53 @@ ...@@ -18,28 +18,53 @@
namespace dragon { namespace dragon {
/*!
* \brief The base tensor class, manage memory or not.
*
* Tensor is usually constructed with the shape info:
*
* \code{.cpp}
* auto* a = new dragon::Tensor(std::vector<int64_t>({2, 3}));
* auto* b = dragon::Tensor().Reshape({2, 3}); // Equivalent
* \endcode
*
* To allocate the data, type meta and device context are also required:
*
* \code{.cpp}
* auto meta = dragon::TypeMeta::Make<float>();
* auto* raw_data = a->raw_mutable_data<dragon::CPUContext>(meta);
* auto* data = b->mutable_data<float, dragon::CPUContext>();
* \endcode
*
* Memory will be reset if required bytes is larger than capacity:
* \code{.cpp}
* std::cout << a->nbytes() << " " << a->capacity() << std::endl; // 24, 24
* std::cout << a->Reshape({2, 4})->size() << std::endl; // 8
* std::cout << a->nbytes() << " " << a->capacity() << std::endl; // 32, 0
* a->mutable_data<float, dragon::CPUContext>();
* a->Reshape({2, 3});
* std::cout << a->nbytes() << " " << a->capacity() << std::endl; // 24, 32
* \endcode
*/
class DRAGON_API Tensor { class DRAGON_API Tensor {
public: public:
Tensor(const Tensor&) = delete; /*! \brief Default constructor */
Tensor& operator=(const Tensor&) = delete;
/*! \brief Default Constructor */
Tensor() : name_("") {} Tensor() : name_("") {}
/*! \brief Constructor with the known name */ /*! \brief Constructor with the name */
explicit Tensor(const string& name) : name_(name) {} explicit Tensor(const string& name) : name_(name) {}
/*! \brief Constructor with the known int64 dimensions */ /*! \brief Constructor with the int64 dimensions */
explicit Tensor(const vec64_t& dims) { explicit Tensor(const vec64_t& dims) {
Reshape(dims); Reshape(dims);
} }
/*! \brief Constructor with the known int32 dimensions */ /*! \brief Constructor with the int32 dimensions */
explicit Tensor(const vec32_t& dims) { explicit Tensor(const vec32_t& dims) {
Reshape(vec64_t(dims.begin(), dims.end())); Reshape(vec64_t(dims.begin(), dims.end()));
} }
/*! \brief Constructor with the known meta */ /*! \brief Constructor with the type meta */
explicit Tensor(const TypeMeta& meta) { explicit Tensor(const TypeMeta& meta) {
set_meta(meta); set_meta(meta);
} }
...@@ -54,7 +79,7 @@ class DRAGON_API Tensor { ...@@ -54,7 +79,7 @@ class DRAGON_API Tensor {
} }
} }
/*! \brief Reshape to the given dimensions */ /*! \brief Change the tensor dimensions */
Tensor* Reshape(const vec64_t& dims) { Tensor* Reshape(const vec64_t& dims) {
dims_ = dims; dims_ = dims;
strides_.resize(dims.size()); strides_.resize(dims.size());
...@@ -79,18 +104,18 @@ class DRAGON_API Tensor { ...@@ -79,18 +104,18 @@ class DRAGON_API Tensor {
return this; return this;
} }
/*! \brief Reshape the dimensions like the given tensor */ /*! \brief Change the tensor dimensions as the other */
Tensor* ReshapeLike(const Tensor& other) { Tensor* ReshapeLike(const Tensor& other) {
return Reshape(other.dims_); return Reshape(other.dims_);
} }
/*! \brief Switch the memory to the specific device */ /*! \brief Switch memory to the specific device */
void SwitchToDevice(int device_id) { void SwitchToDevice(int device_id) {
UnifiedMemory* mem = memory(); UnifiedMemory* mem = memory();
if (mem) mem->SwitchToDevice(device_id); if (mem) mem->SwitchToDevice(device_id);
} }
/*! \brief Copy memory from the tensor with context */ /*! \brief Copy memory from a tensor with context */
template <class Context> template <class Context>
Tensor* CopyFrom(const Tensor& other, Context* ctx) { Tensor* CopyFrom(const Tensor& other, Context* ctx) {
if ((void*)&other == (void*)this) return this; if ((void*)&other == (void*)this) return this;
...@@ -102,7 +127,7 @@ class DRAGON_API Tensor { ...@@ -102,7 +127,7 @@ class DRAGON_API Tensor {
return this; return this;
} }
/*! \brief Copy memory from the vector */ /*! \brief Copy memory from a vector */
template <typename TensorType, typename VectorType> template <typename TensorType, typename VectorType>
Tensor* CopyFrom(const vector<VectorType>& other) { Tensor* CopyFrom(const vector<VectorType>& other) {
if (other.size() > 0) { if (other.size() > 0) {
...@@ -115,7 +140,7 @@ class DRAGON_API Tensor { ...@@ -115,7 +140,7 @@ class DRAGON_API Tensor {
return this; return this;
} }
/*! \brief Copy memory to the vector */ /*! \brief Copy memory to a vector */
template <typename TensorType, typename VectorType> template <typename TensorType, typename VectorType>
void CopyTo(vector<VectorType>& dest) { void CopyTo(vector<VectorType>& dest) {
dest.resize(size()); dest.resize(size());
...@@ -141,7 +166,7 @@ class DRAGON_API Tensor { ...@@ -141,7 +166,7 @@ class DRAGON_API Tensor {
own_memory_ = (memory == nullptr); own_memory_ = (memory == nullptr);
} }
/*! \brief Reset all resources */ /*! \brief Reset tensor to release all resources */
void Reset() { void Reset() {
dims_.clear(); dims_.clear();
strides_.clear(); strides_.clear();
...@@ -156,13 +181,13 @@ class DRAGON_API Tensor { ...@@ -156,13 +181,13 @@ class DRAGON_API Tensor {
} }
} }
/*! \brief Whether the data type is matched */ /*! \brief Return whether the data type is matched */
template <typename T> template <typename T>
bool IsType() { bool IsType() {
return meta_.Match<T>(); return meta_.Match<T>();
} }
/*! \brief Return a string formatting the dimensions */ /*! \brief Return a string formatting the given dimensions */
static string DimString(const vector<int64_t>& dims) { static string DimString(const vector<int64_t>& dims) {
if (dims.size() == 0) return "(0,)"; if (dims.size() == 0) return "(0,)";
std::stringstream ss; std::stringstream ss;
...@@ -187,7 +212,7 @@ class DRAGON_API Tensor { ...@@ -187,7 +212,7 @@ class DRAGON_API Tensor {
return name_; return name_;
} }
/*! \brief Return true if tensor name is set */ /*! \brief Return whether the tensor name is set */
bool has_name() const { bool has_name() const {
return !name_.empty(); return !name_.empty();
} }
...@@ -207,7 +232,7 @@ class DRAGON_API Tensor { ...@@ -207,7 +232,7 @@ class DRAGON_API Tensor {
return capacity_; return capacity_;
} }
/*! \brief Return the total number of bytes */ /*! \brief Return the total number of data bytes */
size_t nbytes() const { size_t nbytes() const {
return size_ * meta_.itemsize(); return size_ * meta_.itemsize();
} }
...@@ -231,50 +256,51 @@ class DRAGON_API Tensor { ...@@ -231,50 +256,51 @@ class DRAGON_API Tensor {
return (int)dims_.size(); return (int)dims_.size();
} }
/*! \brief Return the dimension of specified axis */ /*! \brief Return the dimension of given axis */
int64_t dim(int64_t i) const { int64_t dim(int64_t i) const {
return dims_[axis(i)]; return dims_[axis(i)];
} }
/*! \brief Return the stride of specified axis */ /*! \brief Return the stride of given axis */
int64_t stride(int64_t i) const { int64_t stride(int64_t i) const {
return strides_[axis(i)]; return strides_[axis(i)];
} }
/*! \brief Return all the dimensions */ /*! \brief Return the tensor dimensions */
const vec64_t& dims() const { const vec64_t& dims() const {
return dims_; return dims_;
} }
/*! \brief Return all the strides */ /*! \brief Return the tensor strides */
const vec64_t& strides() const { const vec64_t& strides() const {
return strides_; return strides_;
} }
/*! \brief Return the number of elements along the [start, end) axes */ /*! \brief Return the total number of elements */
int64_t count() const {
return (int64_t)size_;
}
/*! \brief Return the number of elements counting along the given axes */
int64_t count(int64_t start, int64_t end) const { int64_t count(int64_t start, int64_t end) const {
int64_t nelements = 1; int64_t nelements = 1;
for (int64_t i = start; i < end; i++) for (int64_t i = start; i < end; i++) {
nelements *= dim(i); nelements *= dim(i);
}
return nelements; return nelements;
} }
/*! \brief Return the total number of elements */ /*! \brief Return the number of elements counting from the given axis */
int64_t count() const {
return (int64_t)size_;
}
/*! \brief Return the number of elements from the start axis */
int64_t count(int64_t start) const { int64_t count(int64_t start) const {
return count(start, ndim()); return count(start, ndim());
} }
/*! \brief Whether this tensor is empty */ /*! \brief Return whether the total number of elements is zero */
bool empty() const { bool empty() const {
return size_ == 0; return size_ == 0;
} }
/*! \brief Whether this tensor holds a valid memory */ /*! \brief Return whether the memory is set */
bool has_memory() const { bool has_memory() const {
return internal_memory_ != nullptr || external_memory_ != nullptr; return internal_memory_ != nullptr || external_memory_ != nullptr;
} }
...@@ -286,12 +312,12 @@ class DRAGON_API Tensor { ...@@ -286,12 +312,12 @@ class DRAGON_API Tensor {
return ptr; return ptr;
} }
/*! \brief Return the state of memory */ /*! \brief Return the memory state */
UnifiedMemory::State memory_state() const { UnifiedMemory::State memory_state() const {
return memory(true)->state(); return memory(true)->state();
} }
/*! \brief Try to get the raw const data pointer */ /*! \brief Try to return the raw const data pointer */
template <class Context> template <class Context>
const void* const_data_ptr() const { const void* const_data_ptr() const {
TypeId ctx_type = TypeMeta::Id<Context>(); TypeId ctx_type = TypeMeta::Id<Context>();
...@@ -307,7 +333,7 @@ class DRAGON_API Tensor { ...@@ -307,7 +333,7 @@ class DRAGON_API Tensor {
} }
} }
/*! \brief Try to get the raw mutable data pointer */ /*! \brief Try to return the raw mutable data pointer */
template <class Context> template <class Context>
void mutable_data_ptr(void** data_ptr) { void mutable_data_ptr(void** data_ptr) {
auto* mem = memory(); auto* mem = memory();
...@@ -327,21 +353,23 @@ class DRAGON_API Tensor { ...@@ -327,21 +353,23 @@ class DRAGON_API Tensor {
} }
} }
/*! \brief Try to allocate the raw data for memory */ /*!
* \brief Return the raw mutable data pointer.
*
* If memory is not set, create to manage it with the given meta.
*/
template <class Context> template <class Context>
void* raw_mutable_data(const TypeMeta& meta) { void* raw_mutable_data(const TypeMeta& meta) {
void* data_ptr; void* data_ptr;
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
// Return the data of memory directly // Return the data pointer directly
if (meta_ == meta && data_ptr) return data_ptr; if (meta_ == meta && data_ptr) return data_ptr;
// Create a new memory with knowned size // Create a new memory created with size and meta
CHECK_GT(size_, 0) << "\nInvalid tensor size."; CHECK_GT(size_, 0) << "\nInvalid tensor size.";
meta_ = meta; meta_ = meta;
capacity_ = size_ * meta.itemsize(); capacity_ = size_ * meta.itemsize();
internal_memory_.reset(new UnifiedMemory(meta_, capacity_)); internal_memory_.reset(new UnifiedMemory(meta_, capacity_));
// Allocate space
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
// Call the constructor if necessary
if (meta_.ctor()) meta_.ctor()(data_ptr, size_); if (meta_.ctor()) meta_.ctor()(data_ptr, size_);
return data_ptr; return data_ptr;
} }
...@@ -360,7 +388,7 @@ class DRAGON_API Tensor { ...@@ -360,7 +388,7 @@ class DRAGON_API Tensor {
return const_data_ptr<Context>(); return const_data_ptr<Context>();
} }
/*! \brief Get the typed mutable data pointer */ /*! \brief Return the typed mutable data pointer */
template <typename T, class Context> template <typename T, class Context>
T* mutable_data() { T* mutable_data() {
void* data_ptr; void* data_ptr;
...@@ -377,7 +405,7 @@ class DRAGON_API Tensor { ...@@ -377,7 +405,7 @@ class DRAGON_API Tensor {
return static_cast<T*>(raw_mutable_data<Context>(TypeMeta::Make<T>())); return static_cast<T*>(raw_mutable_data<Context>(TypeMeta::Make<T>()));
} }
/*! \brief Get the typed const data pointer */ /*! \brief Return the typed const data pointer */
template <typename T, class Context> template <typename T, class Context>
const T* data() const { const T* data() const {
CHECK(meta_.Match<T>()) << "\nThe type of Tensor(" << name() << ") is " CHECK(meta_.Match<T>()) << "\nThe type of Tensor(" << name() << ") is "
...@@ -391,13 +419,13 @@ class DRAGON_API Tensor { ...@@ -391,13 +419,13 @@ class DRAGON_API Tensor {
version_ = version; version_ = version;
} }
/*! \brief Set the meta of data type */ /*! \brief Set the type meta */
Tensor* set_meta(const TypeMeta& meta) { Tensor* set_meta(const TypeMeta& meta) {
meta_ = meta; meta_ = meta;
return this; return this;
} }
/*! \brief Set the internal memory */ /*! \brief Set to manage the memory */
void set_memory(UnifiedMemory* memory) { void set_memory(UnifiedMemory* memory) {
if (memory != internal_memory_.get()) { if (memory != internal_memory_.get()) {
internal_memory_.reset(memory); internal_memory_.reset(memory);
...@@ -429,6 +457,8 @@ class DRAGON_API Tensor { ...@@ -429,6 +457,8 @@ class DRAGON_API Tensor {
/*! \brief The external memory indicator */ /*! \brief The external memory indicator */
bool own_memory_ = true; bool own_memory_ = true;
DISABLE_COPY_AND_ASSIGN(Tensor);
}; };
} // namespace dragon } // namespace dragon
......
...@@ -31,15 +31,41 @@ struct DRAGON_API TypeRegister { ...@@ -31,15 +31,41 @@ struct DRAGON_API TypeRegister {
} }
}; };
/*!
* \brief Metaclass for all types.
*
* TypeMeta is commonly used for type identification:
*
* \code{.cpp}
* auto meta1 = dragon::TypeMeta::Make<float>();
* auto meta2 = dragon::TypeMeta::Make<float>();
* std::cout << (meta1 == meta2) << std::endl; // 1
* std::cout << (meta1.id() == meta2.id()) << std::endl; // 1
* std::cout << meta1.Match<float>() << std::endl; // 1
* std::cout << (meta1.id() == dragon::TypeMeta::Id<float>()) << std::endl; // 1
* \endcode
*
* Default constructor and destructor are available for non-fundamental types:
*
* \code{.cpp}
* auto meta = dragon::TypeMeta::Make<std::string>();
* auto* raw_string_data = malloc(1 * meta.itemsize());
* meta.ctor()(raw_string_data, 1);
* auto* string_data = reinterpret_cast<std::string*>(raw_string_data);
* std::cout << string_data[0].size();
* meta.dtor()(raw_string_data, 1);
* \endcode
*/
class TypeMeta { class TypeMeta {
public: public:
typedef void (*PlacementNew)(void*, size_t); typedef void (*PlacementNew)(void*, size_t);
typedef void (*TypedCopy)(const void*, void*, size_t); typedef void (*TypedCopy)(const void*, void*, size_t);
typedef void (*TypedDestructor)(void*, size_t); typedef void (*TypedDestructor)(void*, size_t);
TypeMeta() /*! \brief Default constructor */
: id_(0), itemsize_(0), ctor_(nullptr), copy_(nullptr), dtor_(nullptr) {} TypeMeta() : id_(0), itemsize_(0) {}
/*! \brief Constructor with the other type meta */
TypeMeta(const TypeMeta& src) TypeMeta(const TypeMeta& src)
: id_(src.id_), : id_(src.id_),
itemsize_(src.itemsize_), itemsize_(src.itemsize_),
...@@ -57,32 +83,38 @@ class TypeMeta { ...@@ -57,32 +83,38 @@ class TypeMeta {
return *this; return *this;
} }
/*! \brief Return whether the two identifications are equal */
bool operator==(const TypeMeta& other) const { bool operator==(const TypeMeta& other) const {
return (id_ == other.id_); return (id_ == other.id_);
} }
/*! \brief Return whether the two identifications are not equal */
bool operator!=(const TypeMeta& other) const { bool operator!=(const TypeMeta& other) const {
return (id_ != other.id_); return (id_ != other.id_);
} }
/*! \brief Return the identification of given type */
template <typename T> template <typename T>
static TypeId Id() { static TypeId Id() {
return TypeRegister<T>::id(); return TypeRegister<T>::id();
} }
/*! \brief Return the item size of given type */
template <typename T> template <typename T>
static size_t Itemsize() { static size_t Itemsize() {
return sizeof(T); return sizeof(T);
} }
/*! \brief Call the constructor for each element */
template <typename T> template <typename T>
static void Ctor(void* ptr, size_t n) { static void Ctor(void* ptr, size_t n) {
T* typed_ptr = static_cast<T*>(ptr); T* typed_ptr = static_cast<T*>(ptr);
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; ++i) {
new (typed_ptr + i) T; new (typed_ptr + i) T;
} }
} }
/*! \brief Call the destructor for each element */
template <typename T> template <typename T>
static void Dtor(void* ptr, size_t n) { static void Dtor(void* ptr, size_t n) {
T* typed_ptr = static_cast<T*>(ptr); T* typed_ptr = static_cast<T*>(ptr);
...@@ -91,55 +123,66 @@ class TypeMeta { ...@@ -91,55 +123,66 @@ class TypeMeta {
} }
} }
/*! \brief Call the copy constructor for each element */
template <typename T> template <typename T>
static void Copy(const void* src, void* dst, size_t n) { static void Copy(const void* src, void* dst, size_t n) {
const T* typed_src = static_cast<const T*>(src); const T* typed_src = static_cast<const T*>(src);
T* typed_dst = static_cast<T*>(dst); T* typed_dst = static_cast<T*>(dst);
for (size_t i = 0; i < n; ++i) for (size_t i = 0; i < n; ++i) {
typed_dst[i] = typed_src[i]; typed_dst[i] = typed_src[i];
}
} }
#define FundMeta std::enable_if<std::is_fundamental<T>::value, TypeMeta>::type #define FundamentalTypeMeta \
std::enable_if<std::is_fundamental<T>::value, TypeMeta>::type
#define StructMeta \ #define StructuralTypeMeta \
std::enable_if< \ std::enable_if< \
!std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, \ !std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, \
TypeMeta>::type TypeMeta>::type
/*! \brief Return a type meta of given type */
template <typename T> template <typename T>
static typename FundMeta Make() { static typename FundamentalTypeMeta Make() {
return TypeMeta(Id<T>(), Itemsize<T>(), nullptr, nullptr, nullptr); return TypeMeta(Id<T>(), Itemsize<T>(), nullptr, nullptr, nullptr);
} }
/*! \brief Return a type meta of given type */
template <typename T> template <typename T>
static typename StructMeta Make() { static typename StructuralTypeMeta Make() {
return TypeMeta(Id<T>(), Itemsize<T>(), Ctor<T>, Copy<T>, Dtor<T>); return TypeMeta(Id<T>(), Itemsize<T>(), Ctor<T>, Copy<T>, Dtor<T>);
} }
#undef FundMeta #undef FundamentalTypeMeta
#undef StructMeta #undef StructuralTypeMeta
/*! \brief Return whether the meta is matched with given type */
template <typename T> template <typename T>
bool Match() const { bool Match() const {
return (id_ == Id<T>()); return (id_ == Id<T>());
} }
/*! \brief Return the type identification */
const TypeId& id() const { const TypeId& id() const {
return id_; return id_;
} }
/*! \brief Return the item size */
const size_t& itemsize() const { const size_t& itemsize() const {
return itemsize_; return itemsize_;
} }
/*! \brief Return the type constructor */
PlacementNew ctor() const { PlacementNew ctor() const {
return ctor_; return ctor_;
} }
/*! \brief Return the type destructor */
TypedDestructor dtor() const { TypedDestructor dtor() const {
return dtor_; return dtor_;
} }
/*! \brief Return the type copy constructor */
TypedCopy copy() const { TypedCopy copy() const {
return copy_; return copy_;
} }
...@@ -156,9 +199,9 @@ class TypeMeta { ...@@ -156,9 +199,9 @@ class TypeMeta {
private: private:
TypeId id_; TypeId id_;
size_t itemsize_; size_t itemsize_;
PlacementNew ctor_; PlacementNew ctor_ = nullptr;
TypedCopy copy_; TypedCopy copy_ = nullptr;
TypedDestructor dtor_; TypedDestructor dtor_ = nullptr;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -128,7 +128,7 @@ void Workspace::RunGraph( ...@@ -128,7 +128,7 @@ void Workspace::RunGraph(
const int stream) { const int stream) {
CHECK(graph_map_.count(name)) CHECK(graph_map_.count(name))
<< "\nGraph(" << name << ") is not in current workspace."; << "\nGraph(" << name << ") is not in current workspace.";
graph_map_[name]->Run(include, exclude, stream); graph_map_[name]->Run(stream, include, exclude);
} }
void Workspace::RegisterAlias(const string& target, const string& alias) { void Workspace::RegisterAlias(const string& target, const string& alias) {
......
...@@ -17,55 +17,58 @@ ...@@ -17,55 +17,58 @@
namespace dragon { namespace dragon {
class Workspace { /*!
* \brief Sandbox to isolate the resources and computations.
*/
class DRAGON_API Workspace {
public: public:
/*! \brief Constructor */ /*! \brief Constructor with the name */
DRAGON_API explicit Workspace(const string& name); explicit Workspace(const string& name);
/*! \brief Merge resources from other */ /*! \brief Merge resources from other */
DRAGON_API void MergeFrom(Workspace*); void MergeFrom(Workspace* other);
/*! \brief Clear the cached resources */ /*! \brief Clear the cached resources */
DRAGON_API void Clear(); void Clear();
/* \brief Return an unique name */ /* \brief Return an unique name */
DRAGON_API string UniqueName( string UniqueName(
const string& name, const string& name,
const string& suffix, const string& suffix,
const string& scope = "", const string& scope = "",
const bool zero_based = false); const bool zero_based = false);
/* \brief Register an alias for the target */ /* \brief Register an alias for the target */
DRAGON_API void RegisterAlias(const string& target, const string& alias); void RegisterAlias(const string& target, const string& alias);
/*! \brief Return whether tensor is existing */ /*! \brief Return whether tensor is existing */
DRAGON_API bool HasTensor(const string& name, bool external = true) const { bool HasTensor(const string& name, bool external = true) const {
return TryGetTensor(name, external) == nullptr ? false : true; return TryGetTensor(name, external) == nullptr ? false : true;
} }
/*! \brief Create the tensor */ /*! \brief Create the tensor */
DRAGON_API Tensor* CreateTensor(const string&, FillerInfo* = nullptr); Tensor* CreateTensor(const string&, FillerInfo* = nullptr);
/*! \brief Try to return the tensor */ /*! \brief Try to return the tensor */
DRAGON_API Tensor* TryGetTensor(const string&, bool = true) const; Tensor* TryGetTensor(const string&, bool = true) const;
/*! \brief Return the tensor */ /*! \brief Return the tensor */
DRAGON_API Tensor* GetTensor(const string&, bool = true) const; Tensor* GetTensor(const string&, bool = true) const;
/*! \brief Reset the tensor */ /*! \brief Reset the tensor */
DRAGON_API void ResetTensor(const string&); void ResetTensor(const string&);
/*! \brief Return the filler info */ /*! \brief Return the filler info */
DRAGON_API FillerInfo* GetFillerInfo(const string&); FillerInfo* GetFillerInfo(const string&);
/*! \brief Run the operator */ /*! \brief Run the operator */
DRAGON_API void RunOperator(const OperatorDef&); void RunOperator(const OperatorDef&);
/*! \brief Create the graph */ /*! \brief Create the graph */
DRAGON_API GraphBase* CreateGraph(const GraphDef&); GraphBase* CreateGraph(const GraphDef&);
/*! \brief Run the graph */ /*! \brief Run the graph */
DRAGON_API void RunGraph( void RunGraph(
const string& graph_name, const string& graph_name,
const string& include = "", const string& include = "",
const string& exclude = "", const string& exclude = "",
...@@ -77,12 +80,12 @@ class Workspace { ...@@ -77,12 +80,12 @@ class Workspace {
} }
/*! \brief Return the name of cached tensors */ /*! \brief Return the name of cached tensors */
DRAGON_API vector<string> tensors() const; vector<string> tensors() const;
/*! \brief Return the name of cached graphs */ /*! \brief Return the name of cached graphs */
DRAGON_API vector<string> graphs() const; vector<string> graphs() const;
/*! \brief Provide a group of the shared byte data */ /*! \brief Return a group of the shared raw data */
template <class Context> template <class Context>
vector<void*> data(const vector<size_t>& segments) { vector<void*> data(const vector<size_t>& segments) {
int64_t nbytes = 0; int64_t nbytes = 0;
...@@ -96,7 +99,7 @@ class Workspace { ...@@ -96,7 +99,7 @@ class Workspace {
return ret; return ret;
} }
/*! \brief Provide a group of shared typed data */ /*! \brief Return a group of shared typed data */
template <typename T, class Context> template <typename T, class Context>
vector<T*> data(const vector<int64_t>& segments) { vector<T*> data(const vector<int64_t>& segments) {
vector<size_t> segments_in_byte; vector<size_t> segments_in_byte;
...@@ -133,6 +136,8 @@ class Workspace { ...@@ -133,6 +136,8 @@ class Workspace {
/*! \brief The cached graphs */ /*! \brief The cached graphs */
Map<string, unique_ptr<GraphBase>> graph_map_; Map<string, unique_ptr<GraphBase>> graph_map_;
DISABLE_COPY_AND_ASSIGN(Workspace);
}; };
} // namespace dragon } // namespace dragon
......
...@@ -69,7 +69,7 @@ class NumpyFetcher : public TensorFetcherBase { ...@@ -69,7 +69,7 @@ class NumpyFetcher : public TensorFetcherBase {
tensor.nbytes(), tensor.nbytes(),
PyArray_DATA(reinterpret_cast<PyArrayObject*>(array)), PyArray_DATA(reinterpret_cast<PyArrayObject*>(array)),
tensor.raw_data<CUDAContext>(), tensor.raw_data<CUDAContext>(),
tensor.memory()->device_id()); tensor.memory()->device());
} else { } else {
CPUContext::Memcpy<CPUContext, CPUContext>( CPUContext::Memcpy<CPUContext, CPUContext>(
tensor.nbytes(), tensor.nbytes(),
......
...@@ -130,7 +130,7 @@ void RegisterModule(py::module& m) { ...@@ -130,7 +130,7 @@ void RegisterModule(py::module& m) {
#ifdef USE_CUDA #ifdef USE_CUDA
if (device_id < 0) device_id = CUDAContext::current_device(); if (device_id < 0) device_id = CUDAContext::current_device();
auto stream = CUDAContext::object()->stream(device_id, stream_id); auto stream = CUDAContext::object()->stream(device_id, stream_id);
CUDAContext::SyncStream(stream); CUDAContext::SynchronizeStream(stream);
#endif #endif
}); });
......
...@@ -50,7 +50,7 @@ class DLPackWrapper { ...@@ -50,7 +50,7 @@ class DLPackWrapper {
} else { } else {
data = memory->mutable_cuda_data(nbytes); data = memory->mutable_cuda_data(nbytes);
} }
ctx.device_id = memory->device_id(); ctx.device_id = memory->device();
ctx.device_type = DLDeviceType::kDLGPU; ctx.device_type = DLDeviceType::kDLGPU;
break; break;
} }
......
...@@ -191,9 +191,12 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -191,9 +191,12 @@ PYBIND11_MODULE(libdragon_python, m) {
auto* graph = self->CreateGraph(graph_def); auto* graph = self->CreateGraph(graph_def);
if (verbose) { if (verbose) {
bool could_be_serialized = true; bool could_be_serialized = true;
const auto& def = graph->opt_def(); const auto& def = graph->optimized_def();
for (auto& op : def.op()) for (auto& op : def.op()) {
if (op.type() == "GivenTensorFill") could_be_serialized = false; if (op.type() == "GivenTensorFill") {
could_be_serialized = false;
}
}
if (could_be_serialized) { if (could_be_serialized) {
auto msg = string("\n") + def.DebugString(); auto msg = string("\n") + def.DebugString();
msg.pop_back(); msg.pop_back();
......
...@@ -3,6 +3,9 @@ message(STATUS "Build module: ${CMAKE_CURRENT_LIST_DIR}") ...@@ -3,6 +3,9 @@ message(STATUS "Build module: ${CMAKE_CURRENT_LIST_DIR}")
# ---[ Defines # ---[ Defines
add_definitions(-DBUILD_RUNTIME) add_definitions(-DBUILD_RUNTIME)
if (USE_MPI)
remove_definitions(-DUSE_MPI)
endif()
# ---[ Sources # ---[ Sources
set(MODULE_INCLUDES "") set(MODULE_INCLUDES "")
......
...@@ -3,28 +3,33 @@ ...@@ -3,28 +3,33 @@
namespace dragon { namespace dragon {
int type_from_string(std::string type) { namespace {
if (type == "CPU") {
int type_from_string(const std::string& device_type) {
if (device_type == "CPU") {
return 0; return 0;
} else if (type == "GPU") { } else if (device_type == "GPU") {
return 1; return 1;
} else if (type == "CUDA") { } else if (device_type == "CUDA") {
return 1; return 1;
} }
LOG(FATAL) << "Unknown device type: " << type << ", " LOG(FATAL) << "Unsupported device type: " << device_type << "\n"
<< "known device types: " << "Following device types are supported: {"
<< "CPU, " << " * CPU\n"
<< "GPU, " << " * GPU\n"
<< "CUDA"; << " * CUDA\n"
<< "}";
return -1; return -1;
} }
} // namespace
Device::Device() : device_type_(0), device_id_(0) {} Device::Device() : device_type_(0), device_id_(0) {}
Device::Device(std::string device_type, int device_id) Device::Device(const std::string& device_type, int device_id)
: device_type_(type_from_string(device_type)), device_id_(device_id) {} : device_type_(type_from_string(device_type)), device_id_(device_id) {}
Device::Device(std::string device_type) Device::Device(const std::string& device_type)
: device_type_(type_from_string(device_type)), device_id_(0) {} : device_type_(type_from_string(device_type)), device_id_(0) {}
} // namespace dragon } // namespace dragon
...@@ -40,8 +40,8 @@ typedef class Workspace* Workspace_t; ...@@ -40,8 +40,8 @@ typedef class Workspace* Workspace_t;
class DRAGON_API Device { class DRAGON_API Device {
public: public:
Device(); Device();
explicit Device(std::string device_type); explicit Device(const std::string& device_type);
Device(std::string device_type, int device_id); Device(const std::string& device_type, int device_id);
const int& device_type() const { const int& device_type() const {
return device_type_; return device_type_;
...@@ -65,7 +65,7 @@ DRAGON_API Workspace_t ResetWorkspace(Workspace_t ws); ...@@ -65,7 +65,7 @@ DRAGON_API Workspace_t ResetWorkspace(Workspace_t ws);
DRAGON_API Workspace_t ResetWorkspace(const std::string& name); DRAGON_API Workspace_t ResetWorkspace(const std::string& name);
DRAGON_API void MoveWorkspace(Workspace_t dst, Workspace_t src); DRAGON_API void MoveWorkspace(Workspace_t dest, Workspace_t src);
DRAGON_API void DestroyWorkspace(Workspace_t ws); DRAGON_API void DestroyWorkspace(Workspace_t ws);
...@@ -76,15 +76,13 @@ DRAGON_API void DestroyWorkspace(const std::string& name); ...@@ -76,15 +76,13 @@ DRAGON_API void DestroyWorkspace(const std::string& name);
*/ */
DRAGON_API std::string DRAGON_API std::string
CreateGraph(const GraphDef_t graph_def, const Device& device, Workspace_t ws); CreateGraph(const GraphDef_t def, const Device& device, Workspace_t ws);
DRAGON_API std::string CreateGraph( DRAGON_API std::string
const std::string& graph_file, CreateGraph(const std::string& file, const Device& device, Workspace_t ws);
const Device& device,
Workspace_t ws);
DRAGON_API void DRAGON_API void
RunGraph(const std::string& graph_name, Workspace_t ws, int stream_id = 0); RunGraph(const std::string& name, Workspace_t ws, int stream = 0);
/*! /*!
* Tensor API * Tensor API
...@@ -111,9 +109,9 @@ DRAGON_API T* FetchTensor( ...@@ -111,9 +109,9 @@ DRAGON_API T* FetchTensor(
* Proto API * Proto API
*/ */
DRAGON_API void CreateGraphDef(GraphDef_t* graph_def); DRAGON_API void CreateGraphDef(GraphDef_t* def);
DRAGON_API void DestroyGraphDef(GraphDef_t graph_def); DRAGON_API void DestroyGraphDef(GraphDef_t def);
/*! /*!
* Model API * Model API
...@@ -121,8 +119,8 @@ DRAGON_API void DestroyGraphDef(GraphDef_t graph_def); ...@@ -121,8 +119,8 @@ DRAGON_API void DestroyGraphDef(GraphDef_t graph_def);
DRAGON_API void LoadONNXModel( DRAGON_API void LoadONNXModel(
const std::string& model_file, const std::string& model_file,
GraphDef_t init_graph, GraphDef_t init_def,
GraphDef_t pred_graph, GraphDef_t pred_def,
std::vector<std::string>& inputs, std::vector<std::string>& inputs,
std::vector<std::string>& outputs); std::vector<std::string>& outputs);
......
#include "dragon/core/common.h" #include "dragon/core/workspace.h"
#include "dragon/modules/runtime/dragon_runtime.h" #include "dragon/modules/runtime/dragon_runtime.h"
#include "dragon/onnx/onnx_backend.h" #include "dragon/onnx/onnx_backend.h"
#include "dragon/utils/proto_utils.h" #include "dragon/utils/proto_utils.h"
...@@ -9,7 +9,7 @@ std::mutex g_mutex; ...@@ -9,7 +9,7 @@ std::mutex g_mutex;
Map<string, unique_ptr<Workspace>> g_workspaces; Map<string, unique_ptr<Workspace>> g_workspaces;
Map<string, vector<string>> sub_workspaces; Map<string, vector<string>> sub_workspaces;
Workspace* CreateWorkspace(const string& name) { Workspace_t CreateWorkspace(const string& name) {
std::unique_lock<std::mutex> lock(g_mutex); std::unique_lock<std::mutex> lock(g_mutex);
LOG(INFO) << "Create the Workspace(" << name << ")."; LOG(INFO) << "Create the Workspace(" << name << ").";
if (g_workspaces.count(name)) return g_workspaces[name].get(); if (g_workspaces.count(name)) return g_workspaces[name].get();
...@@ -19,7 +19,7 @@ Workspace* CreateWorkspace(const string& name) { ...@@ -19,7 +19,7 @@ Workspace* CreateWorkspace(const string& name) {
return g_workspaces[name].get(); return g_workspaces[name].get();
} }
Workspace* ResetWorkspace(const string& name) { Workspace_t ResetWorkspace(const string& name) {
std::unique_lock<std::mutex> lock(g_mutex); std::unique_lock<std::mutex> lock(g_mutex);
CHECK(g_workspaces.count(name)) CHECK(g_workspaces.count(name))
<< "\nWorkspace(" << name << ") does not exist." << "\nWorkspace(" << name << ") does not exist."
...@@ -34,19 +34,19 @@ Workspace* ResetWorkspace(const string& name) { ...@@ -34,19 +34,19 @@ Workspace* ResetWorkspace(const string& name) {
return g_workspaces[name].get(); return g_workspaces[name].get();
} }
Workspace* ResetWorkspace(Workspace_t ws) { Workspace_t ResetWorkspace(Workspace_t ws) {
CHECK(ws) << "\nGiven workspace is invalid."; CHECK(ws) << "\nGiven workspace is invalid.";
return ResetWorkspace(ws->name()); return ResetWorkspace(ws->name());
} }
void MoveWorkspace(Workspace_t dst, Workspace_t src) { void MoveWorkspace(Workspace_t dest, Workspace_t src) {
std::unique_lock<std::mutex> lock(g_mutex); std::unique_lock<std::mutex> lock(g_mutex);
CHECK(src) << "\nGiven source workspace is invalid."; CHECK(src) << "\nGiven source workspace is invalid.";
CHECK(dst) << "\nGiven destination workspace is invalid."; CHECK(dest) << "\nGiven destination workspace is invalid.";
dst->MergeFrom(src); dest->MergeFrom(src);
sub_workspaces[dst->name()].push_back(src->name()); sub_workspaces[dest->name()].push_back(src->name());
LOG(INFO) << "Move the Workspace(" << src->name() << ") " LOG(INFO) << "Move the Workspace(" << src->name() << ") "
<< "into the Workspace(" << dst->name() << ")."; << "into the Workspace(" << dest->name() << ").";
} }
void DestroyWorkspace(const string& name) { void DestroyWorkspace(const string& name) {
...@@ -63,27 +63,25 @@ void DestroyWorkspace(Workspace_t ws) { ...@@ -63,27 +63,25 @@ void DestroyWorkspace(Workspace_t ws) {
return DestroyWorkspace(ws->name()); return DestroyWorkspace(ws->name());
} }
string string CreateGraph(const GraphDef_t def, const Device& device, Workspace_t ws) {
CreateGraph(const GraphDef_t graph_def, const Device& device, Workspace_t ws) { auto def_v2(*def);
auto graph_def_copy(*graph_def); auto* device_option = def_v2.mutable_device_option();
// Overwritten device options
DeviceOption* device_option = graph_def_copy.mutable_device_option();
device_option->set_device_type((DeviceTypeProto)device.device_type()); device_option->set_device_type((DeviceTypeProto)device.device_type());
device_option->set_device_id(device.device_id()); device_option->set_device_id(device.device_id());
auto* graph = ws->CreateGraph(graph_def_copy); auto* graph = ws->CreateGraph(def_v2);
if (!graph) LOG(FATAL) << "Can not create the graph."; if (!graph) LOG(FATAL) << "Can not create the graph.";
return graph->name(); return graph->name();
} }
std::string std::string
CreateGraph(const string& graph_file, const Device& device, Workspace_t ws) { CreateGraph(const string& file, const Device& device, Workspace_t ws) {
GraphDef graph_def; GraphDef graph_def;
ParseProtoFromText(graph_file.c_str(), &graph_def); ParseProtoFromText(file.c_str(), &graph_def);
return CreateGraph(&graph_def, device, ws); return CreateGraph(&graph_def, device, ws);
} }
void RunGraph(const string& graph_name, Workspace_t ws, const int stream_id) { void RunGraph(const string& name, Workspace_t ws, int stream) {
ws->RunGraph(graph_name, "", "", stream_id); ws->RunGraph(name, "", "", stream);
} }
void CreateTensor(const string& name, Workspace_t ws) { void CreateTensor(const string& name, Workspace_t ws) {
...@@ -148,34 +146,38 @@ void FeedTensor( ...@@ -148,34 +146,38 @@ void FeedTensor(
tensor->raw_mutable_data<CPUContext>(), tensor->raw_mutable_data<CPUContext>(),
static_cast<const void*>(data)); static_cast<const void*>(data));
} else { } else {
LOG(FATAL) << "Unknown device type."; LOG(FATAL) << "Unsupported device type.";
} }
} }
DRAGON_API void CreateGraphDef(GraphDef_t* graph_def) { void CreateGraphDef(GraphDef_t* def) {
*graph_def = new GraphDef(); *def = new GraphDef();
} }
DRAGON_API void DestroyGraphDef(GraphDef_t graph_def) { void DestroyGraphDef(GraphDef_t def) {
if (graph_def) delete graph_def; if (def) {
delete def;
}
} }
void LoadONNXModel( void LoadONNXModel(
const string& model_file, const string& model_file,
GraphDef_t init_graph, GraphDef_t init_def,
GraphDef_t pred_graph, GraphDef_t pred_def,
vector<string>& inputs, vector<string>& inputs,
vector<string>& outputs) { vector<string>& outputs) {
LOG(INFO) << "Load Model: " << model_file << "......"; LOG(INFO) << "Load Model: " << model_file << "......";
LOG(INFO) << "Format: ONNX"; LOG(INFO) << "Format: ONNX";
onnx::ONNXBackend onnx_backend; onnx::ONNXBackend onnx_backend;
onnx_backend.Prepare(model_file, init_graph, pred_graph); onnx_backend.Prepare(model_file, init_def, pred_def);
inputs.clear(); inputs.clear();
outputs.clear(); outputs.clear();
for (const auto& e : pred_graph->input()) for (const auto& input : pred_def->input()) {
inputs.emplace_back(e); inputs.push_back(input);
for (const auto& e : pred_graph->output()) }
outputs.emplace_back(e); for (const auto& output : pred_def->output()) {
outputs.push_back(output);
}
} }
#define INSTANTIATE_API(T) \ #define INSTANTIATE_API(T) \
......
...@@ -8,54 +8,56 @@ namespace dragon { ...@@ -8,54 +8,56 @@ namespace dragon {
#define ELIGIBLE_TENSOR_TYPES \ #define ELIGIBLE_TENSOR_TYPES \
{ "bool", "int8", "uint8", "int32", "int64", "float16", "float32", "float64" } { "bool", "int8", "uint8", "int32", "int64", "float16", "float32", "float64" }
#define DEFINE_TYPE_A_TO_B(Ta, type_str, Tb) \ #define DISPATCH_TYPE_TO(InputType, OutputType) \
if (dtype() == type_str) { \ if (dtype() == types::to_string<OutputType>()) { \
if (InputSize() != 0) { \ if (InputSize() != 0) { \
Output(0)->ReshapeLike(Input(0)); \ Output(0)->ReshapeLike(Input(0)); \
auto* x = Input(0).template data<Ta, Context>(); \ auto* x = Input(0).template data<InputType, Context>(); \
auto* y = Output(0)->template mutable_data<Tb, Context>(); \ auto* y = Output(0)->template mutable_data<OutputType, Context>(); \
kernel::Cast(Input(0).count(), x, y, ctx()); \ kernel::Cast(Input(0).count(), x, y, ctx()); \
} else { \ } else { \
auto n = Output(0)->count(); \ auto n = Output(0)->count(); \
auto* x = Output(0)->template data<Ta, Context>(); \ auto* x = Output(0)->template data<InputType, Context>(); \
auto* scratch = ws()->template data<Tb, Context>({n})[0]; \ auto* scratch = ws()->template data<OutputType, Context>({n})[0]; \
kernel::Cast(n, x, scratch, ctx()); \ kernel::Cast(n, x, scratch, ctx()); \
ctx()->FinishDeviceComputation(); \ ctx()->FinishDeviceComputation(); \
auto* y = Output(0)->template mutable_data<Tb, Context>(); \ auto* y = Output(0)->template mutable_data<OutputType, Context>(); \
math::Copy(n, scratch, y, ctx()); \ math::Copy(n, scratch, y, ctx()); \
} \ } \
return; \ return; \
} }
#define DEFINE_TYPE_A_TO_ALL(Ta) \ #define DISPATCH_TYPE_TO_ALL(InputType) \
DEFINE_TYPE_A_TO_B(Ta, "bool", bool); \ DISPATCH_TYPE_TO(InputType, bool); \
DEFINE_TYPE_A_TO_B(Ta, "int8", int8_t); \ DISPATCH_TYPE_TO(InputType, int8_t); \
DEFINE_TYPE_A_TO_B(Ta, "uint8", uint8_t); \ DISPATCH_TYPE_TO(InputType, uint8_t); \
DEFINE_TYPE_A_TO_B(Ta, "int32", int); \ DISPATCH_TYPE_TO(InputType, int); \
DEFINE_TYPE_A_TO_B(Ta, "int64", int64_t); \ DISPATCH_TYPE_TO(InputType, int64_t); \
DEFINE_TYPE_A_TO_B(Ta, "float16", float16); \ DISPATCH_TYPE_TO(InputType, float16); \
DEFINE_TYPE_A_TO_B(Ta, "float32", float); \ DISPATCH_TYPE_TO(InputType, float); \
DEFINE_TYPE_A_TO_B(Ta, "float64", double) DISPATCH_TYPE_TO(InputType, double); \
LOG(FATAL) << MessageForUnsupported(dtype(), ELIGIBLE_TENSOR_TYPES);
#define DISPATCH_WITH_TENSOR(X) \ #define DISPATCH_WITH_TENSOR(X) \
if (XIsType(X, bool)) { \ if (XIsType(X, bool)) { \
DEFINE_TYPE_A_TO_ALL(bool); \ DISPATCH_TYPE_TO_ALL(bool); \
} else if (XIsType(X, int8_t)) { \ } else if (XIsType(X, int8_t)) { \
DEFINE_TYPE_A_TO_ALL(int8_t); \ DISPATCH_TYPE_TO_ALL(int8_t); \
} else if (XIsType(X, uint8_t)) { \ } else if (XIsType(X, uint8_t)) { \
DEFINE_TYPE_A_TO_ALL(uint8_t); \ DISPATCH_TYPE_TO_ALL(uint8_t); \
} else if (XIsType(X, int)) { \ } else if (XIsType(X, int)) { \
DEFINE_TYPE_A_TO_ALL(int); \ DISPATCH_TYPE_TO_ALL(int); \
} else if (XIsType(X, int64_t)) { \ } else if (XIsType(X, int64_t)) { \
DEFINE_TYPE_A_TO_ALL(int64_t); \ DISPATCH_TYPE_TO_ALL(int64_t); \
} else if (XIsType(X, float16)) { \ } else if (XIsType(X, float16)) { \
DEFINE_TYPE_A_TO_ALL(float16); \ DISPATCH_TYPE_TO_ALL(float16); \
} else if (XIsType(X, float)) { \ } else if (XIsType(X, float)) { \
DEFINE_TYPE_A_TO_ALL(float); \ DISPATCH_TYPE_TO_ALL(float); \
} else if (XIsType(X, double)) { \ } else if (XIsType(X, double)) { \
DEFINE_TYPE_A_TO_ALL(double); \ DISPATCH_TYPE_TO_ALL(double); \
} else { \ } else { \
LOG(FATAL) << TypeString(X, ELIGIBLE_TENSOR_TYPES); \ LOG(FATAL) << MessageForUnsupported( \
types::to_string(X.meta()), ELIGIBLE_TENSOR_TYPES); \
} }
template <class Context> template <class Context>
...@@ -101,8 +103,8 @@ OPERATOR_SCHEMA(CastGradient) ...@@ -101,8 +103,8 @@ OPERATOR_SCHEMA(CastGradient)
REGISTER_GRADIENT(Cast, SimpleGradientMaker); REGISTER_GRADIENT(Cast, SimpleGradientMaker);
#undef ELIGIBLE_TENSOR_TYPES #undef ELIGIBLE_TENSOR_TYPES
#undef DEFINE_TYPE_A_TO_B #undef DISPATCH_TYPE_TO
#undef DEFINE_TYPE_A_TO_ALL #undef DISPATCH_TYPE_TO_ALL
#undef DISPATCH_WITH_TENSOR #undef DISPATCH_WITH_TENSOR
} // namespace dragon } // namespace dragon
...@@ -52,7 +52,8 @@ void ChannelNormalizeOp<Context>::DoRunWithType() { ...@@ -52,7 +52,8 @@ void ChannelNormalizeOp<Context>::DoRunWithType() {
} else if (dtype() == "float64") { } else if (dtype() == "float64") {
DoRunWithTypeAndCast<T, double>(); DoRunWithTypeAndCast<T, double>();
} else { } else {
LOG(FATAL) << TypeString(dtype(), {"float16", "float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
dtype(), {"float16", "float32", "float64"});
} }
} }
......
...@@ -60,7 +60,7 @@ void MultinomialOp<Context>::DoRunWithType() { ...@@ -60,7 +60,7 @@ void MultinomialOp<Context>::DoRunWithType() {
template <class Context> template <class Context>
void MultinomialOp<Context>::RunOnDevice() { void MultinomialOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // Enforce the default stream ctx()->set_stream(0); // Enforce the default stream
DispatchHelper<TensorTypes<float, double>>::Call(this, Input(0)); DispatchHelper<TensorTypes<float, double>>::Call(this, Input(0));
} }
......
...@@ -124,23 +124,24 @@ template <class Context> ...@@ -124,23 +124,24 @@ template <class Context>
void CollectiveOp<Context>::RunOnDevice() { void CollectiveOp<Context>::RunOnDevice() {
if (communication_ == "ALLREDUCE") { if (communication_ == "ALLREDUCE") {
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); i++) {
if (XIsType(Input(i), int8_t)) { auto& X = Input(i);
if (XIsType(X, int8_t)) {
AllReduceDispatcher<int8_t>(&Input(i)); AllReduceDispatcher<int8_t>(&Input(i));
} else if (XIsType(Input(i), uint8_t)) { } else if (XIsType(X, uint8_t)) {
AllReduceDispatcher<uint8_t>(&Input(i)); AllReduceDispatcher<uint8_t>(&Input(i));
} else if (XIsType(Input(i), int)) { } else if (XIsType(X, int)) {
AllReduceDispatcher<int>(&Input(i)); AllReduceDispatcher<int>(&Input(i));
} else if (XIsType(Input(i), int64_t)) { } else if (XIsType(X, int64_t)) {
AllReduceDispatcher<int64_t>(&Input(i)); AllReduceDispatcher<int64_t>(&Input(i));
} else if (XIsType(Input(i), float16)) { } else if (XIsType(X, float16)) {
AllReduceDispatcher<float16>(&Input(i)); AllReduceDispatcher<float16>(&Input(i));
} else if (XIsType(Input(i), float)) { } else if (XIsType(X, float)) {
AllReduceDispatcher<float>(&Input(i)); AllReduceDispatcher<float>(&Input(i));
} else if (XIsType(Input(i), double)) { } else if (XIsType(X, double)) {
AllReduceDispatcher<double>(&Input(i)); AllReduceDispatcher<double>(&Input(i));
} else { } else {
LOG(FATAL) << TypeString( LOG(FATAL) << MessageForUnsupported(
Input(i), types::to_string(X.meta()),
{"int8", {"int8",
"uint8", "uint8",
"int32", "int32",
...@@ -152,25 +153,26 @@ void CollectiveOp<Context>::RunOnDevice() { ...@@ -152,25 +153,26 @@ void CollectiveOp<Context>::RunOnDevice() {
} }
} else if (communication_ == "BROADCAST") { } else if (communication_ == "BROADCAST") {
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); i++) {
if (XIsType(Input(i), bool)) { auto& X = Input(i);
if (XIsType(X, bool)) {
BroadcastDispatcher<bool>(&Input(i)); BroadcastDispatcher<bool>(&Input(i));
} else if (XIsType(Input(i), int8_t)) { } else if (XIsType(X, int8_t)) {
BroadcastDispatcher<int8_t>(&Input(i)); BroadcastDispatcher<int8_t>(&Input(i));
} else if (XIsType(Input(i), uint8_t)) { } else if (XIsType(X, uint8_t)) {
BroadcastDispatcher<uint8_t>(&Input(i)); BroadcastDispatcher<uint8_t>(&Input(i));
} else if (XIsType(Input(i), int)) { } else if (XIsType(X, int)) {
BroadcastDispatcher<int>(&Input(i)); BroadcastDispatcher<int>(&Input(i));
} else if (XIsType(Input(i), int64_t)) { } else if (XIsType(X, int64_t)) {
BroadcastDispatcher<int64_t>(&Input(i)); BroadcastDispatcher<int64_t>(&Input(i));
} else if (XIsType(Input(i), float16)) { } else if (XIsType(X, float16)) {
BroadcastDispatcher<float16>(&Input(i)); BroadcastDispatcher<float16>(&Input(i));
} else if (XIsType(Input(i), float)) { } else if (XIsType(X, float)) {
BroadcastDispatcher<float>(&Input(i)); BroadcastDispatcher<float>(&Input(i));
} else if (XIsType(Input(i), double)) { } else if (XIsType(X, double)) {
BroadcastDispatcher<double>(&Input(i)); BroadcastDispatcher<double>(&Input(i));
} else { } else {
LOG(FATAL) << TypeString( LOG(FATAL) << MessageForUnsupported(
Input(i), types::to_string(X.meta()),
{"bool", {"bool",
"int8", "int8",
"uint8", "uint8",
......
...@@ -149,7 +149,7 @@ class CollectiveOpBase : public Operator<Context> { ...@@ -149,7 +149,7 @@ class CollectiveOpBase : public Operator<Context> {
ncclComm_t nccl_comm() { ncclComm_t nccl_comm() {
auto ret = CUDAContext::object()->nccl_comm( auto ret = CUDAContext::object()->nccl_comm(
this->ctx()->template device_id(), this->ctx()->template device(),
group_str_, group_str_,
nullptr, nullptr,
comm_size_, comm_size_,
...@@ -162,7 +162,7 @@ class CollectiveOpBase : public Operator<Context> { ...@@ -162,7 +162,7 @@ class CollectiveOpBase : public Operator<Context> {
} }
Broadcast((uint8_t*)&comm_uuid, sizeof(comm_uuid)); Broadcast((uint8_t*)&comm_uuid, sizeof(comm_uuid));
ret = CUDAContext::object()->nccl_comm( ret = CUDAContext::object()->nccl_comm(
this->ctx()->template device_id(), this->ctx()->template device(),
group_str_, group_str_,
&comm_uuid, &comm_uuid,
comm_size_, comm_size_,
......
...@@ -85,7 +85,8 @@ void CuDNNCTCLossOp<Context>::RunOnDevice() { ...@@ -85,7 +85,8 @@ void CuDNNCTCLossOp<Context>::RunOnDevice() {
CUDNN_CHECK(cudnnSetCTCLossDescriptor(ctc_desc_, CUDNN_DATA_FLOAT)); CUDNN_CHECK(cudnnSetCTCLossDescriptor(ctc_desc_, CUDNN_DATA_FLOAT));
DoRunWithType<float>(); DoRunWithType<float>();
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
} }
} }
......
...@@ -72,7 +72,8 @@ void NLLLossOp<Context>::RunOnDevice() { ...@@ -72,7 +72,8 @@ void NLLLossOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<float, int64_t>(); DoRunWithType<float, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float32", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float32", "int64"});
} }
} else if (XIsType(Input(0), double)) { } else if (XIsType(Input(0), double)) {
if (XIsType(Input(1), double)) { if (XIsType(Input(1), double)) {
...@@ -80,10 +81,12 @@ void NLLLossOp<Context>::RunOnDevice() { ...@@ -80,10 +81,12 @@ void NLLLossOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<double, int64_t>(); DoRunWithType<double, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float64", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float64", "int64"});
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32", "float64"});
} }
} }
...@@ -139,7 +142,8 @@ void NLLLossGradientOp<Context>::RunOnDevice() { ...@@ -139,7 +142,8 @@ void NLLLossGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<float, int64_t>(); DoRunWithType<float, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float32", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float32", "int64"});
} }
} else if (XIsType(Input(0), double)) { } else if (XIsType(Input(0), double)) {
if (XIsType(Input(1), double)) { if (XIsType(Input(1), double)) {
...@@ -147,10 +151,12 @@ void NLLLossGradientOp<Context>::RunOnDevice() { ...@@ -147,10 +151,12 @@ void NLLLossGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<double, int64_t>(); DoRunWithType<double, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float64", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float64", "int64"});
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32", "float64"});
} }
} }
......
...@@ -72,7 +72,8 @@ void SigmoidFocalLossOp<Context>::RunOnDevice() { ...@@ -72,7 +72,8 @@ void SigmoidFocalLossOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<float, int64_t>(); DoRunWithType<float, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float32", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float32", "int64"});
} }
} else if (XIsType(Input(0), double)) { } else if (XIsType(Input(0), double)) {
if (XIsType(Input(1), double)) { if (XIsType(Input(1), double)) {
...@@ -80,10 +81,12 @@ void SigmoidFocalLossOp<Context>::RunOnDevice() { ...@@ -80,10 +81,12 @@ void SigmoidFocalLossOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<double, int64_t>(); DoRunWithType<double, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float64", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float64", "int64"});
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32", "float64"});
} }
} }
...@@ -139,7 +142,8 @@ void SigmoidFocalLossGradientOp<Context>::RunOnDevice() { ...@@ -139,7 +142,8 @@ void SigmoidFocalLossGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<float, int64_t>(); DoRunWithType<float, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float32", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float32", "int64"});
} }
} else if (XIsType(Input(0), double)) { } else if (XIsType(Input(0), double)) {
if (XIsType(Input(1), double)) { if (XIsType(Input(1), double)) {
...@@ -147,10 +151,12 @@ void SigmoidFocalLossGradientOp<Context>::RunOnDevice() { ...@@ -147,10 +151,12 @@ void SigmoidFocalLossGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<double, int64_t>(); DoRunWithType<double, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float64", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float64", "int64"});
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32", "float64"});
} }
} }
......
...@@ -82,7 +82,8 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() { ...@@ -82,7 +82,8 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<float, int64_t>(); DoRunWithType<float, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float32", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float32", "int64"});
} }
} else if (XIsType(Input(0), double)) { } else if (XIsType(Input(0), double)) {
if (XIsType(Input(1), double)) { if (XIsType(Input(1), double)) {
...@@ -90,10 +91,12 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() { ...@@ -90,10 +91,12 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<double, int64_t>(); DoRunWithType<double, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float64", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float64", "int64"});
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32", "float64"});
} }
} }
...@@ -152,7 +155,8 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() { ...@@ -152,7 +155,8 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<float, int64_t>(); DoRunWithType<float, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float32", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float32", "int64"});
} }
} else if (XIsType(Input(0), double)) { } else if (XIsType(Input(0), double)) {
if (XIsType(Input(1), double)) { if (XIsType(Input(1), double)) {
...@@ -160,10 +164,12 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() { ...@@ -160,10 +164,12 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<double, int64_t>(); DoRunWithType<double, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float64", "int64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float64", "int64"});
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32", "float64"});
} }
} }
......
...@@ -45,8 +45,8 @@ void AxpbyOp<Context>::RunOnDevice() { ...@@ -45,8 +45,8 @@ void AxpbyOp<Context>::RunOnDevice() {
} else if (XIsType(X, double)) { } else if (XIsType(X, double)) {
DoRunWithType<double>(&X, Y); DoRunWithType<double>(&X, Y);
} else } else
LOG(FATAL) << TypeString( LOG(FATAL) << MessageForUnsupported(
X, types::to_string(X.meta()),
{"int8", "uint8", "int32", "int64", "float16", "float32", "float64"}); {"int8", "uint8", "int32", "int64", "float16", "float32", "float64"});
} }
} }
......
...@@ -75,8 +75,8 @@ void MomentsOp<Context>::RunOnDevice() { ...@@ -75,8 +75,8 @@ void MomentsOp<Context>::RunOnDevice() {
} else if (XIsType(X, double)) { } else if (XIsType(X, double)) {
DoRunWithType<double, double>(); DoRunWithType<double, double>();
} else { } else {
LOG(FATAL) << TypeString( LOG(FATAL) << MessageForUnsupported(
X, types::to_string(X.meta()),
{"int8", "uint8", "int32", "int64", "float16", "float32", "float64"}); {"int8", "uint8", "int32", "int64", "float16", "float32", "float64"});
} }
} }
......
...@@ -55,7 +55,8 @@ void AccuracyOp<Context>::RunOnDevice() { ...@@ -55,7 +55,8 @@ void AccuracyOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<float, int64_t>(); DoRunWithType<float, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"int64", "float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float32", "int64"});
} }
} else if (XIsType(Input(0), double)) { } else if (XIsType(Input(0), double)) {
if (XIsType(Input(1), double)) { if (XIsType(Input(1), double)) {
...@@ -63,10 +64,12 @@ void AccuracyOp<Context>::RunOnDevice() { ...@@ -63,10 +64,12 @@ void AccuracyOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), int64_t)) { } else if (XIsType(Input(1), int64_t)) {
DoRunWithType<double, int64_t>(); DoRunWithType<double, int64_t>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"int64", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float64", "int64"});
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32", "float64"});
} }
} }
......
...@@ -114,7 +114,8 @@ void BatchNormOp<Context>::RunOnDevice() { ...@@ -114,7 +114,8 @@ void BatchNormOp<Context>::RunOnDevice() {
InferenceImpl<float, float>(); InferenceImpl<float, float>();
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
} }
} }
...@@ -190,7 +191,8 @@ void BatchNormGradientOp<Context>::RunOnDevice() { ...@@ -190,7 +191,8 @@ void BatchNormGradientOp<Context>::RunOnDevice() {
InferenceImpl<float, float>(); InferenceImpl<float, float>();
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
} }
} }
......
...@@ -90,7 +90,8 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() { ...@@ -90,7 +90,8 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() {
} else if (XIsType(Input(0), float16)) { } else if (XIsType(Input(0), float16)) {
DoRunWithType<float16>(); DoRunWithType<float16>();
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float16"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
} }
} }
...@@ -156,10 +157,12 @@ void CuDNNBatchNormGradientOp<Context>::RunOnDevice() { ...@@ -156,10 +157,12 @@ void CuDNNBatchNormGradientOp<Context>::RunOnDevice() {
TrainingImpl<float16>(); TrainingImpl<float16>();
} else { } else {
// We will support it some day -:) // We will support it some day -:)
LOG(FATAL) << TypeString(Input(0), {"float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float16", "float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
} }
} }
......
...@@ -111,7 +111,8 @@ void SyncBatchNormOp<Context>::RunOnDevice() { ...@@ -111,7 +111,8 @@ void SyncBatchNormOp<Context>::RunOnDevice() {
this->template InferenceImpl<float, float>(); this->template InferenceImpl<float, float>();
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
} }
} }
...@@ -195,7 +196,8 @@ void SyncBatchNormGradientOp<Context>::RunOnDevice() { ...@@ -195,7 +196,8 @@ void SyncBatchNormGradientOp<Context>::RunOnDevice() {
this->template InferenceImpl<float, float>(); this->template InferenceImpl<float, float>();
} }
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
} }
} }
......
...@@ -60,7 +60,8 @@ void GroupNormOp<Context>::RunOnDevice() { ...@@ -60,7 +60,8 @@ void GroupNormOp<Context>::RunOnDevice() {
} else if (XIsType(Input(0), float16)) { } else if (XIsType(Input(0), float16)) {
DoRunWithType<float16, float>(); DoRunWithType<float16, float>();
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32", "float16"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
} }
} }
...@@ -101,7 +102,8 @@ void GroupNormGradientOp<Context>::RunOnDevice() { ...@@ -101,7 +102,8 @@ void GroupNormGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(0), float16)) { } else if (XIsType(Input(0), float16)) {
DoRunWithType<float16, float>(); DoRunWithType<float16, float>();
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float16", "float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
} }
} }
......
...@@ -23,7 +23,8 @@ void LSTMCellOp<Context>::RunOnDevice() { ...@@ -23,7 +23,8 @@ void LSTMCellOp<Context>::RunOnDevice() {
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
DoRunWithType<float>(); DoRunWithType<float>();
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
} }
} }
...@@ -60,7 +61,8 @@ void LSTMCellGradientOp<Context>::RunOnDevice() { ...@@ -60,7 +61,8 @@ void LSTMCellGradientOp<Context>::RunOnDevice() {
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
DoRunWithType<float>(); DoRunWithType<float>();
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
} }
} }
......
...@@ -100,7 +100,8 @@ void UpdateOpBase<Context>::RunOnDevice() { ...@@ -100,7 +100,8 @@ void UpdateOpBase<Context>::RunOnDevice() {
ComputeUpdate(dX_cast); ComputeUpdate(dX_cast);
ApplyUpdate<float>(dX_cast, X); ApplyUpdate<float>(dX_cast, X);
} else { } else {
LOG(FATAL) << TypeString(dX, {"float16", "float32"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(dX.meta()), {"float16", "float32"});
} }
} }
......
...@@ -199,7 +199,8 @@ void ResizeGradientOp<Context>::RunOnDevice() { ...@@ -199,7 +199,8 @@ void ResizeGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(0), double)) { } else if (XIsType(Input(0), double)) {
DoRunWithTypeAndCast<double>(); DoRunWithTypeAndCast<double>();
} else { } else {
LOG(FATAL) << TypeString(Input(0), {"float16", "float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32", "float64"});
}; };
} }
......
...@@ -95,7 +95,8 @@ void RoiAlignGradientOp<Context>::RunOnDevice() { ...@@ -95,7 +95,8 @@ void RoiAlignGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), double)) { } else if (XIsType(Input(1), double)) {
DoRunWithTypeAndCast<double>(); DoRunWithTypeAndCast<double>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float16", "float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float16", "float32", "float64"});
}; };
} }
......
...@@ -98,7 +98,8 @@ void RoiPoolGradientOp<Context>::RunOnDevice() { ...@@ -98,7 +98,8 @@ void RoiPoolGradientOp<Context>::RunOnDevice() {
} else if (XIsType(Input(1), double)) { } else if (XIsType(Input(1), double)) {
DoRunWithTypeAndCast<double>(); DoRunWithTypeAndCast<double>();
} else { } else {
LOG(FATAL) << TypeString(Input(1), {"float16", "float32", "float64"}); LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(1).meta()), {"float16", "float32", "float64"});
}; };
} }
......
...@@ -52,14 +52,14 @@ def device(device_type, device_index=0): ...@@ -52,14 +52,14 @@ def device(device_type, device_index=0):
def eager_scope(data='${DATA}', graph='${GRAPH}'): def eager_scope(data='${DATA}', graph='${GRAPH}'):
"""Context-manager to nest the domain for eager resources. """Context-manager to nest the namespace for eager resources.
Parameters Parameters
---------- ----------
data : str, optional, default='${DATA}' data : str, optional, default='${DATA}'
The domain for resources traced by python. The namespace for resources traced by python.
graph : str, optional, default='${GRAPH}' graph : str, optional, default='${GRAPH}'
The domain for resources traced by graph. The namespace for resources traced by graph.
""" """
domain_tuple = (graph, data) domain_tuple = (graph, data)
......
...@@ -105,7 +105,7 @@ class Workspace(backend.Workspace): ...@@ -105,7 +105,7 @@ class Workspace(backend.Workspace):
return self._collectors return self._collectors
def as_default(self): def as_default(self):
"""Switch ``self`` as the default workspace. """Switch this workspace as the default.
Call this method with the **with** keyword. Call this method with the **with** keyword.
...@@ -114,7 +114,7 @@ class Workspace(backend.Workspace): ...@@ -114,7 +114,7 @@ class Workspace(backend.Workspace):
Returns Returns
------- -------
dragon.Workspace dragon.Workspace
The ``self``. This workspace.
""" """
return _GLOBAL_DEFAULT_WORKSPACE_STACK.get_controller(self) return _GLOBAL_DEFAULT_WORKSPACE_STACK.get_controller(self)
...@@ -273,7 +273,7 @@ class Workspace(backend.Workspace): ...@@ -273,7 +273,7 @@ class Workspace(backend.Workspace):
Returns Returns
------- -------
dragon.Workspace dragon.Workspace
The ``self``. This workspace.
""" """
self.MergeFrom(other) self.MergeFrom(other)
...@@ -302,7 +302,7 @@ class Workspace(backend.Workspace): ...@@ -302,7 +302,7 @@ class Workspace(backend.Workspace):
The tensor to reset. The tensor to reset.
""" """
return self.ResetTensor(_stringify_object(tensor)) self.ResetTensor(_stringify_object(tensor))
def run_backward( def run_backward(
self, self,
...@@ -487,8 +487,7 @@ _GLOBAL_DEFAULT_WORKSPACE_STACK = _DefaultWorkspaceStack() ...@@ -487,8 +487,7 @@ _GLOBAL_DEFAULT_WORKSPACE_STACK = _DefaultWorkspaceStack()
# Predefined graph executing stages. # Predefined graph executing stages.
_PREDEFINED_GRAPH_EXECUTING_STAGES = { _PREDEFINED_GRAPH_EXECUTING_STAGES = {
'default': {'include': '', 'exclude': ''}, 'default': {'include': '', 'exclude': ''},
'forward': {'include': '', 'exclude': 'Gradient'}, 'forward': {'include': '', 'exclude': '.*Gradient.*'},
'backward': {'include': 'Gradient', 'exclude': 'Generate'}, 'backward': {'include': '.*Gradient.*', 'exclude': 'GradientGenerate'},
'backward_v2': {'include': 'Gradient', 'exclude': ''}, 'backward_v2': {'include': '.*Gradient.*', 'exclude': ''},
'external_grads': {'include': '', 'exclude': 'Generate'},
} }
...@@ -153,7 +153,7 @@ setuptools.setup( ...@@ -153,7 +153,7 @@ setuptools.setup(
package_data={'dragon': find_package_data()}, package_data={'dragon': find_package_data()},
package_dir={'dragon': 'dragon'}, package_dir={'dragon': 'dragon'},
cmdclass={'bdist_wheel': bdist_wheel, 'install': install}, cmdclass={'bdist_wheel': bdist_wheel, 'install': install},
python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*', python_requires='>=3.5',
install_requires=['numpy', 'protobuf', 'kpl-dataset'], install_requires=['numpy', 'protobuf', 'kpl-dataset'],
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', 'Development Status :: 5 - Production/Stable',
...@@ -162,12 +162,12 @@ setuptools.setup( ...@@ -162,12 +162,12 @@ setuptools.setup(
'Intended Audience :: Science/Research', 'Intended Audience :: Science/Research',
'License :: OSI Approved :: BSD License', 'License :: OSI Approved :: BSD License',
'Programming Language :: C++', 'Programming Language :: C++',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
......
...@@ -24,10 +24,19 @@ ...@@ -24,10 +24,19 @@
#define TLS_OBJECT __declspec(thread) #define TLS_OBJECT __declspec(thread)
#endif #endif
// Disable the copy and assignment operator for a class
#define DISABLE_COPY_AND_ASSIGN(classname) \
classname(const classname&) = delete; \
classname& operator=(const classname&) = delete
// Concatenate two strings
#define CONCATENATE_IMPL(s1, s2) s1##s2 #define CONCATENATE_IMPL(s1, s2) s1##s2
#define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1, s2) #define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1, s2)
// Return a anonymous variable name using line number
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__) #define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__)
#define NOT_IMPLEMENTED \
LOG(FATAL) << "This module has not been implemented yet." // Throw a fatal logging for not implemented function
#define NOT_IMPLEMENTED LOG(FATAL) << "This function is not implemented."
#endif // DRAGON_UTILS_MARCROS_H_ #endif // DRAGON_UTILS_MARCROS_H_
:: ############################################################################# :: ##############################################################
:: Example command to build on Windows for Visual Studio 2013 (VC12). :: Command file to build on Windows for Visual Studio 2013 (VC12)
:: ############################################################################# :: ##############################################################
@echo off @echo off
setlocal setlocal
SET ORIGINAL_DIR=%cd% :: Build variables
SET REPO_ROOT=%~dp0%.. set ORIGINAL_DIR=%cd%
SET DRAGON_ROOT=%REPO_ROOT%\dragon set REPO_ROOT=%~dp0%..
SET THIRD_PARTY_DIR=%REPO_ROOT%\third_party set DRAGON_ROOT=%REPO_ROOT%\dragon
SET CMAKE_GENERATOR="Visual Studio 12 2013 Win64" set THIRD_PARTY_DIR=%REPO_ROOT%\third_party
set CMAKE_GENERATOR="Visual Studio 12 2013 Win64"
:: Build options :: Build options
SET BUILD_PYTHON=ON set BUILD_PYTHON=ON
SET BUILD_RUNTIME=OFF set BUILD_RUNTIME=OFF
:: Optional libraries
set USE_CUDA=ON
set USE_CUDNN=ON
set USE_OPENMP=ON
set USE_AVX=ON
set USE_AVX2=ON
set USE_FMA=ON
:: Protobuf SDK options :: Protobuf SDK options
SET PROTOBUF_SDK_ROOT_DIR=%THIRD_PARTY_DIR%\protobuf set PROTOBUF_SDK_ROOT_DIR=%THIRD_PARTY_DIR%\protobuf
:: Protobuf Compiler options :: Protobuf Compiler options
:: Set the protobuf compiler(i.e., protoc) if necessary :: Set the protobuf compiler(i.e., protoc) if necessary.
:: If not, a compiler in the sdk or environment will be used :: If not, a compiler in the sdk or environment will be used.
SET PROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_SDK_ROOT_DIR%\bin\protoc set PROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_SDK_ROOT_DIR%\bin\protoc
:: Python options :: Python options
:: Set your python "interpreter" if necessary :: Set your python "interpreter" if necessary.
:: If not, a default interpreter will be used :: If not, a default interpreter will be used.
:: SET PYTHON_EXECUTABLE=X:/Anaconda3/python :: set PYTHON_EXECUTABLE=X:/Anaconda3/python
if %BUILD_PYTHON% == ON ( if %BUILD_PYTHON% == ON (
if NOT DEFINED PYTHON_EXECUTABLE ( if NOT DEFINED PYTHON_EXECUTABLE (
for /F %%i in ('python -c "import sys;print(sys.executable)"') do (set PYTHON_EXECUTABLE=%%i) for /F %%i in ('python -c "import sys;print(sys.executable)"') do (set PYTHON_EXECUTABLE=%%i)
...@@ -47,6 +57,12 @@ cmake .. ^ ...@@ -47,6 +57,12 @@ cmake .. ^
-G%CMAKE_GENERATOR% ^ -G%CMAKE_GENERATOR% ^
-DBUILD_PYTHON=%BUILD_PYTHON% ^ -DBUILD_PYTHON=%BUILD_PYTHON% ^
-DBUILD_RUNTIME=%BUILD_RUNTIME% ^ -DBUILD_RUNTIME=%BUILD_RUNTIME% ^
-USE_CUDA==%USE_CUDA% ^
-USE_CUDNN==%USE_CUDNN% ^
-USE_OPENMP==%USE_OPENMP% ^
-USE_AVX==%USE_AVX% ^
-USE_AVX2==%USE_AVX2% ^
-USE_FMA==%USE_FMA% ^
-DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^ -DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^
-DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^ -DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^
-DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^ -DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^
......
:: ############################################################################# :: ##############################################################
:: Example command to build on Windows for Visual Studio 2015 (VC14). :: Command file to build on Windows for Visual Studio 2015 (VC14)
:: ############################################################################# :: ##############################################################
@echo off @echo off
setlocal setlocal
SET ORIGINAL_DIR=%cd% :: Build variables
SET REPO_ROOT=%~dp0%.. set ORIGINAL_DIR=%cd%
SET DRAGON_ROOT=%REPO_ROOT%\dragon set REPO_ROOT=%~dp0%..
SET THIRD_PARTY_DIR=%REPO_ROOT%\third_party set DRAGON_ROOT=%REPO_ROOT%\dragon
SET CMAKE_GENERATOR="Visual Studio 14 2015 Win64" set THIRD_PARTY_DIR=%REPO_ROOT%\third_party
set CMAKE_GENERATOR="Visual Studio 14 2015 Win64"
:: Build options :: Build options
SET BUILD_PYTHON=ON set BUILD_PYTHON=ON
SET BUILD_RUNTIME=OFF set BUILD_RUNTIME=OFF
:: Optional libraries
set USE_CUDA=ON
set USE_CUDNN=ON
set USE_OPENMP=ON
set USE_AVX=ON
set USE_AVX2=ON
set USE_FMA=ON
:: Protobuf SDK options :: Protobuf SDK options
SET PROTOBUF_SDK_ROOT_DIR=%THIRD_PARTY_DIR%\protobuf set PROTOBUF_SDK_ROOT_DIR=%THIRD_PARTY_DIR%\protobuf
:: Protobuf Compiler options :: Protobuf Compiler options
:: Set the protobuf compiler(i.e., protoc) if necessary :: Set the protobuf compiler(i.e., protoc) if necessary.
:: If not, a compiler in the sdk or environment will be used :: If not, a compiler in the sdk or environment will be used.
SET PROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_SDK_ROOT_DIR%\bin\protoc set PROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_SDK_ROOT_DIR%\bin\protoc
:: Python options :: Python options
:: Set your python "interpreter" if necessary :: Set your python "interpreter" if necessary.
:: If not, a default interpreter will be used :: If not, a default interpreter will be used.
:: SET PYTHON_EXECUTABLE=X:/Anaconda3/python :: set PYTHON_EXECUTABLE=X:/Anaconda3/python
if %BUILD_PYTHON% == ON ( if %BUILD_PYTHON% == ON (
if NOT DEFINED PYTHON_EXECUTABLE ( if NOT DEFINED PYTHON_EXECUTABLE (
for /F %%i in ('python -c "import sys;print(sys.executable)"') do (set PYTHON_EXECUTABLE=%%i) for /F %%i in ('python -c "import sys;print(sys.executable)"') do (set PYTHON_EXECUTABLE=%%i)
...@@ -47,6 +57,12 @@ cmake .. ^ ...@@ -47,6 +57,12 @@ cmake .. ^
-G%CMAKE_GENERATOR% ^ -G%CMAKE_GENERATOR% ^
-DBUILD_PYTHON=%BUILD_PYTHON% ^ -DBUILD_PYTHON=%BUILD_PYTHON% ^
-DBUILD_RUNTIME=%BUILD_RUNTIME% ^ -DBUILD_RUNTIME=%BUILD_RUNTIME% ^
-USE_CUDA==%USE_CUDA% ^
-USE_CUDNN==%USE_CUDNN% ^
-USE_OPENMP==%USE_OPENMP% ^
-USE_AVX==%USE_AVX% ^
-USE_AVX2==%USE_AVX2% ^
-USE_FMA==%USE_FMA% ^
-DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^ -DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^
-DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^ -DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^
-DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^ -DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^
......
:: ############################################################################# :: ###############################################################
:: Example command to build on Windows for Visual Studio 2017 (VC141). :: Command file to build on Windows for Visual Studio 2017 (VC141)
:: ############################################################################# :: ###############################################################
@echo off @echo off
setlocal setlocal
SET ORIGINAL_DIR=%cd% :: Build variables
SET REPO_ROOT=%~dp0%.. set ORIGINAL_DIR=%cd%
SET DRAGON_ROOT=%REPO_ROOT%\dragon set REPO_ROOT=%~dp0%..
SET THIRD_PARTY_DIR=%REPO_ROOT%\third_party set DRAGON_ROOT=%REPO_ROOT%\dragon
SET CMAKE_GENERATOR="Visual Studio 15 2017 Win64" set THIRD_PARTY_DIR=%REPO_ROOT%\third_party
set CMAKE_GENERATOR="Visual Studio 15 2017 Win64"
:: Build options :: Build options
SET BUILD_PYTHON=ON set BUILD_PYTHON=ON
SET BUILD_RUNTIME=OFF set BUILD_RUNTIME=OFF
:: Optional libraries
set USE_CUDA=ON
set USE_CUDNN=ON
set USE_OPENMP=ON
set USE_AVX=ON
set USE_AVX2=ON
set USE_FMA=ON
:: Protobuf SDK options :: Protobuf SDK options
SET PROTOBUF_SDK_ROOT_DIR=%THIRD_PARTY_DIR%\protobuf set PROTOBUF_SDK_ROOT_DIR=%THIRD_PARTY_DIR%\protobuf
:: Protobuf Compiler options :: Protobuf Compiler options
:: Set the protobuf compiler(i.e., protoc) if necessary :: Set the protobuf compiler(i.e., protoc) if necessary.
:: If not, a compiler in the sdk or environment will be used :: If not, a compiler in the sdk or environment will be used.
SET PROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_SDK_ROOT_DIR%\bin\protoc set PROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_SDK_ROOT_DIR%\bin\protoc
:: Python options :: Python options
:: Set your python "interpreter" if necessary :: Set your python "interpreter" if necessary.
:: If not, a default interpreter will be used :: If not, a default interpreter will be used.
:: SET PYTHON_EXECUTABLE=X:/Anaconda3/python :: set PYTHON_EXECUTABLE=X:/Anaconda3/python
if %BUILD_PYTHON% == ON ( if %BUILD_PYTHON% == ON (
if NOT DEFINED PYTHON_EXECUTABLE ( if NOT DEFINED PYTHON_EXECUTABLE (
for /F %%i in ('python -c "import sys;print(sys.executable)"') do (set PYTHON_EXECUTABLE=%%i) for /F %%i in ('python -c "import sys;print(sys.executable)"') do (set PYTHON_EXECUTABLE=%%i)
...@@ -47,6 +57,12 @@ cmake .. ^ ...@@ -47,6 +57,12 @@ cmake .. ^
-G%CMAKE_GENERATOR% ^ -G%CMAKE_GENERATOR% ^
-DBUILD_PYTHON=%BUILD_PYTHON% ^ -DBUILD_PYTHON=%BUILD_PYTHON% ^
-DBUILD_RUNTIME=%BUILD_RUNTIME% ^ -DBUILD_RUNTIME=%BUILD_RUNTIME% ^
-USE_CUDA==%USE_CUDA% ^
-USE_CUDNN==%USE_CUDNN% ^
-USE_OPENMP==%USE_OPENMP% ^
-USE_AVX==%USE_AVX% ^
-USE_AVX2==%USE_AVX2% ^
-USE_FMA==%USE_FMA% ^
-DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^ -DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^
-DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^ -DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^
-DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^ -DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!