Commit 84320495 by Ting PAN

Dismantle Op Kernels

1 parent 96f7277e
Showing with 1472 additions and 1027 deletions
------------------------------------------------------------------------
The list of most significant changes made over time in Dragon.
Dragon 0.2.2.13 (20181204)
DRAGON_VERSION == 2213
Changes (w.r.t. Dragon 0.2.2.12):
Preview Features:
- Dismantled op kernels.
- Added FP16 support for ``NNResizeOp``, ``ProposalOp``,
``ROIPoolingOp``, ``ROIAlignOp``, (R-CNN components).
- Added ``DepthwiseConv2dOp``.
- [PyTorch] Added ``nn.DepthwiseConv2d``
- [PyCaffe] Added ``DepthwiseConvolutionLayer``.
Bugs fixed:
- Fixed the cuda issue that incorrect results in ``AsTypeOp``
under ``float32 -> float16``.
- Removed the support for group convolution implemented with im2col-nhwc.
- Changed the default ``NHWC`` pack format of filter from
``[filter_height, filter_width, in_channels, out_channels]``(TensorFlow) to
``[out_channels, filter_height, filter_width, in_channels]``(CuDNN).
- Changed the VC Runtime from ``MD`` to ``MT``.
------------------------------------------------------------------------
Dragon 0.2.2.12 (20181120)
DRAGON_VERSION == 2212
......@@ -60,6 +98,7 @@ Preview Features:
- [PyCaffe] Added ``DropBlockLayer``.
Bugs fixed:
- Fixed the uncomputed output in ``BiasAddGradientOp``.
......
This diff could not be displayed because it is too large.
<!-- HTML footer for doxygen 1.8.14-->
<!-- start footer part -->
<!--BEGIN GENERATE_TREEVIEW-->
<div id="nav-path" class="navpath"><!-- id is needed for treeview function! -->
<ul>
$navpath
<li class="footer">$generatedby
<a href="http://www.doxygen.org/index.html">
<img class="footer" src="$relpath^doxygen.png" alt="doxygen"/></a> $doxygenversion </li>
</ul>
</div>
<!--END GENERATE_TREEVIEW-->
<!--BEGIN !GENERATE_TREEVIEW-->
<hr class="footer"/><address class="footer"><small>
$generatedby &#160;<a href="http://www.doxygen.org/index.html">
<img class="footer" src="$relpath^doxygen.png" alt="doxygen"/>
</a> $doxygenversion
</small></address>
<!--END !GENERATE_TREEVIEW-->
</body>
</html>
\ No newline at end of file
<!-- HTML header for doxygen 1.8.14-->
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
<meta name="generator" content="Doxygen $doxygenversion"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<!--BEGIN PROJECT_NAME--><title>$projectname: $title</title><!--END PROJECT_NAME-->
<!--BEGIN !PROJECT_NAME--><title>$title</title><!--END !PROJECT_NAME-->
<link href="$relpath^tabs.css" rel="stylesheet" type="text/css"/>
<link rel="icon" href="/static/favicon.png" type="image/x-icon">
<script type="text/javascript" src="$relpath^jquery.js"></script>
<script type="text/javascript" src="$relpath^dynsections.js"></script>
$treeview
$search
$mathjax
<link href="$relpath^$stylesheet" rel="stylesheet" type="text/css" />
$extrastylesheet
</head>
<body>
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
<!--BEGIN TITLEAREA-->
<div id="titlearea">
<table cellspacing="0" cellpadding="0">
<tbody>
<tr style="height: 56px;">
<!--BEGIN PROJECT_LOGO-->
<td id="projectlogo" width="112"><a href="http://dragon.seetatech.com"><img alt="Logo" src="$relpath^$projectlogo"/></a></td>
<!--END PROJECT_LOGO-->
<!--BEGIN PROJECT_NAME-->
<td id="projectalign" style="padding-left: 1.5em;">
<div id="projectname">$projectname
<!--BEGIN PROJECT_NUMBER-->&#160;<span id="projectnumber">$projectnumber</span><!--END PROJECT_NUMBER-->
</div>
<!--BEGIN PROJECT_BRIEF--><div id="projectbrief">$projectbrief</div><!--END PROJECT_BRIEF-->
</td>
<!--END PROJECT_NAME-->
<!--BEGIN !PROJECT_NAME-->
<!--BEGIN PROJECT_BRIEF-->
<td style="padding-left: 0.5em;">
<div id="projectbrief">$projectbrief</div>
</td>
<!--END PROJECT_BRIEF-->
<!--END !PROJECT_NAME-->
<!--BEGIN DISABLE_INDEX-->
<!--BEGIN SEARCHENGINE-->
<td>$searchbox</td>
<!--END SEARCHENGINE-->
<!--END DISABLE_INDEX-->
</tr>
</tbody>
</table>
</div>
<!--END TITLEAREA-->
<!-- end header part -->
\ No newline at end of file
<doxygenlayout version="1.0">
<!-- Generated by doxygen 1.8.14 -->
<!-- Navigation index tabs for HTML output -->
<navindex>
<tab type="mainpage" visible="no" title=""/>
<tab type="pages" visible="no" title="" intro=""/>
<tab type="modules" visible="no" title="" intro=""/>
<tab type="namespaces" visible="no" title="">
<tab type="namespacelist" visible="no" title="" intro="/"/>
<tab type="namespacemembers" visible="no" title="" intro=""/>
</tab>
<tab type="classes" visible="no" title="">
<tab type="classlist" visible="yes" title="" intro=""/>
<tab type="classindex" visible="$ALPHABETICAL_INDEX" title=""/>
<tab type="hierarchy" visible="yes" title="" intro=""/>
<tab type="classmembers" visible="yes" title="" intro=""/>
</tab>
<tab type="files" visible="yes" title="">
<tab type="filelist" visible="yes" title="" intro=""/>
<tab type="globals" visible="yes" title="" intro=""/>
</tab>
<tab type="examples" visible="yes" title="" intro=""/>
<tab type="user" url="namespaces.html" title="Namespaces"/>
<tab type="user" url="classes.html" title="Classes"/>
<tab type="user" url="https://github.com/seetaresearch/Dragon" title="GitHub"/>
</navindex>
<!-- Layout definition for a class page -->
<class>
<briefdescription visible="yes"/>
<includes visible="$SHOW_INCLUDE_FILES"/>
<inheritancegraph visible="$CLASS_GRAPH"/>
<collaborationgraph visible="$COLLABORATION_GRAPH"/>
<memberdecl>
<nestedclasses visible="yes" title=""/>
<publictypes title=""/>
<services title=""/>
<interfaces title=""/>
<publicslots title=""/>
<signals title=""/>
<publicmethods title=""/>
<publicstaticmethods title=""/>
<publicattributes title=""/>
<publicstaticattributes title=""/>
<protectedtypes title=""/>
<protectedslots title=""/>
<protectedmethods title=""/>
<protectedstaticmethods title=""/>
<protectedattributes title=""/>
<protectedstaticattributes title=""/>
<packagetypes title=""/>
<packagemethods title=""/>
<packagestaticmethods title=""/>
<packageattributes title=""/>
<packagestaticattributes title=""/>
<properties title=""/>
<events title=""/>
<privatetypes title=""/>
<privateslots title=""/>
<privatemethods title=""/>
<privatestaticmethods title=""/>
<privateattributes title=""/>
<privatestaticattributes title=""/>
<friends title=""/>
<related title="" subtitle=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<typedefs title=""/>
<enums title=""/>
<services title=""/>
<interfaces title=""/>
<constructors title=""/>
<functions title=""/>
<related title=""/>
<variables title=""/>
<properties title=""/>
<events title=""/>
</memberdef>
<allmemberslink visible="yes"/>
<usedfiles visible="$SHOW_USED_FILES"/>
<authorsection visible="yes"/>
</class>
<!-- Layout definition for a namespace page -->
<namespace>
<briefdescription visible="yes"/>
<memberdecl>
<nestednamespaces visible="yes" title=""/>
<constantgroups visible="yes" title=""/>
<classes visible="yes" title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
</memberdef>
<authorsection visible="yes"/>
</namespace>
<!-- Layout definition for a file page -->
<file>
<briefdescription visible="yes"/>
<includes visible="$SHOW_FILES"/>
<includegraph visible="$INCLUDE_GRAPH"/>
<includedbygraph visible="$INCLUDED_BY_GRAPH"/>
<sourcelink visible="yes"/>
<memberdecl>
<classes visible="yes" title=""/>
<namespaces visible="yes" title=""/>
<constantgroups visible="yes" title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<inlineclasses title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<functions title=""/>
<variables title=""/>
</memberdef>
<authorsection/>
</file>
<!-- Layout definition for a group page -->
<group>
<briefdescription visible="yes"/>
<groupgraph visible="$GROUP_GRAPHS"/>
<memberdecl>
<nestedgroups visible="yes" title=""/>
<dirs visible="yes" title=""/>
<files visible="yes" title=""/>
<namespaces visible="yes" title=""/>
<classes visible="yes" title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<enumvalues title=""/>
<functions title=""/>
<variables title=""/>
<signals title=""/>
<publicslots title=""/>
<protectedslots title=""/>
<privateslots title=""/>
<events title=""/>
<properties title=""/>
<friends title=""/>
<membergroups visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
<memberdef>
<pagedocs/>
<inlineclasses title=""/>
<defines title=""/>
<typedefs title=""/>
<enums title=""/>
<enumvalues title=""/>
<functions title=""/>
<variables title=""/>
<signals title=""/>
<publicslots title=""/>
<protectedslots title=""/>
<privateslots title=""/>
<events title=""/>
<properties title=""/>
<friends title=""/>
</memberdef>
<authorsection visible="yes"/>
</group>
<!-- Layout definition for a directory page -->
<directory>
<briefdescription visible="yes"/>
<directorygraph visible="yes"/>
<memberdecl>
<dirs visible="yes"/>
<files visible="yes"/>
</memberdecl>
<detaileddescription title=""/>
</directory>
</doxygenlayout>
\ No newline at end of file
......@@ -7,8 +7,11 @@ cmake_minimum_required(VERSION 3.0.0)
# ---------------- User Config ----------------
# Set optional buildings
option(BUILD_PYTHON_API "Set ON to build PYTHON API" ON)
option(BUILD_CXX_API "Set ON to build CXX API" OFF)
# Set optional libraries
option(WITH_PYTHON "Set ON to use PYTHON" ON)
option(WITH_CUDA "Set ON to use CUDA" ON)
option(WITH_CUDNN "Set ON to use CUDNN" ON)
option(WITH_BLAS "Set ON to use BLAS" ON)
......@@ -19,14 +22,18 @@ option(WITH_MPI_CUDA "Set ON to use MPI-CUDA" OFF)
option(WITH_MPI_NCCL "Set ON to use MPI-NCCL" OFF)
# Set your 3rdparty
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
if (NOT 3RDPARTY_DIR)
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
endif()
# Set your python "interpreter" if necessary
# if not, a default interpreter will be used
# here, provide several examples:
# set(PYTHON_EXECUTABLE /usr/bin/python) # Linux & OSX, Builtin Python
# set(PYTHON_EXECUTABLE /X/anaconda/bin/python) # Linux & OSX, Anaconda
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda
if (NOT PYTHON_EXECUTABLE)
# set(PYTHON_EXECUTABLE /usr/bin/python) # Linux && OSX, Builtin Python
# set(PYTHON_EXECUTABLE /X/anaconda/bin/python) # Linux && OSX, Anaconda
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda
endif()
# Set CUDA compiling architecture
# Remove "compute_70/sm_70" if using CUDA 8.0
......@@ -38,8 +45,10 @@ set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
# Set CUDNN Library Dir if necessary (Linux/OSX Only)
# For Win, Recommend to use ``3RDPARTY_DIR/lib``
set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux
# set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib) # OSX
if (NOT CUDNN_LIBRARY_DIR)
set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux
# set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib) # OSX
endif()
# ---------------- User Config ----------------
......@@ -68,7 +77,7 @@ set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) # Linux
# ---[ Dependencies
if (WITH_PYTHON)
if (BUILD_PYTHON_API)
include(${PROJECT_SOURCE_DIR}/../CMake/FindPythonLibs.cmake)
include(${PROJECT_SOURCE_DIR}/../CMake/FindNumPy.cmake)
endif()
......@@ -88,7 +97,7 @@ set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release"
include_directories(${3RDPARTY_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/src)
if (WITH_PYTHON)
if (BUILD_PYTHON_API)
include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${NUMPY_INCLUDE_DIR})
endif()
......@@ -111,7 +120,7 @@ set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix"
set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS})
# ---[ Defines
if (WITH_PYTHON)
if (BUILD_PYTHON_API)
ADD_DEFINITIONS(-DWITH_PYTHON)
if (${PYTHON_VERSION_MAJOR} STREQUAL "2")
message(STATUS "Use Python2 [Optional]")
......@@ -166,7 +175,10 @@ endif()
# ---[ Flags
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /O2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4819 /wd4244")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler \"/wd 4819\"")
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
string(REPLACE "/O2" "/Ox" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
if (WITH_OMP)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp")
endif()
......@@ -189,8 +201,12 @@ execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS
execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/dragon.proto)
# ---[ Subdirectories
add_subdirectory(modules/python)
#add_subdirectory(modules/cxx) # Compile CXX module if necessary
if (BUILD_PYTHON_API)
add_subdirectory(modules/python)
endif()
if (BUILD_CXX_API)
add_subdirectory(modules/cxx)
endif()
# ---[ Utils
file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib)
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_COMMON_H_
#define DRAGON_CORE_COMMON_H_
#include <ctime>
#include <cmath>
#include <random>
#include <climits>
#include <float.h>
#include <memory>
#include <string>
#include <queue>
......@@ -48,17 +51,17 @@ template <typename Key, typename Value>
using Map = std::unordered_map<Key, Value>;
template <typename Value>
using Set = std::unordered_set<Value> ;
using Set = std::unordered_set<Value>;
/* * * * * * * * * * * * * * * * * * * * *
* *
* Kernel Version *
* *
* Major(2) | Minor(2) | Patch(12) *
* Major(2) | Minor(2) | Patch(13) *
* *
* * * * * * * * * * * * * * * * * * * * */
#define DRAGON_VERSION 2212
#define DRAGON_VERSION 2213
/* * * * * * * * * * * * * * * * * * * * *
* *
......@@ -74,7 +77,7 @@ using Set = std::unordered_set<Value> ;
* *
* * * * * * * * * * * * * * * * * * * * */
// avoid using of "thread_local" for VS2013 or older Xcode
// Avoid using of "thread_local" for VS2013 or older Xcode
#if defined(__clang__) || defined(__GNUC__)
#define TLS_OBJECT __thread
#else
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_H_
#define DRAGON_CORE_CONTEXT_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_CNML_H_
#define DRAGON_CORE_CONTEXT_CNML_H_
/* CAMBRICON's CNRT && CNML Environment */
/*! CAMBRICON's CNRT && CNML Environment */
#include "core/common.h"
......@@ -90,7 +91,7 @@ class CNMLContext {
static std::mutex& mutex() { static std::mutex m; return m; }
static thread_local CNRTObject cnrt_object_;
static CNRTObject* cuda_object();
private:
int device_id_, stream_id_ = 1, random_seed_;
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_CUDA_H_
#define DRAGON_CORE_CONTEXT_CUDA_H_
/* NVIDIA's CUDA Environment */
/*! NVIDIA's CUDA Environment */
#include "core/common.h"
#include "utils/cuda_device.h"
......@@ -38,9 +39,10 @@ class CUDAObject {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
for (int j = 0; j < cuda_streams[i].size(); j++) {
auto& stream = cuda_streams[i][j];
// follow the caffe2, do not check the stream destroying
// Error code 29 (driver shutting down) is inevitable
// TODO(PhyscalX): Can someone solve this issue?
/*!
* Do not check the stream destroying,
* error code 29 (driver shutting down) is inevitable.
*/
if (stream) cudaStreamDestroy(stream);
}
for (auto& handle : cublas_handles[i])
......@@ -52,14 +54,20 @@ class CUDAObject {
}
}
// follow the caffe2,
// each device takes a group of non-blocking streams
// the stream 0 is reserved for default stream,
// as some computations really require it,
// e.g. cublas.asum() and mixed cpu/cuda operations
// besides, somes calls, such as cudnn.conv() and cudnn.rnn(),
// produce wrong results if running them on non-blocking streams
// note that caffe2 also uses default streams (within CuDNNState)
/*!
* Follow the caffe2,
* Each device takes a group of non-blocking streams.
*
* The stream 0 is reserved for default stream,
* as some computations really require it,
* e.g. cublas.asum() and mixed cpu/cuda operations.
*
* Besides, somes calls, such as cudnn.conv() and cudnn.rnn(),
* and even the simple fp16 conversion, produce wrong results
* if running them on non-blocking streams.
*
* Note that caffe2 also uses default streams (within CuDNNState).
*/
cudaStream_t GetStream(int device_id, int stream_id) {
vector<cudaStream_t>& dev_streams = cuda_streams[device_id];
if (dev_streams.size() <= (unsigned)stream_id)
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_GRAPH_H_
#define DRAGON_CORE_GRAPH_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_GRAPH_GRADIENT_H_
#define DRAGON_CORE_GRAPH_GRADIENT_H_
......@@ -54,4 +55,4 @@ class GraphGradientMaker {
} // namespace dragon
#endif
\ No newline at end of file
#endif // DRAGON_CORE_GRAPH_GRADIENT_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_MIXEDMEM_H_
#define DRAGON_CORE_MIXEDMEM_H_
......@@ -86,7 +87,7 @@ class MixedMemory {
void* cpu_ptr_, *cuda_ptr_, *cnml_ptr_;
int own_cpu_ptr_ = 1, ptr_device_ = 0;
/* For CAMBRICON's CNML Environment */
/*! For CAMBRICON's CNML Environment */
cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr;
cnmlTensor_t cnml_mlu_tensor_ = nullptr;
};
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_H_
#define DRAGON_CORE_OPERATOR_H_
......@@ -188,7 +189,7 @@ DECLARE_REGISTRY(
const OperatorDef&,
Workspace*);
/* NVIDIA's Accelerated Library - CUDNN */
/*! NVIDIA's Accelerated Library - CUDNN */
DECLARE_REGISTRY(
CUDNNOperatorRegistry,
......@@ -196,7 +197,7 @@ DECLARE_REGISTRY(
const OperatorDef&,
Workspace*);
/* CAMBRICON's Accelerated Library - CNML */
/*! CAMBRICON's Accelerated Library - CNML */
DECLARE_REGISTRY(
CNMLOperatorRegistry,
......@@ -247,7 +248,8 @@ DECLARE_REGISTRY(
}
#define INIT_MULTIPLIER(ptr_tensor, size) { \
ptr_tensor = ws()->CreateTensor("/share/multiplier"); \
ptr_tensor = ws()->CreateTensor("/share/multiplier/" \
+ TypeMetaToString(TypeMeta::Make<T>())); \
if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape({ size }); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_
#define DRAGON_CORE_OPERATOR_GRADIENT_H_
......@@ -58,7 +59,7 @@ class GradientMakerBase {
}
virtual inline vector<float> DefaultValues() {
return vector<float>(g_outputs_.size(), 1.0);
return vector<float>(g_outputs_.size(), 1.f);
}
template <class... Args>
......@@ -82,7 +83,7 @@ class GradientMakerBase {
const vector<string>& g_outputs_;
};
// implemented in operator.cc
// Implemented in operator.cc
Gradient MakeGradientForOp(
const OperatorDef& op_def,
const vector<string>& g_outputs);
......@@ -111,7 +112,7 @@ DECLARE_REGISTRY(
const OperatorDef&,
const vector<string>&);
// define in the operator.cc
// Defined in the operator.cc
#define REGISTER_GRADIENT(name, ...) \
REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_
#define DRAGON_CORE_OPERATOR_SCHEMA_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_REGISTRY_H_
#define DRAGON_CORE_REGISTRY_H_
......@@ -66,12 +67,12 @@ class Registerer {
}
};
// use in *.h files
// Used in *.h files
#define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType,...) \
dragon::Registry<SrcType, ObjType,##__VA_ARGS__>* RegistryName(); \
typedef dragon::Registerer<SrcType,ObjType,##__VA_ARGS__> Registerer##RegistryName;
// use in *.cc files
// Used in *.cc files
#define DEFINE_TYPED_REGISTRY(RegistryName,SrcType, ObjType,...) \
Registry<SrcType,ObjType,##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType,ObjType,##__VA_ARGS__>* registry = \
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TENSOR_H_
#define DRAGON_CORE_TENSOR_H_
......@@ -41,8 +42,7 @@ class Tensor {
}
} else {
if (ex_memory_ && !is_shared_ &&
TIndex(ex_memory_->nbytes()) <
TIndex(new_size * meta_.itemsize())) {
capacity_ < TIndex(new_size * meta_.itemsize())) {
delete ex_memory_;
ex_memory_ = nullptr;
capacity_ = 0;
......@@ -194,22 +194,24 @@ class Tensor {
void* raw_mutable_data(const TypeMeta& meta) {
void* data_ptr;
mutable_data_ptr<Context>(&data_ptr);
// Return the memory directly
if (meta_ == meta && data_ptr) return data_ptr;
if (meta_ != meta && data_ptr && !own_mem_) delete ex_memory_;
// Return the new memory
meta_ = meta;
CHECK_GT(size_, 0);
if (own_mem_) {
memory_.reset(new MixedMemory(
meta, size_* meta_.itemsize()));
meta_, size_* meta_.itemsize()));
} else {
if (data_ptr) delete ex_memory_;
ex_memory_ = new MixedMemory(
meta, size_* meta_.itemsize());
meta_, size_* meta_.itemsize());
}
// malloc memory
// Malloc
mutable_data_ptr<Context>(&data_ptr);
// call the constructors
if (meta.ctor()) meta_.ctor()(data_ptr, size_);
capacity_ = size_ * meta.itemsize(), require_init_ = true;
// Call the constructors
if (meta_.ctor()) meta_.ctor()(data_ptr, size_);
capacity_ = size_ * meta_.itemsize(), require_init_ = true;
return data_ptr;
}
......@@ -274,6 +276,7 @@ class Tensor {
TypeMeta::Make<float>(), 4);
require_init_ = true;
} own_mem_ = false;
capacity_ = (TIndex)ex_memory_->nbytes();
}
inline void Share(MixedMemory* mem) {
......@@ -290,7 +293,7 @@ class Tensor {
}
std::function<void()> DECREFPyArray;
~Tensor() { /* DO NOT CALL DECREFARRAY */ }
~Tensor() { /*! DO NOT CALL DECREFARRAY */ }
private:
vector<TIndex> dims_;
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TYPEID_H_
#define DRAGON_CORE_TYPEID_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_TYPES_H_
#define DRAGON_CORE_TYPES_H_
#include <cstdint>
#include <unordered_map>
#include "core/typeid.h"
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_
......@@ -62,7 +63,7 @@ class Workspace {
}
inline void ClearWorkspace() {
// clear tensors & buffers & re-initialization
// Clear tensors, then re-initialization
for (auto& kv : tensor_map_) kv.second->Reset();
InitWorkspace();
}
......@@ -79,11 +80,11 @@ class Workspace {
const string& name,
bool use_remote = true) {
string query = GetTensorName(name);
// search local workspace
// Search local workspace
if (tensor_map_.count(query) > 0)
return tensor_map_[query].get();
if (use_remote) {
// search remote workspace
// Search remote workspace
for (auto& it : ws_map_) {
if (it.second->HasTensor(query))
return it.second->GetTensor(query);
......@@ -125,10 +126,10 @@ class Workspace {
vector<string> GetTensors() {
vector<string> names;
// search local workspace
// Search local workspace
for (auto& it : tensor_map_)
names.push_back(it.first);
// serach remote workspace
// Serach remote workspace
for (auto& it : ws_map_) {
vector<string> sub_names = it.second->GetTensors();
names.insert(names.end(),
......@@ -142,11 +143,11 @@ class Workspace {
inline bool HasFiller(
const string& name,
bool use_remote = true) {
// search local workspace
// Search local workspace
bool result = filler_map_.count(name) > 0;
if (!use_remote) return result;
// search remote workspace
// Search remote workspace
for (auto& it : ws_map_)
result |= it.second->HasFiller(name);
return result;
......@@ -162,11 +163,11 @@ class Workspace {
inline const TensorFiller* GetFiller(
const string& name) {
// search local workspace
// Search local workspace
if (filler_map_.count(name) > 0)
return &filler_map_[name];
// search remote workspace
// Search remote workspace
for (auto& it : ws_map_) {
if (it.second->HasFiller(name))
return it.second->GetFiller(name);
......@@ -238,11 +239,11 @@ class Workspace {
persistent_key = arg.s();
}
if (persistent_key.empty()) {
// run op in the "ONCE" mode
// Run op in the "ONCE" mode
unique_ptr<OperatorBase> op(CreateOperator(meta_op, this));
op->Run();
} else {
// run op in the "PERSISTENT" mode
// Run op in the "PERSISTENT" mode
if (!op_map_.count(persistent_key))
op_map_[persistent_key] = unique_ptr<OperatorBase>(
CreateOperator(meta_op, this));
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
......@@ -21,7 +22,7 @@ class EluOp : public Operator<Context> {
public:
EluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.0)) {}
alpha(OperatorBase::Arg<float>("alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -36,7 +37,7 @@ class EluGradientOp : public Operator<Context> {
public:
EluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.0)) {}
alpha(OperatorBase::Arg<float>("alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
#define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
......@@ -81,7 +82,7 @@ class CuDNNAffineOpBase : public Operator<Context> {
template <typename T>
void ResetDesc(const Tensor& X) {
// determine the range of affine
// Determine the range of affine
start_axis = axis;
if (start_axis < 0) start_axis += (int)X.ndim();
if (num_axes == -1) num_axes = (int)X.ndim() - start_axis;
......@@ -89,14 +90,14 @@ class CuDNNAffineOpBase : public Operator<Context> {
end_axis = start_axis + num_axes;
CHECK_LT(start_axis, (int)X.ndim());
CHECK_LE(start_axis + num_axes, (int)X.ndim());
// determine the input desc
// Determine the input desc
vector<TIndex> input_dims = X.dims();
// cudnn requires ndimensions range from [4, 5]
// CuDNN requires ndimensions range from [4, 5]
if (input_dims.size() < 4) input_dims.resize(4, 1);
else if (input_dims.size() > 5)
LOG(FATAL) << "CuDNN Affine the dimensions up to 5.";
cudnnSetTensorDesc<T>(&input_desc, input_dims);
// determine the scale desc
// Determine the scale desc
vector<TIndex> param_dims(input_dims.size(), 1);
for (int i = start_axis; i < end_axis; i++)
param_dims[i] = input_dims[i];
......@@ -127,24 +128,32 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> {
: CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename DT, typename CT> void RunWithType();
protected:
USE_CUDNN_AFFINE_FUCNTIONS;
};
template <class Context>
class CuDNNAffineGradientOp final : public CuDNNAffineOpBase<Context> {
class CuDNNAffineGradientOp final
: public CuDNNAffineOpBase<Context> {
public:
CuDNNAffineGradientOp(const OperatorDef& def, Workspace* ws)
CuDNNAffineGradientOp(
const OperatorDef& def,
Workspace* ws)
: CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override;
template <typename T> void ComputeScaleGradient(T* dYxX, T* dA);
template <typename DT, typename CT>
void ComputeScaleGradient(DT* dYxX, DT* dA);
template <typename DT, typename CT>
void ComputeBiasGradient(const DT* dY, DT* dB);
template <typename T> void ComputeScaleGradient_v2(T* dYxX, T* dA);
template <typename T> void ComputeBiasGradient(const T* dY, T* dB);
template <typename T> void ComputeBiasGradient_v2(const T* dY, T* dB);
template <typename T> void RunWithType();
template <typename DT, typename CT> void RunWithType();
protected:
USE_CUDNN_AFFINE_FUCNTIONS;
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#include <float.h>
#include "core/operator.h"
namespace dragon {
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_FUNDAMENTAL_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MAXIMUM_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MINIMUM_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
......@@ -21,9 +22,9 @@ class PowOp final : public Operator<Context> {
public:
PowOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.0)),
shift(OperatorBase::Arg<float>("shift", 0.0)),
power(OperatorBase::Arg<float>("power", 1.0)) {
scale(OperatorBase::Arg<float>("scale", 1.f)),
shift(OperatorBase::Arg<float>("shift", 0.f)),
power(OperatorBase::Arg<float>("power", 1.f)) {
power_scale = power * scale;
}
USE_OPERATOR_FUNCTIONS;
......@@ -40,9 +41,9 @@ class PowGradientOp final : public Operator<Context> {
public:
PowGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.0)),
shift(OperatorBase::Arg<float>("shift", 0.0)),
power(OperatorBase::Arg<float>("power", 1.0)) {
scale(OperatorBase::Arg<float>("scale", 1.f)),
shift(OperatorBase::Arg<float>("shift", 0.f)),
power(OperatorBase::Arg<float>("power", 1.f)) {
power_scale = power * scale;
}
USE_OPERATOR_FUNCTIONS;
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_NLL_LOSS_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
#include "core/operator.h"
......@@ -58,4 +59,4 @@ class SigmoidCrossEntropyGradientOp
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_LOSS_SIGMOID_CE_LOSS_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_FOCAL_LOSS_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
......@@ -21,7 +22,7 @@ class SmoothL1LossOp final : public Operator<Context> {
public:
SmoothL1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.0)),
beta(OperatorBase::Arg<float>("beta", 1.f)),
normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS;
......@@ -40,7 +41,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
public:
SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.0)),
beta(OperatorBase::Arg<float>("beta", 1.f)),
normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS;
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_FOCAL_LOSS_OP_H_
#include "operators/loss/sparse_softmax_cross_entropy_op.h"
#include "operators/loss/sparse_softmax_ce_loss_op.h"
namespace dragon {
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
......@@ -78,4 +79,4 @@ class SparseSoftmaxCrossEntropyGradientOp
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CE_LOSS_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
#define DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_ASTYPE_OP_H_
#define DRAGON_OPERATORS_MISC_ASTYPE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
#define DRAGON_OPERATORS_MISC_GRADIENT_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
#define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
......@@ -44,7 +45,7 @@ class FillOp final : public Operator<Context> {
: Operator<Context>(def, ws),
shape_desc(OperatorBase::Arg<string>("shape", "")),
dtype(OperatorBase::Arg<string>("dtype", "float32")),
value(OperatorBase::Arg<float>("value", 0.0)) {
value(OperatorBase::Arg<float>("value", 0.f)) {
GET_ARGUMENTS_WITH_DESC(int, dims);
}
USE_OPERATOR_FUNCTIONS;
......@@ -64,8 +65,8 @@ public:
RandomUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
this->filler.set_type("uniform");
this->filler.set_low(OperatorBase::Arg<float>("low", -1.0));
this->filler.set_high(OperatorBase::Arg<float>("high", 1.0));
this->filler.set_low(OperatorBase::Arg<float>("low", -1.f));
this->filler.set_high(OperatorBase::Arg<float>("high", 1.f));
}
USE_OPERATOR_FUNCTIONS;
};
......@@ -76,8 +77,8 @@ public:
RandomNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
this->filler.set_type("normal");
this->filler.set_mean(OperatorBase::Arg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::Arg<float>("std", 1.0));
this->filler.set_mean(OperatorBase::Arg<float>("mean", 0.f));
this->filler.set_std(OperatorBase::Arg<float>("std", 1.f));
}
USE_OPERATOR_FUNCTIONS;
};
......@@ -88,8 +89,8 @@ public:
TruncatedNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
this->filler.set_type("truncated_normal");
float mu = OperatorBase::Arg<float>("mean", 0.0);
float sigma = OperatorBase::Arg<float>("std", 1.0);
float mu = OperatorBase::Arg<float>("mean", 0.f);
float sigma = OperatorBase::Arg<float>("std", 1.f);
this->filler.set_mean(mu);
this->filler.set_std(sigma);
this->filler.set_low(mu - 2 * sigma);
......@@ -104,7 +105,7 @@ public:
GlorotUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::Arg<float>("scale", 3.0);
float scale = OperatorBase::Arg<float>("scale", 3.f);
this->filler.set_type("xavier");
if (mode == "fan_avg") {
......@@ -125,7 +126,7 @@ public:
GlorotNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::Arg<float>("scale", 2.0);
float scale = OperatorBase::Arg<float>("scale", 2.f);
this->filler.set_type("msra");
if (mode == "fan_avg") {
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MISC_PYTHON_OP_H_
#define DRAGON_OPERATORS_MISC_PYTHON_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
......@@ -56,13 +57,13 @@ class ModelMPIBase : public Operator<Context> {
string dtype;
};
#define USE_MPIMODEL_FUNCTIONS(context) \
using ModelMPIBase<context>::comm; \
using ModelMPIBase<context>::mpi_dtype; \
using ModelMPIBase<context>::comm_size; \
using ModelMPIBase<context>::comm_rank; \
using ModelMPIBase<context>::comm_root; \
using ModelMPIBase<context>::dtype
#define USE_MODEL_MPI_FUNCTIONS \
using ModelMPIBase<Context>::comm; \
using ModelMPIBase<Context>::mpi_dtype; \
using ModelMPIBase<Context>::comm_size; \
using ModelMPIBase<Context>::comm_rank; \
using ModelMPIBase<Context>::comm_root; \
using ModelMPIBase<Context>::dtype
} // namespace dragon
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
......@@ -24,7 +25,7 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
MPIBroadcastOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -36,7 +37,7 @@ public:
MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
......@@ -24,7 +25,7 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
MPIGatherOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -36,7 +37,7 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> {
MPIGatherGradientOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
USE_MODEL_MPI_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGREDUCE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_
#define DRAGON_OPERATORS_NDARRAY_DIMENSION_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
#define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
......@@ -24,7 +25,7 @@ class PadOp final : public Operator<Context> {
pad_l(OperatorBase::Args<int>("pad_l")),
pad_r(OperatorBase::Args<int>("pad_r")),
mode(OperatorBase::Arg<string>("mode", "CONSTANT")),
value(OperatorBase::Arg<float>("value", 0.0f)) {
value(OperatorBase::Arg<float>("value", 0.f)) {
if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length.";
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_
#define DRAGON_OPERATORS_NORM_L2_NORM_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
......@@ -52,7 +53,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
hidden_size(OperatorBase::Arg<int>("hidden_size", 0)),
num_layers(OperatorBase::Arg<int>("num_layers", 1)),
bidirectional(OperatorBase::Arg<bool>("bidirectional", false)),
dropout_ratio(OperatorBase::Arg<float>("dropout_ratio", 1.0)),
dropout_ratio(OperatorBase::Arg<float>("dropout_ratio", 1.f)),
random_seed(def.device_option().random_seed()) {
// determine the rnn direction
rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
#define DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
#define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
......@@ -21,7 +22,7 @@ class MovingAverageOp final : public Operator<Context> {
public:
MovingAverageOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
decay(OperatorBase::Arg<float>("decay", 1.0)) {}
decay(OperatorBase::Arg<float>("decay", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
#define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
......@@ -21,8 +22,8 @@ class UpdateOpBase : public Operator<Context> {
public:
UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
lr_mult(OperatorBase::Arg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::Arg<float>("decay_mult", 1.0)),
lr_mult(OperatorBase::Arg<float>("lr_mult", 1.f)),
decay_mult(OperatorBase::Arg<float>("decay_mult", 1.f)),
slot(OperatorBase::Arg<string>("slot", "")),
zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot";
......
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!