Commit 84320495 by Ting PAN

Dismantle Op Kernels

1 parent 96f7277e
Showing with 1472 additions and 1027 deletions
------------------------------------------------------------------------ ------------------------------------------------------------------------
The list of most significant changes made over time in Dragon. The list of most significant changes made over time in Dragon.
Dragon 0.2.2.13 (20181204)
DRAGON_VERSION == 2213
Changes (w.r.t. Dragon 0.2.2.12):
Preview Features:
- Dismantled op kernels.
- Added FP16 support for ``NNResizeOp``, ``ProposalOp``,
``ROIPoolingOp``, ``ROIAlignOp``, (R-CNN components).
- Added ``DepthwiseConv2dOp``.
- [PyTorch] Added ``nn.DepthwiseConv2d``
- [PyCaffe] Added ``DepthwiseConvolutionLayer``.
Bugs fixed:
- Fixed the cuda issue that incorrect results in ``AsTypeOp``
under ``float32 -> float16``.
- Removed the support for group convolution implemented with im2col-nhwc.
- Changed the default ``NHWC`` pack format of filter from
``[filter_height, filter_width, in_channels, out_channels]``(TensorFlow) to
``[out_channels, filter_height, filter_width, in_channels]``(CuDNN).
- Changed the VC Runtime from ``MD`` to ``MT``.
------------------------------------------------------------------------
Dragon 0.2.2.12 (20181120) Dragon 0.2.2.12 (20181120)
DRAGON_VERSION == 2212 DRAGON_VERSION == 2212
...@@ -60,6 +98,7 @@ Preview Features: ...@@ -60,6 +98,7 @@ Preview Features:
- [PyCaffe] Added ``DropBlockLayer``. - [PyCaffe] Added ``DropBlockLayer``.
Bugs fixed: Bugs fixed:
- Fixed the uncomputed output in ``BiasAddGradientOp``. - Fixed the uncomputed output in ``BiasAddGradientOp``.
......
This diff could not be displayed because it is too large.
<!-- HTML footer for doxygen 1.8.14-->
<!-- start footer part -->
<!--BEGIN GENERATE_TREEVIEW-->
<div id="nav-path" class="navpath"><!-- id is needed for treeview function! -->
<ul>
$navpath
<li class="footer">$generatedby
<a href="http://www.doxygen.org/index.html">
<img class="footer" src="$relpath^doxygen.png" alt="doxygen"/></a> $doxygenversion </li>
</ul>
</div>
<!--END GENERATE_TREEVIEW-->
<!--BEGIN !GENERATE_TREEVIEW-->
<hr class="footer"/><address class="footer"><small>
$generatedby &#160;<a href="http://www.doxygen.org/index.html">
<img class="footer" src="$relpath^doxygen.png" alt="doxygen"/>
</a> $doxygenversion
</small></address>
<!--END !GENERATE_TREEVIEW-->
</body>
</html>
\ No newline at end of file
<!-- HTML header for doxygen 1.8.14-->
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
<meta name="generator" content="Doxygen $doxygenversion"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<!--BEGIN PROJECT_NAME--><title>$projectname: $title</title><!--END PROJECT_NAME-->
<!--BEGIN !PROJECT_NAME--><title>$title</title><!--END !PROJECT_NAME-->
<link href="$relpath^tabs.css" rel="stylesheet" type="text/css"/>
<link rel="icon" href="/static/favicon.png" type="image/x-icon">
<script type="text/javascript" src="$relpath^jquery.js"></script>
<script type="text/javascript" src="$relpath^dynsections.js"></script>
$treeview
$search
$mathjax
<link href="$relpath^$stylesheet" rel="stylesheet" type="text/css" />
$extrastylesheet
</head>
<body>
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
<!--BEGIN TITLEAREA-->
<div id="titlearea">
<table cellspacing="0" cellpadding="0">
<tbody>
<tr style="height: 56px;">
<!--BEGIN PROJECT_LOGO-->
<td id="projectlogo" width="112"><a href="http://dragon.seetatech.com"><img alt="Logo" src="$relpath^$projectlogo"/></a></td>
<!--END PROJECT_LOGO-->
<!--BEGIN PROJECT_NAME-->
<td id="projectalign" style="padding-left: 1.5em;">
<div id="projectname">$projectname
<!--BEGIN PROJECT_NUMBER-->&#160;<span id="projectnumber">$projectnumber</span><!--END PROJECT_NUMBER-->
</div>
<!--BEGIN PROJECT_BRIEF--><div id="projectbrief">$projectbrief</div><!--END PROJECT_BRIEF-->
</td>
<!--END PROJECT_NAME-->
<!--BEGIN !PROJECT_NAME-->
<!--BEGIN PROJECT_BRIEF-->
<td style="padding-left: 0.5em;">
<div id="projectbrief">$projectbrief</div>
</td>
<!--END PROJECT_BRIEF-->
<!--END !PROJECT_NAME-->
<!--BEGIN DISABLE_INDEX-->
<!--BEGIN SEARCHENGINE-->
<td>$searchbox</td>
<!--END SEARCHENGINE-->
<!--END DISABLE_INDEX-->
</tr>
</tbody>
</table>
</div>
<!--END TITLEAREA-->
<!-- end header part -->
\ No newline at end of file
<doxygenlayout version="1.0">
<!-- Generated by doxygen 1.8.14 -->
<!-- Navigation index tabs for HTML output -->
<navindex>
<tab type="mainpage" visible="no" title=""/>
<tab type="pages" visible="no" title="" intro=""/>
<tab type="modules" visible="no" title="" intro=""/>
<tab type="namespaces" visible="no" title="">
<tab type="namespacelist" visible="no" title="" intro="/"/>
<tab type="namespacemembers" visible="no" title="" intro=""/>
</tab>
<tab type="classes" visible="no" title="">
<tab type="classlist" visible="yes" title="" intro=""/>
<tab type="classindex" visible="$ALPHABETICAL_INDEX" title=""/>
<tab type="hierarchy" visible="yes" title="" intro=""/>
<tab type="classmembers" visible="yes" title="" intro=""/>
</tab>
<tab type="files" visible="yes" title="">
<tab type="filelist" visible="yes" title="" intro=""/>
<tab type="globals" visible="yes" title="" intro=""/>
</tab>
<tab type="examples" visible="yes" title="" intro=""/>
<tab type="user" url="namespaces.html" title="Namespaces"/>
<tab type="user" url="classes.html" title="Classes"/>
<tab type="user" url="https://github.com/seetaresearch/Dragon" title="GitHub"/>
</navindex>
<!-- Layout definition for a class page -->
<class>
<briefdescription visible="yes"/>
<includes visible="$SHOW_INCLUDE_FILES"/>
<inheritancegraph visible="$CLASS_GRAPH"/>
<collaborationgraph visible="$COLLABORATION_GRAPH"/>
<memberdecl>
<nestedclasses visible="yes" title=""/>
<publictypes title=""/>
<services title=""/>
<interfaces title=""/>
<publicslots title=""/>
<signals title=""/>
<publicmethods title=""/>
<publicstaticmethods title=""/>
<publicattributes title=""/>
<publicstaticattributes title=""/>
<protectedtypes title=""/>
<protectedslots title=""/>
<protectedmethods title=""/>
<protectedstaticmethods title=""/>
<protectedattributes title=""/>
<protectedstaticattributes title=""/>
<packagetypes title=""/>
<packagemethods title=""/>
<packagestaticmethods title=""/>
<packageattributes title=""/>
<packagestaticattributes title=""/>
<properties title=""/>
<events title=""/>
<privatetypes title=""/>
<privateslots title=""/>
<privatemethods title=""/>
<privatestaticmethods title=""/>
<privateattributes title=""/>
<privatestaticattributes title=""/>
<friends title=""/>
<related title="" subtitle=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<typedefs title=""/>
<enums title=""/>
<services title=""/>
<interfaces title=""/>
<constructors title=""/>
<functions title=""/>
<related title=""/>
<variables title=""/>
<properties title=""/>
<events title=""/>
</memberdef>
<allmemberslink visible="yes"/>
<usedfiles visible="$SHOW_USED_FILES"/>
<authorsection visible="yes"/>
</class>
<!-- Layout definition for a namespace page -->
<namespace>
<briefdescription visible="yes"/>
<memberdecl>
<nestednamespaces visible="yes" title=""/>
<constantgroups visible="yes" title=""/>
<classes visible="yes" title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
</memberdef>
<authorsection visible="yes"/>
</namespace>
<!-- Layout definition for a file page -->
<file>
<briefdescription visible="yes"/>
<includes visible="$SHOW_FILES"/>
<includegraph visible="$INCLUDE_GRAPH"/>
<includedbygraph visible="$INCLUDED_BY_GRAPH"/>
<sourcelink visible="yes"/>
<memberdecl>
<classes visible="yes" title=""/>
<namespaces visible="yes" title=""/>
<constantgroups visible="yes" title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
</memberdef>
<authorsection/>
</file>
<!-- Layout definition for a group page -->
<group>
<briefdescription visible="yes"/>
<groupgraph visible="$GROUP_GRAPHS"/>
<memberdecl>
<nestedgroups visible="yes" title=""/>
<dirs visible="yes" title=""/>
<files visible="yes" title=""/>
<namespaces visible="yes" title=""/>
<classes visible="yes" title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<enumvalues title=""/>
<functions title=""/>
<variables title=""/>
<signals title=""/>
<publicslots title=""/>
<protectedslots title=""/>
<privateslots title=""/>
<events title=""/>
<properties title=""/>
<friends title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<pagedocs/>
<inlineclasses title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<enumvalues title=""/>
<functions title=""/>
<variables title=""/>
<signals title=""/>
<publicslots title=""/>
<protectedslots title=""/>
<privateslots title=""/>
<events title=""/>
<properties title=""/>
<friends title=""/>
</memberdef>
<authorsection visible="yes"/>
</group>
<!-- Layout definition for a directory page -->
<directory>
<briefdescription visible="yes"/>
<directorygraph visible="yes"/>
<memberdecl>
<dirs visible="yes"/>
<files visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
</directory>
</doxygenlayout>
\ No newline at end of file
...@@ -7,8 +7,11 @@ cmake_minimum_required(VERSION 3.0.0) ...@@ -7,8 +7,11 @@ cmake_minimum_required(VERSION 3.0.0)
# ---------------- User Config ---------------- # ---------------- User Config ----------------
# Set optional buildings
option(BUILD_PYTHON_API "Set ON to build PYTHON API" ON)
option(BUILD_CXX_API "Set ON to build CXX API" OFF)
# Set optional libraries # Set optional libraries
option(WITH_PYTHON "Set ON to use PYTHON" ON)
option(WITH_CUDA "Set ON to use CUDA" ON) option(WITH_CUDA "Set ON to use CUDA" ON)
option(WITH_CUDNN "Set ON to use CUDNN" ON) option(WITH_CUDNN "Set ON to use CUDNN" ON)
option(WITH_BLAS "Set ON to use BLAS" ON) option(WITH_BLAS "Set ON to use BLAS" ON)
...@@ -19,14 +22,18 @@ option(WITH_MPI_CUDA "Set ON to use MPI-CUDA" OFF) ...@@ -19,14 +22,18 @@ option(WITH_MPI_CUDA "Set ON to use MPI-CUDA" OFF)
option(WITH_MPI_NCCL "Set ON to use MPI-NCCL" OFF) option(WITH_MPI_NCCL "Set ON to use MPI-NCCL" OFF)
# Set your 3rdparty # Set your 3rdparty
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty) if (NOT 3RDPARTY_DIR)
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
endif()
# Set your python "interpreter" if necessary # Set your python "interpreter" if necessary
# if not, a default interpreter will be used # if not, a default interpreter will be used
# here, provide several examples: # here, provide several examples:
# set(PYTHON_EXECUTABLE /usr/bin/python) # Linux & OSX, Builtin Python if (NOT PYTHON_EXECUTABLE)
# set(PYTHON_EXECUTABLE /X/anaconda/bin/python) # Linux & OSX, Anaconda # set(PYTHON_EXECUTABLE /usr/bin/python) # Linux && OSX, Builtin Python
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda # set(PYTHON_EXECUTABLE /X/anaconda/bin/python) # Linux && OSX, Anaconda
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda
endif()
# Set CUDA compiling architecture # Set CUDA compiling architecture
# Remove "compute_70/sm_70" if using CUDA 8.0 # Remove "compute_70/sm_70" if using CUDA 8.0
...@@ -38,8 +45,10 @@ set(CUDA_ARCH -gencode arch=compute_30,code=sm_30 ...@@ -38,8 +45,10 @@ set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
# Set CUDNN Library Dir if necessary (Linux/OSX Only) # Set CUDNN Library Dir if necessary (Linux/OSX Only)
# For Win, Recommend to use ``3RDPARTY_DIR/lib`` # For Win, Recommend to use ``3RDPARTY_DIR/lib``
set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux if (NOT CUDNN_LIBRARY_DIR)
# set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib) # OSX set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux
# set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib) # OSX
endif()
# ---------------- User Config ---------------- # ---------------- User Config ----------------
...@@ -68,7 +77,7 @@ set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux ...@@ -68,7 +77,7 @@ set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux
# ---[ Dependencies # ---[ Dependencies
if (WITH_PYTHON) if (BUILD_PYTHON_API)
include(${PROJECT_SOURCE_DIR}/../CMake/FindPythonLibs.cmake) include(${PROJECT_SOURCE_DIR}/../CMake/FindPythonLibs.cmake)
include(${PROJECT_SOURCE_DIR}/../CMake/FindNumPy.cmake) include(${PROJECT_SOURCE_DIR}/../CMake/FindNumPy.cmake)
endif() endif()
...@@ -88,7 +97,7 @@ set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release" ...@@ -88,7 +97,7 @@ set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release"
include_directories(${3RDPARTY_DIR}/include) include_directories(${3RDPARTY_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/src) include_directories(${PROJECT_SOURCE_DIR}/src)
if (WITH_PYTHON) if (BUILD_PYTHON_API)
include_directories(${PYTHON_INCLUDE_DIRS}) include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${NUMPY_INCLUDE_DIR}) include_directories(${NUMPY_INCLUDE_DIR})
endif() endif()
...@@ -111,7 +120,7 @@ set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" ...@@ -111,7 +120,7 @@ set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix"
set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS}) set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS})
# ---[ Defines # ---[ Defines
if (WITH_PYTHON) if (BUILD_PYTHON_API)
ADD_DEFINITIONS(-DWITH_PYTHON) ADD_DEFINITIONS(-DWITH_PYTHON)
if (${PYTHON_VERSION_MAJOR} STREQUAL "2") if (${PYTHON_VERSION_MAJOR} STREQUAL "2")
message(STATUS "Use Python2 [Optional]") message(STATUS "Use Python2 [Optional]")
...@@ -166,7 +175,10 @@ endif() ...@@ -166,7 +175,10 @@ endif()
# ---[ Flags # ---[ Flags
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if(WIN32) if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /O2") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4819 /wd4244")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler \"/wd 4819\"")
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
string(REPLACE "/O2" "/Ox" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
if (WITH_OMP) if (WITH_OMP)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp")
endif() endif()
...@@ -189,8 +201,12 @@ execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS ...@@ -189,8 +201,12 @@ execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS
execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/dragon.proto) execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/dragon.proto)
# ---[ Subdirectories # ---[ Subdirectories
add_subdirectory(modules/python) if (BUILD_PYTHON_API)
#add_subdirectory(modules/cxx) # Compile CXX module if necessary add_subdirectory(modules/python)
endif()
if (BUILD_CXX_API)
add_subdirectory(modules/cxx)
endif()
# ---[ Utils # ---[ Utils
file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib) file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib)
\ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_COMMON_H_ #ifndef DRAGON_CORE_COMMON_H_
#define DRAGON_CORE_COMMON_H_ #define DRAGON_CORE_COMMON_H_
#include <ctime> #include <ctime>
#include <cmath>
#include <random> #include <random>
#include <climits> #include <climits>
#include <float.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <queue> #include <queue>
...@@ -48,17 +51,17 @@ template <typename Key, typename Value> ...@@ -48,17 +51,17 @@ template <typename Key, typename Value>
using Map = std::unordered_map<Key, Value>; using Map = std::unordered_map<Key, Value>;
template <typename Value> template <typename Value>
using Set = std::unordered_set<Value> ; using Set = std::unordered_set<Value>;
/* * * * * * * * * * * * * * * * * * * * * /* * * * * * * * * * * * * * * * * * * * *
* * * *
* Kernel Version * * Kernel Version *
* * * *
* Major(2) | Minor(2) | Patch(12) * * Major(2) | Minor(2) | Patch(13) *
* * * *
* * * * * * * * * * * * * * * * * * * * */ * * * * * * * * * * * * * * * * * * * * */
#define DRAGON_VERSION 2212 #define DRAGON_VERSION 2213
/* * * * * * * * * * * * * * * * * * * * * /* * * * * * * * * * * * * * * * * * * * *
* * * *
...@@ -74,7 +77,7 @@ using Set = std::unordered_set<Value> ; ...@@ -74,7 +77,7 @@ using Set = std::unordered_set<Value> ;
* * * *
* * * * * * * * * * * * * * * * * * * * */ * * * * * * * * * * * * * * * * * * * * */
// avoid using of "thread_local" for VS2013 or older Xcode // Avoid using of "thread_local" for VS2013 or older Xcode
#if defined(__clang__) || defined(__GNUC__) #if defined(__clang__) || defined(__GNUC__)
#define TLS_OBJECT __thread #define TLS_OBJECT __thread
#else #else
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_H_ #ifndef DRAGON_CORE_CONTEXT_H_
#define DRAGON_CORE_CONTEXT_H_ #define DRAGON_CORE_CONTEXT_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_CNML_H_ #ifndef DRAGON_CORE_CONTEXT_CNML_H_
#define DRAGON_CORE_CONTEXT_CNML_H_ #define DRAGON_CORE_CONTEXT_CNML_H_
/* CAMBRICON's CNRT && CNML Environment */ /*! CAMBRICON's CNRT && CNML Environment */
#include "core/common.h" #include "core/common.h"
...@@ -90,7 +91,7 @@ class CNMLContext { ...@@ -90,7 +91,7 @@ class CNMLContext {
static std::mutex& mutex() { static std::mutex m; return m; } static std::mutex& mutex() { static std::mutex m; return m; }
static thread_local CNRTObject cnrt_object_; static CNRTObject* cuda_object();
private: private:
int device_id_, stream_id_ = 1, random_seed_; int device_id_, stream_id_ = 1, random_seed_;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_CUDA_H_ #ifndef DRAGON_CORE_CONTEXT_CUDA_H_
#define DRAGON_CORE_CONTEXT_CUDA_H_ #define DRAGON_CORE_CONTEXT_CUDA_H_
/* NVIDIA's CUDA Environment */ /*! NVIDIA's CUDA Environment */
#include "core/common.h" #include "core/common.h"
#include "utils/cuda_device.h" #include "utils/cuda_device.h"
...@@ -38,9 +39,10 @@ class CUDAObject { ...@@ -38,9 +39,10 @@ class CUDAObject {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) { for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
for (int j = 0; j < cuda_streams[i].size(); j++) { for (int j = 0; j < cuda_streams[i].size(); j++) {
auto& stream = cuda_streams[i][j]; auto& stream = cuda_streams[i][j];
// follow the caffe2, do not check the stream destroying /*!
// Error code 29 (driver shutting down) is inevitable * Do not check the stream destroying,
// TODO(PhyscalX): Can someone solve this issue? * error code 29 (driver shutting down) is inevitable.
*/
if (stream) cudaStreamDestroy(stream); if (stream) cudaStreamDestroy(stream);
} }
for (auto& handle : cublas_handles[i]) for (auto& handle : cublas_handles[i])
...@@ -52,14 +54,20 @@ class CUDAObject { ...@@ -52,14 +54,20 @@ class CUDAObject {
} }
} }
// follow the caffe2, /*!
// each device takes a group of non-blocking streams * Follow the caffe2,
// the stream 0 is reserved for default stream, * Each device takes a group of non-blocking streams.
// as some computations really require it, *
// e.g. cublas.asum() and mixed cpu/cuda operations * The stream 0 is reserved for default stream,
// besides, somes calls, such as cudnn.conv() and cudnn.rnn(), * as some computations really require it,
// produce wrong results if running them on non-blocking streams * e.g. cublas.asum() and mixed cpu/cuda operations.
// note that caffe2 also uses default streams (within CuDNNState) *
* Besides, somes calls, such as cudnn.conv() and cudnn.rnn(),
* and even the simple fp16 conversion, produce wrong results
* if running them on non-blocking streams.
*
* Note that caffe2 also uses default streams (within CuDNNState).
*/
cudaStream_t GetStream(int device_id, int stream_id) { cudaStream_t GetStream(int device_id, int stream_id) {
vector<cudaStream_t>& dev_streams = cuda_streams[device_id]; vector<cudaStream_t>& dev_streams = cuda_streams[device_id];
if (dev_streams.size() <= (unsigned)stream_id) if (dev_streams.size() <= (unsigned)stream_id)
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_GRAPH_H_ #ifndef DRAGON_CORE_GRAPH_H_
#define DRAGON_CORE_GRAPH_H_ #define DRAGON_CORE_GRAPH_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_GRAPH_GRADIENT_H_ #ifndef DRAGON_CORE_GRAPH_GRADIENT_H_
#define DRAGON_CORE_GRAPH_GRADIENT_H_ #define DRAGON_CORE_GRAPH_GRADIENT_H_
...@@ -54,4 +55,4 @@ class GraphGradientMaker { ...@@ -54,4 +55,4 @@ class GraphGradientMaker {
} // namespace dragon } // namespace dragon
#endif #endif // DRAGON_CORE_GRAPH_GRADIENT_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_MIXEDMEM_H_ #ifndef DRAGON_CORE_MIXEDMEM_H_
#define DRAGON_CORE_MIXEDMEM_H_ #define DRAGON_CORE_MIXEDMEM_H_
...@@ -86,7 +87,7 @@ class MixedMemory { ...@@ -86,7 +87,7 @@ class MixedMemory {
void* cpu_ptr_, *cuda_ptr_, *cnml_ptr_; void* cpu_ptr_, *cuda_ptr_, *cnml_ptr_;
int own_cpu_ptr_ = 1, ptr_device_ = 0; int own_cpu_ptr_ = 1, ptr_device_ = 0;
/* For CAMBRICON's CNML Environment */ /*! For CAMBRICON's CNML Environment */
cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr; cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr;
cnmlTensor_t cnml_mlu_tensor_ = nullptr; cnmlTensor_t cnml_mlu_tensor_ = nullptr;
}; };
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_H_ #ifndef DRAGON_CORE_OPERATOR_H_
#define DRAGON_CORE_OPERATOR_H_ #define DRAGON_CORE_OPERATOR_H_
...@@ -188,7 +189,7 @@ DECLARE_REGISTRY( ...@@ -188,7 +189,7 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
Workspace*); Workspace*);
/* NVIDIA's Accelerated Library - CUDNN */ /*! NVIDIA's Accelerated Library - CUDNN */
DECLARE_REGISTRY( DECLARE_REGISTRY(
CUDNNOperatorRegistry, CUDNNOperatorRegistry,
...@@ -196,7 +197,7 @@ DECLARE_REGISTRY( ...@@ -196,7 +197,7 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
Workspace*); Workspace*);
/* CAMBRICON's Accelerated Library - CNML */ /*! CAMBRICON's Accelerated Library - CNML */
DECLARE_REGISTRY( DECLARE_REGISTRY(
CNMLOperatorRegistry, CNMLOperatorRegistry,
...@@ -247,7 +248,8 @@ DECLARE_REGISTRY( ...@@ -247,7 +248,8 @@ DECLARE_REGISTRY(
} }
#define INIT_MULTIPLIER(ptr_tensor, size) { \ #define INIT_MULTIPLIER(ptr_tensor, size) { \
ptr_tensor = ws()->CreateTensor("/share/multiplier"); \ ptr_tensor = ws()->CreateTensor("/share/multiplier/" \
+ TypeMetaToString(TypeMeta::Make<T>())); \
if (size > ptr_tensor->count()) { \ if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape({ size }); \ ptr_tensor->Reshape({ size }); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \ math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_ #ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_
#define DRAGON_CORE_OPERATOR_GRADIENT_H_ #define DRAGON_CORE_OPERATOR_GRADIENT_H_
...@@ -58,7 +59,7 @@ class GradientMakerBase { ...@@ -58,7 +59,7 @@ class GradientMakerBase {
} }
virtual inline vector<float> DefaultValues() { virtual inline vector<float> DefaultValues() {
return vector<float>(g_outputs_.size(), 1.0); return vector<float>(g_outputs_.size(), 1.f);
} }
template <class... Args> template <class... Args>
...@@ -82,7 +83,7 @@ class GradientMakerBase { ...@@ -82,7 +83,7 @@ class GradientMakerBase {
const vector<string>& g_outputs_; const vector<string>& g_outputs_;
}; };
// implemented in operator.cc // Implemented in operator.cc
Gradient MakeGradientForOp( Gradient MakeGradientForOp(
const OperatorDef& op_def, const OperatorDef& op_def,
const vector<string>& g_outputs); const vector<string>& g_outputs);
...@@ -111,7 +112,7 @@ DECLARE_REGISTRY( ...@@ -111,7 +112,7 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&);
// define in the operator.cc // Defined in the operator.cc
#define REGISTER_GRADIENT(name, ...) \ #define REGISTER_GRADIENT(name, ...) \
REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_ #ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_
#define DRAGON_CORE_OPERATOR_SCHEMA_H_ #define DRAGON_CORE_OPERATOR_SCHEMA_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_REGISTRY_H_ #ifndef DRAGON_CORE_REGISTRY_H_
#define DRAGON_CORE_REGISTRY_H_ #define DRAGON_CORE_REGISTRY_H_
...@@ -66,12 +67,12 @@ class Registerer { ...@@ -66,12 +67,12 @@ class Registerer {
} }
}; };
// use in *.h files // Used in *.h files
#define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType,...) \ #define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType,...) \
dragon::Registry<SrcType, ObjType,##__VA_ARGS__>* RegistryName(); \ dragon::Registry<SrcType, ObjType,##__VA_ARGS__>* RegistryName(); \
typedef dragon::Registerer<SrcType,ObjType,##__VA_ARGS__> Registerer##RegistryName; typedef dragon::Registerer<SrcType,ObjType,##__VA_ARGS__> Registerer##RegistryName;
// use in *.cc files // Used in *.cc files
#define DEFINE_TYPED_REGISTRY(RegistryName,SrcType, ObjType,...) \ #define DEFINE_TYPED_REGISTRY(RegistryName,SrcType, ObjType,...) \
Registry<SrcType,ObjType,##__VA_ARGS__>* RegistryName() { \ Registry<SrcType,ObjType,##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType,ObjType,##__VA_ARGS__>* registry = \ static Registry<SrcType,ObjType,##__VA_ARGS__>* registry = \
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TENSOR_H_ #ifndef DRAGON_CORE_TENSOR_H_
#define DRAGON_CORE_TENSOR_H_ #define DRAGON_CORE_TENSOR_H_
...@@ -41,8 +42,7 @@ class Tensor { ...@@ -41,8 +42,7 @@ class Tensor {
} }
} else { } else {
if (ex_memory_ && !is_shared_ && if (ex_memory_ && !is_shared_ &&
TIndex(ex_memory_->nbytes()) < capacity_ < TIndex(new_size * meta_.itemsize())) {
TIndex(new_size * meta_.itemsize())) {
delete ex_memory_; delete ex_memory_;
ex_memory_ = nullptr; ex_memory_ = nullptr;
capacity_ = 0; capacity_ = 0;
...@@ -194,22 +194,24 @@ class Tensor { ...@@ -194,22 +194,24 @@ class Tensor {
void* raw_mutable_data(const TypeMeta& meta) { void* raw_mutable_data(const TypeMeta& meta) {
void* data_ptr; void* data_ptr;
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
// Return the memory directly
if (meta_ == meta && data_ptr) return data_ptr; if (meta_ == meta && data_ptr) return data_ptr;
if (meta_ != meta && data_ptr && !own_mem_) delete ex_memory_; // Return the new memory
meta_ = meta; meta_ = meta;
CHECK_GT(size_, 0); CHECK_GT(size_, 0);
if (own_mem_) { if (own_mem_) {
memory_.reset(new MixedMemory( memory_.reset(new MixedMemory(
meta, size_* meta_.itemsize())); meta_, size_* meta_.itemsize()));
} else { } else {
if (data_ptr) delete ex_memory_;
ex_memory_ = new MixedMemory( ex_memory_ = new MixedMemory(
meta, size_* meta_.itemsize()); meta_, size_* meta_.itemsize());
} }
// malloc memory // Malloc
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
// call the constructors // Call the constructors
if (meta.ctor()) meta_.ctor()(data_ptr, size_); if (meta_.ctor()) meta_.ctor()(data_ptr, size_);
capacity_ = size_ * meta.itemsize(), require_init_ = true; capacity_ = size_ * meta_.itemsize(), require_init_ = true;
return data_ptr; return data_ptr;
} }
...@@ -274,6 +276,7 @@ class Tensor { ...@@ -274,6 +276,7 @@ class Tensor {
TypeMeta::Make<float>(), 4); TypeMeta::Make<float>(), 4);
require_init_ = true; require_init_ = true;
} own_mem_ = false; } own_mem_ = false;
capacity_ = (TIndex)ex_memory_->nbytes();
} }
inline void Share(MixedMemory* mem) { inline void Share(MixedMemory* mem) {
...@@ -290,7 +293,7 @@ class Tensor { ...@@ -290,7 +293,7 @@ class Tensor {
} }
std::function<void()> DECREFPyArray; std::function<void()> DECREFPyArray;
~Tensor() { /* DO NOT CALL DECREFARRAY */ } ~Tensor() { /*! DO NOT CALL DECREFARRAY */ }
private: private:
vector<TIndex> dims_; vector<TIndex> dims_;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TYPEID_H_ #ifndef DRAGON_CORE_TYPEID_H_
#define DRAGON_CORE_TYPEID_H_ #define DRAGON_CORE_TYPEID_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TYPES_H_ #ifndef DRAGON_CORE_TYPES_H_
#define DRAGON_CORE_TYPES_H_ #define DRAGON_CORE_TYPES_H_
#include <cstdint>
#include <unordered_map> #include <unordered_map>
#include "core/typeid.h" #include "core/typeid.h"
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_WORKSPACE_H_ #ifndef DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_ #define DRAGON_CORE_WORKSPACE_H_
...@@ -62,7 +63,7 @@ class Workspace { ...@@ -62,7 +63,7 @@ class Workspace {
} }
inline void ClearWorkspace() { inline void ClearWorkspace() {
// clear tensors & buffers & re-initialization // Clear tensors, then re-initialization
for (auto& kv : tensor_map_) kv.second->Reset(); for (auto& kv : tensor_map_) kv.second->Reset();
InitWorkspace(); InitWorkspace();
} }
...@@ -79,11 +80,11 @@ class Workspace { ...@@ -79,11 +80,11 @@ class Workspace {
const string& name, const string& name,
bool use_remote = true) { bool use_remote = true) {
string query = GetTensorName(name); string query = GetTensorName(name);
// search local workspace // Search local workspace
if (tensor_map_.count(query) > 0) if (tensor_map_.count(query) > 0)
return tensor_map_[query].get(); return tensor_map_[query].get();
if (use_remote) { if (use_remote) {
// search remote workspace // Search remote workspace
for (auto& it : ws_map_) { for (auto& it : ws_map_) {
if (it.second->HasTensor(query)) if (it.second->HasTensor(query))
return it.second->GetTensor(query); return it.second->GetTensor(query);
...@@ -125,10 +126,10 @@ class Workspace { ...@@ -125,10 +126,10 @@ class Workspace {
vector<string> GetTensors() { vector<string> GetTensors() {
vector<string> names; vector<string> names;
// search local workspace // Search local workspace
for (auto& it : tensor_map_) for (auto& it : tensor_map_)
names.push_back(it.first); names.push_back(it.first);
// serach remote workspace // Serach remote workspace
for (auto& it : ws_map_) { for (auto& it : ws_map_) {
vector<string> sub_names = it.second->GetTensors(); vector<string> sub_names = it.second->GetTensors();
names.insert(names.end(), names.insert(names.end(),
...@@ -142,11 +143,11 @@ class Workspace { ...@@ -142,11 +143,11 @@ class Workspace {
inline bool HasFiller( inline bool HasFiller(
const string& name, const string& name,
bool use_remote = true) { bool use_remote = true) {
// search local workspace // Search local workspace
bool result = filler_map_.count(name) > 0; bool result = filler_map_.count(name) > 0;
if (!use_remote) return result; if (!use_remote) return result;
// search remote workspace // Search remote workspace
for (auto& it : ws_map_) for (auto& it : ws_map_)
result |= it.second->HasFiller(name); result |= it.second->HasFiller(name);
return result; return result;
...@@ -162,11 +163,11 @@ class Workspace { ...@@ -162,11 +163,11 @@ class Workspace {
inline const TensorFiller* GetFiller( inline const TensorFiller* GetFiller(
const string& name) { const string& name) {
// search local workspace // Search local workspace
if (filler_map_.count(name) > 0) if (filler_map_.count(name) > 0)
return &filler_map_[name]; return &filler_map_[name];
// search remote workspace // Search remote workspace
for (auto& it : ws_map_) { for (auto& it : ws_map_) {
if (it.second->HasFiller(name)) if (it.second->HasFiller(name))
return it.second->GetFiller(name); return it.second->GetFiller(name);
...@@ -238,11 +239,11 @@ class Workspace { ...@@ -238,11 +239,11 @@ class Workspace {
persistent_key = arg.s(); persistent_key = arg.s();
} }
if (persistent_key.empty()) { if (persistent_key.empty()) {
// run op in the "ONCE" mode // Run op in the "ONCE" mode
unique_ptr<OperatorBase> op(CreateOperator(meta_op, this)); unique_ptr<OperatorBase> op(CreateOperator(meta_op, this));
op->Run(); op->Run();
} else { } else {
// run op in the "PERSISTENT" mode // Run op in the "PERSISTENT" mode
if (!op_map_.count(persistent_key)) if (!op_map_.count(persistent_key))
op_map_[persistent_key] = unique_ptr<OperatorBase>( op_map_[persistent_key] = unique_ptr<OperatorBase>(
CreateOperator(meta_op, this)); CreateOperator(meta_op, this));
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
...@@ -21,7 +22,7 @@ class EluOp : public Operator<Context> { ...@@ -21,7 +22,7 @@ class EluOp : public Operator<Context> {
public: public:
EluOp(const OperatorDef& def, Workspace* ws) EluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.0)) {} alpha(OperatorBase::Arg<float>("alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -36,7 +37,7 @@ class EluGradientOp : public Operator<Context> { ...@@ -36,7 +37,7 @@ class EluGradientOp : public Operator<Context> {
public: public:
EluGradientOp(const OperatorDef& def, Workspace* ws) EluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.0)) {} alpha(OperatorBase::Arg<float>("alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP #ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
#define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP #define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
...@@ -81,7 +82,7 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -81,7 +82,7 @@ class CuDNNAffineOpBase : public Operator<Context> {
template <typename T> template <typename T>
void ResetDesc(const Tensor& X) { void ResetDesc(const Tensor& X) {
// determine the range of affine // Determine the range of affine
start_axis = axis; start_axis = axis;
if (start_axis < 0) start_axis += (int)X.ndim(); if (start_axis < 0) start_axis += (int)X.ndim();
if (num_axes == -1) num_axes = (int)X.ndim() - start_axis; if (num_axes == -1) num_axes = (int)X.ndim() - start_axis;
...@@ -89,14 +90,14 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -89,14 +90,14 @@ class CuDNNAffineOpBase : public Operator<Context> {
end_axis = start_axis + num_axes; end_axis = start_axis + num_axes;
CHECK_LT(start_axis, (int)X.ndim()); CHECK_LT(start_axis, (int)X.ndim());
CHECK_LE(start_axis + num_axes, (int)X.ndim()); CHECK_LE(start_axis + num_axes, (int)X.ndim());
// determine the input desc // Determine the input desc
vector<TIndex> input_dims = X.dims(); vector<TIndex> input_dims = X.dims();
// cudnn requires ndimensions range from [4, 5] // CuDNN requires ndimensions range from [4, 5]
if (input_dims.size() < 4) input_dims.resize(4, 1); if (input_dims.size() < 4) input_dims.resize(4, 1);
else if (input_dims.size() > 5) else if (input_dims.size() > 5)
LOG(FATAL) << "CuDNN Affine the dimensions up to 5."; LOG(FATAL) << "CuDNN Affine the dimensions up to 5.";
cudnnSetTensorDesc<T>(&input_desc, input_dims); cudnnSetTensorDesc<T>(&input_desc, input_dims);
// determine the scale desc // Determine the scale desc
vector<TIndex> param_dims(input_dims.size(), 1); vector<TIndex> param_dims(input_dims.size(), 1);
for (int i = start_axis; i < end_axis; i++) for (int i = start_axis; i < end_axis; i++)
param_dims[i] = input_dims[i]; param_dims[i] = input_dims[i];
...@@ -127,24 +128,32 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> { ...@@ -127,24 +128,32 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> {
: CuDNNAffineOpBase<Context>(def, ws) {} : CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename DT, typename CT> void RunWithType();
protected: protected:
USE_CUDNN_AFFINE_FUCNTIONS; USE_CUDNN_AFFINE_FUCNTIONS;
}; };
template <class Context> template <class Context>
class CuDNNAffineGradientOp final : public CuDNNAffineOpBase<Context> { class CuDNNAffineGradientOp final
: public CuDNNAffineOpBase<Context> {
public: public:
CuDNNAffineGradientOp(const OperatorDef& def, Workspace* ws) CuDNNAffineGradientOp(
const OperatorDef& def,
Workspace* ws)
: CuDNNAffineOpBase<Context>(def, ws) {} : CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ComputeScaleGradient(T* dYxX, T* dA);
template <typename DT, typename CT>
void ComputeScaleGradient(DT* dYxX, DT* dA);
template <typename DT, typename CT>
void ComputeBiasGradient(const DT* dY, DT* dB);
template <typename T> void ComputeScaleGradient_v2(T* dYxX, T* dA); template <typename T> void ComputeScaleGradient_v2(T* dYxX, T* dA);
template <typename T> void ComputeBiasGradient(const T* dY, T* dB);
template <typename T> void ComputeBiasGradient_v2(const T* dY, T* dB); template <typename T> void ComputeBiasGradient_v2(const T* dY, T* dB);
template <typename T> void RunWithType();
template <typename DT, typename CT> void RunWithType();
protected: protected:
USE_CUDNN_AFFINE_FUCNTIONS; USE_CUDNN_AFFINE_FUCNTIONS;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#include <float.h>
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
...@@ -21,9 +22,9 @@ class PowOp final : public Operator<Context> { ...@@ -21,9 +22,9 @@ class PowOp final : public Operator<Context> {
public: public:
PowOp(const OperatorDef& def, Workspace* ws) PowOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.0)), scale(OperatorBase::Arg<float>("scale", 1.f)),
shift(OperatorBase::Arg<float>("shift", 0.0)), shift(OperatorBase::Arg<float>("shift", 0.f)),
power(OperatorBase::Arg<float>("power", 1.0)) { power(OperatorBase::Arg<float>("power", 1.f)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -40,9 +41,9 @@ class PowGradientOp final : public Operator<Context> { ...@@ -40,9 +41,9 @@ class PowGradientOp final : public Operator<Context> {
public: public:
PowGradientOp(const OperatorDef& def, Workspace* ws) PowGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.0)), scale(OperatorBase::Arg<float>("scale", 1.f)),
shift(OperatorBase::Arg<float>("shift", 0.0)), shift(OperatorBase::Arg<float>("shift", 0.f)),
power(OperatorBase::Arg<float>("power", 1.0)) { power(OperatorBase::Arg<float>("power", 1.f)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
#include "core/operator.h" #include "core/operator.h"
...@@ -58,4 +59,4 @@ class SigmoidCrossEntropyGradientOp ...@@ -58,4 +59,4 @@ class SigmoidCrossEntropyGradientOp
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
...@@ -21,7 +22,7 @@ class SmoothL1LossOp final : public Operator<Context> { ...@@ -21,7 +22,7 @@ class SmoothL1LossOp final : public Operator<Context> {
public: public:
SmoothL1LossOp(const OperatorDef& def, Workspace* ws) SmoothL1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.0)), beta(OperatorBase::Arg<float>("beta", 1.f)),
normalization(OperatorBase::Arg<string>( normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {} "normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -40,7 +41,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> { ...@@ -40,7 +41,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
public: public:
SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws) SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.0)), beta(OperatorBase::Arg<float>("beta", 1.f)),
normalization(OperatorBase::Arg<string>( normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {} "normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_ #define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_
#include "operators/loss/sparse_softmax_cross_entropy_op.h" #include "operators/loss/sparse_softmax_ce_loss_op.h"
namespace dragon { namespace dragon {
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_ #define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
...@@ -78,4 +79,4 @@ class SparseSoftmaxCrossEntropyGradientOp ...@@ -78,4 +79,4 @@ class SparseSoftmaxCrossEntropyGradientOp
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CE_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_ACCURACY_OP_H_ #ifndef DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
#define DRAGON_OPERATORS_MISC_ACCURACY_OP_H_ #define DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_ASTYPE_OP_H_ #ifndef DRAGON_OPERATORS_MISC_ASTYPE_OP_H_
#define DRAGON_OPERATORS_MISC_ASTYPE_OP_H_ #define DRAGON_OPERATORS_MISC_ASTYPE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_GRADIENT_OP_H_ #ifndef DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
#define DRAGON_OPERATORS_MISC_GRADIENT_OP_H_ #define DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_ #ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_ #define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
#define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
...@@ -44,7 +45,7 @@ class FillOp final : public Operator<Context> { ...@@ -44,7 +45,7 @@ class FillOp final : public Operator<Context> {
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
shape_desc(OperatorBase::Arg<string>("shape", "")), shape_desc(OperatorBase::Arg<string>("shape", "")),
dtype(OperatorBase::Arg<string>("dtype", "float32")), dtype(OperatorBase::Arg<string>("dtype", "float32")),
value(OperatorBase::Arg<float>("value", 0.0)) { value(OperatorBase::Arg<float>("value", 0.f)) {
GET_ARGUMENTS_WITH_DESC(int, dims); GET_ARGUMENTS_WITH_DESC(int, dims);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -64,8 +65,8 @@ public: ...@@ -64,8 +65,8 @@ public:
RandomUniformOp(const OperatorDef& def, Workspace* ws) RandomUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("uniform"); this->filler.set_type("uniform");
this->filler.set_low(OperatorBase::Arg<float>("low", -1.0)); this->filler.set_low(OperatorBase::Arg<float>("low", -1.f));
this->filler.set_high(OperatorBase::Arg<float>("high", 1.0)); this->filler.set_high(OperatorBase::Arg<float>("high", 1.f));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
...@@ -76,8 +77,8 @@ public: ...@@ -76,8 +77,8 @@ public:
RandomNormalOp(const OperatorDef& def, Workspace* ws) RandomNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("normal"); this->filler.set_type("normal");
this->filler.set_mean(OperatorBase::Arg<float>("mean", 0.0)); this->filler.set_mean(OperatorBase::Arg<float>("mean", 0.f));
this->filler.set_std(OperatorBase::Arg<float>("std", 1.0)); this->filler.set_std(OperatorBase::Arg<float>("std", 1.f));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
...@@ -88,8 +89,8 @@ public: ...@@ -88,8 +89,8 @@ public:
TruncatedNormalOp(const OperatorDef& def, Workspace* ws) TruncatedNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("truncated_normal"); this->filler.set_type("truncated_normal");
float mu = OperatorBase::Arg<float>("mean", 0.0); float mu = OperatorBase::Arg<float>("mean", 0.f);
float sigma = OperatorBase::Arg<float>("std", 1.0); float sigma = OperatorBase::Arg<float>("std", 1.f);
this->filler.set_mean(mu); this->filler.set_mean(mu);
this->filler.set_std(sigma); this->filler.set_std(sigma);
this->filler.set_low(mu - 2 * sigma); this->filler.set_low(mu - 2 * sigma);
...@@ -104,7 +105,7 @@ public: ...@@ -104,7 +105,7 @@ public:
GlorotUniformOp(const OperatorDef& def, Workspace* ws) GlorotUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in"); string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::Arg<float>("scale", 3.0); float scale = OperatorBase::Arg<float>("scale", 3.f);
this->filler.set_type("xavier"); this->filler.set_type("xavier");
if (mode == "fan_avg") { if (mode == "fan_avg") {
...@@ -125,7 +126,7 @@ public: ...@@ -125,7 +126,7 @@ public:
GlorotNormalOp(const OperatorDef& def, Workspace* ws) GlorotNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in"); string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::Arg<float>("scale", 2.0); float scale = OperatorBase::Arg<float>("scale", 2.f);
this->filler.set_type("msra"); this->filler.set_type("msra");
if (mode == "fan_avg") { if (mode == "fan_avg") {
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_PYTHON_OP_H_ #ifndef DRAGON_OPERATORS_MISC_PYTHON_OP_H_
#define DRAGON_OPERATORS_MISC_PYTHON_OP_H_ #define DRAGON_OPERATORS_MISC_PYTHON_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_ #ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_ #define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
...@@ -56,13 +57,13 @@ class ModelMPIBase : public Operator<Context> { ...@@ -56,13 +57,13 @@ class ModelMPIBase : public Operator<Context> {
string dtype; string dtype;
}; };
#define USE_MPIMODEL_FUNCTIONS(context) \ #define USE_MODEL_MPI_FUNCTIONS \
using ModelMPIBase<context>::comm; \ using ModelMPIBase<Context>::comm; \
using ModelMPIBase<context>::mpi_dtype; \ using ModelMPIBase<Context>::mpi_dtype; \
using ModelMPIBase<context>::comm_size; \ using ModelMPIBase<Context>::comm_size; \
using ModelMPIBase<context>::comm_rank; \ using ModelMPIBase<Context>::comm_rank; \
using ModelMPIBase<context>::comm_root; \ using ModelMPIBase<Context>::comm_root; \
using ModelMPIBase<context>::dtype using ModelMPIBase<Context>::dtype
} // namespace dragon } // namespace dragon
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_ #ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_ #define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
...@@ -24,7 +25,7 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> { ...@@ -24,7 +25,7 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
MPIBroadcastOp(const OperatorDef& def, Workspace* ws) MPIBroadcastOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -36,7 +37,7 @@ public: ...@@ -36,7 +37,7 @@ public:
MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws) MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_ #ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_ #define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
...@@ -24,7 +25,7 @@ class MPIGatherOp final : public ModelMPIBase<Context> { ...@@ -24,7 +25,7 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
MPIGatherOp(const OperatorDef& def, Workspace *ws) MPIGatherOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -36,7 +37,7 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> { ...@@ -36,7 +37,7 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> {
MPIGatherGradientOp(const OperatorDef& def, Workspace *ws) MPIGatherGradientOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_CROP_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CROP_OP_H_ #define DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_
#define DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_ #define DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_ #define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
#define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_ #define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
...@@ -24,7 +25,7 @@ class PadOp final : public Operator<Context> { ...@@ -24,7 +25,7 @@ class PadOp final : public Operator<Context> {
pad_l(OperatorBase::Args<int>("pad_l")), pad_l(OperatorBase::Args<int>("pad_l")),
pad_r(OperatorBase::Args<int>("pad_r")), pad_r(OperatorBase::Args<int>("pad_r")),
mode(OperatorBase::Arg<string>("mode", "CONSTANT")), mode(OperatorBase::Arg<string>("mode", "CONSTANT")),
value(OperatorBase::Arg<float>("value", 0.0f)) { value(OperatorBase::Arg<float>("value", 0.f)) {
if (pad_r.size() == 0) pad_r = pad_l; if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size()) else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length."; << "The pad_l and pad_r should have the same length.";
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_ #define DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_STACK_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_STACK_OP_H_ #define DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_TILE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TILE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_ #define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_ #ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_
#define DRAGON_OPERATORS_NORM_L2_NORM_H_ #define DRAGON_OPERATORS_NORM_L2_NORM_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
...@@ -52,7 +53,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> { ...@@ -52,7 +53,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
hidden_size(OperatorBase::Arg<int>("hidden_size", 0)), hidden_size(OperatorBase::Arg<int>("hidden_size", 0)),
num_layers(OperatorBase::Arg<int>("num_layers", 1)), num_layers(OperatorBase::Arg<int>("num_layers", 1)),
bidirectional(OperatorBase::Arg<bool>("bidirectional", false)), bidirectional(OperatorBase::Arg<bool>("bidirectional", false)),
dropout_ratio(OperatorBase::Arg<float>("dropout_ratio", 1.0)), dropout_ratio(OperatorBase::Arg<float>("dropout_ratio", 1.f)),
random_seed(def.device_option().random_seed()) { random_seed(def.device_option().random_seed()) {
// determine the rnn direction // determine the rnn direction
rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_ #define DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
#define DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_ #define DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
#define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_ #define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
...@@ -21,7 +22,7 @@ class MovingAverageOp final : public Operator<Context> { ...@@ -21,7 +22,7 @@ class MovingAverageOp final : public Operator<Context> {
public: public:
MovingAverageOp(const OperatorDef& def, Workspace* ws) MovingAverageOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
decay(OperatorBase::Arg<float>("decay", 1.0)) {} decay(OperatorBase::Arg<float>("decay", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_ #ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
#define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_ #define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
...@@ -21,8 +22,8 @@ class UpdateOpBase : public Operator<Context> { ...@@ -21,8 +22,8 @@ class UpdateOpBase : public Operator<Context> {
public: public:
UpdateOpBase(const OperatorDef& def, Workspace* ws) UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
lr_mult(OperatorBase::Arg<float>("lr_mult", 1.0)), lr_mult(OperatorBase::Arg<float>("lr_mult", 1.f)),
decay_mult(OperatorBase::Arg<float>("decay_mult", 1.0)), decay_mult(OperatorBase::Arg<float>("decay_mult", 1.f)),
slot(OperatorBase::Arg<string>("slot", "")), slot(OperatorBase::Arg<string>("slot", "")),
zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) { zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot"; CHECK(!slot.empty()) << "\nRequired a non-empty slot";
......
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!