Commit 9de0f1a3 by Ting PAN

Fix the bug of missing defs on blending assign operators

Summary:

This commit attaches input and output together in assign operators,
which fixes the missing input defs due to identity from input to output.
1 parent 746f2cbb
Showing with 155 additions and 411 deletions
...@@ -39,14 +39,14 @@ html: ...@@ -39,14 +39,14 @@ html:
@echo "Build finished. The HTML pages are in $(BUILDDIR)." @echo "Build finished. The HTML pages are in $(BUILDDIR)."
latex: latex:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)-latex $(SPHINXBUILD) -b latex -j ${NUMBER_OF_PROCESSORS} $(ALLSPHINXOPTS) $(BUILDDIR)-latex
@echo @echo
@echo "Build finished; the LaTeX files are in $(BUILDDIR)-latex." @echo "Build finished; the LaTeX files are in $(BUILDDIR)-latex."
@echo "Run \`make' in that directory to run these through (pdf)latex" \ @echo "Run \`make' in that directory to run these through (pdf)latex" \
"(use \`make latexpdf' here to do that automatically)." "(use \`make latexpdf' here to do that automatically)."
latexpdf: latexpdf:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)-latex $(SPHINXBUILD) -b latex -j ${NUMBER_OF_PROCESSORS} $(ALLSPHINXOPTS) $(BUILDDIR)-latex
@echo "Running LaTeX files through pdflatex..." @echo "Running LaTeX files through pdflatex..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf $(MAKE) -C $(BUILDDIR)-latex all-pdf
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)-latex." @echo "pdflatex finished; the PDF files are in $(BUILDDIR)-latex."
...@@ -39,7 +39,7 @@ extensions = ['sphinx.ext.autodoc', 'sphinxcontrib.katex', 'breathe'] ...@@ -39,7 +39,7 @@ extensions = ['sphinx.ext.autodoc', 'sphinxcontrib.katex', 'breathe']
# Project # Project
project = 'dragon' project = 'dragon'
copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd' copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd'
author = 'SeetaTech' author = 'SeetaTech, Co.,Ltd'
with open('../../../dragon/version.txt', 'r') as f: with open('../../../dragon/version.txt', 'r') as f:
version = f.read().strip() version = f.read().strip()
...@@ -114,6 +114,7 @@ latex_elements = { ...@@ -114,6 +114,7 @@ latex_elements = {
\fvset{breaklines=true, breakanywhere=true} \fvset{breaklines=true, breakanywhere=true}
\setlength{\headheight}{13.6pt} \setlength{\headheight}{13.6pt}
\setlength{\itemindent}{-1pt} \setlength{\itemindent}{-1pt}
\addto\captionsenglish{\renewcommand{\chaptername}{}}
\makeatletter \makeatletter
\renewcommand*\l@subsection{\@dottedtocline{2}{3.8em}{3.8em}} \renewcommand*\l@subsection{\@dottedtocline{2}{3.8em}{3.8em}}
\fancypagestyle{normal}{ \fancypagestyle{normal}{
...@@ -146,13 +147,18 @@ latex_elements = { ...@@ -146,13 +147,18 @@ latex_elements = {
\vspace*{40mm} \vspace*{40mm}
\LARGE \@author \LARGE \@author
\vspace*{40mm}
\LARGE \today
\end{titlepage} \end{titlepage}
\makeatother \makeatother
\pagenumbering{arabic} \pagenumbering{arabic}
''', ''',
'pointsize': '10pt', 'pointsize': '10pt',
'classoptions': ',oneside',
'figure_align': 'H', 'figure_align': 'H',
'fncychap': '\\usepackage[Sonny]{fncychap}',
'printindex': '', 'printindex': '',
'sphinxsetup': ' \ 'sphinxsetup': ' \
hmargin={0.75in,0.75in}, \ hmargin={0.75in,0.75in}, \
......
...@@ -65,7 +65,7 @@ if "%1" == "doxygen" ( ...@@ -65,7 +65,7 @@ if "%1" == "doxygen" (
) )
if "%1" == "html" ( if "%1" == "html" (
%SPHINXBUILD% -b html -j %NUMBER_OF_PROCESSORS% %ALLSPHINXOPTS% %BUILDDIR% %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%
if errorlevel 1 exit /b 1 if errorlevel 1 exit /b 1
echo. echo.
echo.Build finished. The HTML pages are in %BUILDDIR%. echo.Build finished. The HTML pages are in %BUILDDIR%.
......
...@@ -34,14 +34,14 @@ html: ...@@ -34,14 +34,14 @@ html:
@echo "Build finished. The HTML pages are in $(BUILDDIR)." @echo "Build finished. The HTML pages are in $(BUILDDIR)."
latex: latex:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)-latex $(SPHINXBUILD) -b latex -j ${NUMBER_OF_PROCESSORS} $(ALLSPHINXOPTS) $(BUILDDIR)-latex
@echo @echo
@echo "Build finished; the LaTeX files are in $(BUILDDIR)-latex." @echo "Build finished; the LaTeX files are in $(BUILDDIR)-latex."
@echo "Run \`make' in that directory to run these through (pdf)latex" \ @echo "Run \`make' in that directory to run these through (pdf)latex" \
"(use \`make latexpdf' here to do that automatically)." "(use \`make latexpdf' here to do that automatically)."
latexpdf: latexpdf:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)-latex $(SPHINXBUILD) -b latex -j ${NUMBER_OF_PROCESSORS} $(ALLSPHINXOPTS) $(BUILDDIR)-latex
@echo "Running LaTeX files through pdflatex..." @echo "Running LaTeX files through pdflatex..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf $(MAKE) -C $(BUILDDIR)-latex all-pdf
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)-latex." @echo "pdflatex finished; the PDF files are in $(BUILDDIR)-latex."
...@@ -41,7 +41,7 @@ napoleon_use_rtype = False ...@@ -41,7 +41,7 @@ napoleon_use_rtype = False
# Project # Project
project = 'dragon' project = 'dragon'
copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd' copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd'
author = 'SeetaTech' author = 'SeetaTech, Co.,Ltd'
with open('../../../dragon/version.txt', 'r') as f: with open('../../../dragon/version.txt', 'r') as f:
version = f.read().strip() version = f.read().strip()
...@@ -122,6 +122,7 @@ latex_elements = { ...@@ -122,6 +122,7 @@ latex_elements = {
\fvset{breaklines=true, breakanywhere=true} \fvset{breaklines=true, breakanywhere=true}
\setlength{\headheight}{13.6pt} \setlength{\headheight}{13.6pt}
\setlength{\itemindent}{-1pt} \setlength{\itemindent}{-1pt}
\addto\captionsenglish{\renewcommand{\chaptername}{}}
\makeatletter \makeatletter
\renewcommand*\l@subsection{\@dottedtocline{2}{3.8em}{3.8em}} \renewcommand*\l@subsection{\@dottedtocline{2}{3.8em}{3.8em}}
\fancypagestyle{normal}{ \fancypagestyle{normal}{
...@@ -154,13 +155,18 @@ latex_elements = { ...@@ -154,13 +155,18 @@ latex_elements = {
\vspace*{40mm} \vspace*{40mm}
\LARGE \@author \LARGE \@author
\vspace*{40mm}
\LARGE \today
\end{titlepage} \end{titlepage}
\makeatother \makeatother
\pagenumbering{arabic} \pagenumbering{arabic}
''', ''',
'pointsize': '10pt', 'pointsize': '10pt',
'classoptions': ',oneside',
'figure_align': 'H', 'figure_align': 'H',
'fncychap': '\\usepackage[Sonny]{fncychap}',
'printindex': '', 'printindex': '',
'sphinxsetup': ' \ 'sphinxsetup': ' \
hmargin={0.75in,0.75in}, \ hmargin={0.75in,0.75in}, \
......
...@@ -48,9 +48,6 @@ dragon ...@@ -48,9 +48,6 @@ dragon
`constant(...) <dragon/constant.html>`_ `constant(...) <dragon/constant.html>`_
: Return a tensor initialized from the value. : Return a tensor initialized from the value.
`copy(...) <dragon/copy.html>`_
: Copy the input.
`create_function(...) <dragon/create_function.html>`_ `create_function(...) <dragon/create_function.html>`_
: Create a callable graph from the specified outputs. : Create a callable graph from the specified outputs.
...@@ -93,6 +90,9 @@ dragon ...@@ -93,6 +90,9 @@ dragon
`graph_mode(...) <dragon/graph_mode.html>`_ `graph_mode(...) <dragon/graph_mode.html>`_
: Context-manager set the graph execution mode. : Context-manager set the graph execution mode.
`identity(...) <dragon/identity.html>`_
: Return a tensor copied from the input.
`index_select(...) <dragon/index_select.html>`_ `index_select(...) <dragon/index_select.html>`_
: Select the elements according to the index along the given axis. : Select the elements according to the index along the given axis.
...@@ -199,7 +199,6 @@ dragon ...@@ -199,7 +199,6 @@ dragon
dragon/channel_shuffle dragon/channel_shuffle
dragon/concat dragon/concat
dragon/constant dragon/constant
dragon/copy
dragon/create_function dragon/create_function
dragon/device dragon/device
dragon/eager_mode dragon/eager_mode
...@@ -214,6 +213,7 @@ dragon ...@@ -214,6 +213,7 @@ dragon
dragon/get_workspace dragon/get_workspace
dragon/gradients dragon/gradients
dragon/graph_mode dragon/graph_mode
dragon/identity
dragon/index_select dragon/index_select
dragon/linspace dragon/linspace
dragon/load_library dragon/load_library
......
...@@ -174,8 +174,8 @@ __truediv__ ...@@ -174,8 +174,8 @@ __truediv__
.. _dragon.assign(...): assign.html .. _dragon.assign(...): assign.html
.. _dragon.cast(...): cast.html .. _dragon.cast(...): cast.html
.. _dragon.copy(...): copy.html
.. _dragon.fill(...): fill.html .. _dragon.fill(...): fill.html
.. _dragon.identity(...): identity.html
.. _dragon.masked_assign(...): masked_assign.html .. _dragon.masked_assign(...): masked_assign.html
.. _dragon.masked_select(...): masked_select.html .. _dragon.masked_select(...): masked_select.html
.. _dragon.math.add(...): math/add.html .. _dragon.math.add(...): math/add.html
......
...@@ -154,8 +154,8 @@ __truediv__ ...@@ -154,8 +154,8 @@ __truediv__
.. _dragon.assign(...): assign.html .. _dragon.assign(...): assign.html
.. _dragon.cast(...): cast.html .. _dragon.cast(...): cast.html
.. _dragon.copy(...): copy.html
.. _dragon.fill(...): fill.html .. _dragon.fill(...): fill.html
.. _dragon.identity(...): identity.html
.. _dragon.masked_assign(...): masked_assign.html .. _dragon.masked_assign(...): masked_assign.html
.. _dragon.masked_select(...): masked_select.html .. _dragon.masked_select(...): masked_select.html
.. _dragon.math.add(...): math/add.html .. _dragon.math.add(...): math/add.html
......
copy identity
==== ========
.. autofunction:: dragon.copy .. autofunction:: dragon.identity
.. raw:: html .. raw:: html
......
...@@ -55,7 +55,7 @@ if errorlevel 9009 ( ...@@ -55,7 +55,7 @@ if errorlevel 9009 (
:sphinx_ok :sphinx_ok
if "%1" == "html" ( if "%1" == "html" (
%SPHINXBUILD% -b html -j %NUMBER_OF_PROCESSORS% %ALLSPHINXOPTS% %BUILDDIR% %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%
if errorlevel 1 exit /b 1 if errorlevel 1 exit /b 1
echo. echo.
echo.Build finished. The HTML pages are in %BUILDDIR%. echo.Build finished. The HTML pages are in %BUILDDIR%.
......
...@@ -65,7 +65,7 @@ Name Supported Reference ...@@ -65,7 +65,7 @@ Name Supported Reference
`Greater`_ |v| :func:`dragon.math.greater` `Greater`_ |v| :func:`dragon.math.greater`
`HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid` `HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid`
`Hardmax`_ `Hardmax`_
`Identity`_ `Identity`_ |v| :func:`dragon.identity`
`If`_ `If`_
`InstanceNormalization`_ |v| :func:`dragon.nn.instance_norm` `InstanceNormalization`_ |v| :func:`dragon.nn.instance_norm`
`IsInf`_ |v| :func:`dragon.math.is_inf` `IsInf`_ |v| :func:`dragon.math.is_inf`
......
...@@ -58,7 +58,7 @@ vm.tensorflow ...@@ -58,7 +58,7 @@ vm.tensorflow
: Compute the symbolic derivatives of ``ys`` w.r.t. ``xs`` . : Compute the symbolic derivatives of ``ys`` w.r.t. ``xs`` .
`identity(...) <tensorflow/identity.html>`_ `identity(...) <tensorflow/identity.html>`_
: Return a new tensor copying the content of input. : Return a tensor copied from the input.
`linspace(...) <tensorflow/linspace.html>`_ `linspace(...) <tensorflow/linspace.html>`_
: Generate evenly spaced values within intervals along the given axis. : Generate evenly spaced values within intervals along the given axis.
......
...@@ -68,11 +68,9 @@ void _EluGrad<float16>( ...@@ -68,11 +68,9 @@ void _EluGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -170,10 +170,8 @@ void EluGrad<float16, CUDAContext>( ...@@ -170,10 +170,8 @@ void EluGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -83,11 +83,9 @@ void _HardSigmoidGrad<float16>( ...@@ -83,11 +83,9 @@ void _HardSigmoidGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -126,10 +126,8 @@ void HardSigmoidGrad<float16, CUDAContext>( ...@@ -126,10 +126,8 @@ void HardSigmoidGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -88,11 +88,9 @@ void _HardSwishGrad<float16>( ...@@ -88,11 +88,9 @@ void _HardSwishGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -153,10 +153,8 @@ void HardSwishGrad<float16, CUDAContext>( ...@@ -153,10 +153,8 @@ void HardSwishGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -243,10 +243,8 @@ void PReluWGrad<float16, CPUContext>( ...@@ -243,10 +243,8 @@ void PReluWGrad<float16, CPUContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -601,10 +601,8 @@ void PReluWGrad<float16, CUDAContext>( ...@@ -601,10 +601,8 @@ void PReluWGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -124,11 +124,9 @@ void _ReluNGrad<float16>( ...@@ -124,11 +124,9 @@ void _ReluNGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -401,10 +401,8 @@ void ReluNGrad<float16, CUDAContext>( ...@@ -401,10 +401,8 @@ void ReluNGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -85,11 +85,9 @@ void _SeluGrad<float16>( ...@@ -85,11 +85,9 @@ void _SeluGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -191,10 +191,8 @@ void SeluGrad<float16, CUDAContext>( ...@@ -191,10 +191,8 @@ void SeluGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -56,11 +56,9 @@ void _SigmoidGrad<float16>( ...@@ -56,11 +56,9 @@ void _SigmoidGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -135,10 +135,8 @@ void SigmoidGrad<float16, CUDAContext>( ...@@ -135,10 +135,8 @@ void SigmoidGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -122,11 +122,9 @@ void _SoftmaxGrad<float16>( ...@@ -122,11 +122,9 @@ void _SoftmaxGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -257,10 +257,8 @@ void SoftmaxGrad<float16, CUDAContext>( ...@@ -257,10 +257,8 @@ void SoftmaxGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -60,11 +60,9 @@ void _SwishGrad<float16>( ...@@ -60,11 +60,9 @@ void _SwishGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -121,10 +121,8 @@ void SwishGrad<float16, CUDAContext>( ...@@ -121,10 +121,8 @@ void SwishGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -55,11 +55,9 @@ void _TanhGrad<float16>( ...@@ -55,11 +55,9 @@ void _TanhGrad<float16>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -124,10 +124,8 @@ void TanhGrad<float16, CUDAContext>( ...@@ -124,10 +124,8 @@ void TanhGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -93,7 +93,6 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -93,7 +93,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -206,7 +206,6 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -206,7 +206,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -105,14 +105,12 @@ DEFINE_KERNEL_LAUNCHER(float, float); ...@@ -105,14 +105,12 @@ DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(float, double); DEFINE_KERNEL_LAUNCHER(float, double);
DEFINE_KERNEL_LAUNCHER(double, float); DEFINE_KERNEL_LAUNCHER(double, float);
DEFINE_KERNEL_LAUNCHER(double, double); DEFINE_KERNEL_LAUNCHER(double, double);
DEFINE_FP16_KERNEL_LAUNCHER(int8_t); DEFINE_FP16_KERNEL_LAUNCHER(int8_t);
DEFINE_FP16_KERNEL_LAUNCHER(uint8_t); DEFINE_FP16_KERNEL_LAUNCHER(uint8_t);
DEFINE_FP16_KERNEL_LAUNCHER(int); DEFINE_FP16_KERNEL_LAUNCHER(int);
DEFINE_FP16_KERNEL_LAUNCHER(int64_t); DEFINE_FP16_KERNEL_LAUNCHER(int64_t);
DEFINE_FP16_KERNEL_LAUNCHER(float); DEFINE_FP16_KERNEL_LAUNCHER(float);
DEFINE_FP16_KERNEL_LAUNCHER(double); DEFINE_FP16_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_FP16_KERNEL_LAUNCHER #undef DEFINE_FP16_KERNEL_LAUNCHER
......
...@@ -268,14 +268,12 @@ DEFINE_KERNEL_LAUNCHER(float, float); ...@@ -268,14 +268,12 @@ DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(float, double); DEFINE_KERNEL_LAUNCHER(float, double);
DEFINE_KERNEL_LAUNCHER(double, float); DEFINE_KERNEL_LAUNCHER(double, float);
DEFINE_KERNEL_LAUNCHER(double, double); DEFINE_KERNEL_LAUNCHER(double, double);
DEFINE_FP16_KERNEL_LAUNCHER(int8_t); DEFINE_FP16_KERNEL_LAUNCHER(int8_t);
DEFINE_FP16_KERNEL_LAUNCHER(uint8_t); DEFINE_FP16_KERNEL_LAUNCHER(uint8_t);
DEFINE_FP16_KERNEL_LAUNCHER(int); DEFINE_FP16_KERNEL_LAUNCHER(int);
DEFINE_FP16_KERNEL_LAUNCHER(int64_t); DEFINE_FP16_KERNEL_LAUNCHER(int64_t);
DEFINE_FP16_KERNEL_LAUNCHER(float); DEFINE_FP16_KERNEL_LAUNCHER(float);
DEFINE_FP16_KERNEL_LAUNCHER(double); DEFINE_FP16_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_FP16_KERNEL_LAUNCHER #undef DEFINE_FP16_KERNEL_LAUNCHER
......
...@@ -51,7 +51,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -51,7 +51,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -57,7 +57,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -57,7 +57,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -38,7 +38,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -38,7 +38,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -75,7 +75,6 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -75,7 +75,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -77,7 +77,6 @@ DEFINE_KERNEL_LAUNCHER(IndexSelect, int64_t); ...@@ -77,7 +77,6 @@ DEFINE_KERNEL_LAUNCHER(IndexSelect, int64_t);
DEFINE_KERNEL_LAUNCHER(IndexSelect, float16); DEFINE_KERNEL_LAUNCHER(IndexSelect, float16);
DEFINE_KERNEL_LAUNCHER(IndexSelect, float); DEFINE_KERNEL_LAUNCHER(IndexSelect, float);
DEFINE_KERNEL_LAUNCHER(IndexSelect, double); DEFINE_KERNEL_LAUNCHER(IndexSelect, double);
DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, int8_t); DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, int8_t);
DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, uint8_t); DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, uint8_t);
DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, int); DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, int);
...@@ -85,7 +84,6 @@ DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, int64_t); ...@@ -85,7 +84,6 @@ DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, int64_t);
DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, float16); DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, float16);
DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, float); DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, float);
DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, double); DEFINE_KERNEL_LAUNCHER(IndexSelectGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -112,11 +112,9 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -112,11 +112,9 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -80,14 +80,12 @@ DEFINE_KERNEL_LAUNCHER(int64_t, int64_t); ...@@ -80,14 +80,12 @@ DEFINE_KERNEL_LAUNCHER(int64_t, int64_t);
DEFINE_KERNEL_LAUNCHER(int64_t, float16); DEFINE_KERNEL_LAUNCHER(int64_t, float16);
DEFINE_KERNEL_LAUNCHER(int64_t, float); DEFINE_KERNEL_LAUNCHER(int64_t, float);
DEFINE_KERNEL_LAUNCHER(int64_t, double); DEFINE_KERNEL_LAUNCHER(int64_t, double);
DEFINE_GRAD_KERNEL_LAUNCHER(int, float16); DEFINE_GRAD_KERNEL_LAUNCHER(int, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(int, float); DEFINE_GRAD_KERNEL_LAUNCHER(int, float);
DEFINE_GRAD_KERNEL_LAUNCHER(int, double); DEFINE_GRAD_KERNEL_LAUNCHER(int, double);
DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, float16); DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, float); DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, float);
DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, double); DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -85,14 +85,12 @@ DEFINE_KERNEL_LAUNCHER(int64_t, int64_t); ...@@ -85,14 +85,12 @@ DEFINE_KERNEL_LAUNCHER(int64_t, int64_t);
DEFINE_KERNEL_LAUNCHER(int64_t, float16); DEFINE_KERNEL_LAUNCHER(int64_t, float16);
DEFINE_KERNEL_LAUNCHER(int64_t, float); DEFINE_KERNEL_LAUNCHER(int64_t, float);
DEFINE_KERNEL_LAUNCHER(int64_t, double); DEFINE_KERNEL_LAUNCHER(int64_t, double);
DEFINE_GRAD_KERNEL_LAUNCHER(int, float16); DEFINE_GRAD_KERNEL_LAUNCHER(int, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(int, float); DEFINE_GRAD_KERNEL_LAUNCHER(int, float);
DEFINE_GRAD_KERNEL_LAUNCHER(int, double); DEFINE_GRAD_KERNEL_LAUNCHER(int, double);
DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, float16); DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, float); DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, float);
DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, double); DEFINE_GRAD_KERNEL_LAUNCHER(int64_t, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -42,7 +42,6 @@ void _OneHot( ...@@ -42,7 +42,6 @@ void _OneHot(
DEFINE_KERNEL_LAUNCHER(int); DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -42,7 +42,6 @@ __global__ void _OneHot( ...@@ -42,7 +42,6 @@ __global__ void _OneHot(
DEFINE_KERNEL_LAUNCHER(int); DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -132,7 +132,6 @@ DEFINE_CONST_KERNEL_LAUNCHER(int64_t); ...@@ -132,7 +132,6 @@ DEFINE_CONST_KERNEL_LAUNCHER(int64_t);
DEFINE_CONST_KERNEL_LAUNCHER(float16); DEFINE_CONST_KERNEL_LAUNCHER(float16);
DEFINE_CONST_KERNEL_LAUNCHER(float); DEFINE_CONST_KERNEL_LAUNCHER(float);
DEFINE_CONST_KERNEL_LAUNCHER(double); DEFINE_CONST_KERNEL_LAUNCHER(double);
DEFINE_KERNEL_LAUNCHER(ReflectPad, bool); DEFINE_KERNEL_LAUNCHER(ReflectPad, bool);
DEFINE_KERNEL_LAUNCHER(ReflectPad, int8_t); DEFINE_KERNEL_LAUNCHER(ReflectPad, int8_t);
DEFINE_KERNEL_LAUNCHER(ReflectPad, uint8_t); DEFINE_KERNEL_LAUNCHER(ReflectPad, uint8_t);
...@@ -141,7 +140,6 @@ DEFINE_KERNEL_LAUNCHER(ReflectPad, int64_t); ...@@ -141,7 +140,6 @@ DEFINE_KERNEL_LAUNCHER(ReflectPad, int64_t);
DEFINE_KERNEL_LAUNCHER(ReflectPad, float16); DEFINE_KERNEL_LAUNCHER(ReflectPad, float16);
DEFINE_KERNEL_LAUNCHER(ReflectPad, float); DEFINE_KERNEL_LAUNCHER(ReflectPad, float);
DEFINE_KERNEL_LAUNCHER(ReflectPad, double); DEFINE_KERNEL_LAUNCHER(ReflectPad, double);
DEFINE_KERNEL_LAUNCHER(EdgePad, bool); DEFINE_KERNEL_LAUNCHER(EdgePad, bool);
DEFINE_KERNEL_LAUNCHER(EdgePad, int8_t); DEFINE_KERNEL_LAUNCHER(EdgePad, int8_t);
DEFINE_KERNEL_LAUNCHER(EdgePad, uint8_t); DEFINE_KERNEL_LAUNCHER(EdgePad, uint8_t);
...@@ -150,9 +148,8 @@ DEFINE_KERNEL_LAUNCHER(EdgePad, int64_t); ...@@ -150,9 +148,8 @@ DEFINE_KERNEL_LAUNCHER(EdgePad, int64_t);
DEFINE_KERNEL_LAUNCHER(EdgePad, float16); DEFINE_KERNEL_LAUNCHER(EdgePad, float16);
DEFINE_KERNEL_LAUNCHER(EdgePad, float); DEFINE_KERNEL_LAUNCHER(EdgePad, float);
DEFINE_KERNEL_LAUNCHER(EdgePad, double); DEFINE_KERNEL_LAUNCHER(EdgePad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_CONST_KERNEL_LAUNCHER #undef DEFINE_CONST_KERNEL_LAUNCHER
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -151,7 +151,6 @@ DEFINE_CONST_KERNEL_LAUNCHER(int64_t); ...@@ -151,7 +151,6 @@ DEFINE_CONST_KERNEL_LAUNCHER(int64_t);
DEFINE_CONST_KERNEL_LAUNCHER(float16); DEFINE_CONST_KERNEL_LAUNCHER(float16);
DEFINE_CONST_KERNEL_LAUNCHER(float); DEFINE_CONST_KERNEL_LAUNCHER(float);
DEFINE_CONST_KERNEL_LAUNCHER(double); DEFINE_CONST_KERNEL_LAUNCHER(double);
DEFINE_KERNEL_LAUNCHER(ReflectPad, bool); DEFINE_KERNEL_LAUNCHER(ReflectPad, bool);
DEFINE_KERNEL_LAUNCHER(ReflectPad, int8_t); DEFINE_KERNEL_LAUNCHER(ReflectPad, int8_t);
DEFINE_KERNEL_LAUNCHER(ReflectPad, uint8_t); DEFINE_KERNEL_LAUNCHER(ReflectPad, uint8_t);
...@@ -160,7 +159,6 @@ DEFINE_KERNEL_LAUNCHER(ReflectPad, int64_t); ...@@ -160,7 +159,6 @@ DEFINE_KERNEL_LAUNCHER(ReflectPad, int64_t);
DEFINE_KERNEL_LAUNCHER(ReflectPad, float16); DEFINE_KERNEL_LAUNCHER(ReflectPad, float16);
DEFINE_KERNEL_LAUNCHER(ReflectPad, float); DEFINE_KERNEL_LAUNCHER(ReflectPad, float);
DEFINE_KERNEL_LAUNCHER(ReflectPad, double); DEFINE_KERNEL_LAUNCHER(ReflectPad, double);
DEFINE_KERNEL_LAUNCHER(EdgePad, bool); DEFINE_KERNEL_LAUNCHER(EdgePad, bool);
DEFINE_KERNEL_LAUNCHER(EdgePad, int8_t); DEFINE_KERNEL_LAUNCHER(EdgePad, int8_t);
DEFINE_KERNEL_LAUNCHER(EdgePad, uint8_t); DEFINE_KERNEL_LAUNCHER(EdgePad, uint8_t);
...@@ -169,9 +167,8 @@ DEFINE_KERNEL_LAUNCHER(EdgePad, int64_t); ...@@ -169,9 +167,8 @@ DEFINE_KERNEL_LAUNCHER(EdgePad, int64_t);
DEFINE_KERNEL_LAUNCHER(EdgePad, float16); DEFINE_KERNEL_LAUNCHER(EdgePad, float16);
DEFINE_KERNEL_LAUNCHER(EdgePad, float); DEFINE_KERNEL_LAUNCHER(EdgePad, float);
DEFINE_KERNEL_LAUNCHER(EdgePad, double); DEFINE_KERNEL_LAUNCHER(EdgePad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_CONST_KERNEL_LAUNCHER #undef DEFINE_CONST_KERNEL_LAUNCHER
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -63,7 +63,6 @@ void ReduceSumGrad<float16, CPUContext>( ...@@ -63,7 +63,6 @@ void ReduceSumGrad<float16, CPUContext>(
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -128,7 +128,6 @@ void ReduceSumGrad<float16, CUDAContext>( ...@@ -128,7 +128,6 @@ void ReduceSumGrad<float16, CUDAContext>(
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -75,11 +75,9 @@ DEFINE_KERNEL_LAUNCHER(Repeat, int64_t); ...@@ -75,11 +75,9 @@ DEFINE_KERNEL_LAUNCHER(Repeat, int64_t);
DEFINE_KERNEL_LAUNCHER(Repeat, float16); DEFINE_KERNEL_LAUNCHER(Repeat, float16);
DEFINE_KERNEL_LAUNCHER(Repeat, float); DEFINE_KERNEL_LAUNCHER(Repeat, float);
DEFINE_KERNEL_LAUNCHER(Repeat, double); DEFINE_KERNEL_LAUNCHER(Repeat, double);
DEFINE_KERNEL_LAUNCHER(RepeatGrad, float16); DEFINE_KERNEL_LAUNCHER(RepeatGrad, float16);
DEFINE_KERNEL_LAUNCHER(RepeatGrad, float); DEFINE_KERNEL_LAUNCHER(RepeatGrad, float);
DEFINE_KERNEL_LAUNCHER(RepeatGrad, double); DEFINE_KERNEL_LAUNCHER(RepeatGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -139,10 +139,8 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -139,10 +139,8 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -76,7 +76,6 @@ DEFINE_KERNEL_LAUNCHER(Slice, int64_t); ...@@ -76,7 +76,6 @@ DEFINE_KERNEL_LAUNCHER(Slice, int64_t);
DEFINE_KERNEL_LAUNCHER(Slice, float16); DEFINE_KERNEL_LAUNCHER(Slice, float16);
DEFINE_KERNEL_LAUNCHER(Slice, float); DEFINE_KERNEL_LAUNCHER(Slice, float);
DEFINE_KERNEL_LAUNCHER(Slice, double); DEFINE_KERNEL_LAUNCHER(Slice, double);
DEFINE_KERNEL_LAUNCHER(SliceGrad, bool); DEFINE_KERNEL_LAUNCHER(SliceGrad, bool);
DEFINE_KERNEL_LAUNCHER(SliceGrad, int8_t); DEFINE_KERNEL_LAUNCHER(SliceGrad, int8_t);
DEFINE_KERNEL_LAUNCHER(SliceGrad, uint8_t); DEFINE_KERNEL_LAUNCHER(SliceGrad, uint8_t);
...@@ -85,7 +84,6 @@ DEFINE_KERNEL_LAUNCHER(SliceGrad, int64_t); ...@@ -85,7 +84,6 @@ DEFINE_KERNEL_LAUNCHER(SliceGrad, int64_t);
DEFINE_KERNEL_LAUNCHER(SliceGrad, float16); DEFINE_KERNEL_LAUNCHER(SliceGrad, float16);
DEFINE_KERNEL_LAUNCHER(SliceGrad, float); DEFINE_KERNEL_LAUNCHER(SliceGrad, float);
DEFINE_KERNEL_LAUNCHER(SliceGrad, double); DEFINE_KERNEL_LAUNCHER(SliceGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -85,7 +85,6 @@ DEFINE_KERNEL_LAUNCHER(Slice, int64_t); ...@@ -85,7 +85,6 @@ DEFINE_KERNEL_LAUNCHER(Slice, int64_t);
DEFINE_KERNEL_LAUNCHER(Slice, float16); DEFINE_KERNEL_LAUNCHER(Slice, float16);
DEFINE_KERNEL_LAUNCHER(Slice, float); DEFINE_KERNEL_LAUNCHER(Slice, float);
DEFINE_KERNEL_LAUNCHER(Slice, double); DEFINE_KERNEL_LAUNCHER(Slice, double);
DEFINE_KERNEL_LAUNCHER(SliceGrad, bool); DEFINE_KERNEL_LAUNCHER(SliceGrad, bool);
DEFINE_KERNEL_LAUNCHER(SliceGrad, int8_t); DEFINE_KERNEL_LAUNCHER(SliceGrad, int8_t);
DEFINE_KERNEL_LAUNCHER(SliceGrad, uint8_t); DEFINE_KERNEL_LAUNCHER(SliceGrad, uint8_t);
...@@ -94,7 +93,6 @@ DEFINE_KERNEL_LAUNCHER(SliceGrad, int64_t); ...@@ -94,7 +93,6 @@ DEFINE_KERNEL_LAUNCHER(SliceGrad, int64_t);
DEFINE_KERNEL_LAUNCHER(SliceGrad, float16); DEFINE_KERNEL_LAUNCHER(SliceGrad, float16);
DEFINE_KERNEL_LAUNCHER(SliceGrad, float); DEFINE_KERNEL_LAUNCHER(SliceGrad, float);
DEFINE_KERNEL_LAUNCHER(SliceGrad, double); DEFINE_KERNEL_LAUNCHER(SliceGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -85,11 +85,9 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -85,11 +85,9 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -141,10 +141,8 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -141,10 +141,8 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -73,11 +73,9 @@ DEFINE_KERNEL_LAUNCHER(Transpose, int64_t); ...@@ -73,11 +73,9 @@ DEFINE_KERNEL_LAUNCHER(Transpose, int64_t);
DEFINE_KERNEL_LAUNCHER(Transpose, float16); DEFINE_KERNEL_LAUNCHER(Transpose, float16);
DEFINE_KERNEL_LAUNCHER(Transpose, float); DEFINE_KERNEL_LAUNCHER(Transpose, float);
DEFINE_KERNEL_LAUNCHER(Transpose, double); DEFINE_KERNEL_LAUNCHER(Transpose, double);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, float16); DEFINE_KERNEL_LAUNCHER(TransposeGrad, float16);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, float); DEFINE_KERNEL_LAUNCHER(TransposeGrad, float);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, double); DEFINE_KERNEL_LAUNCHER(TransposeGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -81,11 +81,9 @@ DEFINE_KERNEL_LAUNCHER(Transpose, int64_t); ...@@ -81,11 +81,9 @@ DEFINE_KERNEL_LAUNCHER(Transpose, int64_t);
DEFINE_KERNEL_LAUNCHER(Transpose, float16); DEFINE_KERNEL_LAUNCHER(Transpose, float16);
DEFINE_KERNEL_LAUNCHER(Transpose, float); DEFINE_KERNEL_LAUNCHER(Transpose, float);
DEFINE_KERNEL_LAUNCHER(Transpose, double); DEFINE_KERNEL_LAUNCHER(Transpose, double);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, float16); DEFINE_KERNEL_LAUNCHER(TransposeGrad, float16);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, float); DEFINE_KERNEL_LAUNCHER(TransposeGrad, float);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, double); DEFINE_KERNEL_LAUNCHER(TransposeGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -75,6 +75,7 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -75,6 +75,7 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -54,7 +54,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -54,7 +54,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -65,7 +65,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -65,7 +65,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -108,10 +108,8 @@ void BroadcastLossGrad<float16, CPUContext>( ...@@ -108,10 +108,8 @@ void BroadcastLossGrad<float16, CPUContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -231,10 +231,8 @@ void BroadcastLossGrad<float16, CUDAContext>( ...@@ -231,10 +231,8 @@ void BroadcastLossGrad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -88,12 +88,10 @@ DEFINE_KERNEL_LAUNCHER(NLLLoss, float, float); ...@@ -88,12 +88,10 @@ DEFINE_KERNEL_LAUNCHER(NLLLoss, float, float);
DEFINE_KERNEL_LAUNCHER(NLLLoss, float, int64_t); DEFINE_KERNEL_LAUNCHER(NLLLoss, float, int64_t);
DEFINE_KERNEL_LAUNCHER(NLLLoss, double, double); DEFINE_KERNEL_LAUNCHER(NLLLoss, double, double);
DEFINE_KERNEL_LAUNCHER(NLLLoss, double, int64_t); DEFINE_KERNEL_LAUNCHER(NLLLoss, double, int64_t);
DEFINE_KERNEL_LAUNCHER(NLLLossGrad, float, float); DEFINE_KERNEL_LAUNCHER(NLLLossGrad, float, float);
DEFINE_KERNEL_LAUNCHER(NLLLossGrad, float, int64_t); DEFINE_KERNEL_LAUNCHER(NLLLossGrad, float, int64_t);
DEFINE_KERNEL_LAUNCHER(NLLLossGrad, double, double); DEFINE_KERNEL_LAUNCHER(NLLLossGrad, double, double);
DEFINE_KERNEL_LAUNCHER(NLLLossGrad, double, int64_t); DEFINE_KERNEL_LAUNCHER(NLLLossGrad, double, int64_t);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -87,12 +87,10 @@ DEFINE_KERNEL_LAUNCHER(NLLLoss, float, float); ...@@ -87,12 +87,10 @@ DEFINE_KERNEL_LAUNCHER(NLLLoss, float, float);
DEFINE_KERNEL_LAUNCHER(NLLLoss, float, int64_t); DEFINE_KERNEL_LAUNCHER(NLLLoss, float, int64_t);
DEFINE_KERNEL_LAUNCHER(NLLLoss, double, double); DEFINE_KERNEL_LAUNCHER(NLLLoss, double, double);
DEFINE_KERNEL_LAUNCHER(NLLLoss, double, int64_t); DEFINE_KERNEL_LAUNCHER(NLLLoss, double, int64_t);
DEFINE_KERNEL_LAUNCHER(NLLLossGrad, float, float); DEFINE_KERNEL_LAUNCHER(NLLLossGrad, float, float);
DEFINE_KERNEL_LAUNCHER(NLLLossGrad, float, int64_t); DEFINE_KERNEL_LAUNCHER(NLLLossGrad, float, int64_t);
DEFINE_KERNEL_LAUNCHER(NLLLossGrad, double, double); DEFINE_KERNEL_LAUNCHER(NLLLossGrad, double, double);
DEFINE_KERNEL_LAUNCHER(NLLLossGrad, double, int64_t); DEFINE_KERNEL_LAUNCHER(NLLLossGrad, double, int64_t);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -68,10 +68,8 @@ void _SigmoidCrossEntropyGrad( ...@@ -68,10 +68,8 @@ void _SigmoidCrossEntropyGrad(
DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropy, float); DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropy, float);
DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropy, double); DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropy, double);
DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropyGrad, float); DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropyGrad, float);
DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropyGrad, double); DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropyGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -64,10 +64,8 @@ __global__ void _SigmoidCrossEntropyGrad( ...@@ -64,10 +64,8 @@ __global__ void _SigmoidCrossEntropyGrad(
DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropy, float); DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropy, float);
DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropy, double); DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropy, double);
DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropyGrad, float); DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropyGrad, float);
DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropyGrad, double); DEFINE_KERNEL_LAUNCHER(SigmoidCrossEntropyGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -137,12 +137,10 @@ DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, float, float); ...@@ -137,12 +137,10 @@ DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, float, float);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, float, int64_t); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, float, int64_t);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, double, double); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, double, double);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, double, int64_t); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, double, int64_t);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, float, float); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, float, float);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, float, int64_t); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, float, int64_t);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, double, double); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, double, double);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, double, int64_t); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, double, int64_t);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -132,12 +132,10 @@ DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, float, float); ...@@ -132,12 +132,10 @@ DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, float, float);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, float, int64_t); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, float, int64_t);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, double, double); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, double, double);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, double, int64_t); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLoss, double, int64_t);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, float, float); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, float, float);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, float, int64_t); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, float, int64_t);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, double, double); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, double, double);
DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, double, int64_t); DEFINE_KERNEL_LAUNCHER(SigmoidFocalLossGrad, double, int64_t);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -57,10 +57,8 @@ void SmoothL1Grad<float16, CPUContext>( ...@@ -57,10 +57,8 @@ void SmoothL1Grad<float16, CPUContext>(
DEFINE_KERNEL_LAUNCHER(SmoothL1, float); DEFINE_KERNEL_LAUNCHER(SmoothL1, float);
DEFINE_KERNEL_LAUNCHER(SmoothL1, double); DEFINE_KERNEL_LAUNCHER(SmoothL1, double);
DEFINE_KERNEL_LAUNCHER(SmoothL1Grad, float); DEFINE_KERNEL_LAUNCHER(SmoothL1Grad, float);
DEFINE_KERNEL_LAUNCHER(SmoothL1Grad, double); DEFINE_KERNEL_LAUNCHER(SmoothL1Grad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -90,10 +90,8 @@ void SmoothL1Grad<float16, CUDAContext>( ...@@ -90,10 +90,8 @@ void SmoothL1Grad<float16, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(SmoothL1, float); DEFINE_KERNEL_LAUNCHER(SmoothL1, float);
DEFINE_KERNEL_LAUNCHER(SmoothL1, double); DEFINE_KERNEL_LAUNCHER(SmoothL1, double);
DEFINE_KERNEL_LAUNCHER(SmoothL1Grad, float); DEFINE_KERNEL_LAUNCHER(SmoothL1Grad, float);
DEFINE_KERNEL_LAUNCHER(SmoothL1Grad, double); DEFINE_KERNEL_LAUNCHER(SmoothL1Grad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -37,7 +37,6 @@ void _SoftmaxCrossEntropy( ...@@ -37,7 +37,6 @@ void _SoftmaxCrossEntropy(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -41,7 +41,6 @@ __global__ void _SoftmaxCrossEntropy( ...@@ -41,7 +41,6 @@ __global__ void _SoftmaxCrossEntropy(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -95,12 +95,10 @@ DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, float, float); ...@@ -95,12 +95,10 @@ DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, float, float);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, float, int64_t); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, float, int64_t);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, double, double); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, double, double);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, double, int64_t); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, double, int64_t);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, float, float); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, float, float);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, float, int64_t); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, float, int64_t);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, double, double); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, double, double);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, double, int64_t); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, double, int64_t);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -93,12 +93,10 @@ DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, float, float); ...@@ -93,12 +93,10 @@ DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, float, float);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, float, int64_t); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, float, int64_t);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, double, double); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, double, double);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, double, int64_t); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropy, double, int64_t);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, float, float); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, float, float);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, float, int64_t); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, float, int64_t);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, double, double); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, double, double);
DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, double, int64_t); DEFINE_KERNEL_LAUNCHER(SparseSoftmaxCrossEntropyGrad, double, int64_t);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -104,11 +104,9 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -104,11 +104,9 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -160,10 +160,8 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -160,10 +160,8 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -155,7 +155,7 @@ void Moments<float16, float, CPUContext>( ...@@ -155,7 +155,7 @@ void Moments<float16, float, CPUContext>(
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
#define DEFINE_MOMENTS_KERNEL_LAUNCHER(Tx, Ty) \ #define DEFINE_KERNEL_LAUNCHER(Tx, Ty) \
template <> \ template <> \
void Moments<Tx, Ty, CPUContext>( \ void Moments<Tx, Ty, CPUContext>( \
const int num_dims, \ const int num_dims, \
...@@ -169,14 +169,13 @@ void Moments<float16, float, CPUContext>( ...@@ -169,14 +169,13 @@ void Moments<float16, float, CPUContext>(
_Moments(num_dims, dims, num_axes, axes, x, mean, var, ctx); \ _Moments(num_dims, dims, num_axes, axes, x, mean, var, ctx); \
} }
DEFINE_MOMENTS_KERNEL_LAUNCHER(int8_t, float); DEFINE_KERNEL_LAUNCHER(int8_t, float);
DEFINE_MOMENTS_KERNEL_LAUNCHER(uint8_t, float); DEFINE_KERNEL_LAUNCHER(uint8_t, float);
DEFINE_MOMENTS_KERNEL_LAUNCHER(int, float); DEFINE_KERNEL_LAUNCHER(int, float);
DEFINE_MOMENTS_KERNEL_LAUNCHER(int64_t, float); DEFINE_KERNEL_LAUNCHER(int64_t, float);
DEFINE_MOMENTS_KERNEL_LAUNCHER(float, float); DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_MOMENTS_KERNEL_LAUNCHER(double, double); DEFINE_KERNEL_LAUNCHER(double, double);
#undef DEFINE__KERNEL_LAUNCHER
#undef DEFINE_MOMENTS_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -316,7 +316,6 @@ DEFINE_KERNEL_LAUNCHER(int, float); ...@@ -316,7 +316,6 @@ DEFINE_KERNEL_LAUNCHER(int, float);
DEFINE_KERNEL_LAUNCHER(int64_t, float); DEFINE_KERNEL_LAUNCHER(int64_t, float);
DEFINE_KERNEL_LAUNCHER(float, float); DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(double, double); DEFINE_KERNEL_LAUNCHER(double, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -177,12 +177,6 @@ void L2NormalizeGrad<float16, CPUContext>( ...@@ -177,12 +177,6 @@ void L2NormalizeGrad<float16, CPUContext>(
_##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, x, y); \ _##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, x, y); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void name<T, CPUContext>( \ void name<T, CPUContext>( \
...@@ -198,10 +192,15 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double); ...@@ -198,10 +192,15 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
_##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, dy, x, dx); \ _##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, dy, x, dx); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -170,14 +170,6 @@ __global__ void _L2NormalizeGrad( ...@@ -170,14 +170,6 @@ __global__ void _L2NormalizeGrad(
reinterpret_cast<ScalarT*>(y)); \ reinterpret_cast<ScalarT*>(y)); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float16, half, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, float, float, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double, double, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float16, half, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float, float, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double, double, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T, ScalarT, AccT) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(name, T, ScalarT, AccT) \
template <> \ template <> \
void name<T, CUDAContext>( \ void name<T, CUDAContext>( \
...@@ -203,12 +195,19 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double, double, double); ...@@ -203,12 +195,19 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double, double, double);
reinterpret_cast<ScalarT*>(dx)); \ reinterpret_cast<ScalarT*>(dx)); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float16, half, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, float, float, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double, double, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float16, half, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float, float, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float16, half, float); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float, float, float); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double, double, double); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16, half, float); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float, float, float); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double, double, double); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double, double, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -293,10 +293,8 @@ void _AvgPool2dGradNHWC( ...@@ -293,10 +293,8 @@ void _AvgPool2dGradNHWC(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -325,10 +325,8 @@ __global__ void _AvgPool2dGradNHWC( ...@@ -325,10 +325,8 @@ __global__ void _AvgPool2dGradNHWC(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -76,7 +76,6 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -76,7 +76,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -103,7 +103,6 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -103,7 +103,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -304,7 +304,6 @@ void _Col2Im2dNHWC( ...@@ -304,7 +304,6 @@ void _Col2Im2dNHWC(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -356,7 +356,6 @@ __global__ void _Col2Im2dNHWC( ...@@ -356,7 +356,6 @@ __global__ void _Col2Im2dNHWC(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -243,6 +243,7 @@ void DepthwiseConv2d<float, CPUContext>( ...@@ -243,6 +243,7 @@ void DepthwiseConv2d<float, CPUContext>(
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
#undef DISPATCH_DATA_KERNEL
} // namespace kernel } // namespace kernel
......
...@@ -12,9 +12,9 @@ namespace kernel { ...@@ -12,9 +12,9 @@ namespace kernel {
namespace { namespace {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
#define LOAD(x, i) __ldg(x + i) #define LDG(x, i) __ldg(x + i)
#else #else
#define LOAD(x, i) x[i] #define LDG(x, i) x[i]
#endif #endif
template <typename T, typename AccT, int KKH, int KKW> template <typename T, typename AccT, int KKH, int KKW>
...@@ -60,7 +60,7 @@ __global__ void _DepthwiseConv2dNCHW( ...@@ -60,7 +60,7 @@ __global__ void _DepthwiseConv2dNCHW(
iw = iw_start + kw * dilation_w; iw = iw_start + kw * dilation_w;
if (ih >= 0 && ih < H && iw >= 0 && iw < W) { if (ih >= 0 && ih < H && iw >= 0 && iw < W) {
xi = x_start + ih * W + iw; xi = x_start + ih * W + iw;
sum_val += convert::To<AccT>(Multiplies(LOAD(x, xi), LOAD(w, wi))); sum_val += convert::To<AccT>(Multiplies(LDG(x, xi), LDG(w, wi)));
} }
++wi; ++wi;
} // End kw } // End kw
...@@ -112,7 +112,7 @@ __global__ void _DepthwiseConv2dNHWC( ...@@ -112,7 +112,7 @@ __global__ void _DepthwiseConv2dNHWC(
iw = iw_start + kw * dilation_w; iw = iw_start + kw * dilation_w;
if (ih >= 0 && ih < H && iw >= 0 && iw < W) { if (ih >= 0 && ih < H && iw >= 0 && iw < W) {
xi = ((x_start + ih) * W + iw) * C + c; xi = ((x_start + ih) * W + iw) * C + c;
sum_val += convert::To<AccT>(Multiplies(LOAD(x, xi), LOAD(w, wi))); sum_val += convert::To<AccT>(Multiplies(LDG(x, xi), LDG(w, wi)));
} }
++wi; ++wi;
} // End kw } // End kw
...@@ -164,7 +164,7 @@ __global__ void _DepthwiseConv2dGradNCHW( ...@@ -164,7 +164,7 @@ __global__ void _DepthwiseConv2dGradNCHW(
ow = ow / stride_w; ow = ow / stride_w;
if (oh >= 0 && oh < out_h && ow >= 0 && ow < out_w) { if (oh >= 0 && oh < out_h && ow >= 0 && ow < out_w) {
yi = y_start + oh * out_w + ow; yi = y_start + oh * out_w + ow;
sum_val += convert::To<AccT>(Multiplies(LOAD(dy, yi), LOAD(w, wi))); sum_val += convert::To<AccT>(Multiplies(LDG(dy, yi), LDG(w, wi)));
} }
} }
++wi; ++wi;
...@@ -217,7 +217,7 @@ __global__ void _DepthwiseConv2dGradNHWC( ...@@ -217,7 +217,7 @@ __global__ void _DepthwiseConv2dGradNHWC(
ow = ow / stride_w; ow = ow / stride_w;
if (oh >= 0 && oh < out_h && ow >= 0 && ow < out_w) { if (oh >= 0 && oh < out_h && ow >= 0 && ow < out_w) {
yi = ((y_start + oh) * out_w + ow) * C + c; yi = ((y_start + oh) * out_w + ow) * C + c;
sum_val += convert::To<AccT>(Multiplies(LOAD(dy, yi), LOAD(w, wi))); sum_val += convert::To<AccT>(Multiplies(LDG(dy, yi), LDG(w, wi)));
} }
} }
++wi; ++wi;
...@@ -267,7 +267,7 @@ __global__ void _DepthwiseConv2dWGradNCHW( ...@@ -267,7 +267,7 @@ __global__ void _DepthwiseConv2dWGradNCHW(
if (ih >= 0 && iw >= 0 && ih < H && iw < W) { if (ih >= 0 && iw >= 0 && ih < H && iw < W) {
xi = ((i * C + c) * H + ih) * W + iw; xi = ((i * C + c) * H + ih) * W + iw;
yi = (i * C + c) * out_h * out_w + j; yi = (i * C + c) * out_h * out_w + j;
sum_val += convert::To<AccT>(Multiplies(LOAD(dy, yi), LOAD(x, xi))); sum_val += convert::To<AccT>(Multiplies(LDG(dy, yi), LDG(x, xi)));
} }
} }
} }
...@@ -320,7 +320,7 @@ __global__ void _DepthwiseConv2dWGradNHWC( ...@@ -320,7 +320,7 @@ __global__ void _DepthwiseConv2dWGradNHWC(
if (ih >= 0 && iw >= 0 && ih < H && iw < W) { if (ih >= 0 && iw >= 0 && ih < H && iw < W) {
xi = ((i * H + ih) * W + iw) * C + c; xi = ((i * H + ih) * W + iw) * C + c;
yi = (i * ohw + j) * C + c; yi = (i * ohw + j) * C + c;
sum_val += convert::To<AccT>(Multiplies(LOAD(dy, yi), LOAD(x, xi))); sum_val += convert::To<AccT>(Multiplies(LDG(dy, yi), LDG(x, xi)));
} }
} }
} }
...@@ -333,7 +333,7 @@ __global__ void _DepthwiseConv2dWGradNHWC( ...@@ -333,7 +333,7 @@ __global__ void _DepthwiseConv2dWGradNHWC(
} }
} }
#undef LOAD #undef LDG
} // namespace } // namespace
...@@ -528,10 +528,8 @@ __global__ void _DepthwiseConv2dWGradNHWC( ...@@ -528,10 +528,8 @@ __global__ void _DepthwiseConv2dWGradNHWC(
DEFINE_KERNEL_LAUNCHER(float16, half, float); DEFINE_KERNEL_LAUNCHER(float16, half, float);
DEFINE_KERNEL_LAUNCHER(float, float, float); DEFINE_KERNEL_LAUNCHER(float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float); DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float); DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float);
#undef DISPATCH_DATA_KERNEL #undef DISPATCH_DATA_KERNEL
#undef DISPATCH_WEIGHT_KERNEL #undef DISPATCH_WEIGHT_KERNEL
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
......
...@@ -287,10 +287,8 @@ void _MaxPool2dGradNHWC( ...@@ -287,10 +287,8 @@ void _MaxPool2dGradNHWC(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -335,10 +335,8 @@ __global__ void _MaxPool2dGradNHWC( ...@@ -335,10 +335,8 @@ __global__ void _MaxPool2dGradNHWC(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -272,10 +272,8 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -272,10 +272,8 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -181,10 +181,8 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -181,10 +181,8 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -185,11 +185,9 @@ void RoiAlign<float16, CPUContext>( ...@@ -185,11 +185,9 @@ void RoiAlign<float16, CPUContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -149,11 +149,9 @@ void RoiPool<float16, CPUContext>( ...@@ -149,11 +149,9 @@ void RoiPool<float16, CPUContext>(
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16); DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -9,6 +9,12 @@ namespace kernel { ...@@ -9,6 +9,12 @@ namespace kernel {
namespace { namespace {
#if __CUDA_ARCH__ >= 350
#define LDG(x, i) __ldg(x + i)
#else
#define LDG(x, i) x[i]
#endif
template <typename T> template <typename T>
__global__ void _RoiPool( __global__ void _RoiPool(
const int nthreads, const int nthreads,
...@@ -22,6 +28,7 @@ __global__ void _RoiPool( ...@@ -22,6 +28,7 @@ __global__ void _RoiPool(
const float* rois, const float* rois,
int* mask, int* mask,
T* y) { T* y) {
auto Greater = math::GreaterFunctor<T>();
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int ow = yi % out_w; const int ow = yi % out_w;
const int oh = (yi / out_w) % out_h; const int oh = (yi / out_w) % out_h;
...@@ -32,7 +39,7 @@ __global__ void _RoiPool( ...@@ -32,7 +39,7 @@ __global__ void _RoiPool(
const int batch_ind = roi[0]; const int batch_ind = roi[0];
if (batch_ind < 0) { if (batch_ind < 0) {
y[yi] = T(0); y[yi] = convert::To<T>(0.f);
mask[yi] = -1; mask[yi] = -1;
continue; continue;
} }
...@@ -60,119 +67,22 @@ __global__ void _RoiPool( ...@@ -60,119 +67,22 @@ __global__ void _RoiPool(
int max_idx = empty ? -1 : 0; int max_idx = empty ? -1 : 0;
const T* offset_x = x + (batch_ind * C + c) * H * W; const T* offset_x = x + (batch_ind * C + c) * H * W;
T val = empty ? T(0) : offset_x[0]; T val = empty ? convert::To<T>(0.f) : offset_x[0];
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
const int xi = h * W + w;
#if __CUDA_ARCH__ >= 350
if (__ldg(offset_x + xi) > val) {
val = __ldg(offset_x + xi);
max_idx = xi;
}
#else
if (offset_x[xi] > val) {
val = offset_x[xi];
max_idx = xi;
}
#endif
}
}
y[yi] = val;
mask[yi] = max_idx;
}
}
template <>
__global__ void _RoiPool<half>(
const int nthreads,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const float spatial_scale,
const half* x,
const float* rois,
int* mask,
half* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int ow = yi % out_w;
const int oh = (yi / out_w) % out_h;
const int c = (yi / out_w / out_h) % C;
const int n = yi / out_w / out_h / C;
const float* roi = rois + n * 5;
const int batch_ind = roi[0];
if (batch_ind < 0) {
y[yi] = __float2half(0.f);
mask[yi] = -1;
continue;
}
const int roi_start_w = round(roi[1] * spatial_scale);
const int roi_start_h = round(roi[2] * spatial_scale);
const int roi_end_w = round(roi[3] * spatial_scale);
const int roi_end_h = round(roi[4] * spatial_scale);
const int roi_w = max(roi_end_w - roi_start_w + 1, 1);
const int roi_h = max(roi_end_h - roi_start_h + 1, 1);
const float bin_h = (float)roi_h / (float)out_h;
const float bin_w = (float)roi_w / (float)out_w;
int hstart = floor(bin_h * oh);
int wstart = floor(bin_w * ow);
int hend = ceil(bin_h * (oh + 1));
int wend = ceil(bin_w * (ow + 1));
hstart = min(max(hstart + roi_start_h, 0), H);
hend = min(max(hend + roi_start_h, 0), H);
wstart = min(max(wstart + roi_start_w, 0), W);
wend = min(max(wend + roi_start_w, 0), W);
const bool empty = (hend <= hstart) || (wend <= wstart);
int max_idx = empty ? -1 : 0;
const half* offset_x = x + ((batch_ind * C + c) * H * W);
#if __CUDA_ARCH__ >= 530
half val = empty ? __float2half(0.f) : __ldg(offset_x);
#else
float val = empty ? 0.f : __half2float(*offset_x);
#endif
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
const int xi = h * W + w; const int xi = h * W + w;
#if __CUDA_ARCH__ >= 530 if (Greater(LDG(offset_x, xi), val)) {
if (__hgt(__ldg(offset_x + xi), val)) { val = LDG(offset_x, xi);
val = __ldg(offset_x + xi);
max_idx = xi; max_idx = xi;
} }
#elif __CUDA_ARCH__ >= 350
if (__half2float(__ldg(offset_x + xi)) > val) {
val = __half2float(__ldg(offset_x + xi));
max_idx = xi;
} }
#else
if (__half2float(offset_x[xi]) > val) {
val = __half2float(offset_x[xi]);
max_idx = xi;
}
#endif
} }
}
#if __CUDA_ARCH__ >= 530
y[yi] = val; y[yi] = val;
#else
y[yi] = __float2half(val);
#endif
mask[yi] = max_idx; mask[yi] = max_idx;
} }
} }
template <typename T> template <typename T, typename AccT>
__global__ void _RoiPoolGrad( __global__ void _RoiPoolGrad(
const int nthreads, const int nthreads,
const int C, const int C,
...@@ -184,7 +94,7 @@ __global__ void _RoiPoolGrad( ...@@ -184,7 +94,7 @@ __global__ void _RoiPoolGrad(
const T* dy, const T* dy,
const float* rois, const float* rois,
const int* mask, const int* mask,
float* dx) { AccT* dx) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int c = (yi / out_w / out_h) % C; const int c = (yi / out_w / out_h) % C;
const int n = yi / out_w / out_h / C; const int n = yi / out_w / out_h / C;
...@@ -193,116 +103,20 @@ __global__ void _RoiPoolGrad( ...@@ -193,116 +103,20 @@ __global__ void _RoiPoolGrad(
const int batch_ind = roi[0]; const int batch_ind = roi[0];
if (batch_ind < 0) continue; if (batch_ind < 0) continue;
float* offset_dx = dx + (batch_ind * C + c) * H * W; AccT* offset_dx = dx + (batch_ind * C + c) * H * W;
#if __CUDA_ARCH__ >= 350 if (LDG(mask, yi) != -1) {
if (__ldg(mask + yi) != -1) { atomicAdd(offset_dx + LDG(mask, yi), convert::To<AccT>(dy[yi]));
atomicAdd(offset_dx + __ldg(mask + yi), (float)dy[yi]);
}
#else
if (mask[yi] != -1) {
atomicAdd(offset_dx + mask[yi], (float)dy[yi]);
} }
#endif
} }
} }
template <> #undef LDG
__global__ void _RoiPoolGrad<half>(
const int nthreads,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const float spatial_scale,
const half* dy,
const float* rois,
const int* mask,
float* dx) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int c = (yi / out_w / out_h) % C;
const int n = yi / out_w / out_h / C;
const float* roi = rois + n * 5;
const int batch_ind = roi[0];
if (batch_ind < 0) continue;
float* offset_dx = dx + (batch_ind * C + c) * H * W;
#if __CUDA_ARCH__ >= 350
if (__ldg(mask + yi) != -1) {
atomicAdd(offset_dx + __ldg(mask + yi), __half2float(dy[yi]));
}
#else
if (mask[yi] != -1) {
atomicAdd(offset_dx + mask[yi], __half2float(dy[yi]));
}
#endif
}
}
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> #define DEFINE_KERNEL_LAUNCHER(T, ScalarT) \
void RoiPool<float16, CUDAContext>(
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int num_rois,
const float spatial_scale,
const float16* x,
const float* rois,
int* mask,
float16* y,
CUDAContext* ctx) {
auto nthreads = num_rois * C * out_h * out_w;
_RoiPool<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
spatial_scale,
reinterpret_cast<const half*>(x),
rois,
mask,
reinterpret_cast<half*>(y));
}
template <>
void RoiPoolGrad<float16, CUDAContext>(
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int num_rois,
const float spatial_scale,
const float16* dy,
const float* rois,
const int* mask,
float* dx,
CUDAContext* ctx) {
auto nthreads = num_rois * C * out_h * out_w;
_RoiPoolGrad<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
spatial_scale,
reinterpret_cast<const half*>(dy),
rois,
mask,
dx);
} // RoiPoolGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void RoiPool<T, CUDAContext>( \ void RoiPool<T, CUDAContext>( \
const int C, \ const int C, \
...@@ -319,10 +133,20 @@ void RoiPoolGrad<float16, CUDAContext>( ...@@ -319,10 +133,20 @@ void RoiPoolGrad<float16, CUDAContext>(
CUDAContext* ctx) { \ CUDAContext* ctx) { \
auto nthreads = num_rois * C * out_h * out_w; \ auto nthreads = num_rois * C * out_h * out_w; \
_RoiPool<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _RoiPool<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, C, H, W, out_h, out_w, spatial_scale, x, rois, mask, y); \ nthreads, \
} C, \
H, \
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \ W, \
out_h, \
out_w, \
spatial_scale, \
reinterpret_cast<const ScalarT*>(x), \
rois, \
mask, \
reinterpret_cast<ScalarT*>(y)); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T, ScalarT) \
template <> \ template <> \
void RoiPoolGrad<T, CUDAContext>( \ void RoiPoolGrad<T, CUDAContext>( \
const int C, \ const int C, \
...@@ -343,15 +167,25 @@ void RoiPoolGrad<float16, CUDAContext>( ...@@ -343,15 +167,25 @@ void RoiPoolGrad<float16, CUDAContext>(
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>( \ ctx->cuda_stream()>>>( \
nthreads, C, H, W, out_h, out_w, spatial_scale, dy, rois, mask, dx); \ nthreads, \
} C, \
H, \
DEFINE_KERNEL_LAUNCHER(float); W, \
DEFINE_KERNEL_LAUNCHER(double); out_h, \
out_w, \
DEFINE_GRAD_KERNEL_LAUNCHER(float); spatial_scale, \
DEFINE_GRAD_KERNEL_LAUNCHER(double); reinterpret_cast<const ScalarT*>(dy), \
rois, \
mask, \
dx); \
}
DEFINE_KERNEL_LAUNCHER(float16, half);
DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16, half);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(double, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
#include "dragon/operators/array/reshape_ops.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
template <class Context>
void IdentityOp<Context>::RunOnDevice() {
auto &X = Input(0), *Y = Output(0, {0});
// Store for the gradient calculation
STORE_INPUT_SPEC(0);
// Maybe copy the contents
Y->ReshapeLike(X)->CopyFrom(X, ctx());
}
DEPLOY_CPU_OPERATOR(Identity);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Identity);
#endif
OPERATOR_SCHEMA(Identity)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(IdentityGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1)
/* dY => dX */
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Identity, SimpleGradientMaker);
} // namespace dragon
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!