Commit c5def39b by Ting PAN

Dragon 0.3 Preview

1 parent 84320495
Showing with 1333 additions and 356 deletions
...@@ -4,3 +4,9 @@ ...@@ -4,3 +4,9 @@
[submodule "DragonLair"] [submodule "DragonLair"]
path = DragonLair path = DragonLair
url = https://github.com/seetaresearch/DragonLair.git url = https://github.com/seetaresearch/DragonLair.git
[submodule "ThirdParty/eigen"]
path = ThirdParty/eigen
url = https://github.com/eigenteam/eigen-git-mirror.git
[submodule "ThirdParty/cub"]
path = ThirdParty/cub
url = https://github.com/NVlabs/cub
This directory holds (*after you download them*):
- msmpi.dll / mpiexec.exe / smpd.exe (for ``mpi``, Windows Only)
- cudnn64_*.dll (For ``cudnn``, Windows Only)
- libopenblas.dll / libquadmath-0.dll / libgfortran-3.dll / libgcc_s_seh-1.dll (For ``cblas``, Windows Only)
This directory holds (*after you download them*):
- mpi/*.h (for ``mpi``, Windows/Linux)
- google/protobuf/*.h (For ``google protobuf``, Windows Only)
- cudnn.h (For ``cudnn``, Windows Only)
- cblas.h and relevent header files (For ``cblas``, Windows/Linux)
- getopt.h and unistd.h (For ``platform-relevent`` header files, Windows Only)
This directory holds (*after you download them*):
- msmpi.lib/libmpi.so (for ``mpi``, Windows/Linux)
- libprotobuf.lib (For ``google protobuf``, Windows Only)
- cudnn.lib (For ``cudnn``, Windows Only)
- libopenblas.lib (For ``cblas``, Windows Only)
- python27.lib/python35.lib/python36.lib (For ``python27/35/36``, Windows Only)
------------------------------------------------------------------------ ------------------------------------------------------------------------
The list of most significant changes made over time in Dragon. The list of most significant changes made over time in Dragon.
Dragon 0.3.0.0 (20190110)
DRAGON_VERSION == 3000
Changes (w.r.t. Dragon 0.2.2.13):
Preview Features:
- New V.M framework: ONNX
We have extended the exporting and importing(Runtime) for ONNX.
- Operators Refactor:
* <NDArray> Faster implementation for following multiple axes operators:
``Crop``, ``Pad``, ``Tile``, ``Reduce`` and ``Transpose``.
* <Norm> Faster implementation for fused norm operators:
``BatchNorm``, ``GroupNorm``, ``LayerNorm``, ``InstanceNorm``.
- Use ``Eigen`` as the default cpu math library instead of ``OpenBLAS``.
- Integer data types support for common operators,
see the documentation for more detail information.
- A new workspace-local dummy name interface has been introduced,
which unifies the naming of static and dynamic computation graph.
Bugs fixed:
- Repair the basic TensorFlow API, following the master branch.
- More reliable shape inference for static computation graph.
------------------------------------------------------------------------
Dragon 0.2.2.13 (20181204) Dragon 0.2.2.13 (20181204)
DRAGON_VERSION == 2213 DRAGON_VERSION == 2213
......
...@@ -14,7 +14,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ...@@ -14,7 +14,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libprotobuf-dev \ libprotobuf-dev \
protobuf-compiler \ protobuf-compiler \
libopencv-dev \ libopencv-dev \
libopenblas-dev \
libboost-all-dev \ libboost-all-dev \
python3-pip \ python3-pip \
python3-dev \ python3-dev \
...@@ -35,14 +34,13 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel -i https://pypi.tuna. ...@@ -35,14 +34,13 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel -i https://pypi.tuna.
pyyaml \ pyyaml \
cython cython
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/3rdparty.zip && \ RUN git clone --recursive https://github.com/seetaresearch/Dragon.git && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \ mv Dragon/ThirdParty ./ && rm -rf Dragon
cd openmpi && ls | grep -v install | xargs rm -r && cp install/bin/mpirun /usr/bin
RUN git clone https://github.com/seetaresearch/Dragon.git && \ RUN git clone https://github.com/seetaresearch/Dragon.git && \
cd Dragon/Dragon && rm CMakeLists.txt && \ cd Dragon/Dragon && mkdir build && cd build && cmake .. -DTHIRD_PARTY_DIR=/ThirdParty \
wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/CMakeLists.txt && \ -DPYTHON_EXECUTABLE=/usr/bin/python3 -DWITH_CUDA=OFF -DWITH_CUDNN=OFF -DBUILD_CXX_API=ON && \
mkdir build && cd build && cmake .. && make install -j8 && cd .. && rm -rf build && \ make install -j $(nproc) && cd .. && rm -rf build && \
cd python && python3 setup.py install cd python && python3 setup.py install
RUN rm /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && ln -s /usr/bin/pip3 /usr/bin/pip RUN rm /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && ln -s /usr/bin/pip3 /usr/bin/pip
\ No newline at end of file
...@@ -15,7 +15,6 @@ RUN rm /etc/apt/sources.list.d/cuda.list && rm /etc/apt/sources.list.d/nvidia-ml ...@@ -15,7 +15,6 @@ RUN rm /etc/apt/sources.list.d/cuda.list && rm /etc/apt/sources.list.d/nvidia-ml
libprotobuf-dev \ libprotobuf-dev \
protobuf-compiler \ protobuf-compiler \
libopencv-dev \ libopencv-dev \
libopenblas-dev \
libboost-all-dev \ libboost-all-dev \
libnccl2 \ libnccl2 \
libnccl-dev \ libnccl-dev \
...@@ -38,14 +37,15 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel -i https://pypi.tuna. ...@@ -38,14 +37,15 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel -i https://pypi.tuna.
pyyaml \ pyyaml \
cython cython
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/3rdparty.zip && \ RUN git clone --recursive https://github.com/seetaresearch/Dragon.git && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \ mv Dragon/ThirdParty ./ && rm -rf Dragon
cd openmpi && ls | grep -v install | xargs rm -r && cp install/bin/mpirun /usr/bin
RUN cd ThirdParty/mpi && bash build.sh && rm -rf src *.gz && cp bin/mpirun /usr/bin
RUN git clone https://github.com/seetaresearch/Dragon.git && \ RUN git clone https://github.com/seetaresearch/Dragon.git && \
cd Dragon/Dragon && rm CMakeLists.txt && \ cd Dragon/Dragon && mkdir build && cd build && cmake .. -DTHIRD_PARTY_DIR=/ThirdParty \
wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/CMakeLists.txt && \ -DPYTHON_EXECUTABLE=/usr/bin/python3 -DWITH_MPI=ON -DWITH_NCCL=ON -DBUILD_CXX_API=ON && \
mkdir build && cd build && cmake .. && make install -j8 && cd .. && rm -rf build && \ make install -j $(nproc) && cd .. && rm -rf build && \
cd python && python3 setup.py install cd python && python3 setup.py install
RUN rm /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && ln -s /usr/bin/pip3 /usr/bin/pip RUN rm /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && ln -s /usr/bin/pip3 /usr/bin/pip
\ No newline at end of file
Building Dragon Documentation
=============================
This page will help you to build the following documentations:
Dragon CXX API: http://dragon.seetatech.com/api/cpp/index.html
Dragon Python API: http://dragon.seetatech.com/api/python/index.html
Build Documentation of CXX API
------------------------------
```bash
cd Dragon/Docs/api/cxx
doxygen Doxyfile
```
Then, open the ```./html/index.html``` in your browser.
Build Documentation of Python API
---------------------------------
```bash
pip install sphinx_bootstrap_theme
cd Dragon/Docs/api/python
make html
```
Then, open the ```./_build/html/index.html``` in your browser.
\ No newline at end of file
...@@ -791,9 +791,9 @@ WARN_LOGFILE = ...@@ -791,9 +791,9 @@ WARN_LOGFILE =
# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
# Note: If this tag is empty the current directory is searched. # Note: If this tag is empty the current directory is searched.
INPUT = ../../Dragon/include \ INPUT = ../../../Dragon/include \
../../Dragon/src \ ../../../Dragon/src \
../../Dragon/modules ../../../Dragon/modules
# This tag can be used to specify the character encoding of the source files # This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
......
/*-------------------- basics --------------------*/
body {
font-size: 16px;
font-family: Lato, "Helvetica Neue", Helvetica, Arial, sans-serif;
}
table.field-list {
width: 100%;
}
.field-list ul {
padding-left: 0;
}
.field-name {
width: 110px;
}
code.descname,
code.descclassname {
color: #555
}
div.admonition {
padding: 16px;
background-color: #f6f8fa;
border: 1px solid rgb(204, 204, 204);
border-radius: 4px;
}
p.admonition-title {
color: #187dbb;
margin: 0px 10px 5px 0px;
font-weight: bold;
font-size: 16px;
line-height: 1.5em;
}
p.last {
font-size: 16px;
color: #000;
}
p {
margin: 0px 0px 21px 0px
}
.footer {
width: 0
}
h1 {
margin-bottom: 25px;
}
h2 {
margin-bottom: 21px;
}
h3 {
margin-top: 10.5px;
margin-bottom: 30px;
}
h1, h2, h3, h4, h5, h6, .h1, .h2, .h3, .h4, .h5, .h6 {
font-family: "Lato"
}
a {
color: #2786e0;
text-decoration: none;
}
a:hover, a:focus {
color: #902594;
text-decoration: none;
}
.context {
display: none
}
/*-------------------- layout --------------------*/
.container.doc-container {
width: 97%;
}
div.content {
padding: 0px 30px 0px 30px;
float: right;
width: calc(100% - 345px);
}
@media (max-width: 999px) {
div.content {
overflow-x: auto;
width: 100%;
}
}
.section:before {
content: " ";
display: block;
height: 60px;
margin: -60px 0 0;
}
/*-------------------- navbar-brand --------------------*/
.navbar-default .navbar-brand {
font-size: 25px;
font-weight: bold;
}
.navbar-brand {
height: 60px;
}
/*-------------------- navbar-main --------------------*/
.navbar {
#background-color: #2695ff;
background: linear-gradient(to right, #4070a0, #5bc0de);
border: 0px;
margin-bottom: 0px;
opacity: 0.9
}
.navbar-inverse .navbar-collapse {
background-color: transparent;
}
.navbar-nav {
font-size: 18px;
margin-left: 25px;
}
.navbar .container {
width: 97%
}
/*-------------------- navbar-item --------------------*/
.navbar-nav>li {
-webkit-transition: .2s;
transition: .2s;
display: inline-block;
padding: 7px 15px;
}
.navbar-inverse .navbar-nav>li>a {
opacity: .7;
}
.navbar-inverse .navbar-nav>li>a:hover {
opacity: 1;
color: #fff;
text-decoration: none;
}
.navbar-inverse .navbar-nav>li>a:hover,
.navbar-inverse .navbar-nav>li>a:focus {
color: #ffffff;
background-color: transparent
}
.navbar-inverse .navbar-nav>.open>a:active,
.navbar-inverse .navbar-nav>.open>a:hover,
.navbar-inverse .navbar-nav>.open>a:focus {
background-color: transparent;
text-align: center;
}
.navbar-inverse .navbar-nav>.open>a {
background-color: transparent;
}
/*-------------------- navbar-dropdown --------------------*/
.navbar-inverse .navbar-nav>.open .dropdown-menu > li > a,
.navbar-inverse .navbar-nav .open .dropdown-menu>li>a:hover,
.navbar-inverse .navbar-nav .open .dropdown-menu>li>a:focus {
color: #4070a0;;
background-color: white;
box-shadow: 0px 1px 1px 1px;
}
.navbar-inverse .navbar-nav .open .dropdown-menu>li>a:hover {
background-color: white;
color: #5bc0de;
}
.navbar-inverse .dropdown-menu {
background-color: #fff;
top: 100%;
border-radius: 3px;
text-align: center;
opacity: 0.95;
}
.navbar-nav .open .dropdown-menu {
float: left;
border-color: rgba(38, 149, 255, 0.32);
}
.navbar .dropdown-menu {
border:none;
}
.navbar-inverse .dropdown-menu>li>a:hover {
color: #0a5655;
background-color: #f5f5f5;
}
.navbar-inverse .navbar-nav>.open>a:hover,
.navbar-inverse .navbar-nav>.open>a:focus {
text-align: left;
}
.navbar .dropdown-menu>li>a, .navbar .dropdown-menu>li>a:focus {
font-size: 16px;
font-weight: 600;
}
.navbar-inverse .navbar-nav>.open .dropdown-menu > li > a {
color: #4070a0;
padding: 8px;
}
.navbar-nav>li>.dropdown-menu {
margin-top: 1px;
}
/*-------------------- navbar-search --------------------*/
.navbar-inverse .navbar-form {
border-color: #2695ff;
padding-top: 7px;
padding-bottom: 1px;
border: none;
}
.navbar-form .form-control {
border-radius: 5px;
}
.navbar-toggle {
padding: 0px 0px;
margin-top: 20px;
}
/*-------------------- code --------------------*/
.xref.py.py-mod.docutils.literal {
color: #103d3e;
background-color: #ffffff;
font-size: 40px;
font-family: Consolas;
padding: 0px;
}
.xref.py.py-mod.docutils.literal:target {
background-color: #e7f2fa;
border-bottom: 3px solid #c7254e;
margin-bottom: -3px;
}
code.docutils.literal {
color: #32577b;
font-family: Lato;
font-weight: bold;
font-size: 14px;
}
code.docutils.literal:hover {
color: #902594;
}
.highlight-python {
margin-bottom: 21px;
}
dt {
font-weight: 700;
background: #e7f2fa;
border-bottom: solid #0079b2;
border-radius: 1px;
margin-bottom: 20px;
}
dt:target, .highlighted {
background-color: #e7f2fa;
border-bottom: 3px solid #c7254e;
}
dt:target:before {
background-color: white;
content: '';
display: block;
height: 65px;
margin: -20px 0 0;
}
dl.method dt {
background: #f0f0f0;
border-bottom: solid #ccc;
}
dt em {
font-weight: normal;
font-style: normal;
font-size: 90%;
}
dd {
margin-top: 3px;
margin-bottom: 10px;
margin-left: 30px;
}
table {
font-size: 16px;
}
table.table tr, td, th {
padding-top: 5px;
padding-bottom: 5px;
padding-left: 10px;
padding-right: 10px;
border: 1px solid rgb(223, 226, 229);
}
table.field-list.table tr {
border: 0px solid rgb(223, 226, 229);
}
blockquote {
border-left-width: 0px;
color: #6f6f6f;
padding: 10.5px 21px;
margin: 0 0 21px;
font-size: 17px;
border-left: 5px solid #dddddd;
background-color: #f6f8fa;
border: 1px solid rgb(204, 204, 204);
border-radius: 4px;
max-width: 750px;
}
ul.simple li {
margin-bottom: 10px;
margin-left: 10px;
}
/*------------------sidebar-----------------------*/
div.sphinxsidebar {
position: fixed;
overflow: auto;
display: none;
height: calc(100% - 40px);
}
div.leftsidebar {
width: 300px;
margin-left: 25px;
background: transparent;
}
@media (min-width: 1000px) {
div.sphinxsidebar {display: block}
}
div.sphinxsidebar ul {
padding-left: 25px;
list-style-type: none !important;
}
div.sphinxsidebar li {
padding-top: 5px;
margin-bottom: 5px;
margin-left: -10px;
}
div.sphinxsidebarwrapper {
padding: 60px 10px 60px 20px;
background: transparent;
}
div.sphinxsidebar ul ul {
margin-left: 0px;
padding-top: 3px;
}
div.sphinxsidebar li.opened .tocToggle:before {
font-family: 'FontAwesome';
content: "\f115";
margin: 0 5px 0 -15px;
color: #2695ff;
}
div.sphinxsidebar li.closed .tocToggle:before {
font-family: 'FontAwesome';
content: "\f0fe";
margin: 0 5px 0 -15px;
color: #2695ff;
}
div.sphinxsidebar li.leaf .tocToggle:before {
font-family: 'FontAwesome';
content: "\f101";
margin: 0 5px 0 -15px;
color: #2695ff;
}
div.sphinxsidebar li.focused .tocToggle:before {
font-family: 'FontAwesome';
content: "\f06e";
margin: 0 5px 0 -15px;
color: #2695ff;
}
/*-------------------- install --------------------*/
.opt-group {
margin-top: 10px;
margin-bottom: 10px;
}
.btn-default {
color: #333;
background-color: #fff;
border-color: #ccc;
}
.btn-default:hover, .btn-default:focus, .btn-default:active, .btn-default.active.focus,
.btn-default.active:focus, .btn-default.active:hover, .btn-default:active.focus, .btn-default:active:focus,
.btn-default:active:hover, .btn-default.active, .open>.dropdown-toggle.btn-default, .btn-default:active:focus {
color: #fff;
background-color: #0079b2;
border-color: #0079b2;
}
\ No newline at end of file
(function ($) {
/**
* Patch TOC list.
*
* Will mutate the underlying span to have a correct ul for nav.
*
* @param $span: Span containing nested UL"s to mutate.
* @param minLevel: Starting level for nested lists. (1: global, 2: local).
*/
var patchToc = function ($ul, minLevel) {
var findA,
patchTables,
$localLi;
// Find all a "internal" tags, traversing recursively.
findA = function ($elem, level) {
level = level || 0;
var $items = $elem.find("> li > a.internal, > ul, > li > ul");
// Iterate everything in order.
$items.each(function (index, item) {
var $item = $(item),
tag = item.tagName.toLowerCase(),
$childrenLi = $item.children("li"),
$parentLi = $($item.parent("li"), $item.parent().parent("li"));
// Add dropdowns if more children and above minimum level.
if (tag === "ul" && level >= minLevel && $childrenLi.length > 0) {
$parentLi
.addClass("dropdown-submenu")
.children("a").first().attr("tabindex", -1);
$item.addClass("dropdown-menu");
}
findA($item, level + 1);
});
};
findA($ul);
};
/**
* Patch all tables to remove ``docutils`` class and add Bootstrap base
* ``table`` class.
*/
patchTables = function () {
$("table.docutils")
.removeClass("docutils")
.addClass("table")
.attr("border", 0);
};
$(window).load(function () {
/*
* Scroll the window to avoid the topnav bar
* https://github.com/twbs/bootstrap/issues/1768
if ($("#navbar.navbar-fixed-top").length > 0) {
var navHeight = $("#navbar").height(),
shiftWindow = function() { scrollBy(0, -navHeight - 10); };
if (location.hash) {
setTimeout(shiftWindow, 1);
}
window.addEventListener("hashchange", shiftWindow);
}
*/
});
$(document).ready(function () {
// Add styling, structure to TOC"s.
$(".dropdown-menu").each(function () {
$(this).find("ul").each(function (index, item){
var $item = $(item);
$item.addClass("unstyled");
});
});
// Global TOC.
if ($("ul.globaltoc li").length) {
patchToc($("ul.globaltoc"), 1);
} else {
// Remove Global TOC.
$(".globaltoc-container").remove();
}
// Local TOC.
$(".bs-sidenav ul").addClass("nav nav-list");
$(".bs-sidenav > ul > li > a").addClass("nav-header");
// back to top
setTimeout(function () {
var $sideBar = $(".bs-sidenav");
var $content = $(".content");
// Enlarge content if sidebar is larger.
if ($sideBar.outerHeight(true) > $content.outerHeight(true)) {
$content.css("min-height", $sideBar.outerHeight(true));
}
$sideBar
// Add affix.
.affix({
offset: {
top: function () {
var offsetTop = $sideBar.offset().top;
var sideBarMargin = parseInt($sideBar.css("margin-top"), 10);
var navOuterHeight = $("#navbar").outerHeight(true);
return (this.top = offsetTop - navOuterHeight);
},
bottom: function () {
return (this.bottom = $(".footer").outerHeight(true));
}
}
})
// Trigger to reset if page content is scrolled to bottom.
.trigger("scroll.bs.affix.data-api");
}, 0);
// Local TOC.
patchToc($("ul.localtoc"), 2);
// Mutate sub-lists (for bs-2.3.0).
$(".dropdown-menu ul").not(".dropdown-menu").each(function () {
var $ul = $(this),
$parent = $ul.parent(),
tag = $parent[0].tagName.toLowerCase(),
$kids = $ul.children().detach();
// Replace list with items if submenu header.
if (tag === "ul") {
$ul.replaceWith($kids);
} else if (tag === "li") {
// Insert into previous list.
$parent.after($kids);
$ul.remove();
}
});
// Add divider in page TOC.
$localLi = $("ul.localtoc li");
if ($localLi.length > 2) {
$localLi.first().after("<li class=\"divider\"></li>");
}
// Patch tables.
patchTables();
// Add Note, Warning styles. (BS v2,3 compatible).
$(".admonition").addClass("alert alert-info")
.filter(".warning, .caution")
.removeClass("alert-info")
.addClass("alert-warning").end()
.filter(".error, .danger")
.removeClass("alert-info")
.addClass("alert-danger alert-error").end();
// Inline code styles to Bootstrap style.
$("tt.docutils.literal").not(".xref").each(function (i, e) {
// ignore references
if (!$(e).parent().hasClass("reference")) {
$(e).replaceWith(function () {
return $("<code />").html($(this).html());
});
}});
// Update sourcelink to remove outerdiv (fixes appearance in navbar).
var $srcLink = $(".nav #sourcelink");
$srcLink.parent().html($srcLink.html());
});
}(window.$jqTheme || window.jQuery));
\ No newline at end of file
This diff could not be displayed because it is too large.
$(document).ready(function () {
function label(lbl) {
return lbl.replace(/[ .]/g, '-').toLowerCase();
}
function showContent() {
$('.opt-group .opt').each(function(){
$('.'+label($(this).text())).hide();
$('.highlight-'+label($(this).text())).hide();
});
$('.opt-group .active').each(function(){
$('.'+label($(this).text())).show();
$('.highlight-'+label($(this).text())).show();
});
}
showContent();
function setContent() {
var el = $(this);
el.siblings().removeClass('active');
el.addClass('active');
showContent();
}
$('.opt-group').on('click', '.opt', setContent);
});
\ No newline at end of file
$(document).ready(function () {
function addToggle(tocClass) {
// Add Title
$(tocClass + " div.sphinxsidebarwrapper").prepend("<h3>Contents</h3>");
var allEntry = $(tocClass + " div.sphinxsidebarwrapper li");
var L1Entry = $(tocClass + " div.sphinxsidebarwrapper").children("ul").first().children("li");
allEntry.each(function () {
$(this).prepend("<span class='tocToggle'></span>");
var childUL = $(this).find("ul");
if (childUL.length && childUL.first().children().length) {
// $(this).addClass('closed');
// $(this).find("ul").first().hide();
} else {
$(this).addClass("leaf");
}
var anchor = $(this).children("a").first();
anchor.click(function() {toggle(anchor); autoExpand(anchor);});
});
// toctree-l1
L1Entry.each(function () {
$(this).removeClass("leaf").addClass('closed');
$(this).find("ul").first().show();
}
)
};
toggle = function(elem) {
if ($(elem).parent().hasClass("closed")) {
$(elem).parent().find("ul").first().show();
$(elem).parent().removeClass("closed").addClass("opened");
} else if ($(elem).parent().hasClass("opened")) {
$(elem).parent().find("ul").first().hide();
$(elem).parent().removeClass("opened").addClass("closed");
} else {
}
}
function autoExpand(elem) {
if (elem.parent().hasClass("closed")) {
elem.parent().removeClass("closed").addClass("opened");
elem.parent().children("ul").first().show();
} else if (elem.parent().hasClass("opened")) {
elem.parent().removeClass("opened").addClass("closed");
elem.parent().children("ul").first().hide();
} else {
}
}
function keepExpand() {
var url = window.location.href, currentEntry;
var entryList = $('.sphinxsidebar li');
for(var i = entryList.length - 1; i >= 0; --i) {
var entryURL = entryList.eq(i).find('a').first().attr('href');
if (entryURL == '#') {
currentEntry = entryList.eq(i);
break;
}
}
var allEntry = $(".leftsidebar div.sphinxsidebarwrapper li");
allEntry.each(function () {
var anchor = $(this).children("a").first();
anchor.click(function () { autoExpand(anchor); });
});
if (!currentEntry.hasClass('leaf')) currentEntry.removeClass("closed").addClass("opened");
else currentEntry.removeClass("opened").addClass("focused");
while(currentEntry.parent().is('ul') && currentEntry.parent().parent().is('li')) {
currentEntry = currentEntry.parent().parent();
xx = currentEntry.parent().children('li');
xx.each(function () {$(this).removeClass('leaf').addClass('closed');});
currentEntry.removeClass("closed").addClass("opened");
currentEntry.children("ul").first().show();
}
}
addToggle(".leftsidebar");
keepExpand()
});
\ No newline at end of file
...@@ -67,4 +67,4 @@ html_sidebars = {'index': ['localtoc.html'], ...@@ -67,4 +67,4 @@ html_sidebars = {'index': ['localtoc.html'],
# overloads # overloads
def setup(app): def setup(app):
app.config.values['autodoc_member_order'] = ('bysource', True) app.config.values['autodoc_member_order'] = ('bysource', True)
\ No newline at end of file
...@@ -5,16 +5,5 @@ ...@@ -5,16 +5,5 @@
.. toctree:: .. toctree::
:hidden: :hidden:
.. currentmodule:: dragon.core.scope .. automodule:: dragon.core.scope
:members:
.. autoclass:: dragon.core.scope.TensorScope \ No newline at end of file
:members:
.. autoclass:: dragon.core.scope.PhaseScope
:members:
.. autoclass:: dragon.core.scope.DeviceScope
:members:
.. autoclass:: dragon.core.scope.WorkspaceScope
:members:
...@@ -36,8 +36,6 @@ List Brief ...@@ -36,8 +36,6 @@ List Brief
`Tensor.Normal`_ Register as a variable with normal initializer. `Tensor.Normal`_ Register as a variable with normal initializer.
`Tensor.TruncatedNormal`_ Register as a variable with truncated normal initializer. `Tensor.TruncatedNormal`_ Register as a variable with truncated normal initializer.
`Tensor.Gaussian`_ Register as a variable with gaussian initializer. `Tensor.Gaussian`_ Register as a variable with gaussian initializer.
`Tensor.Xavier`_ Register as a variable with xavier initializer.
`Tensor.MSRA`_ Register as a variable with msra initializer.
`Tensor.GlorotUniform`_ Register as a variable with glorot uniform initializer. `Tensor.GlorotUniform`_ Register as a variable with glorot uniform initializer.
`Tensor.GlorotNormal`_ Register as a variable with glorot normal initializer. `Tensor.GlorotNormal`_ Register as a variable with glorot normal initializer.
============================== ============================================================================= ============================== =============================================================================
...@@ -80,8 +78,6 @@ API Reference ...@@ -80,8 +78,6 @@ API Reference
.. _Tensor.Normal: #dragon.core.tensor.Tensor.Normal .. _Tensor.Normal: #dragon.core.tensor.Tensor.Normal
.. _Tensor.TruncatedNormal: #dragon.core.tensor.Tensor.TruncatedNormal .. _Tensor.TruncatedNormal: #dragon.core.tensor.Tensor.TruncatedNormal
.. _Tensor.Gaussian: #dragon.core.tensor.Tensor.Gaussian .. _Tensor.Gaussian: #dragon.core.tensor.Tensor.Gaussian
.. _Tensor.Xavier: #dragon.core.tensor.Tensor.Xavier
.. _Tensor.MSRA: #dragon.core.tensor.Tensor.MSRA
.. _Tensor.GlorotUniform: #dragon.core.tensor.Tensor.GlorotUniform .. _Tensor.GlorotUniform: #dragon.core.tensor.Tensor.GlorotUniform
.. _Tensor.GlorotNormal: #dragon.core.tensor.Tensor.GlorotNormal .. _Tensor.GlorotNormal: #dragon.core.tensor.Tensor.GlorotNormal
......
...@@ -55,7 +55,6 @@ List Brief ...@@ -55,7 +55,6 @@ List Brief
`ResetWorkspace`_ Reset the specific workspace. `ResetWorkspace`_ Reset the specific workspace.
`ClearWorkspace`_ Clear the specific workspace. `ClearWorkspace`_ Clear the specific workspace.
`LogMetaGraph`_ Log the meta graph. `LogMetaGraph`_ Log the meta graph.
`LogOptimizedGraph`_ Log the optimized graph.
`ExportMetaGraph`_ Export the meta graph into a file under specific folder. `ExportMetaGraph`_ Export the meta graph into a file under specific folder.
============================== ============================================================================= ============================== =============================================================================
...@@ -85,7 +84,6 @@ API Reference ...@@ -85,7 +84,6 @@ API Reference
.. _Snapshot: #dragon.core.workspace.Snapshot .. _Snapshot: #dragon.core.workspace.Snapshot
.. _Restore: #dragon.core.workspace.Restore .. _Restore: #dragon.core.workspace.Restore
.. _LogMetaGraph: #dragon.core.workspace.LogMetaGraph .. _LogMetaGraph: #dragon.core.workspace.LogMetaGraph
.. _LogOptimizedGraph: #dragon.core.workspace.LogOptimizedGraph
.. _ExportMetaGraph: #dragon.core.workspace.ExportMetaGraph .. _ExportMetaGraph: #dragon.core.workspace.ExportMetaGraph
.. _theano.function(*args, **kwargs): ../vm/theano/compile.html#dragon.vm.theano.compile.function.function .. _theano.function(*args, **kwargs): ../vm/theano/compile.html#dragon.vm.theano.compile.function.function
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
.. automodule:: dragon.operators.loss .. automodule:: dragon.operators.loss
:members: :members:
.. |smooth_l1_beta| mathmacro:: \, \frac{1}{\sigma^{2}}
.. |l1_loss_function| mathmacro:: \, Loss = \frac{ \sum \left| Weight * (Input - Target) \right|}{ Normalization} .. |l1_loss_function| mathmacro:: \, Loss = \frac{ \sum \left| Weight * (Input - Target) \right|}{ Normalization}
.. |l2_loss_function| mathmacro:: \, Loss = \frac{ \sum \frac{1}{2}\left|\left| Weight * (Input - Target) \right|\right|}{ Normalization} .. |l2_loss_function| mathmacro:: \, Loss = \frac{ \sum \frac{1}{2}\left|\left| Weight * (Input - Target) \right|\right|}{ Normalization}
\ No newline at end of file
...@@ -33,6 +33,6 @@ ...@@ -33,6 +33,6 @@
.. |caffe_moving_average_function| mathmacro:: \\ \, \\ x_{moving} \leftarrow Momentum * x_{moving} + x_{stat} \\ \, .. |caffe_moving_average_function| mathmacro:: \\ \, \\ x_{moving} \leftarrow Momentum * x_{moving} + x_{stat} \\ \,
.. _ops.Scale(*args, **kwargs): arithmetic.html#dragon.operators.arithmetic.Scale .. _ops.Affine(*args, **kwargs): arithmetic.html#dragon.operators.arithmetic.Affine
.. _Caffe: https://github.com/BVLC/caffe/ .. _Caffe: https://github.com/BVLC/caffe/
===================
:mod:`dragon.utils`
===================
Wrapper
-------
.. toctree::
:hidden:
utils/vision/data_batch
=================================== =====================================================================
List Brief
=================================== =====================================================================
`dragon.utils.vision.data_batch`_ Efficient Batch data provider based on `LMDB`_.
=================================== =====================================================================
Component
---------
.. toctree::
:hidden:
utils/vision/data_reader
utils/vision/data_transformer
utils/vision/blob_fetcher
========================================== =====================================================================
List Brief
========================================== =====================================================================
`dragon.utils.vision.data_reader`_ Queue encoded string from `LMDB`_.
`dragon.utils.vision.data_transformer`_ Queue transformed images from `DataReader`_.
`dragon.utils.vision.blob_fetcher`_ Queue blobs from `DataTransformer`_.
========================================== =====================================================================
.. _LMDB: http://lmdb.readthedocs.io/en/release
.. _DataReader: utils/vision/data_reader.html#dragon.utils.vision.data_reader
.. _DataTransformer: utils/vision/data_transformer.html#dragon.utils.vision.data_transformer
.. _dragon.utils.vision.data_batch: utils/vision/data_batch.html
.. _dragon.utils.vision.data_reader: utils/vision/data_reader.html
.. _dragon.utils.vision.data_transformer: utils/vision/data_transformer.html
.. _dragon.utils.vision.blob_fetcher: utils/vision/blob_fetcher.html
\ No newline at end of file
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
.. toctree:: .. toctree::
:hidden: :hidden:
.. currentmodule:: dragon.io.blob_fetcher .. currentmodule:: dragon.utils.vision.blob_fetcher
.. autoclass:: BlobFetcher .. autoclass:: BlobFetcher
:members: :members:
.. automethod:: __init__ .. automethod:: __init__
.. _DataTransformer: data_transformer.html#dragon.io.data_transformer .. _DataTransformer: data_transformer.html#dragon.utils.vision.data_transformer
\ No newline at end of file \ No newline at end of file
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
.. toctree:: .. toctree::
:hidden: :hidden:
.. currentmodule:: dragon.io.data_batch .. currentmodule:: dragon.utils.vision.data_batch
.. autoclass:: DataBatch .. autoclass:: DataBatch
:members: :members:
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
.. toctree:: .. toctree::
:hidden: :hidden:
.. currentmodule:: dragon.io.data_reader .. currentmodule:: dragon.utils.vision.data_reader
.. autoclass:: DataReader .. autoclass:: DataReader
:members: :members:
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
.. toctree:: .. toctree::
:hidden: :hidden:
.. currentmodule:: dragon.io.data_transformer .. currentmodule:: dragon.utils.vision.data_transformer
.. autoclass:: DataTransformer .. autoclass:: DataTransformer
:members: :members:
......
...@@ -41,7 +41,7 @@ leading the results to be deterministic and reproduceable. ...@@ -41,7 +41,7 @@ leading the results to be deterministic and reproduceable.
|paratitle| **Keras** |paratitle| **Keras**
|para| Referring to **Keras**, whose APIs are designed to wrap existing backends, |para| Referring to **Keras**, whose API is designed to wrap existing backends,
we **IMPLEMENT** most `Layers`_ of `Caffe`_ without any efforts. we **IMPLEMENT** most `Layers`_ of `Caffe`_ without any efforts.
Our backend has provided both the latest and optimized deep learning operators, Our backend has provided both the latest and optimized deep learning operators,
that outperforms the original `Caffe`_ or other forked ones. that outperforms the original `Caffe`_ or other forked ones.
......
...@@ -65,6 +65,7 @@ List Brief ...@@ -65,6 +65,7 @@ List Brief
`EltwiseLayer`_ The implementation of ``EltwiseLayer`` `EltwiseLayer`_ The implementation of ``EltwiseLayer``
`AddLayer`_ The extended implementation of ``EltwiseLayer``. `AddLayer`_ The extended implementation of ``EltwiseLayer``.
`ConcatLayer`_ The implementation of ``ConcatLayer``. `ConcatLayer`_ The implementation of ``ConcatLayer``.
`SliceLayer`_ The implementation of ``SliceLayer``.
`DenseConcatLayer`_ The implementation for `DenseNet`_. `DenseConcatLayer`_ The implementation for `DenseNet`_.
`CropLayer`_ The implementation of ``CropLayer``. `CropLayer`_ The implementation of ``CropLayer``.
`ReshapeLayer`_ The implementation of ``ReshapeLayer``. `ReshapeLayer`_ The implementation of ``ReshapeLayer``.
...@@ -74,7 +75,6 @@ List Brief ...@@ -74,7 +75,6 @@ List Brief
`SoftmaxLayer`_ The implementation of ``SoftmaxLayer``. `SoftmaxLayer`_ The implementation of ``SoftmaxLayer``.
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``. `ArgMaxLayer`_ The implementation of ``ArgMaxLayer``.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``. `BatchNormLayer`_ The implementation of ``BatchNormLayer``.
`BatchRenormLayer`_ The implementation of ``BatchRenormLayer``.
`GroupNormLayer`_ The implementation of ``GroupNormLayer``. `GroupNormLayer`_ The implementation of ``GroupNormLayer``.
`InstanceNormLayer`_ The implementation of ``InstanceNormLayer``. `InstanceNormLayer`_ The implementation of ``InstanceNormLayer``.
`ScaleLayer`_ The implementation of ``ScaleLayer``. `ScaleLayer`_ The implementation of ``ScaleLayer``.
...@@ -179,6 +179,7 @@ API Reference ...@@ -179,6 +179,7 @@ API Reference
.. _EltwiseLayer: #dragon.vm.caffe.layers.common.EltwiseLayer .. _EltwiseLayer: #dragon.vm.caffe.layers.common.EltwiseLayer
.. _AddLayer: #dragon.vm.caffe.layers.common.AddLayer .. _AddLayer: #dragon.vm.caffe.layers.common.AddLayer
.. _ConcatLayer: #dragon.vm.caffe.layers.common.ConcatLayer .. _ConcatLayer: #dragon.vm.caffe.layers.common.ConcatLayer
.. _SliceLayer: #dragon.vm.caffe.layers.common.SliceLayer
.. _DenseConcatLayer: #dragon.vm.caffe.layers.common.DenseConcatLayer .. _DenseConcatLayer: #dragon.vm.caffe.layers.common.DenseConcatLayer
.. _CropLayer: #dragon.vm.caffe.layers.common.CropLayer .. _CropLayer: #dragon.vm.caffe.layers.common.CropLayer
.. _ReshapeLayer: #dragon.vm.caffe.layers.common.ReshapeLayer .. _ReshapeLayer: #dragon.vm.caffe.layers.common.ReshapeLayer
...@@ -188,7 +189,6 @@ API Reference ...@@ -188,7 +189,6 @@ API Reference
.. _SoftmaxLayer: #dragon.vm.caffe.layers.common.SoftmaxLayer .. _SoftmaxLayer: #dragon.vm.caffe.layers.common.SoftmaxLayer
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer .. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer .. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
.. _BatchRenormLayer: #dragon.vm.caffe.layers.common.BatchRenormLayer
.. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer .. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer
.. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer .. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer
.. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer .. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer
......
...@@ -8,20 +8,20 @@ ...@@ -8,20 +8,20 @@
Quick Shortcut Quick Shortcut
-------------- --------------
==================== ============================================================================= ========================= =============================================================================
List Brief List Brief
==================== ============================================================================= ========================= =============================================================================
`Net.copy_from`_ Copy the parameters from the binary proto file. `Net.copy_from`_ Copy the parameters from the binary proto file.
`Net.forward`_ Forward Pass. `Net.forward`_ Forward Pass.
`Net.backward`_ Backward Pass. `Net.backward`_ Backward Pass.
`Net.function`_ Forward + Backward Pass. `Net.function`_ Forward + Backward Pass.
`Net.save`_ Save the parameters into a binary file. `Net.save`_ Save the parameters into a binary file.
`Net.blobs`_ Return the blobs. `Net.blobs`_ Return the blobs.
`Net.params`_ Return the parameters. `Net.params`_ Return the parameters.
`Net.inputs`_ Return the inputs of net. `Net.trainable_params`_ Return the trainable parameters.
`Net.outputs`_ Return the outputs of net. `Net.inputs`_ Return the inputs of net.
`Net.replace`_ Replace the A as B. `Net.outputs`_ Return the outputs of net.
==================== ============================================================================= ========================= =============================================================================
API Reference API Reference
------------- -------------
...@@ -42,9 +42,9 @@ API Reference ...@@ -42,9 +42,9 @@ API Reference
.. _Net.save: #dragon.vm.caffe.net.Net.save .. _Net.save: #dragon.vm.caffe.net.Net.save
.. _Net.blobs: #dragon.vm.caffe.net.Net.blobs .. _Net.blobs: #dragon.vm.caffe.net.Net.blobs
.. _Net.params: #dragon.vm.caffe.net.Net.params .. _Net.params: #dragon.vm.caffe.net.Net.params
.. _Net.trainable_params: #dragon.vm.caffe.net.Net.trainable_params
.. _Net.inputs: #dragon.vm.caffe.net.Net.inputs .. _Net.inputs: #dragon.vm.caffe.net.Net.inputs
.. _Net.outputs: #dragon.vm.caffe.net.Net.outputs .. _Net.outputs: #dragon.vm.caffe.net.Net.outputs
.. _Net.replace: #dragon.vm.caffe.net.Net.replace
.. _Net.function: #dragon.vm.caffe.net.Net.function .. _Net.function: #dragon.vm.caffe.net.Net.function
.. _NetInit(proto_txt, phase): #dragon.vm.caffe.net.Net.NetInit .. _NetInit(proto_txt, phase): #dragon.vm.caffe.net.Net.NetInit
...@@ -57,7 +57,6 @@ API Reference ...@@ -57,7 +57,6 @@ API Reference
.. _FilterNet(net.cpp, L259): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/net.cpp#L259 .. _FilterNet(net.cpp, L259): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/net.cpp#L259
.. _Init(net.cpp, L44): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/net.cpp#L44 .. _Init(net.cpp, L44): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/net.cpp#L44
.. _ForwardBackward(net.cpp, L85): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/include/caffe/net.hpp#L85 .. _ForwardBackward(net.cpp, L85): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/include/caffe/net.hpp#L85
.. _ShareTrainedLayersWith(net.cpp, L665): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/net.cpp#L665
.. _CopyTrainedLayersFromBinaryProto(net.cpp, L780): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/net.cpp#L780 .. _CopyTrainedLayersFromBinaryProto(net.cpp, L780): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/net.cpp#L780
.. _Net_forward(pycaffe.py, L88): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/python/caffe/pycaffe.py#L88 .. _Net_forward(pycaffe.py, L88): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/python/caffe/pycaffe.py#L88
.. _Net_backward(pycaffe.py, L137): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/python/caffe/pycaffe.py#L137 .. _Net_backward(pycaffe.py, L137): https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/python/caffe/pycaffe.py#L137
......
...@@ -33,7 +33,7 @@ which contributes an efficient `compile`_ module for **Theano**. ...@@ -33,7 +33,7 @@ which contributes an efficient `compile`_ module for **Theano**.
|paratitle| **Keras** |paratitle| **Keras**
|para| `Keras`_ is smart enough to invent new fine-grained apis to unify various backends. |para| `Keras`_ is smart enough to invent new fine-grained API to unify various backends.
|para| We DO NOT follow it because the computations performed by different backends are confused. |para| We DO NOT follow it because the computations performed by different backends are confused.
Besides, the efforts to learn `Keras`_ are also expensive. Besides, the efforts to learn `Keras`_ are also expensive.
......
Dragon - Python APIs Dragon - Python API
==================== ===================
Dragon is a computation graph based distributed deep learning framework. Dragon is a computation graph based distributed deep learning framework.
...@@ -12,7 +12,7 @@ For using it, import as follows: ...@@ -12,7 +12,7 @@ For using it, import as follows:
Style Orientation Style Orientation
----------------- -----------------
However, it will not help you much because Dragon is designed without systemic APIs. However, it will not help you much because Dragon is designed without systemic API.
We have extended it with **FOUR** Programming Styles: We have extended it with **FOUR** Programming Styles:
...@@ -82,7 +82,7 @@ Table of Contents ...@@ -82,7 +82,7 @@ Table of Contents
contents/updaters contents/updaters
contents/memonger contents/memonger
contents/core contents/core
contents/io contents/utils
contents/operators contents/operators
contents/vm contents/vm
contents/tools contents/tools
...@@ -102,7 +102,7 @@ Packages ...@@ -102,7 +102,7 @@ Packages
===================== ===================================================================== ===================== =====================================================================
`dragon.core`_ The core package. `dragon.core`_ The core package.
`dragon.io`_ The io package. `dragon.utils`_ The utils package.
`dragon.operators`_ The operators package. `dragon.operators`_ The operators package.
`dragon.tools`_ The tools package. `dragon.tools`_ The tools package.
`dragon.vm`_ The vm package. `dragon.vm`_ The vm package.
...@@ -113,7 +113,7 @@ Packages ...@@ -113,7 +113,7 @@ Packages
.. _dragon.core: contents/core.html .. _dragon.core: contents/core.html
.. _dragon.core.tensor.Tensor: contents/core/tensor.html .. _dragon.core.tensor.Tensor: contents/core/tensor.html
.. _dragon.core.workspace: contents/core/workspace.html .. _dragon.core.workspace: contents/core/workspace.html
.. _dragon.io: contents/io.html .. _dragon.utils: contents/utils.html
.. _dragon.operators: contents/operators.html .. _dragon.operators: contents/operators.html
.. _dragon.updaters: contents/updaters.html .. _dragon.updaters: contents/updaters.html
.. _dragon.memonger: contents/memonger.html .. _dragon.memonger: contents/memonger.html
......
# ---------------- Welcom To Use Dragon ---------------- # ---------------- Welcom To Use Dragon ----------------
project(dragon) project(dragon)
cmake_minimum_required(VERSION 3.0.0) cmake_minimum_required(VERSION 3.0.2)
# ---------------- Welcom To Use Dragon ---------------- # ---------------- Welcom To Use Dragon ----------------
...@@ -14,16 +14,19 @@ option(BUILD_CXX_API "Set ON to build CXX API" OFF) ...@@ -14,16 +14,19 @@ option(BUILD_CXX_API "Set ON to build CXX API" OFF)
# Set optional libraries # Set optional libraries
option(WITH_CUDA "Set ON to use CUDA" ON) option(WITH_CUDA "Set ON to use CUDA" ON)
option(WITH_CUDNN "Set ON to use CUDNN" ON) option(WITH_CUDNN "Set ON to use CUDNN" ON)
option(WITH_BLAS "Set ON to use BLAS" ON)
option(WITH_OMP "Set ON to use OpenMP" ON) option(WITH_OMP "Set ON to use OpenMP" ON)
option(WITH_SSE "Set ON to use SSE 4.1" ON)
option(WITH_MPI "Set ON to use MPI" OFF) option(WITH_MPI "Set ON to use MPI" OFF)
option(WITH_MPI_CUDA "Set ON to use MPI-CUDA" OFF) option(WITH_NCCL "Set ON to use NCCL" OFF)
option(WITH_MPI_NCCL "Set ON to use MPI-NCCL" OFF)
# Set your 3rdparty # Set your 3rdparty
if (NOT 3RDPARTY_DIR) if (NOT THIRD_PARTY_DIR)
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty) set(THIRD_PARTY_DIR ${PROJECT_SOURCE_DIR}/../ThirdParty)
endif()
# Set your protobuf compiler(protc) if necessary
# if not, a default "protoc" in the environment path will be used
if (NOT PROTOC_EXECUTABLE)
set(PROTOC_EXECUTABLE protoc)
endif() endif()
# Set your python "interpreter" if necessary # Set your python "interpreter" if necessary
...@@ -43,13 +46,6 @@ set(CUDA_ARCH -gencode arch=compute_30,code=sm_30 ...@@ -43,13 +46,6 @@ set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
-gencode arch=compute_60,code=sm_60 -gencode arch=compute_60,code=sm_60
-gencode arch=compute_70,code=sm_70) -gencode arch=compute_70,code=sm_70)
# Set CUDNN Library Dir if necessary (Linux/OSX Only)
# For Win, Recommend to use ``3RDPARTY_DIR/lib``
if (NOT CUDNN_LIBRARY_DIR)
set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux
# set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib) # OSX
endif()
# ---------------- User Config ---------------- # ---------------- User Config ----------------
...@@ -94,7 +90,8 @@ set(CMAKE_BUILD_TYPE Release CACHE STRING "set build type to release") ...@@ -94,7 +90,8 @@ set(CMAKE_BUILD_TYPE Release CACHE STRING "set build type to release")
set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release" FORCE) set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release" FORCE)
# ---[ Includes # ---[ Includes
include_directories(${3RDPARTY_DIR}/include) include_directories(${THIRD_PARTY_DIR}/eigen)
include_directories(${THIRD_PARTY_DIR}/protobuf/include)
include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/src) include_directories(${PROJECT_SOURCE_DIR}/src)
if (BUILD_PYTHON_API) if (BUILD_PYTHON_API)
...@@ -103,21 +100,35 @@ if (BUILD_PYTHON_API) ...@@ -103,21 +100,35 @@ if (BUILD_PYTHON_API)
endif() endif()
if (WITH_CUDA) if (WITH_CUDA)
include_directories(${CUDA_INCLUDE_DIRS}) include_directories(${CUDA_INCLUDE_DIRS})
include_directories(${THIRD_PARTY_DIR}/cub)
endif()
if (WITH_CUDNN)
include_directories(${THIRD_PARTY_DIR}/cudnn/include)
endif() endif()
if (WITH_MPI) if (WITH_MPI)
include_directories(${3RDPARTY_DIR}/include/mpi) include_directories(${THIRD_PARTY_DIR}/mpi/include)
endif() endif()
# ---[ Lib Directories # ---[ Lib Directories
set(3RDPARTY_LIBS ${3RDPARTY_DIR}/lib) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/protobuf/lib)
link_directories(${3RDPARTY_LIBS}) if (WITH_CUDA)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
endif()
if (WITH_CUDNN) if (WITH_CUDNN)
link_directories(${CUDNN_LIBRARY_DIR}) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib/x64)
endif() endif()
if (WITH_MPI)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/mpi/lib)
endif()
link_directories(${THIRD_PARTY_LIBRARY_DIRS})
# ---[ Install # ---[ Install
set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" FORCE) set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" FORCE)
set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS}) set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${THIRD_PARTY_LIBRARY_DIRS})
# ---[ Defines # ---[ Defines
if (BUILD_PYTHON_API) if (BUILD_PYTHON_API)
...@@ -126,57 +137,38 @@ if (BUILD_PYTHON_API) ...@@ -126,57 +137,38 @@ if (BUILD_PYTHON_API)
message(STATUS "Use Python2 [Optional]") message(STATUS "Use Python2 [Optional]")
elseif (${PYTHON_VERSION_MAJOR} STREQUAL "3") elseif (${PYTHON_VERSION_MAJOR} STREQUAL "3")
message(STATUS "Use Python3 [Optional]") message(STATUS "Use Python3 [Optional]")
ADD_DEFINITIONS(-DWITH_PYTHON3) add_definitions(-DWITH_PYTHON3)
else() else()
message("Invalid version of Python(Detected ${PYTHON_VERSION_STRING})") message("Invalid version of Python(Detected ${PYTHON_VERSION_STRING})")
message(FATAL_ERROR "Do you set PYTHON_EXECUTABLE correctly?") message(FATAL_ERROR "Do you set PYTHON_EXECUTABLE correctly?")
endif() endif()
endif() endif()
if (WITH_CUDA) if (WITH_CUDA)
ADD_DEFINITIONS(-DWITH_CUDA) add_definitions(-DWITH_CUDA)
message(STATUS "Use CUDA [Optional]") message(STATUS "Use CUDA [Optional]")
endif() endif()
if (WITH_CUDNN) if (WITH_CUDNN)
ADD_DEFINITIONS(-DWITH_CUDNN) add_definitions(-DWITH_CUDNN)
message(STATUS "Use CUDNN [Optional]") message(STATUS "Use CUDNN [Optional]")
endif() endif()
if (WITH_BLAS)
ADD_DEFINITIONS(-DWITH_BLAS)
message(STATUS "Use BLAS [Optional]")
else()
message(STATUS "Unuse BLAS [Optional]"
"\n -- > GEMM/GEMV is disabled"
"\n -- > prefer not to run as CPU Mode")
endif()
if (WITH_OMP) if (WITH_OMP)
ADD_DEFINITIONS(-DWITH_OMP) ADD_DEFINITIONS(-DWITH_OMP)
message(STATUS "Use OpenMP [Optional]") message(STATUS "Use OpenMP [Optional]")
endif() endif()
if (WITH_SSE)
ADD_DEFINITIONS(-DWITH_SSE)
message(STATUS "Use SSE [Optional]")
if(UNIX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1")
endif()
endif()
if (WITH_MPI) if (WITH_MPI)
ADD_DEFINITIONS(-DWITH_MPI) add_definitions(-DWITH_MPI)
message(STATUS "Use MPI [Optional]") message(STATUS "Use MPI [Optional]")
endif() endif()
if (WITH_MPI_CUDA) if (WITH_NCCL)
ADD_DEFINITIONS(-DWITH_MPI_CUDA) add_definitions(-DWITH_NCCL)
message(STATUS "Use MPI-CUDA [Optional]") message(STATUS "Use NCCL [Optional]")
endif()
if (WITH_MPI_NCCL)
ADD_DEFINITIONS(-DWITH_MPI_NCCL)
message(STATUS "Use MPI-NCCL [Optional]")
endif() endif()
# ---[ Flags # ---[ Flags
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if(WIN32) if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4819 /wd4244") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4244 /wd4800 /wd4819 /wd4996")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler \"/wd 4819\"") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler \"/wd 4800 /wd 4819\"")
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
string(REPLACE "/O2" "/Ox" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") string(REPLACE "/O2" "/Ox" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
if (WITH_OMP) if (WITH_OMP)
...@@ -195,10 +187,14 @@ endif() ...@@ -195,10 +187,14 @@ endif()
# ---[ Warnings # ---[ Warnings
# ---[ Commands # ---[ Commands
set (PROTOS_DIR ${PROJECT_SOURCE_DIR}/src/protos)
message(STATUS "Generate Protobuf Files") # ~ Protobuf
execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/caffemodel.proto) set(PROTO_SOURCE_DIR ${PROJECT_SOURCE_DIR}/src/proto)
execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/dragon.proto) file(GLOB_RECURSE PROTO_FILES ${PROTO_SOURCE_DIR}/*.proto)
foreach(PROTO_FILE ${PROTO_FILES})
message(STATUS "Generate Proto Files (ref: " ${PROTO_FILE} ")")
execute_process(COMMAND ${PROTOC_EXECUTABLE} -I=${PROTO_SOURCE_DIR} --cpp_out=${PROTO_SOURCE_DIR} ${PROTO_FILE})
endforeach()
# ---[ Subdirectories # ---[ Subdirectories
if (BUILD_PYTHON_API) if (BUILD_PYTHON_API)
...@@ -207,6 +203,3 @@ endif() ...@@ -207,6 +203,3 @@ endif()
if (BUILD_CXX_API) if (BUILD_CXX_API)
add_subdirectory(modules/cxx) add_subdirectory(modules/cxx)
endif() endif()
# ---[ Utils
file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib)
\ No newline at end of file
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
#include <ctime> #include <ctime>
#include <cmath> #include <cmath>
#include <random>
#include <climits> #include <climits>
#include <float.h> #include <float.h>
#include <random>
#include <numeric>
#include <memory> #include <memory>
#include <string> #include <string>
#include <queue> #include <queue>
#include <stack> #include <stack>
#include <array>
#include <vector> #include <vector>
#include <set> #include <set>
#include <map> #include <map>
...@@ -32,7 +34,7 @@ ...@@ -32,7 +34,7 @@
#include <functional> #include <functional>
#include "core/types.h" #include "core/types.h"
#include "protos/dragon.pb.h" #include "proto/dragon.pb.h"
#include "utils/logging.h" #include "utils/logging.h"
namespace dragon { namespace dragon {
...@@ -57,11 +59,11 @@ using Set = std::unordered_set<Value>; ...@@ -57,11 +59,11 @@ using Set = std::unordered_set<Value>;
* * * *
* Kernel Version * * Kernel Version *
* * * *
* Major(2) | Minor(2) | Patch(13) * * Major(3) | Minor(0) | Patch(00) *
* * * *
* * * * * * * * * * * * * * * * * * * * */ * * * * * * * * * * * * * * * * * * * * */
#define DRAGON_VERSION 2213 #define DRAGON_VERSION 3000
/* * * * * * * * * * * * * * * * * * * * * /* * * * * * * * * * * * * * * * * * * * *
* * * *
......
...@@ -19,20 +19,32 @@ namespace dragon { ...@@ -19,20 +19,32 @@ namespace dragon {
class CPUContext { class CPUContext {
public: public:
/*! \brief Default Constructor */
CPUContext(): random_seed_(3) {} CPUContext(): random_seed_(3) {}
/*! \brief Constructor with the specified random seed */
CPUContext(unsigned int random_seed) CPUContext(unsigned int random_seed)
: random_seed_(random_seed) {} : random_seed_(random_seed) {}
/*! \brief Constructor with the specified device option */
CPUContext(const DeviceOption& option) CPUContext(const DeviceOption& option)
: random_seed_(option.has_random_seed() ? : random_seed_(option.has_random_seed() ?
option.random_seed() : DEFAULT_RNG_SEED) {} option.random_seed() : DEFAULT_RNG_SEED) {}
/*! \brief Deconstructor */
virtual ~CPUContext() {} virtual ~CPUContext() {}
inline void SwitchToDevice() {} /*! \brief Switch to the device of this context */
inline void SwitchToDevice(int stream_id) {} void SwitchToDevice() {}
inline void FinishDeviceCompution() {} /*! \brief Switch to the device with the given stream */
void SwitchToDevice(int stream_id) {}
inline static void* New(size_t nbytes) { /*! \brief Synchronize the dispatched operations */
void FinishDeviceCompution() {}
/*! \brief Malloc the memory */
static void* New(size_t nbytes) {
void* data; void* data;
#ifdef WITH_CUDA_HOST_MEM #ifdef WITH_CUDA_HOST_MEM
CUDA_CHECK(cudaMallocHost(&data, nbytes)); CUDA_CHECK(cudaMallocHost(&data, nbytes));
...@@ -43,60 +55,72 @@ class CPUContext { ...@@ -43,60 +55,72 @@ class CPUContext {
return data; return data;
} }
inline static void Memset( /*! \brief Zero-Reset the memory */
static void Memset(
size_t nbytes, size_t nbytes,
void* ptr) { void* ptr) {
memset(ptr, 0, nbytes); memset(ptr, 0, nbytes);
} }
inline void MemsetAsync( /*! \brief Zero-Reset the memory asynchronously */
void MemsetAsync(
size_t nbytes, size_t nbytes,
void* ptr) { void* ptr) {
memset(ptr, 0, nbytes); memset(ptr, 0, nbytes);
} }
/*! \brief Copy the memory */
template<class DstContext, class SrcContext> template<class DstContext, class SrcContext>
inline static void Memcpy( static void Memcpy(
size_t nbytes, size_t nbytes,
void* dst, void* dst,
const void* src) { const void* src) {
memcpy(dst, src, nbytes); memcpy(dst, src, nbytes);
} }
/*! \brief Copy the memory asynchronously */
template<class DstContext, class SrcContext> template<class DstContext, class SrcContext>
inline void MemcpyAsync( void MemcpyAsync(
size_t nbytes, size_t nbytes,
void* dst, void* dst,
const void* src) { const void* src) {
memcpy(dst, src, nbytes); memcpy(dst, src, nbytes);
} }
/*! \brief Copy the memory with given type asynchronously */
template<typename T, class DstContext, class SrcContext> template<typename T, class DstContext, class SrcContext>
inline void Copy( void Copy(
int n, int n,
T* dst, T* dst,
const T* src) { const T* src) {
if (dst == src) return; if (dst == src) return;
// only the basic types(e.g. int/float) can memcpy correctly
if (std::is_fundamental<T>::value) if (std::is_fundamental<T>::value)
Memcpy<DstContext, SrcContext>( Memcpy<DstContext, SrcContext>(
n * sizeof(T), (void*)dst, (const void*)src); n * sizeof(T), (void*)dst, (const void*)src);
else for (int i = 0; i < n; i++) dst[i] = src[i]; else for (int i = 0; i < n; i++) dst[i] = src[i];
} }
inline static void Delete(void* data) { free(data); } /*! \brief Free the memory */
static void Delete(void* data) { free(data); }
/*! \brief Return the device id */
int device_id() const { return 0; }
inline int device_id() const { return 0; } /*! \brief Set the stream id */
inline void set_stream_id(int stream_id) {} void set_stream_id(int stream_id) {}
inline std::mt19937* rand_generator() { /*! \brief Return the internal random generator */
std::mt19937* rand_generator() {
if (!rand_generator_.get()) if (!rand_generator_.get())
rand_generator_.reset(new std::mt19937(random_seed_)); rand_generator_.reset(new std::mt19937(random_seed_));
return rand_generator_.get(); return rand_generator_.get();
} }
private: private:
/*! \brief Store the random seed */
unsigned int random_seed_; unsigned int random_seed_;
/*! \brief Store the internal random generator */
unique_ptr<std::mt19937> rand_generator_; unique_ptr<std::mt19937> rand_generator_;
}; };
......
...@@ -36,7 +36,7 @@ class CNMLContext { ...@@ -36,7 +36,7 @@ class CNMLContext {
: device_id_(option.device_id()), : device_id_(option.device_id()),
random_seed_(option.has_random_seed() ? random_seed_(option.has_random_seed() ?
option.random_seed() : DEFAULT_RNG_SEED) { option.random_seed() : DEFAULT_RNG_SEED) {
CHECK_EQ(option.device_type(), CNML); CHECK_EQ(option.device_type(), PROTO_CNML);
} }
CNMLContext(const int device_id = 0) CNMLContext(const int device_id = 0)
......
...@@ -41,7 +41,7 @@ class GraphBase { ...@@ -41,7 +41,7 @@ class GraphBase {
const string& exclude, const string& exclude,
const int stream_id = 1) = 0; const int stream_id = 1) = 0;
inline string name() const { return name_; } string name() const { return name_; }
protected: protected:
string name_, phase_; string name_, phase_;
...@@ -64,22 +64,27 @@ class Graph : public GraphBase { ...@@ -64,22 +64,27 @@ class Graph : public GraphBase {
const int stream_id = 1) override; const int stream_id = 1) override;
GraphDef Prune(const GraphDef& meta_graph); GraphDef Prune(const GraphDef& meta_graph);
GraphDef MakeUpdate(const GraphDef& meta_graph);
GraphDef Share(const GraphDef& optimized_graph); GraphDef Share(const GraphDef& optimized_graph);
void ShareGrads(GraphDef& optimized_graph); void ShareGrads(GraphDef& optimized_graph);
GraphDef BuildUpdateOps(const GraphDef& meta_graph);
void RecomputingAware( void RecomputingAware(
const GraphDef& optimized_graph, const GraphDef& optimized_graph,
Workspace* ws); Workspace* ws);
inline Workspace* ws() const { return ws_; } Workspace* ws() const { return ws_; }
protected: protected:
void ForwardShareDyeing(string u, string ancestor); void ForwardShareDyeing(
const string& u,
const string& ancestor);
void ForwardPruneDyeing( void ForwardPruneDyeing(
string u, const string& u,
string leaf, const string& leaf,
vector<string> path); const vector<string>& path);
void BackwardPruneDyeing(string v); void BackwardPruneDyeing(string v);
vector<OperatorBase*> ops_; vector<OperatorBase*> ops_;
......
...@@ -28,16 +28,11 @@ class GraphGradientMaker { ...@@ -28,16 +28,11 @@ class GraphGradientMaker {
void Share(const string& grads_prefix, GraphDef& graph); void Share(const string& grads_prefix, GraphDef& graph);
inline void SetTerms( void SetTerms(const Map<string, string>& terms) { terms_ = terms; }
const Map<string, string>& terms) { terms_ = terms; } void SetOperatorPrefix(const string& prefix) { op_prefix_ = prefix; }
inline void SetOperatorPrefix( void SetOperatorSuffix(const string& suffix) { op_suffix_ = suffix; }
const string& prefix) { op_prefix_ = prefix; } void AddExternalGrad(const string& name) { external_grads_.insert(name); }
inline void SetOperatorSuffix( void AddIgnoreGrad(const string& name) { ignore_grads_.insert(name); }
const string& suffix) { op_suffix_ = suffix; }
inline void AddExternalGrad(
const string& name) { external_grads_.insert(name); }
inline void AddIgnoreGrad(
const string& name) { ignore_grads_.insert(name); }
private: private:
bool CheckGrad( bool CheckGrad(
......
...@@ -22,73 +22,129 @@ namespace dragon { ...@@ -22,73 +22,129 @@ namespace dragon {
typedef enum { typedef enum {
NCHW, NCHW,
NHWC, NHWC,
} DataOrder; } StorageOrder;
class MixedMemory { class MixedMemory {
public: public:
typedef enum { typedef enum {
/*! \brief Initial state */
UNINITIALIZED, UNINITIALIZED,
/*! \brief Memory could be modified by CPUContext last time */
STATE_AT_CPU, STATE_AT_CPU,
/*! \brief Memory could be modified by CUDAContext last time */
STATE_AT_CUDA, STATE_AT_CUDA,
/*! \brief Memory could be modified by CNMLContext last time */
STATE_AT_CNML, STATE_AT_CNML,
/*! \brief Memory should be copied to another device next time */
SWITCHED, SWITCHED,
/*! \brief Host and Device now hold the same contents */
SYNCED, SYNCED,
} State; } State;
/*! \brief Default Constructor */
MixedMemory() : cpu_ptr_(nullptr), MixedMemory() : cpu_ptr_(nullptr),
cuda_ptr_(nullptr), cnml_ptr_(nullptr) {} cuda_ptr_(nullptr), cnml_ptr_(nullptr) {}
/*! \brief Constructor with the known meta and size */
MixedMemory(const TypeMeta& meta, const size_t nbytes) MixedMemory(const TypeMeta& meta, const size_t nbytes)
: meta_(meta), nbytes_(nbytes), cpu_ptr_(nullptr), : meta_(meta), nbytes_(nbytes), cpu_ptr_(nullptr),
cuda_ptr_(nullptr), cnml_ptr_(nullptr) {} cuda_ptr_(nullptr), cnml_ptr_(nullptr) {}
/*! \brief Deconstructor */
~MixedMemory(); ~MixedMemory();
/*! \brief Return the const data pointer on CPUContext */
const void* cpu_data(); const void* cpu_data();
/*! \brief Return the const data pointer on CUDAContext */
const void* cuda_data(); const void* cuda_data();
/*! \brief Return the const data pointer on CNMLContext */
const void* cnml_data(); const void* cnml_data();
/*! \brief Return the mutable data pointer on CPUContext */
void* mutable_cpu_data(); void* mutable_cpu_data();
/*! \brief Return the mutable data pointer on CUDAContext */
void* mutable_cuda_data(); void* mutable_cuda_data();
/*! \brief Return the mutable data pointer on CNMLContext */
void* mutable_cnml_data(); void* mutable_cnml_data();
/*! \brief Allocate the mlu devive memory */
void* malloc_cnml_data(); void* malloc_cnml_data();
/*! \brief Copy the mlu device memory to the host */
void fetch_cnml_data(void** data); void fetch_cnml_data(void** data);
/*! \brief Return the binding CNML cpu tensor */
cnmlCpuTensor_t& cnml_cpu_tensor(); cnmlCpuTensor_t& cnml_cpu_tensor();
/*! \brief Return the binding CNML mlu tensor */
cnmlTensor_t& cnml_mlu_tensor(); cnmlTensor_t& cnml_mlu_tensor();
/*! \brief Set the cpu data pointer from external context */
void set_cpu_data(void* cpu_ptr, size_t nbytes); void set_cpu_data(void* cpu_ptr, size_t nbytes);
/*! \brief Switch to the device set by Context before */
void SwitchToDevice(); void SwitchToDevice();
/*! \brief Switch to the specified device */
void SwitchToCUDADevice(int device_id); void SwitchToCUDADevice(int device_id);
inline size_t nbytes() const { return nbytes_; } /*! \brief Return the total bytes of this memory */
size_t nbytes() const { return nbytes_; }
inline size_t nchunks() const { return nchunks_; } /*! \brief Return the chunks of this memory */
size_t nchunks() const { return nchunks_; }
/*! \brief Set the chunks of this memory */
void set_nchunks(size_t nchunks) { nchunks_ = nchunks; } void set_nchunks(size_t nchunks) { nchunks_ = nchunks; }
inline State state() const { return state_; } /*! \brief Return the state of this memory */
State state() const { return state_; }
/*! \brief Return or Set the storage order */
StorageOrder order() const { return order_; }
inline DataOrder order() const { return order_; } /*! \brief Set the storage order */
inline void set_order(DataOrder order) { order_ = order; } void set_order(StorageOrder order) { order_ = order; }
/*! \brief Return a string to describe the internal structure */
const Map<string, string> info() const; const Map<string, string> info() const;
/*! \brief Control the state machine to CPUContext */
void ToCPU(); void ToCPU();
/*! \brief Control the state machine to CUDAContext */
void ToCUDA(); void ToCUDA();
private: private:
/*! \brief The type meta to call the deconstructor */
TypeMeta meta_; TypeMeta meta_;
/*! \brief The number of total bytes */
size_t nbytes_ = 0, nchunks_ = 1; size_t nbytes_ = 0, nchunks_ = 1;
DataOrder order_ = NCHW; /*! \brief The optional storage order */
StorageOrder order_ = NCHW;
/*! \brief Current memory status indicator */
State state_ = UNINITIALIZED; State state_ = UNINITIALIZED;
/*! \brief Data pointers */
void* cpu_ptr_, *cuda_ptr_, *cnml_ptr_; void* cpu_ptr_, *cuda_ptr_, *cnml_ptr_;
int own_cpu_ptr_ = 1, ptr_device_ = 0;
/*! For CAMBRICON's CNML Environment */ /*! \brief Whether this memory owns the cpu data pointer */
int own_cpu_ptr_ = 1;
/*! \brief Store the device id for some data pointers */
int ptr_device_ = 0;
/*! \brief Binding cpu tensor for CAMBRICON's CNML Library */
cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr; cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr;
/*! \brief Binding mlu tensor for CAMBRICON's CNML Library */
cnmlTensor_t cnml_mlu_tensor_ = nullptr; cnmlTensor_t cnml_mlu_tensor_ = nullptr;
}; };
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "utils/cast.h" #include "utils/cast.h"
#ifdef WITH_MPI #ifdef WITH_MPI
#include <mpi/mpi.h> #include <mpi.h>
#endif #endif
namespace dragon { namespace dragon {
...@@ -30,51 +30,98 @@ class Workspace; ...@@ -30,51 +30,98 @@ class Workspace;
class OperatorBase { class OperatorBase {
public: public:
/*! Default constructor */
OperatorBase(const OperatorDef& def, Workspace* ws); OperatorBase(const OperatorDef& def, Workspace* ws);
/*! Default deconstructor */
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
/*! \brief Return the specified input tensor */
Tensor& Input(int idx); Tensor& Input(int idx);
/*! \brief Return the specified output tensor */
Tensor* Output(int idx); Tensor* Output(int idx);
inline size_t InputSize() { return inputs_.size(); } /*! \brief Return the number of inputs */
inline size_t OutputSize() { return outputs_.size(); } int InputSize() { return (int)inputs_.size(); }
/*! \brief Return the number of outputs */
int OutputSize() { return (int)outputs_.size(); }
/*! \brief Modify this operator according to the given def */
void MutableOp(const OperatorDef& def); void MutableOp(const OperatorDef& def);
void MutableOp(const vector<string>& inputs,
const vector<string>& outputs,
const string& anchor);
inline void SwitchToPhase(const string& phase) { phase_ = phase; } /*! \brief Modify this operator according to the given properties */
void MutableOp(
const vector<string>& inputs,
const vector<string>& outputs,
const string& anchor);
/*! \brief Switch the internal running phase */
void SwitchToPhase(const string& phase) { phase_ = phase; }
/*! \brief Run this operator on the specified stream */
virtual void Run(int stream_id = 1) { NOT_IMPLEMENTED; } virtual void Run(int stream_id = 1) { NOT_IMPLEMENTED; }
/*! \brief Fusion this operator into the specified graph */
virtual void Fusion(void* graph) { NOT_IMPLEMENTED; } virtual void Fusion(void* graph) { NOT_IMPLEMENTED; }
inline const string& name() const { return def_.name(); } /*! \brief Return the operator name */
inline const string& type() const { return def_.type(); } const string& name() const { return def_.name(); }
inline const string& phase() const { return phase_; }
inline const string& anchor() { return anchor_; } /*! \brief Return the operator type */
inline Workspace* ws() const { return ws_; } const string& type() const { return def_.type(); }
/*! \brief Return the current running phase */
const string& phase() const { return phase_; }
/*! \brief Return the anchor name of this operator */
const string& anchor() const { return anchor_; }
/*! \brief Return the mount name in this operator */
const string mount_name(const string& name) const {
return "/mnt/" + anchor_ + "/" + name;
}
/*! \brief Return the parent workspace */
Workspace* ws() const { return ws_; }
/*! \brief Return the value of the specified argument */
template <typename T> template <typename T>
T Arg(const string& name, const T& default_value); T Arg(const string& name, const T& default_value);
/*! \brief Return the values of the specified argument */
template <typename T> template <typename T>
vector<T> Args(const string& name); vector<T> Args(const string& name);
inline const Map<std::string, const Argument*>& args() { return args_; } /*! \brief Return the argument map of this operator */
inline const Argument& arg(const string& name) { return *(args_[name]); } const Map<std::string, const Argument*>& args() { return args_; }
/*! \brief Return the specified argument */
const Argument& arg(const string& name) { return *(args_[name]); }
typedef Map<string, vector<OperatorBase*> > RecomputeMap; typedef Map<string, vector<OperatorBase*> > RecomputeMap;
inline RecomputeMap& recompute_map() { return recompute_map_; }
/*! \brief Return the recomputing map of this operator */
RecomputeMap& recompute_map() { return recompute_map_; }
/*! \brief Set the given recomputing map */
void set_recompute_map(RecomputeMap recompute_map) { void set_recompute_map(RecomputeMap recompute_map) {
recompute_map_ = recompute_map; recompute_map_ = recompute_map;
} }
inline const OperatorDef& def() const { return def_; } /*! \brief Return the stored operator def */
inline string DebugString() const { return def_.DebugString(); } const OperatorDef& def() const { return def_; }
/*! \brief Return the debug string of the stored operator def */
string DebugString() const { return def_.DebugString(); }
/*! \brief Return the debug DType string on given tensor */
string DTypeHelper( string DTypeHelper(
const Tensor& tensor, const Tensor& tensor,
const Set<string>& dtypes) const; const Set<string>& dtypes) const;
/* \brief Return the debug DType string on given type */
string DTypeHelper( string DTypeHelper(
const string& dtype, const string& dtype,
const Set<string>& dtypes) const; const Set<string>& dtypes) const;
...@@ -126,8 +173,8 @@ class Operator : public OperatorBase { ...@@ -126,8 +173,8 @@ class Operator : public OperatorBase {
virtual void RunOnDevice() = 0; virtual void RunOnDevice() = 0;
inline Context* ctx() { return &ctx_; } Context* ctx() { return &ctx_; }
inline bool AllowRun() { return allow_run_; } bool AllowRun() { return allow_run_; }
protected: protected:
Context ctx_; Context ctx_;
...@@ -165,6 +212,7 @@ OperatorBase* CreateOperator(const OperatorDef& def, Workspace* ws); ...@@ -165,6 +212,7 @@ OperatorBase* CreateOperator(const OperatorDef& def, Workspace* ws);
using OperatorBase::type; \ using OperatorBase::type; \
using OperatorBase::phase; \ using OperatorBase::phase; \
using OperatorBase::anchor; \ using OperatorBase::anchor; \
using OperatorBase::mount_name; \
using OperatorBase::def; \ using OperatorBase::def; \
using OperatorBase::InputSize; \ using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \ using OperatorBase::OutputSize; \
...@@ -214,9 +262,8 @@ DECLARE_REGISTRY( ...@@ -214,9 +262,8 @@ DECLARE_REGISTRY(
unique_ptr< Filler<type, Context> > filler( \ unique_ptr< Filler<type, Context> > filler( \
CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \ CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \ filler->Fill(&tensor, ctx()); \
ctx()->FinishDeviceCompution(); \
} else { \ } else { \
TIndex count = 1; \ int64_t count = 1; \
for(int i = 0; i < shape.size(); i++) count *= shape[i]; \ for(int i = 0; i < shape.size(); i++) count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \ CHECK_EQ(count, tensor.count()) \
<< "\nModel request " << "Tensor(" << tensor.name() << ")'s " \ << "\nModel request " << "Tensor(" << tensor.name() << ")'s " \
...@@ -235,9 +282,8 @@ DECLARE_REGISTRY( ...@@ -235,9 +282,8 @@ DECLARE_REGISTRY(
unique_ptr< Filler<T, Context> > filler( \ unique_ptr< Filler<T, Context> > filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \ CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \ filler->Fill(&tensor, ctx()); \
ctx()->FinishDeviceCompution(); \
} else { \ } else { \
TIndex count = 1; \ int64_t count = 1; \
for(int i = 0; i < shape.size(); i++) count *= shape[i]; \ for(int i = 0; i < shape.size(); i++) count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \ CHECK_EQ(count, tensor.count()) \
<< "\nModel request " << "Tensor(" << tensor.name() << ")'s " \ << "\nModel request " << "Tensor(" << tensor.name() << ")'s " \
...@@ -247,22 +293,17 @@ DECLARE_REGISTRY( ...@@ -247,22 +293,17 @@ DECLARE_REGISTRY(
tensor.Reshape(shape); \ tensor.Reshape(shape); \
} }
#define INIT_MULTIPLIER(ptr_tensor, size) { \
ptr_tensor = ws()->CreateTensor("/share/multiplier/" \
+ TypeMetaToString(TypeMeta::Make<T>())); \
if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape({ size }); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
ptr_tensor->template mutable_data<T, Context>(), ctx()); \
} \
}
#define DECLARE_MULTIPLIER(name, size) \ #define DECLARE_MULTIPLIER(name, size) \
const T* name; \ const T* name; \
{ \ { \
Tensor* _auto_multiplier_; \ auto* mp = ws()->CreateTensor("/share/multiplier/" \
INIT_MULTIPLIER(_auto_multiplier_, size); \ + TypeMetaToString(TypeMeta::Make<T>())); \
name = _auto_multiplier_->template data<T, Context>(); \ if (size > mp->count()) { \
mp->Reshape({ size }); \
math::Set<T, Context>(size, cast::to<T>(1.f), \
mp->template mutable_data<T, Context>(), ctx()); \
} \
name = mp->template data<T, Context>(); \
} }
#define DECLARE_ARGUMENT_WITH_DESC(type, argument) \ #define DECLARE_ARGUMENT_WITH_DESC(type, argument) \
...@@ -291,7 +332,7 @@ DECLARE_REGISTRY( ...@@ -291,7 +332,7 @@ DECLARE_REGISTRY(
CHECK(argument##_tensor->IsType<type>()) \ CHECK(argument##_tensor->IsType<type>()) \
<< "\nThe type of " << #argument << " should be " << #type << "."; \ << "\nThe type of " << #argument << " should be " << #type << "."; \
CHECK_EQ(argument##_tensor->count(), 1) \ CHECK_EQ(argument##_tensor->count(), 1) \
<< "\nThe argument of " << #argument << " should be a scalar"; \ << "\nThe argument of " << #argument << " should be a scalar."; \
return argument##_tensor->template data<type, CPUContext>()[0]; \ return argument##_tensor->template data<type, CPUContext>()[0]; \
} }
...@@ -299,13 +340,19 @@ DECLARE_REGISTRY( ...@@ -299,13 +340,19 @@ DECLARE_REGISTRY(
template <class Context> \ template <class Context> \
type classname<Context>::argument(int idx) { \ type classname<Context>::argument(int idx) { \
if (argument##_desc.empty()) { \ if (argument##_desc.empty()) { \
CHECK_LT(idx, argument##_value.size()); \ CHECK_LT(idx, argument##_value.size()) \
<< "\nExcepted the size of " << #argument \
<< " > " << idx << ". (Got " \
<< argument##_value.size() << ")."; \
return argument##_value[idx]; \ return argument##_value[idx]; \
} \ } \
CHECK_LT(idx, argument##_desc.size()); \ CHECK_LT(idx, argument##_desc.size()) \
<< "\nExcepted the size of " << #argument \
<< " > " << idx << ". (Got " \
<< argument##_desc.size() << ")."; \
Tensor* argument##_tensor = ws()->GetTensor(argument##_desc[idx]); \ Tensor* argument##_tensor = ws()->GetTensor(argument##_desc[idx]); \
CHECK(argument##_tensor->IsType<type>()) \ CHECK(argument##_tensor->IsType<type>()) \
<< "\nThe type of " << #argument << " should be " << #type; \ << "\nThe type of " << #argument << " should be " << #type << "."; \
CHECK_EQ(argument##_tensor->count(), 1) \ CHECK_EQ(argument##_tensor->count(), 1) \
<< "\nThe argument of " << #argument << " at pos(" \ << "\nThe argument of " << #argument << " at pos(" \
<< idx << ") should be a scalar."; \ << idx << ") should be a scalar."; \
...@@ -313,7 +360,7 @@ DECLARE_REGISTRY( ...@@ -313,7 +360,7 @@ DECLARE_REGISTRY(
} }
#define GET_ARGUMENTS_SIZE(argument) \ #define GET_ARGUMENTS_SIZE(argument) \
std::max(argument##_value.size(), argument##_desc.size()) (int)std::max(argument##_value.size(), argument##_desc.size())
#define XIsType(x, dtype) \ #define XIsType(x, dtype) \
x.template IsType<dtype>() x.template IsType<dtype>()
......
...@@ -40,11 +40,11 @@ class GradientMakerBase { ...@@ -40,11 +40,11 @@ class GradientMakerBase {
g_inputs_(def.input_size()) {} g_inputs_(def.input_size()) {}
virtual ~GradientMakerBase() {} virtual ~GradientMakerBase() {}
inline virtual bool CopyDeviceOption() const { return true; } virtual bool CopyDeviceOption() const { return true; }
inline virtual bool CopyEngine() const { return true; } virtual bool CopyEngine() const { return true; }
inline virtual bool CopyArguments() const { return true; } virtual bool CopyArguments() const { return true; }
inline virtual Gradient Make() { virtual Gradient Make() {
vector<OperatorDef> new_defs = MakeDefs(); vector<OperatorDef> new_defs = MakeDefs();
Argument anchor; Argument anchor;
anchor.set_name("anchor"); anchor.set_s(def.name()); anchor.set_name("anchor"); anchor.set_s(def.name());
...@@ -53,29 +53,40 @@ class GradientMakerBase { ...@@ -53,29 +53,40 @@ class GradientMakerBase {
return Gradient(new_defs, g_inputs_, DefaultValues()); return Gradient(new_defs, g_inputs_, DefaultValues());
}; };
virtual inline vector<OperatorDef> MakeDefs() { virtual vector<OperatorDef> MakeDefs() {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
return vector<OperatorDef>(); return vector<OperatorDef>();
} }
virtual inline vector<float> DefaultValues() { virtual vector<float> DefaultValues() {
return vector<float>(g_outputs_.size(), 1.f); return vector<float>(g_outputs_.size(), 1.f);
} }
template <class... Args> template <class... Args>
inline static vector<OperatorDef> SingleDef(const Args& ... args) { static vector<OperatorDef> SingleDef(const Args& ... args) {
return vector<OperatorDef> { MakeOperatorDef(args...) }; return vector<OperatorDef> { MakeOperatorDef(args...) };
} }
inline string I(const int i) { return def.input(i); } const string I(const int i) const {
inline string O(const int i) { return def.output(i); } return i < def.input_size() ?
def.input(i) : "ignore";
}
const string O(const int i) const {
return i < def.output_size() ?
def.output(i) : "ignore";
}
inline string GI(const int i) { string GI(const int i) {
if (i >= g_inputs_.size()) return "ignore"; if (i >= g_inputs_.size()) return "ignore";
g_inputs_[i] = def.input(i) + "_grad"; g_inputs_[i] = def.input(i) + "_grad";
return g_inputs_[i]; return g_inputs_[i];
} }
inline string GO(const int i) { return g_outputs_[i]; }
const string GO(const int i) const {
return i < g_outputs_.size() ?
g_outputs_[i] : "ignore";
}
protected: protected:
const OperatorDef& def; const OperatorDef& def;
...@@ -100,6 +111,52 @@ class NoGradient : public GradientMakerBase { ...@@ -100,6 +111,52 @@ class NoGradient : public GradientMakerBase {
} }
}; };
// Here we define some common gradient makers
// Reuse them to make the codes cleaner
class SimpleGradientMaker final : public GradientMakerBase {
public:
/*!
* <SimpleMaker>
*
* Inputs: X1, X2, ..., Xn, dY
* Outputs: dX1, dX2, ..., dXn
*
*/
GRADIENT_MAKER_CTOR(SimpleGradientMaker);
vector<OperatorDef> MakeDefs() override {
vector<string> inputs, outputs;
for (const auto& input : def.input()) {
inputs.push_back(input);
}
inputs.push_back(GO(0));
for (int i = 0; i < def.input_size(); i++) {
outputs.push_back(GI(i));
}
return SingleDef(def.type() +
"Gradient", "", inputs, outputs);
}
};
class InplaceGradientMaker final : public GradientMakerBase {
public:
/*!
* <InplaceMaker>
*
* Inputs: Y, dY
* Outputs: dX
*
*/
GRADIENT_MAKER_CTOR(InplaceGradientMaker);
vector<OperatorDef> MakeDefs() override {
return SingleDef(
def.type() + "Gradient", /*! OpType */
"", /*! OpName */
vector<string>({ O(0), GO(0) }), /*! Inputs */
vector<string>({ GI(0) })); /*! Outputs */
}
};
DECLARE_REGISTRY( DECLARE_REGISTRY(
GradientRegistry, GradientRegistry,
GradientMakerBase, GradientMakerBase,
......
...@@ -37,14 +37,14 @@ class OpSchema { ...@@ -37,14 +37,14 @@ class OpSchema {
bool Verify(const OperatorDef& def) const; bool Verify(const OperatorDef& def) const;
inline OpSchema& IgnoreVerify() { OpSchema& IgnoreVerify() {
ignore_verify_ = true; ignore_verify_ = true;
return *this; return *this;
} }
OpSchema& Inplace(set<pair<int, int> > inplace); OpSchema& Inplace(set<pair<int, int> > inplace);
std::function<bool(int, int)> CheckInplace; std::function<bool(int, int)> CheckInplace;
inline bool AllowInplace() const { return allow_inplace_; } bool AllowInplace() const { return allow_inplace_; }
OpSchema& NumInputs(int n); OpSchema& NumInputs(int n);
OpSchema& NumInputs(int min_num, int max_num); OpSchema& NumInputs(int min_num, int max_num);
...@@ -86,8 +86,8 @@ class OpSchemaRegistry { ...@@ -86,8 +86,8 @@ class OpSchemaRegistry {
static const OpSchema* Schema(const string& op_type) { static const OpSchema* Schema(const string& op_type) {
auto& m = schema_map(); auto& m = schema_map();
if (m.count(op_type)) return &m[op_type]; if (m.count(op_type)) return &m[op_type];
else LOG(FATAL) << "OpSchema(" << op_type LOG(WARNING) << "OpSchema(" << op_type
<< ") has not registered yet."; << ") has not registered yet.";
return nullptr; return nullptr;
} }
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#ifndef DRAGON_CORE_REGISTRY_H_ #ifndef DRAGON_CORE_REGISTRY_H_
#define DRAGON_CORE_REGISTRY_H_ #define DRAGON_CORE_REGISTRY_H_
#include <functional>
#include "core/common.h" #include "core/common.h"
namespace dragon { namespace dragon {
...@@ -69,14 +67,14 @@ class Registerer { ...@@ -69,14 +67,14 @@ class Registerer {
// Used in *.h files // Used in *.h files
#define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType,...) \ #define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType,...) \
dragon::Registry<SrcType, ObjType,##__VA_ARGS__>* RegistryName(); \ Registry<SrcType, ObjType, ##__VA_ARGS__>* RegistryName(); \
typedef dragon::Registerer<SrcType,ObjType,##__VA_ARGS__> Registerer##RegistryName; typedef Registerer<SrcType, ObjType, ##__VA_ARGS__> Registerer##RegistryName;
// Used in *.cc files // Used in *.cc files
#define DEFINE_TYPED_REGISTRY(RegistryName,SrcType, ObjType,...) \ #define DEFINE_TYPED_REGISTRY(RegistryName,SrcType, ObjType,...) \
Registry<SrcType,ObjType,##__VA_ARGS__>* RegistryName() { \ Registry<SrcType,ObjType, ##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType,ObjType,##__VA_ARGS__>* registry = \ static Registry<SrcType, ObjType, ##__VA_ARGS__>* registry = \
new Registry<SrcType,ObjType,##__VA_ARGS__>(); \ new Registry<SrcType, ObjType, ##__VA_ARGS__>(); \
return registry; \ return registry; \
} }
...@@ -93,6 +91,6 @@ class Registerer { ...@@ -93,6 +91,6 @@ class Registerer {
#define REGISTER_CLASS(RegistryName, key, ...) \ #define REGISTER_CLASS(RegistryName, key, ...) \
REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
} // namepsace dragon } // namepsace dragon
#endif //DRAGON_CORE_REGISTRY_H_ #endif //DRAGON_CORE_REGISTRY_H_
\ No newline at end of file
...@@ -20,42 +20,32 @@ ...@@ -20,42 +20,32 @@
namespace dragon { namespace dragon {
typedef char int8;
typedef unsigned char uint8;
#ifdef _MSC_VER #ifdef _MSC_VER
typedef struct __declspec(align(2)) { typedef struct __declspec(align(2)) {
unsigned short x; unsigned short x;
} float16; } float16;
typedef struct __declspec(align(4)) {
unsigned int x;
} float32;
#else #else
typedef struct { typedef struct {
unsigned short x; unsigned short x;
} __attribute__((aligned(2))) float16; } __attribute__((aligned(2))) float16;
typedef struct {
unsigned int x;
} __attribute__((aligned(4))) float32;
#endif #endif
inline const TypeMeta& TypeStringToMeta( inline const TypeMeta& TypeStringToMeta(
const std::string& str_type) { const std::string& str_type) {
static std::unordered_map<std::string, TypeMeta> static std::unordered_map<std::string, TypeMeta>
s2m_type_map { s2m_type_map {
{ "float32", TypeMeta::Make<float>() }, { "bool", TypeMeta::Make<bool>() },
{ "int8", TypeMeta::Make<int8_t>() },
{ "uint8", TypeMeta::Make<uint8_t>() },
{ "int32", TypeMeta::Make<int>() }, { "int32", TypeMeta::Make<int>() },
{ "int64", TypeMeta::Make<int64_t>() }, { "int64", TypeMeta::Make<int64_t>() },
{ "float64", TypeMeta::Make<double>() },
{ "float16", TypeMeta::Make<float16>() }, { "float16", TypeMeta::Make<float16>() },
{ "uint8", TypeMeta::Make<uint8>() }, { "float32", TypeMeta::Make<float>() },
{ "int8", TypeMeta::Make<int8>() }, { "float64", TypeMeta::Make<double>() },
}; };
static TypeMeta unknown_type; static TypeMeta unknown_type;
return s2m_type_map.count(str_type) ? return s2m_type_map.count(str_type) ?
...@@ -66,13 +56,14 @@ inline const std::string TypeMetaToString( ...@@ -66,13 +56,14 @@ inline const std::string TypeMetaToString(
const TypeMeta& meta) { const TypeMeta& meta) {
static std::unordered_map<TypeId, std::string> static std::unordered_map<TypeId, std::string>
m2s_type_map { m2s_type_map {
{ TypeMeta::Id<float>(), "float32" }, { TypeMeta::Id<bool>(), "bool" },
{ TypeMeta::Id<int8_t>(), "int8" },
{ TypeMeta::Id<uint8_t>(), "uint8" },
{ TypeMeta::Id<int>(), "int32" }, { TypeMeta::Id<int>(), "int32" },
{ TypeMeta::Id<int64_t>(), "int64" }, { TypeMeta::Id<int64_t>(), "int64" },
{ TypeMeta::Id<double>(), "float64", },
{ TypeMeta::Id<float16>(), "float16" }, { TypeMeta::Id<float16>(), "float16" },
{ TypeMeta::Id<uint8>(), "uint8" }, { TypeMeta::Id<float>(), "float32" },
{ TypeMeta::Id<int8>(), "int8" } { TypeMeta::Id<double>(), "float64", },
}; };
return m2s_type_map.count(meta.id()) ? return m2s_type_map.count(meta.id()) ?
m2s_type_map[meta.id()] : "unknown"; m2s_type_map[meta.id()] : "unknown";
......
...@@ -33,8 +33,8 @@ class DropoutOp final : public Operator<Context> { ...@@ -33,8 +33,8 @@ class DropoutOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
DECLARE_ARGUMENT_WITH_DESC(float, prob);
bool use_scale; bool use_scale;
DECLARE_ARGUMENT_WITH_DESC(float, prob);
}; };
template <class Context> template <class Context>
...@@ -52,8 +52,8 @@ class DropoutGradientOp final : public Operator<Context> { ...@@ -52,8 +52,8 @@ class DropoutGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
DECLARE_ARGUMENT_WITH_DESC(float, prob);
bool use_scale; bool use_scale;
DECLARE_ARGUMENT_WITH_DESC(float, prob);
}; };
DEFINE_ARGUMENT_WITH_DESC(float, DropoutOp, prob); DEFINE_ARGUMENT_WITH_DESC(float, DropoutOp, prob);
...@@ -86,12 +86,12 @@ public: ...@@ -86,12 +86,12 @@ public:
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
DECLARE_ARGUMENT_WITH_DESC(float, prob);
bool use_scale, states_initialized; bool use_scale, states_initialized;
cudnnTensorDescriptor_t input_desc; cudnnTensorDescriptor_t input_desc;
cudnnDropoutDescriptor_t dropout_desc; cudnnDropoutDescriptor_t dropout_desc;
size_t states_size, reserve_space_size; size_t states_size, reserve_space_size;
unsigned long long random_seed; unsigned long long random_seed;
DECLARE_ARGUMENT_WITH_DESC(float, prob);
}; };
template <class Context> template <class Context>
...@@ -117,12 +117,12 @@ public: ...@@ -117,12 +117,12 @@ public:
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
DECLARE_ARGUMENT_WITH_DESC(float, prob);
bool use_scale, states_initialized; bool use_scale, states_initialized;
cudnnTensorDescriptor_t input_desc; cudnnTensorDescriptor_t input_desc;
cudnnDropoutDescriptor_t dropout_desc; cudnnDropoutDescriptor_t dropout_desc;
size_t states_size, reserve_space_size; size_t states_size, reserve_space_size;
unsigned long long random_seed; unsigned long long random_seed;
DECLARE_ARGUMENT_WITH_DESC(float, prob);
}; };
DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutOp, prob); DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutOp, prob);
......
...@@ -30,7 +30,7 @@ class PReluOp final : public Operator<Context> { ...@@ -30,7 +30,7 @@ class PReluOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex channel_shared, channels, dim; int64_t channel_shared, channels, dim;
string data_format; string data_format;
}; };
...@@ -47,7 +47,7 @@ class PReluGradientOp final : public Operator<Context> { ...@@ -47,7 +47,7 @@ class PReluGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex channel_shared, channels, dim; int64_t channel_shared, channels, dim;
string data_format; string data_format;
}; };
......
...@@ -22,14 +22,14 @@ class SoftmaxOp final : public Operator<Context> { ...@@ -22,14 +22,14 @@ class SoftmaxOp final : public Operator<Context> {
public: public:
SoftmaxOp(const OperatorDef& def, Workspace* ws) SoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {} axis(OperatorBase::Arg<int64_t>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, outer_dim, inner_dim; int64_t axis, outer_dim, inner_dim;
}; };
template <class Context> template <class Context>
...@@ -37,14 +37,14 @@ class SoftmaxGradientOp final : public Operator<Context> { ...@@ -37,14 +37,14 @@ class SoftmaxGradientOp final : public Operator<Context> {
public: public:
SoftmaxGradientOp(const OperatorDef& def, Workspace* ws) SoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {} axis(OperatorBase::Arg<int64_t>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, outer_dim, inner_dim; int64_t axis, outer_dim, inner_dim;
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -54,7 +54,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> { ...@@ -54,7 +54,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
public: public:
CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws) CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) { axis(OperatorBase::Arg<int64_t>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
...@@ -69,7 +69,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> { ...@@ -69,7 +69,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, outer_dim, inner_dim; int64_t axis, outer_dim, inner_dim;
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc, output_desc;
}; };
...@@ -78,7 +78,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> { ...@@ -78,7 +78,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> {
public: public:
CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws) CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) { axis(OperatorBase::Arg<int64_t>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
...@@ -93,7 +93,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> { ...@@ -93,7 +93,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, outer_dim, inner_dim; int64_t axis, outer_dim, inner_dim;
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc, output_desc;
}; };
......
...@@ -22,16 +22,16 @@ class AffineOp final : public Operator<Context> { ...@@ -22,16 +22,16 @@ class AffineOp final : public Operator<Context> {
public: public:
AffineOp(const OperatorDef& def, Workspace* ws) AffineOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int64_t>("axis", 1)),
num_axes(OperatorBase::Arg<int>("num_axes", 1)) {} num_axes(OperatorBase::Arg<int64_t>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, start_axis, num_axes; int64_t axis, num_axes;
TIndex outer_dim, scale_dim, inner_dim; int64_t outer_dim, scale_dim, inner_dim;
}; };
template <class Context> template <class Context>
...@@ -39,8 +39,8 @@ class AffineGradientOp final : public Operator<Context> { ...@@ -39,8 +39,8 @@ class AffineGradientOp final : public Operator<Context> {
public: public:
AffineGradientOp(const OperatorDef& def, Workspace* ws) AffineGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int64_t>("axis", 1)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)) {} num_axes(OperatorBase::Arg<int64_t>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -49,8 +49,8 @@ class AffineGradientOp final : public Operator<Context> { ...@@ -49,8 +49,8 @@ class AffineGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, start_axis, num_axes; int64_t axis, num_axes;
TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim; int64_t outer_dim, inner_dim, scale_dim, sum_dim, dim;
Tensor sum_result; Tensor sum_result;
}; };
...@@ -63,8 +63,8 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -63,8 +63,8 @@ class CuDNNAffineOpBase : public Operator<Context> {
public: public:
CuDNNAffineOpBase(const OperatorDef& def, Workspace* ws) CuDNNAffineOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int64_t>("axis", 1)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)) { num_axes(OperatorBase::Arg<int64_t>("num_axes", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc));
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_desc)); CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_desc));
...@@ -81,31 +81,9 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -81,31 +81,9 @@ class CuDNNAffineOpBase : public Operator<Context> {
} }
template <typename T> template <typename T>
void ResetDesc(const Tensor& X) { void ResetDesc(const Tensor& X);
// Determine the range of affine
start_axis = axis;
if (start_axis < 0) start_axis += (int)X.ndim();
if (num_axes == -1) num_axes = (int)X.ndim() - start_axis;
else if (num_axes == 0) num_axes = 1;
end_axis = start_axis + num_axes;
CHECK_LT(start_axis, (int)X.ndim());
CHECK_LE(start_axis + num_axes, (int)X.ndim());
// Determine the input desc
vector<TIndex> input_dims = X.dims();
// CuDNN requires ndimensions range from [4, 5]
if (input_dims.size() < 4) input_dims.resize(4, 1);
else if (input_dims.size() > 5)
LOG(FATAL) << "CuDNN Affine the dimensions up to 5.";
cudnnSetTensorDesc<T>(&input_desc, input_dims);
// Determine the scale desc
vector<TIndex> param_dims(input_dims.size(), 1);
for (int i = start_axis; i < end_axis; i++)
param_dims[i] = input_dims[i];
cudnnSetTensorDesc<T>(&param_desc, param_dims);
}
TIndex axis, start_axis, end_axis, num_axes;
int64_t axis, num_axes;
cudnnTensorDescriptor_t input_desc, param_desc; cudnnTensorDescriptor_t input_desc, param_desc;
cudnnOpTensorDescriptor_t mul_desc, add_desc; cudnnOpTensorDescriptor_t mul_desc, add_desc;
cudnnReduceTensorDescriptor_t reduce_desc; cudnnReduceTensorDescriptor_t reduce_desc;
...@@ -113,7 +91,7 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -113,7 +91,7 @@ class CuDNNAffineOpBase : public Operator<Context> {
#define USE_CUDNN_AFFINE_FUCNTIONS \ #define USE_CUDNN_AFFINE_FUCNTIONS \
USE_OPERATOR_FUNCTIONS; \ USE_OPERATOR_FUNCTIONS; \
using CuDNNAffineOpBase<Context>::start_axis; \ using CuDNNAffineOpBase<Context>::axis; \
using CuDNNAffineOpBase<Context>::num_axes; \ using CuDNNAffineOpBase<Context>::num_axes; \
using CuDNNAffineOpBase<Context>::input_desc; \ using CuDNNAffineOpBase<Context>::input_desc; \
using CuDNNAffineOpBase<Context>::param_desc; \ using CuDNNAffineOpBase<Context>::param_desc; \
...@@ -131,7 +109,7 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> { ...@@ -131,7 +109,7 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> {
template <typename DT, typename CT> void RunWithType(); template <typename DT, typename CT> void RunWithType();
protected: protected:
USE_CUDNN_AFFINE_FUCNTIONS; USE_CUDNN_AFFINE_FUCNTIONS;
}; };
template <class Context> template <class Context>
...@@ -155,9 +133,9 @@ public: ...@@ -155,9 +133,9 @@ public:
template <typename DT, typename CT> void RunWithType(); template <typename DT, typename CT> void RunWithType();
protected: protected:
USE_CUDNN_AFFINE_FUCNTIONS; USE_CUDNN_AFFINE_FUCNTIONS;
TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim; int64_t outer_dim, inner_dim, scale_dim, dim, sum_dim;
Tensor sum_result; Tensor sum_result;
}; };
......
...@@ -30,7 +30,7 @@ class ClipOp final : public Operator<Context> { ...@@ -30,7 +30,7 @@ class ClipOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float low, high; float low, high, lowT, highT;
}; };
template <class Context> template <class Context>
...@@ -46,7 +46,7 @@ class ClipGradientOp final : public Operator<Context> { ...@@ -46,7 +46,7 @@ class ClipGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float low, high; float low, high, lowT, highT;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,8 +22,8 @@ class DotOp final : public Operator<Context> { ...@@ -22,8 +22,8 @@ class DotOp final : public Operator<Context> {
public: public:
DotOp(const OperatorDef& def, Workspace* ws) DotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
TransA(OperatorBase::Arg<bool>("TransA", false)), transA(OperatorBase::Arg<bool>("transA", false)),
TransB(OperatorBase::Arg<bool>("TransB", false)) {} transB(OperatorBase::Arg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -32,7 +32,8 @@ class DotOp final : public Operator<Context> { ...@@ -32,7 +32,8 @@ class DotOp final : public Operator<Context> {
template <typename T> void GemvRunWithType(); template <typename T> void GemvRunWithType();
protected: protected:
TIndex TransA, TransB, M, K1, K2, N1, N2; int64_t M1, N1, M2, N2;
int64_t transA, transB, M, K1, K2, N;
}; };
template <class Context> template <class Context>
...@@ -40,8 +41,8 @@ class DotGradientOp final : public Operator<Context> { ...@@ -40,8 +41,8 @@ class DotGradientOp final : public Operator<Context> {
public: public:
DotGradientOp(const OperatorDef& def, Workspace* ws) DotGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
TransA(OperatorBase::Arg<bool>("TransA", false)), transA(OperatorBase::Arg<bool>("transA", false)),
TransB(OperatorBase::Arg<bool>("TransB", false)) {} transB(OperatorBase::Arg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -50,9 +51,10 @@ class DotGradientOp final : public Operator<Context> { ...@@ -50,9 +51,10 @@ class DotGradientOp final : public Operator<Context> {
template <typename T> void GemvRunWithType(); template <typename T> void GemvRunWithType();
protected: protected:
TIndex TransA, TransB, M, K1, K2, N1, N2; int64_t M1, N1, M2, N2;
int64_t transA, transB, M, K1, K2, N;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
\ No newline at end of file
...@@ -23,21 +23,26 @@ class EltwiseOp final : public Operator<Context> { ...@@ -23,21 +23,26 @@ class EltwiseOp final : public Operator<Context> {
EltwiseOp(const OperatorDef& def, Workspace* ws) EltwiseOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
operation(OperatorBase::Arg<string>("operation", "SUM")), operation(OperatorBase::Arg<string>("operation", "SUM")),
coeffs(OperatorBase::Args<float>("coeffs")) { coeffs(OperatorBase::Args<float>("coefficients")) {
// Check the number of coeffients
if (coeffs.size() > 0) { if (coeffs.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize()) CHECK_EQ(coeffs.size(), InputSize())
<< "\nOp has " << InputSize() << " inputs, " << "\nOp has " << InputSize() << " inputs, "
<< "but provided " << coeffs.size() << " coeffs."; << "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1)); } else coeffs.resize(InputSize(), 1.f);
// Compute the alpha for product operation
for (auto e : coeffs) { if (e != 1.f) alpha *= e; }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunWithType();
protected: protected:
string operation; string operation;
float alpha = 1.f;
vector<float> coeffs; vector<float> coeffs;
}; };
...@@ -47,21 +52,25 @@ class EltwiseGradientOp final : public Operator<Context> { ...@@ -47,21 +52,25 @@ class EltwiseGradientOp final : public Operator<Context> {
EltwiseGradientOp(const OperatorDef& def, Workspace* ws) EltwiseGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
operation(OperatorBase::Arg<string>("operation", "SUM")), operation(OperatorBase::Arg<string>("operation", "SUM")),
coeffs(OperatorBase::Args<float>("coeff")) { coeffs(OperatorBase::Args<float>("coefficients")) {
if (coeffs.size() > 0) { if (coeffs.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize()) CHECK_EQ(coeffs.size(), OutputSize())
<< "\nop has " << InputSize() << " inputs, " << "\nOp has " << OutputSize() << " inputs, "
<< "but provided " << coeffs.size() << " coeffs."; << "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1)); } else coeffs.resize(InputSize(), 1.f);
// Compute the alpha for product operation
for (auto e : coeffs) { if (e != 1.f) alpha *= e; }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunWithType();
protected: protected:
string operation; string operation;
float alpha = 1.f;
vector<float> coeffs; vector<float> coeffs;
}; };
......
...@@ -10,21 +10,21 @@ ...@@ -10,21 +10,21 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_FULLY_CONNECTED_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_FULLY_CONNECTED_OP_H_
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class InnerProductOp final : public Operator<Context> { class FullyConnectedOp final : public Operator<Context> {
public: public:
InnerProductOp(const OperatorDef& def, Workspace *ws) FullyConnectedOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int64_t>("axis", 1)),
num_output(OperatorBase::Arg<int>("num_output", 0)), N(OperatorBase::Arg<int64_t>("num_output", 0)),
TransW(OperatorBase::Arg<bool>("TransW", true)) {} transW(OperatorBase::Arg<bool>("transW", true)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice(); void RunOnDevice();
...@@ -32,26 +32,26 @@ class InnerProductOp final : public Operator<Context> { ...@@ -32,26 +32,26 @@ class InnerProductOp final : public Operator<Context> {
template <typename T> void NoTransRunWithType(); template <typename T> void NoTransRunWithType();
protected: protected:
TIndex axis, num_output, TransW, M, K; int64_t axis, transW, M, K, N;
}; };
template <class Context> template <class Context>
class InnerProductGradientOp final : public Operator<Context> { class FullyConnectedGradientOp final : public Operator<Context> {
public: public:
InnerProductGradientOp(const OperatorDef& def, Workspace *ws) FullyConnectedGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)), axis(OperatorBase::Arg<int64_t>("axis", 1)),
num_output(OperatorBase::Arg<int>("num_output", 0)), N(OperatorBase::Arg<int64_t>("num_output", 0)),
TransW(OperatorBase::Arg<bool>("TransW", true)) {} transW(OperatorBase::Arg<bool>("transW", true)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, num_output, TransW, M, K; int64_t axis, transW, M, K, N;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_FULLY_CONNECTED_OP_H_
\ No newline at end of file \ No newline at end of file
This diff could not be displayed because it is too large.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!