Commit 84320495 by Ting PAN

Dismantle Op Kernels

1 parent 96f7277e
Showing with 1762 additions and 1315 deletions
------------------------------------------------------------------------ ------------------------------------------------------------------------
The list of most significant changes made over time in Dragon. The list of most significant changes made over time in Dragon.
Dragon 0.2.2.13 (20181204)
DRAGON_VERSION == 2213
Changes (w.r.t. Dragon 0.2.2.12):
Preview Features:
- Dismantled op kernels.
- Added FP16 support for ``NNResizeOp``, ``ProposalOp``,
``ROIPoolingOp``, ``ROIAlignOp``, (R-CNN components).
- Added ``DepthwiseConv2dOp``.
- [PyTorch] Added ``nn.DepthwiseConv2d``
- [PyCaffe] Added ``DepthwiseConvolutionLayer``.
Bugs fixed:
- Fixed the cuda issue that incorrect results in ``AsTypeOp``
under ``float32 -> float16``.
- Removed the support for group convolution implemented with im2col-nhwc.
- Changed the default ``NHWC`` pack format of filter from
``[filter_height, filter_width, in_channels, out_channels]``(TensorFlow) to
``[out_channels, filter_height, filter_width, in_channels]``(CuDNN).
- Changed the VC Runtime from ``MD`` to ``MT``.
------------------------------------------------------------------------
Dragon 0.2.2.12 (20181120) Dragon 0.2.2.12 (20181120)
DRAGON_VERSION == 2212 DRAGON_VERSION == 2212
...@@ -60,6 +98,7 @@ Preview Features: ...@@ -60,6 +98,7 @@ Preview Features:
- [PyCaffe] Added ``DropBlockLayer``. - [PyCaffe] Added ``DropBlockLayer``.
Bugs fixed: Bugs fixed:
- Fixed the uncomputed output in ``BiasAddGradientOp``. - Fixed the uncomputed output in ``BiasAddGradientOp``.
......
This diff could not be displayed because it is too large.
<!-- HTML footer for doxygen 1.8.14-->
<!-- start footer part -->
<!--BEGIN GENERATE_TREEVIEW-->
<div id="nav-path" class="navpath"><!-- id is needed for treeview function! -->
<ul>
$navpath
<li class="footer">$generatedby
<a href="http://www.doxygen.org/index.html">
<img class="footer" src="$relpath^doxygen.png" alt="doxygen"/></a> $doxygenversion </li>
</ul>
</div>
<!--END GENERATE_TREEVIEW-->
<!--BEGIN !GENERATE_TREEVIEW-->
<hr class="footer"/><address class="footer"><small>
$generatedby &#160;<a href="http://www.doxygen.org/index.html">
<img class="footer" src="$relpath^doxygen.png" alt="doxygen"/>
</a> $doxygenversion
</small></address>
<!--END !GENERATE_TREEVIEW-->
</body>
</html>
\ No newline at end of file
<!-- HTML header for doxygen 1.8.14-->
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
<meta name="generator" content="Doxygen $doxygenversion"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<!--BEGIN PROJECT_NAME--><title>$projectname: $title</title><!--END PROJECT_NAME-->
<!--BEGIN !PROJECT_NAME--><title>$title</title><!--END !PROJECT_NAME-->
<link href="$relpath^tabs.css" rel="stylesheet" type="text/css"/>
<link rel="icon" href="/static/favicon.png" type="image/x-icon">
<script type="text/javascript" src="$relpath^jquery.js"></script>
<script type="text/javascript" src="$relpath^dynsections.js"></script>
$treeview
$search
$mathjax
<link href="$relpath^$stylesheet" rel="stylesheet" type="text/css" />
$extrastylesheet
</head>
<body>
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
<!--BEGIN TITLEAREA-->
<div id="titlearea">
<table cellspacing="0" cellpadding="0">
<tbody>
<tr style="height: 56px;">
<!--BEGIN PROJECT_LOGO-->
<td id="projectlogo" width="112"><a href="http://dragon.seetatech.com"><img alt="Logo" src="$relpath^$projectlogo"/></a></td>
<!--END PROJECT_LOGO-->
<!--BEGIN PROJECT_NAME-->
<td id="projectalign" style="padding-left: 1.5em;">
<div id="projectname">$projectname
<!--BEGIN PROJECT_NUMBER-->&#160;<span id="projectnumber">$projectnumber</span><!--END PROJECT_NUMBER-->
</div>
<!--BEGIN PROJECT_BRIEF--><div id="projectbrief">$projectbrief</div><!--END PROJECT_BRIEF-->
</td>
<!--END PROJECT_NAME-->
<!--BEGIN !PROJECT_NAME-->
<!--BEGIN PROJECT_BRIEF-->
<td style="padding-left: 0.5em;">
<div id="projectbrief">$projectbrief</div>
</td>
<!--END PROJECT_BRIEF-->
<!--END !PROJECT_NAME-->
<!--BEGIN DISABLE_INDEX-->
<!--BEGIN SEARCHENGINE-->
<td>$searchbox</td>
<!--END SEARCHENGINE-->
<!--END DISABLE_INDEX-->
</tr>
</tbody>
</table>
</div>
<!--END TITLEAREA-->
<!-- end header part -->
\ No newline at end of file
<doxygenlayout version="1.0">
<!-- Generated by doxygen 1.8.14 -->
<!-- Navigation index tabs for HTML output -->
<navindex>
<tab type="mainpage" visible="no" title=""/>
<tab type="pages" visible="no" title="" intro=""/>
<tab type="modules" visible="no" title="" intro=""/>
<tab type="namespaces" visible="no" title="">
<tab type="namespacelist" visible="no" title="" intro="/"/>
<tab type="namespacemembers" visible="no" title="" intro=""/>
</tab>
<tab type="classes" visible="no" title="">
<tab type="classlist" visible="yes" title="" intro=""/>
<tab type="classindex" visible="$ALPHABETICAL_INDEX" title=""/>
<tab type="hierarchy" visible="yes" title="" intro=""/>
<tab type="classmembers" visible="yes" title="" intro=""/>
</tab>
<tab type="files" visible="yes" title="">
<tab type="filelist" visible="yes" title="" intro=""/>
<tab type="globals" visible="yes" title="" intro=""/>
</tab>
<tab type="examples" visible="yes" title="" intro=""/>
<tab type="user" url="namespaces.html" title="Namespaces"/>
<tab type="user" url="classes.html" title="Classes"/>
<tab type="user" url="https://github.com/seetaresearch/Dragon" title="GitHub"/>
</navindex>
<!-- Layout definition for a class page -->
<class>
<briefdescription visible="yes"/>
<includes visible="$SHOW_INCLUDE_FILES"/>
<inheritancegraph visible="$CLASS_GRAPH"/>
<collaborationgraph visible="$COLLABORATION_GRAPH"/>
<memberdecl>
<nestedclasses visible="yes" title=""/>
<publictypes title=""/>
<services title=""/>
<interfaces title=""/>
<publicslots title=""/>
<signals title=""/>
<publicmethods title=""/>
<publicstaticmethods title=""/>
<publicattributes title=""/>
<publicstaticattributes title=""/>
<protectedtypes title=""/>
<protectedslots title=""/>
<protectedmethods title=""/>
<protectedstaticmethods title=""/>
<protectedattributes title=""/>
<protectedstaticattributes title=""/>
<packagetypes title=""/>
<packagemethods title=""/>
<packagestaticmethods title=""/>
<packageattributes title=""/>
<packagestaticattributes title=""/>
<properties title=""/>
<events title=""/>
<privatetypes title=""/>
<privateslots title=""/>
<privatemethods title=""/>
<privatestaticmethods title=""/>
<privateattributes title=""/>
<privatestaticattributes title=""/>
<friends title=""/>
<related title="" subtitle=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<typedefs title=""/>
<enums title=""/>
<services title=""/>
<interfaces title=""/>
<constructors title=""/>
<functions title=""/>
<related title=""/>
<variables title=""/>
<properties title=""/>
<events title=""/>
</memberdef>
<allmemberslink visible="yes"/>
<usedfiles visible="$SHOW_USED_FILES"/>
<authorsection visible="yes"/>
</class>
<!-- Layout definition for a namespace page -->
<namespace>
<briefdescription visible="yes"/>
<memberdecl>
<nestednamespaces visible="yes" title=""/>
<constantgroups visible="yes" title=""/>
<classes visible="yes" title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
</memberdef>
<authorsection visible="yes"/>
</namespace>
<!-- Layout definition for a file page -->
<file>
<briefdescription visible="yes"/>
<includes visible="$SHOW_FILES"/>
<includegraph visible="$INCLUDE_GRAPH"/>
<includedbygraph visible="$INCLUDED_BY_GRAPH"/>
<sourcelink visible="yes"/>
<memberdecl>
<classes visible="yes" title=""/>
<namespaces visible="yes" title=""/>
<constantgroups visible="yes" title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
</memberdef>
<authorsection/>
</file>
<!-- Layout definition for a group page -->
<group>
<briefdescription visible="yes"/>
<groupgraph visible="$GROUP_GRAPHS"/>
<memberdecl>
<nestedgroups visible="yes" title=""/>
<dirs visible="yes" title=""/>
<files visible="yes" title=""/>
<namespaces visible="yes" title=""/>
<classes visible="yes" title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<enumvalues title=""/>
<functions title=""/>
<variables title=""/>
<signals title=""/>
<publicslots title=""/>
<protectedslots title=""/>
<privateslots title=""/>
<events title=""/>
<properties title=""/>
<friends title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<pagedocs/>
<inlineclasses title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<enumvalues title=""/>
<functions title=""/>
<variables title=""/>
<signals title=""/>
<publicslots title=""/>
<protectedslots title=""/>
<privateslots title=""/>
<events title=""/>
<properties title=""/>
<friends title=""/>
</memberdef>
<authorsection visible="yes"/>
</group>
<!-- Layout definition for a directory page -->
<directory>
<briefdescription visible="yes"/>
<directorygraph visible="yes"/>
<memberdecl>
<dirs visible="yes"/>
<files visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
</directory>
</doxygenlayout>
\ No newline at end of file
...@@ -7,39 +7,48 @@ cmake_minimum_required(VERSION 3.0.0) ...@@ -7,39 +7,48 @@ cmake_minimum_required(VERSION 3.0.0)
# ---------------- User Config ---------------- # ---------------- User Config ----------------
# Set optional buildings
option(BUILD_PYTHON_API "Set ON to build PYTHON API" ON)
option(BUILD_CXX_API "Set ON to build CXX API" OFF)
# Set optional libraries # Set optional libraries
option(WITH_PYTHON "Set ON to use PYTHON" ON) option(WITH_CUDA "Set ON to use CUDA" ON)
option(WITH_CUDA "Set ON to use CUDA" ON) option(WITH_CUDNN "Set ON to use CUDNN" ON)
option(WITH_CUDNN "Set ON to use CUDNN" ON) option(WITH_BLAS "Set ON to use BLAS" ON)
option(WITH_BLAS "Set ON to use BLAS" ON) option(WITH_OMP "Set ON to use OpenMP" ON)
option(WITH_OMP "Set ON to use OpenMP" ON) option(WITH_SSE "Set ON to use SSE 4.1" ON)
option(WITH_SSE "Set ON to use SSE 4.1" ON) option(WITH_MPI "Set ON to use MPI" OFF)
option(WITH_MPI "Set ON to use MPI" OFF) option(WITH_MPI_CUDA "Set ON to use MPI-CUDA" OFF)
option(WITH_MPI_CUDA "Set ON to use MPI-CUDA" OFF) option(WITH_MPI_NCCL "Set ON to use MPI-NCCL" OFF)
option(WITH_MPI_NCCL "Set ON to use MPI-NCCL" OFF)
# Set your 3rdparty # Set your 3rdparty
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty) if (NOT 3RDPARTY_DIR)
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
endif()
# Set your python "interpreter" if necessary # Set your python "interpreter" if necessary
# if not, a default interpreter will be used # if not, a default interpreter will be used
# here, provide several examples: # here, provide several examples:
# set(PYTHON_EXECUTABLE /usr/bin/python) # Linux & OSX, Builtin Python if (NOT PYTHON_EXECUTABLE)
# set(PYTHON_EXECUTABLE /X/anaconda/bin/python) # Linux & OSX, Anaconda # set(PYTHON_EXECUTABLE /usr/bin/python) # Linux && OSX, Builtin Python
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda # set(PYTHON_EXECUTABLE /X/anaconda/bin/python) # Linux && OSX, Anaconda
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda
endif()
# Set CUDA compiling architecture # Set CUDA compiling architecture
# Remove "compute_70/sm_70" if using CUDA 8.0 # Remove "compute_70/sm_70" if using CUDA 8.0
set(CUDA_ARCH -gencode arch=compute_30,code=sm_30 set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
-gencode arch=compute_35,code=sm_35 -gencode arch=compute_35,code=sm_35
-gencode arch=compute_50,code=sm_50 -gencode arch=compute_50,code=sm_50
-gencode arch=compute_60,code=sm_60 -gencode arch=compute_60,code=sm_60
-gencode arch=compute_70,code=sm_70) -gencode arch=compute_70,code=sm_70)
# Set CUDNN Library Dir if necessary (Linux/OSX Only) # Set CUDNN Library Dir if necessary (Linux/OSX Only)
# For Win, Recommend to use ``3RDPARTY_DIR/lib`` # For Win, Recommend to use ``3RDPARTY_DIR/lib``
set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux if (NOT CUDNN_LIBRARY_DIR)
# set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib) # OSX set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux
# set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib) # OSX
endif()
# ---------------- User Config ---------------- # ---------------- User Config ----------------
...@@ -68,7 +77,7 @@ set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux ...@@ -68,7 +77,7 @@ set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux
# ---[ Dependencies # ---[ Dependencies
if (WITH_PYTHON) if (BUILD_PYTHON_API)
include(${PROJECT_SOURCE_DIR}/../CMake/FindPythonLibs.cmake) include(${PROJECT_SOURCE_DIR}/../CMake/FindPythonLibs.cmake)
include(${PROJECT_SOURCE_DIR}/../CMake/FindNumPy.cmake) include(${PROJECT_SOURCE_DIR}/../CMake/FindNumPy.cmake)
endif() endif()
...@@ -88,7 +97,7 @@ set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release" ...@@ -88,7 +97,7 @@ set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release"
include_directories(${3RDPARTY_DIR}/include) include_directories(${3RDPARTY_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/src) include_directories(${PROJECT_SOURCE_DIR}/src)
if (WITH_PYTHON) if (BUILD_PYTHON_API)
include_directories(${PYTHON_INCLUDE_DIRS}) include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${NUMPY_INCLUDE_DIR}) include_directories(${NUMPY_INCLUDE_DIR})
endif() endif()
...@@ -111,7 +120,7 @@ set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" ...@@ -111,7 +120,7 @@ set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix"
set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS}) set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS})
# ---[ Defines # ---[ Defines
if (WITH_PYTHON) if (BUILD_PYTHON_API)
ADD_DEFINITIONS(-DWITH_PYTHON) ADD_DEFINITIONS(-DWITH_PYTHON)
if (${PYTHON_VERSION_MAJOR} STREQUAL "2") if (${PYTHON_VERSION_MAJOR} STREQUAL "2")
message(STATUS "Use Python2 [Optional]") message(STATUS "Use Python2 [Optional]")
...@@ -166,7 +175,10 @@ endif() ...@@ -166,7 +175,10 @@ endif()
# ---[ Flags # ---[ Flags
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if(WIN32) if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /O2") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4819 /wd4244")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler \"/wd 4819\"")
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
string(REPLACE "/O2" "/Ox" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
if (WITH_OMP) if (WITH_OMP)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp")
endif() endif()
...@@ -189,8 +201,12 @@ execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS ...@@ -189,8 +201,12 @@ execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS
execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/dragon.proto) execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/dragon.proto)
# ---[ Subdirectories # ---[ Subdirectories
add_subdirectory(modules/python) if (BUILD_PYTHON_API)
#add_subdirectory(modules/cxx) # Compile CXX module if necessary add_subdirectory(modules/python)
endif()
if (BUILD_CXX_API)
add_subdirectory(modules/cxx)
endif()
# ---[ Utils # ---[ Utils
file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib) file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib)
\ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_COMMON_H_ #ifndef DRAGON_CORE_COMMON_H_
#define DRAGON_CORE_COMMON_H_ #define DRAGON_CORE_COMMON_H_
#include <ctime> #include <ctime>
#include <cmath>
#include <random> #include <random>
#include <climits> #include <climits>
#include <float.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <queue> #include <queue>
...@@ -48,17 +51,17 @@ template <typename Key, typename Value> ...@@ -48,17 +51,17 @@ template <typename Key, typename Value>
using Map = std::unordered_map<Key, Value>; using Map = std::unordered_map<Key, Value>;
template <typename Value> template <typename Value>
using Set = std::unordered_set<Value> ; using Set = std::unordered_set<Value>;
/* * * * * * * * * * * * * * * * * * * * * /* * * * * * * * * * * * * * * * * * * * *
* * * *
* Kernel Version * * Kernel Version *
* * * *
* Major(2) | Minor(2) | Patch(12) * * Major(2) | Minor(2) | Patch(13) *
* * * *
* * * * * * * * * * * * * * * * * * * * */ * * * * * * * * * * * * * * * * * * * * */
#define DRAGON_VERSION 2212 #define DRAGON_VERSION 2213
/* * * * * * * * * * * * * * * * * * * * * /* * * * * * * * * * * * * * * * * * * * *
* * * *
...@@ -74,7 +77,7 @@ using Set = std::unordered_set<Value> ; ...@@ -74,7 +77,7 @@ using Set = std::unordered_set<Value> ;
* * * *
* * * * * * * * * * * * * * * * * * * * */ * * * * * * * * * * * * * * * * * * * * */
// avoid using of "thread_local" for VS2013 or older Xcode // Avoid using of "thread_local" for VS2013 or older Xcode
#if defined(__clang__) || defined(__GNUC__) #if defined(__clang__) || defined(__GNUC__)
#define TLS_OBJECT __thread #define TLS_OBJECT __thread
#else #else
...@@ -86,6 +89,6 @@ using Set = std::unordered_set<Value> ; ...@@ -86,6 +89,6 @@ using Set = std::unordered_set<Value> ;
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__) #define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__)
#define NOT_IMPLEMENTED LOG(FATAL) << "This module has not been implemented yet." #define NOT_IMPLEMENTED LOG(FATAL) << "This module has not been implemented yet."
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_COMMON_H_ #endif // DRAGON_CORE_COMMON_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_H_ #ifndef DRAGON_CORE_CONTEXT_H_
#define DRAGON_CORE_CONTEXT_H_ #define DRAGON_CORE_CONTEXT_H_
...@@ -102,6 +103,6 @@ class CPUContext { ...@@ -102,6 +103,6 @@ class CPUContext {
#define CPU_FP16_NOT_SUPPORTED \ #define CPU_FP16_NOT_SUPPORTED \
LOG(FATAL) << "FP16 is unsupported for CPUContext."; LOG(FATAL) << "FP16 is unsupported for CPUContext.";
} // namepsace dragon } // namepsace dragon
#endif // DRAGON_CORE_CONTEXT_H_ #endif // DRAGON_CORE_CONTEXT_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_CNML_H_ #ifndef DRAGON_CORE_CONTEXT_CNML_H_
#define DRAGON_CORE_CONTEXT_CNML_H_ #define DRAGON_CORE_CONTEXT_CNML_H_
/* CAMBRICON's CNRT && CNML Environment */ /*! CAMBRICON's CNRT && CNML Environment */
#include "core/common.h" #include "core/common.h"
...@@ -90,13 +91,13 @@ class CNMLContext { ...@@ -90,13 +91,13 @@ class CNMLContext {
static std::mutex& mutex() { static std::mutex m; return m; } static std::mutex& mutex() { static std::mutex m; return m; }
static thread_local CNRTObject cnrt_object_; static CNRTObject* cuda_object();
private: private:
int device_id_, stream_id_ = 1, random_seed_; int device_id_, stream_id_ = 1, random_seed_;
unique_ptr<std::mt19937> rand_generator_; unique_ptr<std::mt19937> rand_generator_;
}; };
} // namepsace dragon } // namepsace dragon
#endif // DRAGON_CORE_CONTEXT_CNML_H_ #endif // DRAGON_CORE_CONTEXT_CNML_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_CUDA_H_ #ifndef DRAGON_CORE_CONTEXT_CUDA_H_
#define DRAGON_CORE_CONTEXT_CUDA_H_ #define DRAGON_CORE_CONTEXT_CUDA_H_
/* NVIDIA's CUDA Environment */ /*! NVIDIA's CUDA Environment */
#include "core/common.h" #include "core/common.h"
#include "utils/cuda_device.h" #include "utils/cuda_device.h"
...@@ -38,9 +39,10 @@ class CUDAObject { ...@@ -38,9 +39,10 @@ class CUDAObject {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) { for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
for (int j = 0; j < cuda_streams[i].size(); j++) { for (int j = 0; j < cuda_streams[i].size(); j++) {
auto& stream = cuda_streams[i][j]; auto& stream = cuda_streams[i][j];
// follow the caffe2, do not check the stream destroying /*!
// Error code 29 (driver shutting down) is inevitable * Do not check the stream destroying,
// TODO(PhyscalX): Can someone solve this issue? * error code 29 (driver shutting down) is inevitable.
*/
if (stream) cudaStreamDestroy(stream); if (stream) cudaStreamDestroy(stream);
} }
for (auto& handle : cublas_handles[i]) for (auto& handle : cublas_handles[i])
...@@ -52,14 +54,20 @@ class CUDAObject { ...@@ -52,14 +54,20 @@ class CUDAObject {
} }
} }
// follow the caffe2, /*!
// each device takes a group of non-blocking streams * Follow the caffe2,
// the stream 0 is reserved for default stream, * Each device takes a group of non-blocking streams.
// as some computations really require it, *
// e.g. cublas.asum() and mixed cpu/cuda operations * The stream 0 is reserved for default stream,
// besides, somes calls, such as cudnn.conv() and cudnn.rnn(), * as some computations really require it,
// produce wrong results if running them on non-blocking streams * e.g. cublas.asum() and mixed cpu/cuda operations.
// note that caffe2 also uses default streams (within CuDNNState) *
* Besides, somes calls, such as cudnn.conv() and cudnn.rnn(),
* and even the simple fp16 conversion, produce wrong results
* if running them on non-blocking streams.
*
* Note that caffe2 also uses default streams (within CuDNNState).
*/
cudaStream_t GetStream(int device_id, int stream_id) { cudaStream_t GetStream(int device_id, int stream_id) {
vector<cudaStream_t>& dev_streams = cuda_streams[device_id]; vector<cudaStream_t>& dev_streams = cuda_streams[device_id];
if (dev_streams.size() <= (unsigned)stream_id) if (dev_streams.size() <= (unsigned)stream_id)
...@@ -333,8 +341,8 @@ class CUDAContext { ...@@ -333,8 +341,8 @@ class CUDAContext {
inline void set_stream_id(int stream_id) {} inline void set_stream_id(int stream_id) {}
}; };
#endif // WITH_CUDA #endif // WITH_CUDA
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_CONTEXT_CUDA_H_ #endif // DRAGON_CORE_CONTEXT_CUDA_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_GRAPH_H_ #ifndef DRAGON_CORE_GRAPH_H_
#define DRAGON_CORE_GRAPH_H_ #define DRAGON_CORE_GRAPH_H_
...@@ -101,6 +102,6 @@ DECLARE_REGISTRY( ...@@ -101,6 +102,6 @@ DECLARE_REGISTRY(
#define REGISTER_GRAPH(name, ...) \ #define REGISTER_GRAPH(name, ...) \
REGISTER_CLASS(GraphRegistry, name, __VA_ARGS__) REGISTER_CLASS(GraphRegistry, name, __VA_ARGS__)
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_GRAPH_H_ #endif // DRAGON_CORE_GRAPH_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_GRAPH_GRADIENT_H_ #ifndef DRAGON_CORE_GRAPH_GRADIENT_H_
#define DRAGON_CORE_GRAPH_GRADIENT_H_ #define DRAGON_CORE_GRAPH_GRADIENT_H_
...@@ -52,6 +53,6 @@ class GraphGradientMaker { ...@@ -52,6 +53,6 @@ class GraphGradientMaker {
int cur_op_idx_; int cur_op_idx_;
}; };
} // namespace dragon } // namespace dragon
#endif #endif // DRAGON_CORE_GRAPH_GRADIENT_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_MIXEDMEM_H_ #ifndef DRAGON_CORE_MIXEDMEM_H_
#define DRAGON_CORE_MIXEDMEM_H_ #define DRAGON_CORE_MIXEDMEM_H_
...@@ -86,11 +87,11 @@ class MixedMemory { ...@@ -86,11 +87,11 @@ class MixedMemory {
void* cpu_ptr_, *cuda_ptr_, *cnml_ptr_; void* cpu_ptr_, *cuda_ptr_, *cnml_ptr_;
int own_cpu_ptr_ = 1, ptr_device_ = 0; int own_cpu_ptr_ = 1, ptr_device_ = 0;
/* For CAMBRICON's CNML Environment */ /*! For CAMBRICON's CNML Environment */
cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr; cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr;
cnmlTensor_t cnml_mlu_tensor_ = nullptr; cnmlTensor_t cnml_mlu_tensor_ = nullptr;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_MIXEDMEM_H_ #endif // DRAGON_CORE_MIXEDMEM_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_H_ #ifndef DRAGON_CORE_OPERATOR_H_
#define DRAGON_CORE_OPERATOR_H_ #define DRAGON_CORE_OPERATOR_H_
...@@ -188,7 +189,7 @@ DECLARE_REGISTRY( ...@@ -188,7 +189,7 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
Workspace*); Workspace*);
/* NVIDIA's Accelerated Library - CUDNN */ /*! NVIDIA's Accelerated Library - CUDNN */
DECLARE_REGISTRY( DECLARE_REGISTRY(
CUDNNOperatorRegistry, CUDNNOperatorRegistry,
...@@ -196,7 +197,7 @@ DECLARE_REGISTRY( ...@@ -196,7 +197,7 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
Workspace*); Workspace*);
/* CAMBRICON's Accelerated Library - CNML */ /*! CAMBRICON's Accelerated Library - CNML */
DECLARE_REGISTRY( DECLARE_REGISTRY(
CNMLOperatorRegistry, CNMLOperatorRegistry,
...@@ -247,7 +248,8 @@ DECLARE_REGISTRY( ...@@ -247,7 +248,8 @@ DECLARE_REGISTRY(
} }
#define INIT_MULTIPLIER(ptr_tensor, size) { \ #define INIT_MULTIPLIER(ptr_tensor, size) { \
ptr_tensor = ws()->CreateTensor("/share/multiplier"); \ ptr_tensor = ws()->CreateTensor("/share/multiplier/" \
+ TypeMetaToString(TypeMeta::Make<T>())); \
if (size > ptr_tensor->count()) { \ if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape({ size }); \ ptr_tensor->Reshape({ size }); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \ math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
...@@ -358,6 +360,6 @@ DECLARE_REGISTRY( ...@@ -358,6 +360,6 @@ DECLARE_REGISTRY(
REGISTER_CNML_OPERATOR(name, CnML##name##Op<CNMLContext>); \ REGISTER_CNML_OPERATOR(name, CnML##name##Op<CNMLContext>); \
INSTANTIATE_CNML_OPERATOR(name); INSTANTIATE_CNML_OPERATOR(name);
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_OPERATOR_H_ #endif // DRAGON_CORE_OPERATOR_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_ #ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_
#define DRAGON_CORE_OPERATOR_GRADIENT_H_ #define DRAGON_CORE_OPERATOR_GRADIENT_H_
...@@ -58,7 +59,7 @@ class GradientMakerBase { ...@@ -58,7 +59,7 @@ class GradientMakerBase {
} }
virtual inline vector<float> DefaultValues() { virtual inline vector<float> DefaultValues() {
return vector<float>(g_outputs_.size(), 1.0); return vector<float>(g_outputs_.size(), 1.f);
} }
template <class... Args> template <class... Args>
...@@ -82,7 +83,7 @@ class GradientMakerBase { ...@@ -82,7 +83,7 @@ class GradientMakerBase {
const vector<string>& g_outputs_; const vector<string>& g_outputs_;
}; };
// implemented in operator.cc // Implemented in operator.cc
Gradient MakeGradientForOp( Gradient MakeGradientForOp(
const OperatorDef& op_def, const OperatorDef& op_def,
const vector<string>& g_outputs); const vector<string>& g_outputs);
...@@ -111,7 +112,7 @@ DECLARE_REGISTRY( ...@@ -111,7 +112,7 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&);
// define in the operator.cc // Defined in the operator.cc
#define REGISTER_GRADIENT(name, ...) \ #define REGISTER_GRADIENT(name, ...) \
REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
...@@ -119,6 +120,6 @@ DECLARE_REGISTRY( ...@@ -119,6 +120,6 @@ DECLARE_REGISTRY(
REGISTER_GRADIENT(name, NoGradient); \ REGISTER_GRADIENT(name, NoGradient); \
REGISTER_CLASS(NoGradientRegistry, name, NoGradient) REGISTER_CLASS(NoGradientRegistry, name, NoGradient)
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_OPERATOR_GRADIENT_H_ #endif // DRAGON_CORE_OPERATOR_GRADIENT_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_ #ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_
#define DRAGON_CORE_OPERATOR_SCHEMA_H_ #define DRAGON_CORE_OPERATOR_SCHEMA_H_
...@@ -101,6 +102,6 @@ class OpSchemaRegistry { ...@@ -101,6 +102,6 @@ class OpSchemaRegistry {
static OpSchema& ANONYMOUS_VARIABLE(name) = \ static OpSchema& ANONYMOUS_VARIABLE(name) = \
OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__)
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_OPERATOR_SCHEMA_H_ #endif // DRAGON_CORE_OPERATOR_SCHEMA_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_REGISTRY_H_ #ifndef DRAGON_CORE_REGISTRY_H_
#define DRAGON_CORE_REGISTRY_H_ #define DRAGON_CORE_REGISTRY_H_
...@@ -66,12 +67,12 @@ class Registerer { ...@@ -66,12 +67,12 @@ class Registerer {
} }
}; };
// use in *.h files // Used in *.h files
#define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType,...) \ #define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType,...) \
dragon::Registry<SrcType, ObjType,##__VA_ARGS__>* RegistryName(); \ dragon::Registry<SrcType, ObjType,##__VA_ARGS__>* RegistryName(); \
typedef dragon::Registerer<SrcType,ObjType,##__VA_ARGS__> Registerer##RegistryName; typedef dragon::Registerer<SrcType,ObjType,##__VA_ARGS__> Registerer##RegistryName;
// use in *.cc files // Used in *.cc files
#define DEFINE_TYPED_REGISTRY(RegistryName,SrcType, ObjType,...) \ #define DEFINE_TYPED_REGISTRY(RegistryName,SrcType, ObjType,...) \
Registry<SrcType,ObjType,##__VA_ARGS__>* RegistryName() { \ Registry<SrcType,ObjType,##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType,ObjType,##__VA_ARGS__>* registry = \ static Registry<SrcType,ObjType,##__VA_ARGS__>* registry = \
...@@ -94,4 +95,4 @@ class Registerer { ...@@ -94,4 +95,4 @@ class Registerer {
} // namepsace dragon } // namepsace dragon
#endif //DRAGON_CORE_REGISTRY_H_ #endif //DRAGON_CORE_REGISTRY_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TENSOR_H_ #ifndef DRAGON_CORE_TENSOR_H_
#define DRAGON_CORE_TENSOR_H_ #define DRAGON_CORE_TENSOR_H_
...@@ -40,9 +41,8 @@ class Tensor { ...@@ -40,9 +41,8 @@ class Tensor {
capacity_ = 0; capacity_ = 0;
} }
} else { } else {
if (ex_memory_ && !is_shared_ && if (ex_memory_ && !is_shared_ &&
TIndex(ex_memory_->nbytes()) < capacity_ < TIndex(new_size * meta_.itemsize())) {
TIndex(new_size * meta_.itemsize())) {
delete ex_memory_; delete ex_memory_;
ex_memory_ = nullptr; ex_memory_ = nullptr;
capacity_ = 0; capacity_ = 0;
...@@ -194,22 +194,24 @@ class Tensor { ...@@ -194,22 +194,24 @@ class Tensor {
void* raw_mutable_data(const TypeMeta& meta) { void* raw_mutable_data(const TypeMeta& meta) {
void* data_ptr; void* data_ptr;
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
// Return the memory directly
if (meta_ == meta && data_ptr) return data_ptr; if (meta_ == meta && data_ptr) return data_ptr;
if (meta_ != meta && data_ptr && !own_mem_) delete ex_memory_; // Return the new memory
meta_ = meta; meta_ = meta;
CHECK_GT(size_, 0); CHECK_GT(size_, 0);
if (own_mem_) { if (own_mem_) {
memory_.reset(new MixedMemory( memory_.reset(new MixedMemory(
meta, size_* meta_.itemsize())); meta_, size_* meta_.itemsize()));
} else { } else {
if (data_ptr) delete ex_memory_;
ex_memory_ = new MixedMemory( ex_memory_ = new MixedMemory(
meta, size_* meta_.itemsize()); meta_, size_* meta_.itemsize());
} }
// malloc memory // Malloc
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
// call the constructors // Call the constructors
if (meta.ctor()) meta_.ctor()(data_ptr, size_); if (meta_.ctor()) meta_.ctor()(data_ptr, size_);
capacity_ = size_ * meta.itemsize(), require_init_ = true; capacity_ = size_ * meta_.itemsize(), require_init_ = true;
return data_ptr; return data_ptr;
} }
...@@ -274,6 +276,7 @@ class Tensor { ...@@ -274,6 +276,7 @@ class Tensor {
TypeMeta::Make<float>(), 4); TypeMeta::Make<float>(), 4);
require_init_ = true; require_init_ = true;
} own_mem_ = false; } own_mem_ = false;
capacity_ = (TIndex)ex_memory_->nbytes();
} }
inline void Share(MixedMemory* mem) { inline void Share(MixedMemory* mem) {
...@@ -290,7 +293,7 @@ class Tensor { ...@@ -290,7 +293,7 @@ class Tensor {
} }
std::function<void()> DECREFPyArray; std::function<void()> DECREFPyArray;
~Tensor() { /* DO NOT CALL DECREFARRAY */ } ~Tensor() { /*! DO NOT CALL DECREFARRAY */ }
private: private:
vector<TIndex> dims_; vector<TIndex> dims_;
...@@ -303,6 +306,6 @@ class Tensor { ...@@ -303,6 +306,6 @@ class Tensor {
bool own_mem_ = true, require_init_ = true; bool own_mem_ = true, require_init_ = true;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_TENSOR_H_ #endif // DRAGON_CORE_TENSOR_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TYPEID_H_ #ifndef DRAGON_CORE_TYPEID_H_
#define DRAGON_CORE_TYPEID_H_ #define DRAGON_CORE_TYPEID_H_
...@@ -37,7 +38,7 @@ class TypeMeta { ...@@ -37,7 +38,7 @@ class TypeMeta {
: id_(0), itemsize_(0), : id_(0), itemsize_(0),
ctor_(nullptr), copy_(nullptr), dtor_(nullptr) {} ctor_(nullptr), copy_(nullptr), dtor_(nullptr) {}
TypeMeta(const TypeMeta& src) TypeMeta(const TypeMeta& src)
: id_(src.id_), itemsize_(src.itemsize_), : id_(src.id_), itemsize_(src.itemsize_),
ctor_(src.ctor_), copy_(src.copy_), dtor_(src.dtor_) {} ctor_(src.ctor_), copy_(src.copy_), dtor_(src.dtor_) {}
...@@ -102,7 +103,7 @@ class TypeMeta { ...@@ -102,7 +103,7 @@ class TypeMeta {
} }
#define FundMeta std::enable_if<std::is_fundamental<T>::value,TypeMeta>::type #define FundMeta std::enable_if<std::is_fundamental<T>::value,TypeMeta>::type
#define StructMeta std::enable_if<!std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, TypeMeta>::type #define StructMeta std::enable_if<!std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, TypeMeta>::type
template <typename T> template <typename T>
static typename FundMeta Make() { static typename FundMeta Make() {
...@@ -134,6 +135,6 @@ class TypeMeta { ...@@ -134,6 +135,6 @@ class TypeMeta {
TypedDestructor dtor_; TypedDestructor dtor_;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_TYPEID_H_ #endif // DRAGON_CORE_TYPEID_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TYPES_H_ #ifndef DRAGON_CORE_TYPES_H_
#define DRAGON_CORE_TYPES_H_ #define DRAGON_CORE_TYPES_H_
#include <cstdint>
#include <unordered_map> #include <unordered_map>
#include "core/typeid.h" #include "core/typeid.h"
...@@ -76,6 +78,6 @@ inline const std::string TypeMetaToString( ...@@ -76,6 +78,6 @@ inline const std::string TypeMetaToString(
m2s_type_map[meta.id()] : "unknown"; m2s_type_map[meta.id()] : "unknown";
} }
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_TYPES_H_ #endif // DRAGON_CORE_TYPES_H_
\ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_WORKSPACE_H_ #ifndef DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_ #define DRAGON_CORE_WORKSPACE_H_
...@@ -62,7 +63,7 @@ class Workspace { ...@@ -62,7 +63,7 @@ class Workspace {
} }
inline void ClearWorkspace() { inline void ClearWorkspace() {
// clear tensors & buffers & re-initialization // Clear tensors, then re-initialization
for (auto& kv : tensor_map_) kv.second->Reset(); for (auto& kv : tensor_map_) kv.second->Reset();
InitWorkspace(); InitWorkspace();
} }
...@@ -79,11 +80,11 @@ class Workspace { ...@@ -79,11 +80,11 @@ class Workspace {
const string& name, const string& name,
bool use_remote = true) { bool use_remote = true) {
string query = GetTensorName(name); string query = GetTensorName(name);
// search local workspace // Search local workspace
if (tensor_map_.count(query) > 0) if (tensor_map_.count(query) > 0)
return tensor_map_[query].get(); return tensor_map_[query].get();
if (use_remote) { if (use_remote) {
// search remote workspace // Search remote workspace
for (auto& it : ws_map_) { for (auto& it : ws_map_) {
if (it.second->HasTensor(query)) if (it.second->HasTensor(query))
return it.second->GetTensor(query); return it.second->GetTensor(query);
...@@ -125,10 +126,10 @@ class Workspace { ...@@ -125,10 +126,10 @@ class Workspace {
vector<string> GetTensors() { vector<string> GetTensors() {
vector<string> names; vector<string> names;
// search local workspace // Search local workspace
for (auto& it : tensor_map_) for (auto& it : tensor_map_)
names.push_back(it.first); names.push_back(it.first);
// serach remote workspace // Serach remote workspace
for (auto& it : ws_map_) { for (auto& it : ws_map_) {
vector<string> sub_names = it.second->GetTensors(); vector<string> sub_names = it.second->GetTensors();
names.insert(names.end(), names.insert(names.end(),
...@@ -142,11 +143,11 @@ class Workspace { ...@@ -142,11 +143,11 @@ class Workspace {
inline bool HasFiller( inline bool HasFiller(
const string& name, const string& name,
bool use_remote = true) { bool use_remote = true) {
// search local workspace // Search local workspace
bool result = filler_map_.count(name) > 0; bool result = filler_map_.count(name) > 0;
if (!use_remote) return result; if (!use_remote) return result;
// search remote workspace // Search remote workspace
for (auto& it : ws_map_) for (auto& it : ws_map_)
result |= it.second->HasFiller(name); result |= it.second->HasFiller(name);
return result; return result;
...@@ -162,11 +163,11 @@ class Workspace { ...@@ -162,11 +163,11 @@ class Workspace {
inline const TensorFiller* GetFiller( inline const TensorFiller* GetFiller(
const string& name) { const string& name) {
// search local workspace // Search local workspace
if (filler_map_.count(name) > 0) if (filler_map_.count(name) > 0)
return &filler_map_[name]; return &filler_map_[name];
// search remote workspace // Search remote workspace
for (auto& it : ws_map_) { for (auto& it : ws_map_) {
if (it.second->HasFiller(name)) if (it.second->HasFiller(name))
return it.second->GetFiller(name); return it.second->GetFiller(name);
...@@ -238,11 +239,11 @@ class Workspace { ...@@ -238,11 +239,11 @@ class Workspace {
persistent_key = arg.s(); persistent_key = arg.s();
} }
if (persistent_key.empty()) { if (persistent_key.empty()) {
// run op in the "ONCE" mode // Run op in the "ONCE" mode
unique_ptr<OperatorBase> op(CreateOperator(meta_op, this)); unique_ptr<OperatorBase> op(CreateOperator(meta_op, this));
op->Run(); op->Run();
} else { } else {
// run op in the "PERSISTENT" mode // Run op in the "PERSISTENT" mode
if (!op_map_.count(persistent_key)) if (!op_map_.count(persistent_key))
op_map_[persistent_key] = unique_ptr<OperatorBase>( op_map_[persistent_key] = unique_ptr<OperatorBase>(
CreateOperator(meta_op, this)); CreateOperator(meta_op, this));
...@@ -294,6 +295,6 @@ class Workspace { ...@@ -294,6 +295,6 @@ class Workspace {
ProxyMap proxy_map_; ProxyMap proxy_map_;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_WORKSPACE_H_ #endif // DRAGON_CORE_WORKSPACE_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
...@@ -129,8 +130,8 @@ DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutGradientOp, prob); ...@@ -129,8 +130,8 @@ DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutGradientOp, prob);
#endif #endif
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
...@@ -21,7 +22,7 @@ class EluOp : public Operator<Context> { ...@@ -21,7 +22,7 @@ class EluOp : public Operator<Context> {
public: public:
EluOp(const OperatorDef& def, Workspace* ws) EluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.0)) {} alpha(OperatorBase::Arg<float>("alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -36,7 +37,7 @@ class EluGradientOp : public Operator<Context> { ...@@ -36,7 +37,7 @@ class EluGradientOp : public Operator<Context> {
public: public:
EluGradientOp(const OperatorDef& def, Workspace* ws) EluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.0)) {} alpha(OperatorBase::Arg<float>("alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -106,8 +107,8 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> { ...@@ -106,8 +107,8 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> {
#endif #endif
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
...@@ -50,6 +51,6 @@ class PReluGradientOp final : public Operator<Context> { ...@@ -50,6 +51,6 @@ class PReluGradientOp final : public Operator<Context> {
string data_format; string data_format;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
...@@ -102,8 +103,8 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> { ...@@ -102,8 +103,8 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
...@@ -36,6 +37,6 @@ class SEluGradientOp final : public Operator<Context> { ...@@ -36,6 +37,6 @@ class SEluGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP #ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
#define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP #define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
...@@ -92,8 +93,8 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> { ...@@ -92,8 +93,8 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP #endif // DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
...@@ -96,8 +97,8 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> { ...@@ -96,8 +97,8 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> {
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc, output_desc;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
} }
#endif // DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
...@@ -92,8 +93,8 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> { ...@@ -92,8 +93,8 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
...@@ -81,7 +82,7 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -81,7 +82,7 @@ class CuDNNAffineOpBase : public Operator<Context> {
template <typename T> template <typename T>
void ResetDesc(const Tensor& X) { void ResetDesc(const Tensor& X) {
// determine the range of affine // Determine the range of affine
start_axis = axis; start_axis = axis;
if (start_axis < 0) start_axis += (int)X.ndim(); if (start_axis < 0) start_axis += (int)X.ndim();
if (num_axes == -1) num_axes = (int)X.ndim() - start_axis; if (num_axes == -1) num_axes = (int)X.ndim() - start_axis;
...@@ -89,14 +90,14 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -89,14 +90,14 @@ class CuDNNAffineOpBase : public Operator<Context> {
end_axis = start_axis + num_axes; end_axis = start_axis + num_axes;
CHECK_LT(start_axis, (int)X.ndim()); CHECK_LT(start_axis, (int)X.ndim());
CHECK_LE(start_axis + num_axes, (int)X.ndim()); CHECK_LE(start_axis + num_axes, (int)X.ndim());
// determine the input desc // Determine the input desc
vector<TIndex> input_dims = X.dims(); vector<TIndex> input_dims = X.dims();
// cudnn requires ndimensions range from [4, 5] // CuDNN requires ndimensions range from [4, 5]
if (input_dims.size() < 4) input_dims.resize(4, 1); if (input_dims.size() < 4) input_dims.resize(4, 1);
else if (input_dims.size() > 5) else if (input_dims.size() > 5)
LOG(FATAL) << "CuDNN Affine the dimensions up to 5."; LOG(FATAL) << "CuDNN Affine the dimensions up to 5.";
cudnnSetTensorDesc<T>(&input_desc, input_dims); cudnnSetTensorDesc<T>(&input_desc, input_dims);
// determine the scale desc // Determine the scale desc
vector<TIndex> param_dims(input_dims.size(), 1); vector<TIndex> param_dims(input_dims.size(), 1);
for (int i = start_axis; i < end_axis; i++) for (int i = start_axis; i < end_axis; i++)
param_dims[i] = input_dims[i]; param_dims[i] = input_dims[i];
...@@ -127,24 +128,32 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> { ...@@ -127,24 +128,32 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> {
: CuDNNAffineOpBase<Context>(def, ws) {} : CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename DT, typename CT> void RunWithType();
protected: protected:
USE_CUDNN_AFFINE_FUCNTIONS; USE_CUDNN_AFFINE_FUCNTIONS;
}; };
template <class Context> template <class Context>
class CuDNNAffineGradientOp final : public CuDNNAffineOpBase<Context> { class CuDNNAffineGradientOp final
: public CuDNNAffineOpBase<Context> {
public: public:
CuDNNAffineGradientOp(const OperatorDef& def, Workspace* ws) CuDNNAffineGradientOp(
const OperatorDef& def,
Workspace* ws)
: CuDNNAffineOpBase<Context>(def, ws) {} : CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ComputeScaleGradient(T* dYxX, T* dA);
template <typename DT, typename CT>
void ComputeScaleGradient(DT* dYxX, DT* dA);
template <typename DT, typename CT>
void ComputeBiasGradient(const DT* dY, DT* dB);
template <typename T> void ComputeScaleGradient_v2(T* dYxX, T* dA); template <typename T> void ComputeScaleGradient_v2(T* dYxX, T* dA);
template <typename T> void ComputeBiasGradient(const T* dY, T* dB);
template <typename T> void ComputeBiasGradient_v2(const T* dY, T* dB); template <typename T> void ComputeBiasGradient_v2(const T* dY, T* dB);
template <typename T> void RunWithType();
template <typename DT, typename CT> void RunWithType();
protected: protected:
USE_CUDNN_AFFINE_FUCNTIONS; USE_CUDNN_AFFINE_FUCNTIONS;
...@@ -154,8 +163,8 @@ protected: ...@@ -154,8 +163,8 @@ protected:
#endif #endif
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#include <float.h>
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
...@@ -49,6 +49,6 @@ class ClipGradientOp final : public Operator<Context> { ...@@ -49,6 +49,6 @@ class ClipGradientOp final : public Operator<Context> {
float low, high; float low, high;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
...@@ -52,6 +53,6 @@ class DotGradientOp final : public Operator<Context> { ...@@ -52,6 +53,6 @@ class DotGradientOp final : public Operator<Context> {
TIndex TransA, TransB, M, K1, K2, N1, N2; TIndex TransA, TransB, M, K1, K2, N1, N2;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
...@@ -64,6 +65,6 @@ class EltwiseGradientOp final : public Operator<Context> { ...@@ -64,6 +65,6 @@ class EltwiseGradientOp final : public Operator<Context> {
vector<float> coeffs; vector<float> coeffs;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
...@@ -36,6 +37,6 @@ class ExpGradientOp final : public Operator<Context> { ...@@ -36,6 +37,6 @@ class ExpGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_
...@@ -297,6 +298,6 @@ class RDivGradientOp final : public Operator<Context> { ...@@ -297,6 +298,6 @@ class RDivGradientOp final : public Operator<Context> {
<< X2->DimString(); \ << X2->DimString(); \
} }
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
...@@ -48,6 +49,6 @@ class GramMatrixGradientOp final : public Operator<Context> { ...@@ -48,6 +49,6 @@ class GramMatrixGradientOp final : public Operator<Context> {
TIndex x_offset, y_offset; TIndex x_offset, y_offset;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
...@@ -51,6 +52,6 @@ class InnerProductGradientOp final : public Operator<Context> { ...@@ -51,6 +52,6 @@ class InnerProductGradientOp final : public Operator<Context> {
TIndex axis, num_output, TransW, M, K; TIndex axis, num_output, TransW, M, K;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
...@@ -36,6 +37,6 @@ class LogGradientOp final : public Operator<Context> { ...@@ -36,6 +37,6 @@ class LogGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
...@@ -50,6 +51,6 @@ class MatmulGradientOp final : public Operator<Context> { ...@@ -50,6 +51,6 @@ class MatmulGradientOp final : public Operator<Context> {
TIndex n, x1_offset, x2_offset, y_offset; TIndex n, x1_offset, x2_offset, y_offset;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_
...@@ -38,6 +39,6 @@ class MaximumGradientOp final : public Operator<Context> { ...@@ -38,6 +39,6 @@ class MaximumGradientOp final : public Operator<Context> {
template <typename T> void BroadcastRunWithType(); template <typename T> void BroadcastRunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_
...@@ -38,6 +39,6 @@ class MinimumGradientOp final : public Operator<Context> { ...@@ -38,6 +39,6 @@ class MinimumGradientOp final : public Operator<Context> {
template <typename T> void BroadcastRunWithType(); template <typename T> void BroadcastRunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
...@@ -21,9 +22,9 @@ class PowOp final : public Operator<Context> { ...@@ -21,9 +22,9 @@ class PowOp final : public Operator<Context> {
public: public:
PowOp(const OperatorDef& def, Workspace* ws) PowOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.0)), scale(OperatorBase::Arg<float>("scale", 1.f)),
shift(OperatorBase::Arg<float>("shift", 0.0)), shift(OperatorBase::Arg<float>("shift", 0.f)),
power(OperatorBase::Arg<float>("power", 1.0)) { power(OperatorBase::Arg<float>("power", 1.f)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -40,9 +41,9 @@ class PowGradientOp final : public Operator<Context> { ...@@ -40,9 +41,9 @@ class PowGradientOp final : public Operator<Context> {
public: public:
PowGradientOp(const OperatorDef& def, Workspace* ws) PowGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.0)), scale(OperatorBase::Arg<float>("scale", 1.f)),
shift(OperatorBase::Arg<float>("shift", 0.0)), shift(OperatorBase::Arg<float>("shift", 0.f)),
power(OperatorBase::Arg<float>("power", 1.0)) { power(OperatorBase::Arg<float>("power", 1.f)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -54,6 +55,6 @@ class PowGradientOp final : public Operator<Context> { ...@@ -54,6 +55,6 @@ class PowGradientOp final : public Operator<Context> {
float scale, shift, power, power_scale; float scale, shift, power, power_scale;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
...@@ -36,6 +37,6 @@ public: ...@@ -36,6 +37,6 @@ public:
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
...@@ -31,6 +32,6 @@ class CompareOp final : public Operator<Context> { ...@@ -31,6 +32,6 @@ class CompareOp final : public Operator<Context> {
string operation; string operation;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_ #endif // DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
...@@ -26,6 +27,6 @@ class CopyOp final : public Operator<Context> { ...@@ -26,6 +27,6 @@ class CopyOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_ #endif // DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
...@@ -82,6 +83,6 @@ class ScanGradientOp final: public Operator<Context> { ...@@ -82,6 +83,6 @@ class ScanGradientOp final: public Operator<Context> {
string step_type, step_tensor; string step_type, step_tensor;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_ #endif // DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
...@@ -82,8 +83,8 @@ class CuDNNCTCLossOp final : public Operator<Context> { ...@@ -82,8 +83,8 @@ class CuDNNCTCLossOp final : public Operator<Context> {
#endif #endif
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
...@@ -50,6 +51,6 @@ class L1LossGradientOp final : public Operator<Context> { ...@@ -50,6 +51,6 @@ class L1LossGradientOp final : public Operator<Context> {
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
...@@ -50,6 +51,6 @@ class L2LossGradientOp final : public Operator<Context> { ...@@ -50,6 +51,6 @@ class L2LossGradientOp final : public Operator<Context> {
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_
...@@ -72,6 +73,6 @@ class NLLLossGradientOp : public Operator<Context> { ...@@ -72,6 +73,6 @@ class NLLLossGradientOp : public Operator<Context> {
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
#include "core/operator.h" #include "core/operator.h"
...@@ -56,6 +57,6 @@ class SigmoidCrossEntropyGradientOp ...@@ -56,6 +57,6 @@ class SigmoidCrossEntropyGradientOp
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_
...@@ -74,6 +75,6 @@ class SigmoidFocalLossGradientOp ...@@ -74,6 +75,6 @@ class SigmoidFocalLossGradientOp
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
...@@ -21,7 +22,7 @@ class SmoothL1LossOp final : public Operator<Context> { ...@@ -21,7 +22,7 @@ class SmoothL1LossOp final : public Operator<Context> {
public: public:
SmoothL1LossOp(const OperatorDef& def, Workspace* ws) SmoothL1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.0)), beta(OperatorBase::Arg<float>("beta", 1.f)),
normalization(OperatorBase::Arg<string>( normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {} "normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -40,7 +41,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> { ...@@ -40,7 +41,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
public: public:
SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws) SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.0)), beta(OperatorBase::Arg<float>("beta", 1.f)),
normalization(OperatorBase::Arg<string>( normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {} "normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -54,6 +55,6 @@ class SmoothL1LossGradientOp final : public Operator<Context> { ...@@ -54,6 +55,6 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_ #define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
...@@ -63,6 +64,6 @@ class SoftmaxCrossEntropyGradientOp ...@@ -63,6 +64,6 @@ class SoftmaxCrossEntropyGradientOp
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_
#include "operators/loss/sparse_softmax_cross_entropy_op.h" #include "operators/loss/sparse_softmax_ce_loss_op.h"
namespace dragon { namespace dragon {
...@@ -74,6 +75,6 @@ class SoftmaxFocalLossGradientOp ...@@ -74,6 +75,6 @@ class SoftmaxFocalLossGradientOp
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_ #define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
...@@ -76,6 +77,6 @@ class SparseSoftmaxCrossEntropyGradientOp ...@@ -76,6 +77,6 @@ class SparseSoftmaxCrossEntropyGradientOp
string normalization; string normalization;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CE_LOSS_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_ACCURACY_OP_H_ #ifndef DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
#define DRAGON_OPERATORS_MISC_ACCURACY_OP_H_ #define DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
...@@ -40,6 +41,6 @@ class AccuracyOp final : public Operator<Context> { ...@@ -40,6 +41,6 @@ class AccuracyOp final : public Operator<Context> {
Tensor ignore; Tensor ignore;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_MISC_ACCURACY_OP_H_ #endif // DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------ * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_ASTYPE_OP_H_ #ifndef DRAGON_OPERATORS_MISC_ASTYPE_OP_H_
#define DRAGON_OPERATORS_MISC_ASTYPE_OP_H_ #define DRAGON_OPERATORS_MISC_ASTYPE_OP_H_
...@@ -32,6 +33,6 @@ class AsTypeOp final : public Operator<Context> { ...@@ -32,6 +33,6 @@ class AsTypeOp final : public Operator<Context> {
bool inplace; bool inplace;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_MISC_ASTYPE_OP_H_ #endif // DRAGON_OPERATORS_MISC_ASTYPE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_GRADIENT_OP_H_ #ifndef DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
#define DRAGON_OPERATORS_MISC_GRADIENT_OP_H_ #define DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
...@@ -60,6 +61,6 @@ class StopGradientOp final : public Operator<Context> { ...@@ -60,6 +61,6 @@ class StopGradientOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_MISC_GRADIENT_OP_H_ #endif // DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_ #ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_ #define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
...@@ -52,6 +53,6 @@ class ImageDataOp final : public Operator<Context> { ...@@ -52,6 +53,6 @@ class ImageDataOp final : public Operator<Context> {
Tensor mean, std; Tensor mean, std;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_ #endif // DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
#define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
...@@ -44,7 +45,7 @@ class FillOp final : public Operator<Context> { ...@@ -44,7 +45,7 @@ class FillOp final : public Operator<Context> {
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
shape_desc(OperatorBase::Arg<string>("shape", "")), shape_desc(OperatorBase::Arg<string>("shape", "")),
dtype(OperatorBase::Arg<string>("dtype", "float32")), dtype(OperatorBase::Arg<string>("dtype", "float32")),
value(OperatorBase::Arg<float>("value", 0.0)) { value(OperatorBase::Arg<float>("value", 0.f)) {
GET_ARGUMENTS_WITH_DESC(int, dims); GET_ARGUMENTS_WITH_DESC(int, dims);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -64,8 +65,8 @@ public: ...@@ -64,8 +65,8 @@ public:
RandomUniformOp(const OperatorDef& def, Workspace* ws) RandomUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("uniform"); this->filler.set_type("uniform");
this->filler.set_low(OperatorBase::Arg<float>("low", -1.0)); this->filler.set_low(OperatorBase::Arg<float>("low", -1.f));
this->filler.set_high(OperatorBase::Arg<float>("high", 1.0)); this->filler.set_high(OperatorBase::Arg<float>("high", 1.f));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
...@@ -76,8 +77,8 @@ public: ...@@ -76,8 +77,8 @@ public:
RandomNormalOp(const OperatorDef& def, Workspace* ws) RandomNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("normal"); this->filler.set_type("normal");
this->filler.set_mean(OperatorBase::Arg<float>("mean", 0.0)); this->filler.set_mean(OperatorBase::Arg<float>("mean", 0.f));
this->filler.set_std(OperatorBase::Arg<float>("std", 1.0)); this->filler.set_std(OperatorBase::Arg<float>("std", 1.f));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
...@@ -88,8 +89,8 @@ public: ...@@ -88,8 +89,8 @@ public:
TruncatedNormalOp(const OperatorDef& def, Workspace* ws) TruncatedNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("truncated_normal"); this->filler.set_type("truncated_normal");
float mu = OperatorBase::Arg<float>("mean", 0.0); float mu = OperatorBase::Arg<float>("mean", 0.f);
float sigma = OperatorBase::Arg<float>("std", 1.0); float sigma = OperatorBase::Arg<float>("std", 1.f);
this->filler.set_mean(mu); this->filler.set_mean(mu);
this->filler.set_std(sigma); this->filler.set_std(sigma);
this->filler.set_low(mu - 2 * sigma); this->filler.set_low(mu - 2 * sigma);
...@@ -104,7 +105,7 @@ public: ...@@ -104,7 +105,7 @@ public:
GlorotUniformOp(const OperatorDef& def, Workspace* ws) GlorotUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in"); string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::Arg<float>("scale", 3.0); float scale = OperatorBase::Arg<float>("scale", 3.f);
this->filler.set_type("xavier"); this->filler.set_type("xavier");
if (mode == "fan_avg") { if (mode == "fan_avg") {
...@@ -125,7 +126,7 @@ public: ...@@ -125,7 +126,7 @@ public:
GlorotNormalOp(const OperatorDef& def, Workspace* ws) GlorotNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in"); string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::Arg<float>("scale", 2.0); float scale = OperatorBase::Arg<float>("scale", 2.f);
this->filler.set_type("msra"); this->filler.set_type("msra");
if (mode == "fan_avg") { if (mode == "fan_avg") {
...@@ -143,6 +144,6 @@ public: ...@@ -143,6 +144,6 @@ public:
DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims); DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims);
DEFINE_ARGUMENTS_WITH_DESC(int, FillOp, dims); DEFINE_ARGUMENTS_WITH_DESC(int, FillOp, dims);
} // namespace } // namespace
#endif // DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #endif // DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_PYTHON_OP_H_ #ifndef DRAGON_OPERATORS_MISC_PYTHON_OP_H_
#define DRAGON_OPERATORS_MISC_PYTHON_OP_H_ #define DRAGON_OPERATORS_MISC_PYTHON_OP_H_
...@@ -52,8 +53,8 @@ public: ...@@ -52,8 +53,8 @@ public:
void RunOnDevice() override; void RunOnDevice() override;
}; };
} // namespace dragon } // namespace dragon
#endif // WITH_PYTHON #endif // WITH_PYTHON
#endif // DRAGON_OPERATORS_MISC_PYTHON_OP_H_ #endif // DRAGON_OPERATORS_MISC_PYTHON_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_ #ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_ #define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
...@@ -56,16 +57,16 @@ class ModelMPIBase : public Operator<Context> { ...@@ -56,16 +57,16 @@ class ModelMPIBase : public Operator<Context> {
string dtype; string dtype;
}; };
#define USE_MPIMODEL_FUNCTIONS(context) \ #define USE_MODEL_MPI_FUNCTIONS \
using ModelMPIBase<context>::comm; \ using ModelMPIBase<Context>::comm; \
using ModelMPIBase<context>::mpi_dtype; \ using ModelMPIBase<Context>::mpi_dtype; \
using ModelMPIBase<context>::comm_size; \ using ModelMPIBase<Context>::comm_size; \
using ModelMPIBase<context>::comm_rank; \ using ModelMPIBase<Context>::comm_rank; \
using ModelMPIBase<context>::comm_root; \ using ModelMPIBase<Context>::comm_root; \
using ModelMPIBase<context>::dtype using ModelMPIBase<Context>::dtype
} // namespace dragon } // namespace dragon
#endif // WITH_MPI #endif // WITH_MPI
#endif // DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_ #endif // DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_ #ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_ #define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
...@@ -24,7 +25,7 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> { ...@@ -24,7 +25,7 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
MPIBroadcastOp(const OperatorDef& def, Workspace* ws) MPIBroadcastOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -36,14 +37,14 @@ public: ...@@ -36,14 +37,14 @@ public:
MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws) MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // WITH_MPI #endif // WITH_MPI
#endif //DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_ #endif //DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_ #ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_ #define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
...@@ -24,7 +25,7 @@ class MPIGatherOp final : public ModelMPIBase<Context> { ...@@ -24,7 +25,7 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
MPIGatherOp(const OperatorDef& def, Workspace *ws) MPIGatherOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -36,14 +37,14 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> { ...@@ -36,14 +37,14 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> {
MPIGatherGradientOp(const OperatorDef& def, Workspace *ws) MPIGatherGradientOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // WITH_MPI #endif // WITH_MPI
#endif // DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_ #endif // DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
...@@ -42,6 +43,6 @@ DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, start); ...@@ -42,6 +43,6 @@ DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, start);
DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, stop); DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, stop);
DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, step); DEFINE_ARGUMENT_WITH_DESC(int, ArangeOp, step);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_ARANGE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_ARANGE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_
...@@ -36,6 +37,6 @@ class ArgReduceOp final : public Operator<Context> { ...@@ -36,6 +37,6 @@ class ArgReduceOp final : public Operator<Context> {
TIndex axis, axis_dim, top_k, count, inner_dim; TIndex axis, axis_dim, top_k, count, inner_dim;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
...@@ -52,6 +53,6 @@ class ConcatGradientOp : public Operator<Context> { ...@@ -52,6 +53,6 @@ class ConcatGradientOp : public Operator<Context> {
vector<TIndex> concat_dims; vector<TIndex> concat_dims;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_CROP_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CROP_OP_H_ #define DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
...@@ -66,6 +67,6 @@ class CropGradientOp final : public Operator<Context> { ...@@ -66,6 +67,6 @@ class CropGradientOp final : public Operator<Context> {
Tensor* dest, *source, navigator; Tensor* dest, *source, navigator;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_CROP_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_
#define DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_ #define DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_
...@@ -152,6 +153,6 @@ public: ...@@ -152,6 +153,6 @@ public:
DEFINE_DIMENSION_GRADIENT_OP(Squeeze); DEFINE_DIMENSION_GRADIENT_OP(Squeeze);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_ #define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
...@@ -49,6 +50,6 @@ class GatherGradientOp final : public Operator<Context> { ...@@ -49,6 +50,6 @@ class GatherGradientOp final : public Operator<Context> {
bool acc_grad; bool acc_grad;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
...@@ -33,6 +34,6 @@ class OneHotOp final : public Operator < Context > { ...@@ -33,6 +34,6 @@ class OneHotOp final : public Operator < Context > {
TIndex depth, on_value, off_value; TIndex depth, on_value, off_value;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
#define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_ #define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
...@@ -24,7 +25,7 @@ class PadOp final : public Operator<Context> { ...@@ -24,7 +25,7 @@ class PadOp final : public Operator<Context> {
pad_l(OperatorBase::Args<int>("pad_l")), pad_l(OperatorBase::Args<int>("pad_l")),
pad_r(OperatorBase::Args<int>("pad_r")), pad_r(OperatorBase::Args<int>("pad_r")),
mode(OperatorBase::Arg<string>("mode", "CONSTANT")), mode(OperatorBase::Arg<string>("mode", "CONSTANT")),
value(OperatorBase::Arg<float>("value", 0.0f)) { value(OperatorBase::Arg<float>("value", 0.f)) {
if (pad_r.size() == 0) pad_r = pad_l; if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size()) else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length."; << "The pad_l and pad_r should have the same length.";
...@@ -85,6 +86,6 @@ class PadGradientOp final : public Operator<Context> { ...@@ -85,6 +86,6 @@ class PadGradientOp final : public Operator<Context> {
Tensor* dest, *source, navigator; Tensor* dest, *source, navigator;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_PAD_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_ #define DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
...@@ -52,6 +53,6 @@ protected: ...@@ -52,6 +53,6 @@ protected:
Tensor* pick_indices; Tensor* pick_indices;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
...@@ -53,6 +54,6 @@ class ReduceGradientOp final : public Operator<Context> { ...@@ -53,6 +54,6 @@ class ReduceGradientOp final : public Operator<Context> {
string operation; string operation;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
...@@ -55,6 +56,6 @@ class RepeatGradientOp final : public Operator<Context> { ...@@ -55,6 +56,6 @@ class RepeatGradientOp final : public Operator<Context> {
DEFINE_ARGUMENT_WITH_DESC(int, RepeatOp, repeats); DEFINE_ARGUMENT_WITH_DESC(int, RepeatOp, repeats);
DEFINE_ARGUMENT_WITH_DESC(int, RepeatGradientOp, repeats); DEFINE_ARGUMENT_WITH_DESC(int, RepeatGradientOp, repeats);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
...@@ -24,6 +25,6 @@ class ShapeOp final : public Operator<Context> { ...@@ -24,6 +25,6 @@ class ShapeOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
}; };
} // namespace dragon } // namespace dragon
#endif //DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_ #endif //DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
...@@ -54,6 +55,6 @@ class SliceGradientOp final : public Operator<Context> { ...@@ -54,6 +55,6 @@ class SliceGradientOp final : public Operator<Context> {
vector<TIndex> slice_dims; vector<TIndex> slice_dims;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_STACK_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_STACK_OP_H_ #define DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
...@@ -52,6 +53,6 @@ class StackGradientOp final : public Operator<Context> { ...@@ -52,6 +53,6 @@ class StackGradientOp final : public Operator<Context> {
vector<TIndex> concat_dims; vector<TIndex> concat_dims;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_STACK_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_TILE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TILE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
...@@ -55,6 +56,6 @@ class TileGradientOp final : public Operator<Context> { ...@@ -55,6 +56,6 @@ class TileGradientOp final : public Operator<Context> {
DEFINE_ARGUMENTS_WITH_DESC(int, TileOp, multiples); DEFINE_ARGUMENTS_WITH_DESC(int, TileOp, multiples);
DEFINE_ARGUMENTS_WITH_DESC(int, TileGradientOp, multiples); DEFINE_ARGUMENTS_WITH_DESC(int, TileGradientOp, multiples);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_TILE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
...@@ -49,6 +50,6 @@ class TransposeGradientOp final : public Operator<Context> { ...@@ -49,6 +50,6 @@ class TransposeGradientOp final : public Operator<Context> {
Tensor* order, *old_steps, *new_steps; Tensor* order, *old_steps, *new_steps;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
...@@ -208,6 +209,6 @@ class CuDNNBatchNormGradientOp final ...@@ -208,6 +209,6 @@ class CuDNNBatchNormGradientOp final
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_ #endif // DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_ #define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
...@@ -80,6 +81,6 @@ class BatchRenormGradientOp final : public Operator<Context> { ...@@ -80,6 +81,6 @@ class BatchRenormGradientOp final : public Operator<Context> {
bool use_global_stats; bool use_global_stats;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_ #endif // DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
...@@ -108,6 +109,6 @@ class FusedGroupNormGradientOp final : public Operator<Context> { ...@@ -108,6 +109,6 @@ class FusedGroupNormGradientOp final : public Operator<Context> {
string data_format; string data_format;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_ #endif // DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
...@@ -64,6 +65,6 @@ class InstanceNormGradientOp final : public Operator<Context> { ...@@ -64,6 +65,6 @@ class InstanceNormGradientOp final : public Operator<Context> {
string data_format; string data_format;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_ #endif // DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_ #ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_
#define DRAGON_OPERATORS_NORM_L2_NORM_H_ #define DRAGON_OPERATORS_NORM_L2_NORM_H_
...@@ -58,6 +59,6 @@ class L2NormGradientOp final : public Operator<Context> { ...@@ -58,6 +59,6 @@ class L2NormGradientOp final : public Operator<Context> {
TIndex outer_dim, dim, inner_dim; TIndex outer_dim, dim, inner_dim;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NORM_L2_NORM_H_ #endif // DRAGON_OPERATORS_NORM_L2_NORM_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
...@@ -52,7 +53,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> { ...@@ -52,7 +53,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
hidden_size(OperatorBase::Arg<int>("hidden_size", 0)), hidden_size(OperatorBase::Arg<int>("hidden_size", 0)),
num_layers(OperatorBase::Arg<int>("num_layers", 1)), num_layers(OperatorBase::Arg<int>("num_layers", 1)),
bidirectional(OperatorBase::Arg<bool>("bidirectional", false)), bidirectional(OperatorBase::Arg<bool>("bidirectional", false)),
dropout_ratio(OperatorBase::Arg<float>("dropout_ratio", 1.0)), dropout_ratio(OperatorBase::Arg<float>("dropout_ratio", 1.f)),
random_seed(def.device_option().random_seed()) { random_seed(def.device_option().random_seed()) {
// determine the rnn direction // determine the rnn direction
rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
...@@ -154,8 +155,8 @@ class CuDNNRecurrentGradientOp final : public CuDNNRecurrentOpBase<Context> { ...@@ -154,8 +155,8 @@ class CuDNNRecurrentGradientOp final : public CuDNNRecurrentOpBase<Context> {
#endif #endif
#endif // WITH_CUDNN #endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #endif // DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_ #define DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
...@@ -39,6 +40,6 @@ class LSTMCellGradientOp final : public Operator<Context> { ...@@ -39,6 +40,6 @@ class LSTMCellGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_ #endif // DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_ #define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
...@@ -40,6 +41,6 @@ public: ...@@ -40,6 +41,6 @@ public:
void RunOnDevice() override {} void RunOnDevice() override {}
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_RECURRENT_OP_H_ #endif // DRAGON_OPERATORS_RECURRENT_RECURRENT_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
#define DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_ #define DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
...@@ -48,6 +49,6 @@ class RNNParamSetOp final : public Operator<Context> { ...@@ -48,6 +49,6 @@ class RNNParamSetOp final : public Operator<Context> {
TIndex layer_id, param_id; TIndex layer_id, param_id;
}; };
} // namespace dragon } // namespace dragon
#endif #endif
\ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
...@@ -31,6 +32,6 @@ class AdamUpdateOp final : public UpdateOpBase<Context> { ...@@ -31,6 +32,6 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
int t; float lr, beta1, beta2, eps; int t; float lr, beta1, beta2, eps;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_ #endif // DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
...@@ -78,8 +79,8 @@ class CollectiveUpdateOp final : public Operator<Context> { ...@@ -78,8 +79,8 @@ class CollectiveUpdateOp final : public Operator<Context> {
#endif #endif
}; };
#endif // WITH_MPI #endif // WITH_MPI
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_ #endif // DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
#define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_ #define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
...@@ -21,7 +22,7 @@ class MovingAverageOp final : public Operator<Context> { ...@@ -21,7 +22,7 @@ class MovingAverageOp final : public Operator<Context> {
public: public:
MovingAverageOp(const OperatorDef& def, Workspace* ws) MovingAverageOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
decay(OperatorBase::Arg<float>("decay", 1.0)) {} decay(OperatorBase::Arg<float>("decay", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
...@@ -31,6 +32,6 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> { ...@@ -31,6 +32,6 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
float lr, momentum; float lr, momentum;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_ #endif // DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
...@@ -31,6 +32,6 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> { ...@@ -31,6 +32,6 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
float lr, decay, eps; float lr, decay, eps;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_ #endif // DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
...@@ -32,6 +33,6 @@ class SGDUpdateOp final : public UpdateOpBase<Context> { ...@@ -32,6 +33,6 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
float old_lr, lr, momentum, correction; float old_lr, lr, momentum, correction;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_ #endif // DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------ /*!
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// *
// Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See, * along with the software. If not, See,
// *
// <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
// *
// ------------------------------------------------------------- * ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_ #ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
#define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_ #define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
...@@ -21,8 +22,8 @@ class UpdateOpBase : public Operator<Context> { ...@@ -21,8 +22,8 @@ class UpdateOpBase : public Operator<Context> {
public: public:
UpdateOpBase(const OperatorDef& def, Workspace* ws) UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
lr_mult(OperatorBase::Arg<float>("lr_mult", 1.0)), lr_mult(OperatorBase::Arg<float>("lr_mult", 1.f)),
decay_mult(OperatorBase::Arg<float>("decay_mult", 1.0)), decay_mult(OperatorBase::Arg<float>("decay_mult", 1.f)),
slot(OperatorBase::Arg<string>("slot", "")), slot(OperatorBase::Arg<string>("slot", "")),
zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) { zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot"; CHECK(!slot.empty()) << "\nRequired a non-empty slot";
...@@ -52,6 +53,6 @@ class UpdateOpBase : public Operator<Context> { ...@@ -52,6 +53,6 @@ class UpdateOpBase : public Operator<Context> {
using UpdateOpBase<context>::Param; \ using UpdateOpBase<context>::Param; \
using UpdateOpBase<context>::Slot using UpdateOpBase<context>::Slot
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_ #endif // DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
\ No newline at end of file \ No newline at end of file
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!