Commit 9ca4b60f by Ting PAN

Add HardSigmoid && HardSwish && Swish Operator

Summary:
This commit adds the hardsigmoid, hardswish and swish op with specialized kernel,
there are widely used in MobileNetV3 and EfficientNet.
1 parent f76c693e
Showing with 2165 additions and 374 deletions
...@@ -18,7 +18,7 @@ from dragon.core.util import tls ...@@ -18,7 +18,7 @@ from dragon.core.util import tls
def device(device_type, device_index=0): def device(device_type, device_index=0):
"""Context-manager to nest the the device spec. """Context-manager to nest the device spec.
Examples: Examples:
......
...@@ -16,7 +16,7 @@ vm.dali ...@@ -16,7 +16,7 @@ vm.dali
######### #########
`device(...) <dali/device.html>`_ `device(...) <dali/device.html>`_
: Context-manager to nest the the device spec. : Context-manager to nest the device spec.
`get_device_type(...) <dali/get_device_type.html>`_ `get_device_type(...) <dali/get_device_type.html>`_
: Return the current nesting device type. : Return the current nesting device type.
......
...@@ -55,7 +55,7 @@ dragon ...@@ -55,7 +55,7 @@ dragon
: Create a callable graph from the specified outputs. : Create a callable graph from the specified outputs.
`device(...) <dragon/device.html>`_ `device(...) <dragon/device.html>`_
: Context-manager to nest the the device spec. : Context-manager to nest the device spec.
`eager_mode(...) <dragon/eager_mode.html>`_ `eager_mode(...) <dragon/eager_mode.html>`_
: Context-manager set the eager execution mode. : Context-manager set the eager execution mode.
......
...@@ -63,6 +63,13 @@ dragon.nn ...@@ -63,6 +63,13 @@ dragon.nn
: Apply the group normalization. : Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_. `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`hardsigmoid(...) <nn/hardsigmoid.html>`_
: Apply the hard sigmoid function.
`hardswish(...) <nn/hardswish.html>`_
: Apply the hard swish function.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
`instance_norm(...) <nn/instance_norm.html>`_ `instance_norm(...) <nn/instance_norm.html>`_
: Apply the instance normalization. : Apply the instance normalization.
`[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_ `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
...@@ -79,7 +86,7 @@ dragon.nn ...@@ -79,7 +86,7 @@ dragon.nn
`[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_. `[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_.
`log_softmax(...) <nn/log_softmax.html>`_ `log_softmax(...) <nn/log_softmax.html>`_
: Apply the composite of logarithm and softmax. : Compute the composite of logarithm and softmax.
`prelu(...) <nn/prelu.html>`_ `prelu(...) <nn/prelu.html>`_
: Apply the parametric rectified linear unit. : Apply the parametric rectified linear unit.
...@@ -101,10 +108,14 @@ dragon.nn ...@@ -101,10 +108,14 @@ dragon.nn
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_. `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`softmax(...) <nn/softmax.html>`_ `softmax(...) <nn/softmax.html>`_
: Apply the softmax function. : Compute the softmax result.
`space_to_depth(...) <nn/space_to_depth.html>`_ `space_to_depth(...) <nn/space_to_depth.html>`_
: Rearrange blocks of spatial data into depth. : Rearrange blocks of spatial data into depth.
`swish(...) <nn/swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`sync_batch_norm(...) <nn/sync_batch_norm.html>`_ `sync_batch_norm(...) <nn/sync_batch_norm.html>`_
: Apply the batch normalization with synced statistics. : Apply the batch normalization with synced statistics.
...@@ -128,6 +139,8 @@ dragon.nn ...@@ -128,6 +139,8 @@ dragon.nn
nn/elu nn/elu
nn/fully_connected nn/fully_connected
nn/group_norm nn/group_norm
nn/hardsigmoid
nn/hardswish
nn/instance_norm nn/instance_norm
nn/layer_norm nn/layer_norm
nn/leaky_relu nn/leaky_relu
...@@ -140,6 +153,7 @@ dragon.nn ...@@ -140,6 +153,7 @@ dragon.nn
nn/selu nn/selu
nn/softmax nn/softmax
nn/space_to_depth nn/space_to_depth
nn/swish
nn/sync_batch_norm nn/sync_batch_norm
.. raw:: html .. raw:: html
......
hardsigmoid
===========
.. autofunction:: dragon.nn.hardsigmoid
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
hardswish
=========
.. autofunction:: dragon.nn.hardswish
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
swish
=====
.. autofunction:: dragon.nn.swish
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
...@@ -63,7 +63,7 @@ Name Supported Reference ...@@ -63,7 +63,7 @@ Name Supported Reference
`GlobalLpPool`_ `GlobalLpPool`_
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d` `GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d`
`Greater`_ |v| :func:`dragon.math.greater` `Greater`_ |v| :func:`dragon.math.greater`
`HardSigmoid`_ `HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid`
`Hardmax`_ `Hardmax`_
`Identity`_ `Identity`_
`If`_ `If`_
......
...@@ -37,7 +37,7 @@ vm.tensorflow ...@@ -37,7 +37,7 @@ vm.tensorflow
: Return a tensor initialized from the value. : Return a tensor initialized from the value.
`device(...) <tensorflow/device.html>`_ `device(...) <tensorflow/device.html>`_
: Context-manager to nest the the device spec. : Context-manager to nest the device spec.
`expand_dims(...) <tensorflow/expand_dims.html>`_ `expand_dims(...) <tensorflow/expand_dims.html>`_
: Expand the dimensions of input with size 1. : Expand the dimensions of input with size 1.
......
...@@ -16,6 +16,9 @@ activations ...@@ -16,6 +16,9 @@ activations
`get(...) <activations/get.html>`_ `get(...) <activations/get.html>`_
: Return the activation callable by identifier. : Return the activation callable by identifier.
`hard_sigmoid(...) <activations/hard_sigmoid.html>`_
: Apply the hard sigmoid function to input.
`linear(...) <activations/linear.html>`_ `linear(...) <activations/linear.html>`_
: Apply the linear activation to input. : Apply the linear activation to input.
...@@ -33,6 +36,10 @@ activations ...@@ -33,6 +36,10 @@ activations
`softmax(...) <activations/softmax.html>`_ `softmax(...) <activations/softmax.html>`_
: Apply the softmax function to input. : Apply the softmax function to input.
`swish(...) <activations/swish.html>`_
: Apply the swish function to input.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`tanh(...) <activations/tanh.html>`_ `tanh(...) <activations/tanh.html>`_
: Apply the tanh function to input. : Apply the tanh function to input.
...@@ -42,11 +49,13 @@ activations ...@@ -42,11 +49,13 @@ activations
activations/elu activations/elu
activations/exponential activations/exponential
activations/get activations/get
activations/hard_sigmoid
activations/linear activations/linear
activations/relu activations/relu
activations/selu activations/selu
activations/sigmoid activations/sigmoid
activations/softmax activations/softmax
activations/swish
activations/tanh activations/tanh
.. raw:: html .. raw:: html
......
hard_sigmoid
============
.. autofunction:: dragon.vm.tensorflow.keras.activations.hard_sigmoid
.. raw:: html
<style>
h1:before {
content: "tf.keras.activations.";
color: #103d3e;
}
</style>
swish
=====
.. autofunction:: dragon.vm.tensorflow.keras.activations.swish
.. raw:: html
<style>
h1:before {
content: "tf.keras.activations.";
color: #103d3e;
}
</style>
...@@ -86,6 +86,10 @@ vm.tensorflow.nn ...@@ -86,6 +86,10 @@ vm.tensorflow.nn
`sparse_softmax_cross_entropy_with_logits(...) <nn/sparse_softmax_cross_entropy_with_logits.html>`_ `sparse_softmax_cross_entropy_with_logits(...) <nn/sparse_softmax_cross_entropy_with_logits.html>`_
: Compute the softmax cross entropy with sparse labels. : Compute the softmax cross entropy with sparse labels.
`swish(...) <nn/swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -113,6 +117,7 @@ vm.tensorflow.nn ...@@ -113,6 +117,7 @@ vm.tensorflow.nn
nn/softmax_cross_entropy_with_logits nn/softmax_cross_entropy_with_logits
nn/space_to_depth nn/space_to_depth
nn/sparse_softmax_cross_entropy_with_logits nn/sparse_softmax_cross_entropy_with_logits
nn/swish
.. raw:: html .. raw:: html
......
swish
=====
.. autofunction:: dragon.vm.tensorflow.nn.swish
.. raw:: html
<style>
h1:before {
content: "tf.nn.";
color: #103d3e;
}
</style>
...@@ -109,6 +109,12 @@ vm.torch ...@@ -109,6 +109,12 @@ vm.torch
`from_numpy(...) <torch/from_numpy.html>`_ `from_numpy(...) <torch/from_numpy.html>`_
: Create a tensor from the given numpy array. : Create a tensor from the given numpy array.
`full(...) <torch/full.html>`_
: Return a tensor filled with a scalar.
`full_like(...) <torch/full_like.html>`_
: Return a tensor filled with a scalar with size as input.
`ge(...) <torch/ge.html>`_ `ge(...) <torch/ge.html>`_
: Compute the element-wise greater-equal comparison. : Compute the element-wise greater-equal comparison.
...@@ -172,6 +178,9 @@ vm.torch ...@@ -172,6 +178,9 @@ vm.torch
`ne(...) <torch/ne.html>`_ `ne(...) <torch/ne.html>`_
: Compute the element-wise not-equal comparison. : Compute the element-wise not-equal comparison.
`neg(...) <torch/neg.html>`_
: Compute the element-wise negative.
`nonzero(...) <torch/nonzero.html>`_ `nonzero(...) <torch/nonzero.html>`_
: Return the index of non-zero elements. : Return the index of non-zero elements.
...@@ -299,6 +308,8 @@ vm.torch ...@@ -299,6 +308,8 @@ vm.torch
torch/flatten torch/flatten
torch/floor torch/floor
torch/from_numpy torch/from_numpy
torch/full
torch/full_like
torch/ge torch/ge
torch/gt torch/gt
torch/index_select torch/index_select
...@@ -320,6 +331,7 @@ vm.torch ...@@ -320,6 +331,7 @@ vm.torch
torch/multinomial torch/multinomial
torch/narrow torch/narrow
torch/ne torch/ne
torch/neg
torch/no_grad torch/no_grad
torch/nonzero torch/nonzero
torch/ones torch/ones
......
...@@ -314,9 +314,33 @@ ndimension ...@@ -314,9 +314,33 @@ ndimension
.. automethod:: dragon.vm.torch.Tensor.ndimension .. automethod:: dragon.vm.torch.Tensor.ndimension
ne ne
### ##
.. automethod:: dragon.vm.torch.Tensor.ne .. automethod:: dragon.vm.torch.Tensor.ne
neg
###
.. automethod:: dragon.vm.torch.Tensor.neg
neg\_
#####
.. automethod:: dragon.vm.torch.Tensor.neg_
new_ones
########
.. automethod:: dragon.vm.torch.Tensor.new_ones
new_empty
#########
.. automethod:: dragon.vm.torch.Tensor.new_empty
new_full
########
.. automethod:: dragon.vm.torch.Tensor.new_full
new_zeros
#########
.. automethod:: dragon.vm.torch.Tensor.new_zeros
nonzero nonzero
####### #######
.. automethod:: dragon.vm.torch.Tensor.nonzero .. automethod:: dragon.vm.torch.Tensor.nonzero
...@@ -497,10 +521,12 @@ zero\_ ...@@ -497,10 +521,12 @@ zero\_
.. _torch.cos(...): cos.html .. _torch.cos(...): cos.html
.. _torch.cumsum(...): cumsum.html .. _torch.cumsum(...): cumsum.html
.. _torch.div(...): div.html .. _torch.div(...): div.html
.. _torch.empty(...): empty.html
.. _torch.eq(...): eq.html .. _torch.eq(...): eq.html
.. _torch.exp(...): exp.html .. _torch.exp(...): exp.html
.. _torch.flatten(...): flatten.html .. _torch.flatten(...): flatten.html
.. _torch.floor(...): floor.html .. _torch.floor(...): floor.html
.. _torch.full(...): full.html
.. _torch.ge(...): ge.html .. _torch.ge(...): ge.html
.. _torch.gt(...): gt.html .. _torch.gt(...): gt.html
.. _torch.le(...): le.html .. _torch.le(...): le.html
...@@ -509,6 +535,7 @@ zero\_ ...@@ -509,6 +535,7 @@ zero\_
.. _torch.ne(...): ne.html .. _torch.ne(...): ne.html
.. _torch.neg(...): neg.html .. _torch.neg(...): neg.html
.. _torch.nonzero(...): nonzero.html .. _torch.nonzero(...): nonzero.html
.. _torch.ones(...): ones.html
.. _torch.pow(...): pow.html .. _torch.pow(...): pow.html
.. _torch.reciprocal(...): reciprocal.html .. _torch.reciprocal(...): reciprocal.html
.. _torch.reshape(...): reshape.html .. _torch.reshape(...): reshape.html
...@@ -526,6 +553,7 @@ zero\_ ...@@ -526,6 +553,7 @@ zero\_
.. _torch.unique(...): unique.html .. _torch.unique(...): unique.html
.. _torch.unsqueeze(...): unsqueeze.html .. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html .. _torch.where(...): where.html
.. _torch.zeros(...): zeros.html
.. raw:: html .. raw:: html
......
full
====
.. autofunction:: dragon.vm.torch.full
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
full_like
=========
.. autofunction:: dragon.vm.torch.full_like
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
neg
===
.. autofunction:: dragon.vm.torch.neg
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
...@@ -84,6 +84,13 @@ vm.torch.nn ...@@ -84,6 +84,13 @@ vm.torch.nn
: Apply the gumbel softmax with a temperature. : Apply the gumbel softmax with a temperature.
`[Jang et.al, 2016] <https://arxiv.org/abs/1611.01144>`_. `[Jang et.al, 2016] <https://arxiv.org/abs/1611.01144>`_.
`class Hardsigmoid <nn/Hardsigmoid.html>`_
: Apply the hard sigmoid function.
`class Hardswish <nn/Hardswish.html>`_
: Apply the hard swish function.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
`class KLDivLoss <nn/KLDivLoss.html>`_ `class KLDivLoss <nn/KLDivLoss.html>`_
: Compute the Kullback-Leibler divergence. : Compute the Kullback-Leibler divergence.
...@@ -178,6 +185,10 @@ vm.torch.nn ...@@ -178,6 +185,10 @@ vm.torch.nn
`class Softmax <nn/Softmax.html>`_ `class Softmax <nn/Softmax.html>`_
: Apply the softmax function. : Apply the softmax function.
`class Swish <nn/Swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`class Tanh <nn/Tanh.html>`_ `class Tanh <nn/Tanh.html>`_
: Apply the tanh function. : Apply the tanh function.
...@@ -222,6 +233,8 @@ vm.torch.nn ...@@ -222,6 +233,8 @@ vm.torch.nn
nn/GroupNorm nn/GroupNorm
nn/GRU nn/GRU
nn/GumbelSoftmax nn/GumbelSoftmax
nn/Hardsigmoid
nn/Hardswish
nn/KLDivLoss nn/KLDivLoss
nn/L1Loss nn/L1Loss
nn/LeakyReLU nn/LeakyReLU
...@@ -250,6 +263,7 @@ vm.torch.nn ...@@ -250,6 +263,7 @@ vm.torch.nn
nn/SigmoidFocalLoss nn/SigmoidFocalLoss
nn/SmoothL1Loss nn/SmoothL1Loss
nn/Softmax nn/Softmax
nn/Swish
nn/Tanh nn/Tanh
nn/SyncBatchNorm nn/SyncBatchNorm
nn/Upsample nn/Upsample
......
Hardsigmoid
===========
.. autoclass:: dragon.vm.torch.nn.Hardsigmoid
__init__
--------
.. automethod:: dragon.vm.torch.nn.Hardsigmoid.__init__
.. _torch.nn.functional.hardsigmoid(...): functional/hardsigmoid.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
Hardswish
=========
.. autoclass:: dragon.vm.torch.nn.Hardswish
__init__
--------
.. automethod:: dragon.vm.torch.nn.Hardswish.__init__
.. _torch.nn.functional.hardswish(...): functional/hardswish.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
Swish
=====
.. autoclass:: dragon.vm.torch.nn.Swish
__init__
--------
.. automethod:: dragon.vm.torch.nn.Swish.__init__
.. _torch.nn.functional.swish(...): functional/swish.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
...@@ -53,6 +53,13 @@ vm.torch.nn.functional ...@@ -53,6 +53,13 @@ vm.torch.nn.functional
: Apply the group normalization to input. : Apply the group normalization to input.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_. `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`hardsigmoid(...) <functional/hardsigmoid.html>`_
: Apply the hard sigmoid function to input.
`hardswish(...) <functional/hardswish.html>`_
: Apply the hard swish function to input.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
`kl_div(...) <functional/kl_div.html>`_ `kl_div(...) <functional/kl_div.html>`_
: Compute the Kullback-Leibler divergence. : Compute the Kullback-Leibler divergence.
...@@ -113,6 +120,10 @@ vm.torch.nn.functional ...@@ -113,6 +120,10 @@ vm.torch.nn.functional
`softmax(...) <functional/softmax.html>`_ `softmax(...) <functional/softmax.html>`_
: Apply the softmax function to input. : Apply the softmax function to input.
`swish(...) <functional/swish.html>`_
: Apply the swish function to input.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`sync_batch_norm(...) <functional/sync_batch_norm.html>`_ `sync_batch_norm(...) <functional/sync_batch_norm.html>`_
: Apply the sync batch normalization to input. : Apply the sync batch normalization to input.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_. `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
...@@ -145,6 +156,8 @@ vm.torch.nn.functional ...@@ -145,6 +156,8 @@ vm.torch.nn.functional
functional/dropout functional/dropout
functional/elu functional/elu
functional/group_norm functional/group_norm
functional/hardsigmoid
functional/hardswish
functional/kl_div functional/kl_div
functional/l1_loss functional/l1_loss
functional/leaky_relu functional/leaky_relu
...@@ -165,6 +178,7 @@ vm.torch.nn.functional ...@@ -165,6 +178,7 @@ vm.torch.nn.functional
functional/sigmoid_focal_loss functional/sigmoid_focal_loss
functional/smooth_l1_loss functional/smooth_l1_loss
functional/softmax functional/softmax
functional/swish
functional/sync_batch_norm functional/sync_batch_norm
functional/tanh functional/tanh
functional/upsample functional/upsample
......
hardsigmoid
===========
.. autofunction:: dragon.vm.torch.nn.functional.hardsigmoid
.. _torch.nn.Hardsigmoid(...): ../Hardsigmoid.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
hardswish
=========
.. autofunction:: dragon.vm.torch.nn.functional.hardswish
.. _torch.nn.Hardswish(...): ../Hardswish.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
swish
=====
.. autofunction:: dragon.vm.torch.nn.functional.swish
.. _torch.nn.Swish(...): ../Swish.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
...@@ -25,9 +25,12 @@ template <> ...@@ -25,9 +25,12 @@ template <>
__global__ void __global__ void
_Elu<half>(const int nthreads, const float alpha, const half* x, half* y) { _Elu<half>(const int nthreads, const float alpha, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 350
const float val = __half2float(__ldg(x + i)); const float val = __half2float(__ldg(x + i));
y[i] = val > 0.f ? __ldg(x + i) : __float2half(alpha * (exp(val) - 1.f)); y[i] = val > 0.f ? __ldg(x + i) : __float2half(alpha * (exp(val) - 1.f));
#else
const float val = __half2float(x[i]);
y[i] = val > 0.f ? x[i] : __float2half(alpha * (exp(val) - 1.f));
#endif #endif
} }
} }
...@@ -36,12 +39,10 @@ template <> ...@@ -36,12 +39,10 @@ template <>
__global__ void __global__ void
_Elu<half2>(const int nthreads, const float alpha, const half2* x, half2* y) { _Elu<half2>(const int nthreads, const float alpha, const half2* x, half2* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float2 val = __half22float2(x[i]); const float2 val = __half22float2(x[i]);
y[i] = __floats2half2_rn( y[i] = __floats2half2_rn(
val.x > 0.f ? val.x : alpha * (exp(val.x) - 1.f), val.x > 0.f ? val.x : alpha * (exp(val.x) - 1.f),
val.y > 0.f ? val.y : alpha * (exp(val.y) - 1.f)); val.y > 0.f ? val.y : alpha * (exp(val.y) - 1.f));
#endif
} }
} }
...@@ -69,10 +70,9 @@ __global__ void _EluGrad<half>( ...@@ -69,10 +70,9 @@ __global__ void _EluGrad<half>(
const half* y, const half* y,
half* dx) { half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float val = __half2float(y[i]); const float val = __half2float(y[i]);
dx[i] = __hmul(dy[i], __float2half(val > 0.f ? 1.f : (alpha + val))); dx[i] =
#endif __float2half(__half2float(dy[i]) * (val > 0.f ? 1.f : (alpha + val)));
} }
} // EluGrad } // EluGrad
...@@ -84,14 +84,11 @@ __global__ void _EluGrad<half2>( ...@@ -84,14 +84,11 @@ __global__ void _EluGrad<half2>(
const half2* y, const half2* y,
half2* dx) { half2* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float2 val = __half22float2(y[i]); const float2 val = __half22float2(y[i]);
dx[i] = __hmul2( const float2 grad = __half22float2(dy[i]);
dy[i], dx[i] = __floats2half2_rn(
__floats2half2_rn( grad.x * (val.x > 0.f ? 1.f : (alpha + val.x)),
val.x > 0.f ? 1.f : (alpha + val.x), grad.y * (val.y > 0.f ? 1.f : (alpha + val.y)));
val.y > 0.f ? 1.f : (alpha + val.y)));
#endif
} }
} // EluGrad } // EluGrad
......
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _HardSigmoid(
const int count,
const T alpha,
const T beta,
const T* x,
T* y) {
EigenVectorArrayMap<T>(y, count) =
(ConstEigenVectorArrayMap<T>(x, count) * alpha + beta)
.cwiseMin(T(1))
.cwiseMax(T(0));
}
template <>
void _HardSigmoid<float16>(
const int count,
const float16 alpha,
const float16 beta,
const float16* x,
float16* y) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename T>
void _HardSigmoidGrad(
const int count,
const T alpha,
const T* dy,
const T* y,
T* dx) {
ConstEigenVectorArrayMap<T> Y(y, count);
EigenVectorArrayMap<T>(dx, count) =
(Y > T(0) && Y < T(1))
.select(ConstEigenVectorArrayMap<T>(dy, count) * alpha, T(0));
}
template <>
void _HardSigmoidGrad<float16>(
const int count,
const float16 alpha,
const float16* dy,
const float16* y,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void HardSigmoid<T, CPUContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_HardSigmoid(count, cast::to<T>(alpha), cast::to<T>(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void HardSigmoidGrad<T, CPUContext>( \
const int count, \
const float alpha, \
const T* dy, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
_HardSigmoidGrad(count, cast::to<T>(alpha), dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void _HardSigmoid(
const int nthreads,
const T alpha,
const T beta,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = max(T(0), min(T(1), fma(x[i], alpha, beta)));
}
}
__global__ void _HardSigmoid(
const int nthreads,
const float alpha,
const float beta,
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] =
__float2half(max(0.f, min(1.f, fma(__half2float(x[i]), alpha, beta))));
}
}
template <typename T>
__global__ void _HardSigmoidGrad(
const int nthreads,
const float alpha,
const T* dy,
const T* y,
T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = (y[i] > T(0) && y[i] < T(1)) ? dy[i] * alpha : T(0);
}
}
template <>
__global__ void _HardSigmoidGrad<half>(
const int nthreads,
const float alpha,
const half* dy,
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float val = __half2float(y[i]);
dx[i] = __half2float(
(val > 0.f && val < 1.f) ? __half2float(dy[i]) * alpha : 0.f);
}
} // HardSigmoidGrad
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void HardSigmoid<float16, CUDAContext>(
const int count,
const float alpha,
const float beta,
const float16* x,
float16* y,
CUDAContext* ctx) {
_HardSigmoid<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
alpha,
beta,
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
template <>
void HardSigmoidGrad<float16, CUDAContext>(
const int count,
const float alpha,
const float16* dy,
const float16* y,
float16* dx,
CUDAContext* ctx) {
_HardSigmoidGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
alpha,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y),
reinterpret_cast<half*>(dx));
} // HardSigmoidGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void HardSigmoid<T, CUDAContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_HardSigmoid<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, T(alpha), T(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void HardSigmoidGrad<T, CUDAContext>( \
const int count, \
const float alpha, \
const T* dy, \
const T* y, \
T* dx, \
CUDAContext* ctx) { \
_HardSigmoidGrad<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(count, alpha, dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _HardSwish(
const int count,
const T alpha,
const T beta,
const T* x,
T* y) {
ConstEigenVectorArrayMap<T> X(x, count);
EigenVectorArrayMap<T>(y, count) =
X * ((X * alpha + beta).cwiseMin(T(1)).cwiseMax(T(0)));
}
template <>
void _HardSwish<float16>(
const int count,
const float16 alpha,
const float16 beta,
const float16* x,
float16* y) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename T>
void _HardSwishGrad(
const int count,
const T alpha,
const T beta,
const T* dy,
const T* x,
T* dx) {
const auto bound = beta / alpha;
const auto alpha2x = alpha * T(2);
EigenVectorArrayMap<T>(dx, count) = ConstEigenVectorArrayMap<T>(dy, count) *
ConstEigenVectorArrayMap<T>(x, count).unaryExpr([&](T a) {
return (a < -bound) ? T(0) : (a < bound ? a * alpha2x + beta : T(1));
});
}
template <>
void _HardSwishGrad<float16>(
const int count,
const float16 alpha,
const float16 beta,
const float16* dy,
const float16* x,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void HardSwish<T, CPUContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_HardSwish(count, cast::to<T>(alpha), cast::to<T>(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void HardSwishGrad<T, CPUContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* dy, \
const T* x, \
T* dx, \
CPUContext* ctx) { \
_HardSwishGrad(count, cast::to<T>(alpha), cast::to<T>(beta), dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void
_HardSwish(const int nthreads, const T alpha, const T beta, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __ldg(x + i) * max(T(0), min(T(1), fma(__ldg(x + i), alpha, beta)));
#else
y[i] = x[i] * max(T(0), min(T(1), fma(x[i], alpha, beta)));
#endif
}
}
__global__ void _HardSwish(
const int nthreads,
const float alpha,
const float beta,
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __float2half(
__half2float(__ldg(x + i)) *
max(0.f, min(1.f, fma(__half2float(__ldg(x + i)), alpha, beta))));
#else
y[i] = __float2half(
__half2float(x[i]) *
max(0.f, min(1.f, fma(__half2float(x[i]), alpha, beta))));
#endif
}
}
template <typename T>
__global__ void _HardSwishGrad(
const int nthreads,
const T alpha,
const T beta,
const T* dy,
const T* x,
T* dx) {
const T bound = beta / alpha;
const T alpha2x = alpha * T(2);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
dx[i] = (__ldg(x + i) < -bound)
? T(0)
: (__ldg(x + i) < bound) ? dy[i] * fma(__ldg(x + i), alpha2x, beta)
: dy[i];
#else
dx[i] = (x[i] < -bound)
? T(0)
: (x[i] < bound) ? dy[i] * fma(x[i], alpha2x, beta) : dy[i];
#endif
}
}
__global__ void _HardSwishGrad(
const int nthreads,
const float alpha,
const float beta,
const half* dy,
const half* x,
half* dx) {
const float bound = beta / alpha;
const float alpha2x = alpha * 2.f;
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float val = __half2float(x[i]);
dx[i] = (val < -bound) ? kZero
: (val < bound)
? __float2half(__half2float(dy[i]) * fma(val, alpha2x, beta))
: dy[i];
}
} // HardSwishGrad
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void HardSwish<float16, CUDAContext>(
const int count,
const float alpha,
const float beta,
const float16* x,
float16* y,
CUDAContext* ctx) {
_HardSwish<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
alpha,
beta,
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
template <>
void HardSwishGrad<float16, CUDAContext>(
const int count,
const float alpha,
const float beta,
const float16* dy,
const float16* x,
float16* dx,
CUDAContext* ctx) {
_HardSwishGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
alpha,
beta,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(dx));
} // HardSwishGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void HardSwish<T, CUDAContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_HardSwish<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, T(alpha), T(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void HardSwishGrad<T, CUDAContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* dy, \
const T* x, \
T* dx, \
CUDAContext* ctx) { \
_HardSwishGrad<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(count, T(alpha), T(beta), dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
...@@ -10,9 +10,9 @@ namespace { ...@@ -10,9 +10,9 @@ namespace {
template <typename T> template <typename T>
void _Relu(const int count, const T alpha, const T* x, T* y) { void _Relu(const int count, const T alpha, const T* x, T* y) {
ConstEigenVectorArrayMap<T> X(x, count);
EigenVectorArrayMap<T>(y, count) = EigenVectorArrayMap<T>(y, count) =
ConstEigenVectorArrayMap<T>(x, count).unaryExpr( X.cwiseMax(T(0)) + X.cwiseMin(T(0)) * alpha;
[&](T a) { return a > T(0) ? a : alpha * a; });
} }
template <> template <>
......
...@@ -10,8 +10,8 @@ namespace { ...@@ -10,8 +10,8 @@ namespace {
template <typename T> template <typename T>
void _Softmax( void _Softmax(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
T* y) { T* y) {
int row_offset, col_offset, yi; int row_offset, col_offset, yi;
...@@ -45,8 +45,8 @@ void _Softmax( ...@@ -45,8 +45,8 @@ void _Softmax(
template <> template <>
void _Softmax<float16>( void _Softmax<float16>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* x, const float16* x,
float16* y) { float16* y) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
...@@ -55,8 +55,8 @@ void _Softmax<float16>( ...@@ -55,8 +55,8 @@ void _Softmax<float16>(
template <typename T> template <typename T>
void _SoftmaxGrad( void _SoftmaxGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* dy, const T* dy,
const T* y, const T* y,
T* dx) { T* dx) {
...@@ -82,8 +82,8 @@ void _SoftmaxGrad( ...@@ -82,8 +82,8 @@ void _SoftmaxGrad(
template <> template <>
void _SoftmaxGrad<float16>( void _SoftmaxGrad<float16>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* dy, const float16* dy,
const float16* y, const float16* y,
float16* dx) { float16* dx) {
...@@ -98,25 +98,25 @@ void _SoftmaxGrad<float16>( ...@@ -98,25 +98,25 @@ void _SoftmaxGrad<float16>(
template <> \ template <> \
void Softmax<T, CPUContext>( \ void Softmax<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_Softmax(outer_dim, axis_dim, inner_dim, x, y); \ _Softmax(outer_dim, inner_dim, axis_dim, x, y); \
} }
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void SoftmaxGrad<T, CPUContext>( \ void SoftmaxGrad<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* dy, \ const T* dy, \
const T* y, \ const T* y, \
T* dx, \ T* dx, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_SoftmaxGrad(outer_dim, axis_dim, inner_dim, dy, y, dx); \ _SoftmaxGrad(outer_dim, inner_dim, axis_dim, dy, y, dx); \
} }
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
......
...@@ -185,8 +185,8 @@ __global__ void _SoftmaxGrad<half>( ...@@ -185,8 +185,8 @@ __global__ void _SoftmaxGrad<half>(
template <> template <>
void Softmax<float16, CUDAContext>( void Softmax<float16, CUDAContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* x, const float16* x,
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
...@@ -203,8 +203,8 @@ void Softmax<float16, CUDAContext>( ...@@ -203,8 +203,8 @@ void Softmax<float16, CUDAContext>(
template <> template <>
void SoftmaxGrad<float16, CUDAContext>( void SoftmaxGrad<float16, CUDAContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* dy, const float16* dy,
const float16* y, const float16* y,
float16* dx, float16* dx,
...@@ -223,8 +223,8 @@ void SoftmaxGrad<float16, CUDAContext>( ...@@ -223,8 +223,8 @@ void SoftmaxGrad<float16, CUDAContext>(
template <> \ template <> \
void Softmax<T, CUDAContext>( \ void Softmax<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
...@@ -237,8 +237,8 @@ void SoftmaxGrad<float16, CUDAContext>( ...@@ -237,8 +237,8 @@ void SoftmaxGrad<float16, CUDAContext>(
template <> \ template <> \
void SoftmaxGrad<T, CUDAContext>( \ void SoftmaxGrad<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* dy, \ const T* dy, \
const T* y, \ const T* y, \
T* dx, \ T* dx, \
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _Swish(const int count, const T* x, T* y) {
ConstEigenVectorArrayMap<T> X(x, count);
EigenVectorArrayMap<T>(y, count) = X / (T(1) + (-X).exp());
}
template <>
void _Swish<float16>(const int count, const float16* x, float16* y) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename T>
void _SwishGrad(const int count, const T* dy, const T* x, const T* y, T* dx) {
ConstEigenVectorArrayMap<T> X(x, count);
ConstEigenVectorArrayMap<T> Y(y, count);
EigenVectorArrayMap<T>(dx, count) = ConstEigenVectorArrayMap<T>(dy, count) *
(Y + (T(1) / (T(1) + (-X).exp())) * (T(1) - Y));
}
template <>
void _SwishGrad<float16>(
const int count,
const float16* dy,
const float16* x,
const float16* y,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Swish<T, CPUContext>( \
const int count, const T* x, T* y, CPUContext* ctx) { \
_Swish(count, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void SwishGrad<T, CPUContext>( \
const int count, \
const T* dy, \
const T* x, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
_SwishGrad(count, dy, x, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void _Swish(const int nthreads, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __ldg(x + i) / (T(1) + exp(-__ldg(x + i)));
#else
y[i] = x[i] / (T(1) + exp(-x[i]));
#endif
}
}
template <>
__global__ void _Swish<half>(const int nthreads, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __float2half(
__half2float(__ldg(x + i)) / (1.f + exp(-__half2float(__ldg(x + i)))));
#else
y[i] = __float2half(__half2float(x[i]) / (1.f + exp(-__half2float(x[i]))));
#endif
}
}
template <typename T>
__global__ void
_SwishGrad(const int nthreads, const T* dy, const T* x, const T* y, T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
dx[i] =
dy[i] * (__ldg(y + i) + (T(1) - __ldg(y + i)) / (T(1) + exp(-x[i])));
#else
dx[i] = dy[i] * (y[i] + (T(1) - y[i]) / (T(1) + exp(-x[i])));
#endif
}
}
template <>
__global__ void _SwishGrad<half>(
const int nthreads,
const half* dy,
const half* x,
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(__ldg(y + i)) +
(1.f - __half2float(__ldg(y + i))) /
(1.f + exp(-__half2float(x[i])))));
#else
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(y[i]) +
(1.f - __half2float(y[i])) / (1.f + exp(-__half2float(x[i])))));
#endif
}
} // SwishGrad
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void Swish<float16, CUDAContext>(
const int count,
const float16* x,
float16* y,
CUDAContext* ctx) {
_Swish<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, reinterpret_cast<const half*>(x), reinterpret_cast<half*>(y));
}
template <>
void SwishGrad<float16, CUDAContext>(
const int count,
const float16* dy,
const float16* x,
const float16* y,
float16* dx,
CUDAContext* ctx) {
_SwishGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(y),
reinterpret_cast<half*>(dx));
} // SwishGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Swish<T, CUDAContext>( \
const int count, const T* x, T* y, CUDAContext* ctx) { \
_Swish<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void SwishGrad<T, CUDAContext>( \
const int count, \
const T* dy, \
const T* x, \
const T* y, \
T* dx, \
CUDAContext* ctx) { \
_SwishGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, dy, x, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
...@@ -31,8 +31,8 @@ void _ChannelAffine( ...@@ -31,8 +31,8 @@ void _ChannelAffine(
template <typename T> template <typename T>
void _ChannelAffine( void _ChannelAffine(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
const T* w, const T* w,
const T* b, const T* b,
...@@ -59,8 +59,8 @@ void _ChannelAffine( ...@@ -59,8 +59,8 @@ void _ChannelAffine(
template <> template <>
void ChannelAffine<float16, CPUContext>( void ChannelAffine<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* x, const float16* x,
const float16* w, const float16* w,
const float16* b, const float16* b,
...@@ -73,8 +73,8 @@ void ChannelAffine<float16, CPUContext>( ...@@ -73,8 +73,8 @@ void ChannelAffine<float16, CPUContext>(
template <> \ template <> \
void ChannelAffine<T, CPUContext>( \ void ChannelAffine<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* x, \ const T* x, \
const T* w, \ const T* w, \
const T* b, \ const T* b, \
...@@ -83,7 +83,7 @@ void ChannelAffine<float16, CPUContext>( ...@@ -83,7 +83,7 @@ void ChannelAffine<float16, CPUContext>(
if (inner_dim == 1) { \ if (inner_dim == 1) { \
_ChannelAffine(outer_dim, axis_dim, x, w, b, y); \ _ChannelAffine(outer_dim, axis_dim, x, w, b, y); \
} else { \ } else { \
_ChannelAffine(outer_dim, axis_dim, inner_dim, x, w, b, y); \ _ChannelAffine(outer_dim, inner_dim, axis_dim, x, w, b, y); \
} \ } \
} }
...@@ -93,6 +93,7 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -93,6 +93,7 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -12,8 +12,8 @@ namespace { ...@@ -12,8 +12,8 @@ namespace {
template <typename T> template <typename T>
__global__ void _ChannelAffine( __global__ void _ChannelAffine(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
const T* w, const T* w,
T* y) { T* y) {
...@@ -29,8 +29,8 @@ __global__ void _ChannelAffine( ...@@ -29,8 +29,8 @@ __global__ void _ChannelAffine(
template <> template <>
__global__ void _ChannelAffine<half>( __global__ void _ChannelAffine<half>(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const half* x, const half* x,
const half* w, const half* w,
half* y) { half* y) {
...@@ -51,8 +51,8 @@ __global__ void _ChannelAffine<half>( ...@@ -51,8 +51,8 @@ __global__ void _ChannelAffine<half>(
template <typename T> template <typename T>
__global__ void _ChannelAffine( __global__ void _ChannelAffine(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
const T* w, const T* w,
const T* b, const T* b,
...@@ -70,8 +70,8 @@ __global__ void _ChannelAffine( ...@@ -70,8 +70,8 @@ __global__ void _ChannelAffine(
template <> template <>
__global__ void _ChannelAffine<half>( __global__ void _ChannelAffine<half>(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const half* x, const half* x,
const half* w, const half* w,
const half* b, const half* b,
...@@ -95,8 +95,8 @@ __global__ void _ChannelAffine<half>( ...@@ -95,8 +95,8 @@ __global__ void _ChannelAffine<half>(
template <> template <>
__global__ void _ChannelAffine<float>( __global__ void _ChannelAffine<float>(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float* x, const float* x,
const float* w, const float* w,
const float* b, const float* b,
...@@ -114,8 +114,8 @@ __global__ void _ChannelAffine<float>( ...@@ -114,8 +114,8 @@ __global__ void _ChannelAffine<float>(
template <> template <>
__global__ void _ChannelAffine<double>( __global__ void _ChannelAffine<double>(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const double* x, const double* x,
const double* w, const double* w,
const double* b, const double* b,
...@@ -137,14 +137,14 @@ __global__ void _ChannelAffine<double>( ...@@ -137,14 +137,14 @@ __global__ void _ChannelAffine<double>(
template <> template <>
void ChannelAffine<float16, CUDAContext>( void ChannelAffine<float16, CUDAContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* x, const float16* x,
const float16* w, const float16* w,
const float16* b, const float16* b,
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
const int nthreads = outer_dim * axis_dim * inner_dim; const auto nthreads = outer_dim * axis_dim * inner_dim;
if (b != nullptr) { if (b != nullptr) {
_ChannelAffine<<< _ChannelAffine<<<
CUDA_BLOCKS(nthreads), CUDA_BLOCKS(nthreads),
...@@ -152,8 +152,8 @@ void ChannelAffine<float16, CUDAContext>( ...@@ -152,8 +152,8 @@ void ChannelAffine<float16, CUDAContext>(
0, 0,
ctx->cuda_stream()>>>( ctx->cuda_stream()>>>(
nthreads, nthreads,
axis_dim,
inner_dim, inner_dim,
axis_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(w), reinterpret_cast<const half*>(w),
reinterpret_cast<const half*>(b), reinterpret_cast<const half*>(b),
...@@ -165,8 +165,8 @@ void ChannelAffine<float16, CUDAContext>( ...@@ -165,8 +165,8 @@ void ChannelAffine<float16, CUDAContext>(
0, 0,
ctx->cuda_stream()>>>( ctx->cuda_stream()>>>(
nthreads, nthreads,
axis_dim,
inner_dim, inner_dim,
axis_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(w), reinterpret_cast<const half*>(w),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -177,26 +177,26 @@ void ChannelAffine<float16, CUDAContext>( ...@@ -177,26 +177,26 @@ void ChannelAffine<float16, CUDAContext>(
template <> \ template <> \
void ChannelAffine<T, CUDAContext>( \ void ChannelAffine<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* x, \ const T* x, \
const T* w, \ const T* w, \
const T* b, \ const T* b, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \ const auto nthreads = outer_dim * axis_dim * inner_dim; \
if (b != nullptr) { \ if (b != nullptr) { \
_ChannelAffine<<< \ _ChannelAffine<<< \
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(nthreads, axis_dim, inner_dim, x, w, b, y); \ ctx->cuda_stream()>>>(nthreads, inner_dim, axis_dim, x, w, b, y); \
} else { \ } else { \
_ChannelAffine<<< \ _ChannelAffine<<< \
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(nthreads, axis_dim, inner_dim, x, w, y); \ ctx->cuda_stream()>>>(nthreads, inner_dim, axis_dim, x, w, y); \
} \ } \
} }
...@@ -206,6 +206,7 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -206,6 +206,7 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -10,8 +10,8 @@ namespace { ...@@ -10,8 +10,8 @@ namespace {
template <typename T> template <typename T>
void _CumSum( void _CumSum(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const bool exclusive, const bool exclusive,
const T* x, const T* x,
T* y, T* y,
...@@ -33,8 +33,8 @@ void _CumSum( ...@@ -33,8 +33,8 @@ void _CumSum(
template <> template <>
void _CumSum<float16>( void _CumSum<float16>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const bool exclusive, const bool exclusive,
const float16* x, const float16* x,
float16* y, float16* y,
...@@ -45,8 +45,8 @@ void _CumSum<float16>( ...@@ -45,8 +45,8 @@ void _CumSum<float16>(
template <typename T> template <typename T>
void _CumSumReverse( void _CumSumReverse(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const bool exclusive, const bool exclusive,
const T* x, const T* x,
T* y, T* y,
...@@ -72,8 +72,8 @@ void _CumSumReverse( ...@@ -72,8 +72,8 @@ void _CumSumReverse(
template <> template <>
void _CumSumReverse<float16>( void _CumSumReverse<float16>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const bool exclusive, const bool exclusive,
const float16* x, const float16* x,
float16* y, float16* y,
...@@ -89,17 +89,17 @@ void _CumSumReverse<float16>( ...@@ -89,17 +89,17 @@ void _CumSumReverse<float16>(
template <> \ template <> \
void CumSum<T, CPUContext>( \ void CumSum<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const bool exclusive, \ const bool exclusive, \
const bool reverse, \ const bool reverse, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
if (reverse) { \ if (reverse) { \
_CumSumReverse(outer_dim, axis_dim, inner_dim, exclusive, x, y, ctx); \ _CumSumReverse(outer_dim, inner_dim, axis_dim, exclusive, x, y, ctx); \
} else { \ } else { \
_CumSum(outer_dim, axis_dim, inner_dim, exclusive, x, y, ctx); \ _CumSum(outer_dim, inner_dim, axis_dim, exclusive, x, y, ctx); \
} \ } \
} }
...@@ -110,7 +110,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t); ...@@ -110,7 +110,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -98,8 +98,8 @@ __global__ void _CumSumReverse<half>( ...@@ -98,8 +98,8 @@ __global__ void _CumSumReverse<half>(
template <> template <>
void CumSum<float16, CUDAContext>( void CumSum<float16, CUDAContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const bool exclusive, const bool exclusive,
const bool reverse, const bool reverse,
const float16* x, const float16* x,
...@@ -129,8 +129,8 @@ void CumSum<float16, CUDAContext>( ...@@ -129,8 +129,8 @@ void CumSum<float16, CUDAContext>(
template <> \ template <> \
void CumSum<T, CUDAContext>( \ void CumSum<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const bool exclusive, \ const bool exclusive, \
const bool reverse, \ const bool reverse, \
const T* x, \ const T* x, \
...@@ -155,7 +155,6 @@ DEFINE_KERNEL_LAUNCHER(int); ...@@ -155,7 +155,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -13,16 +13,16 @@ void _IndexSelect( ...@@ -13,16 +13,16 @@ void _IndexSelect(
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int select_dim, const int select_dim,
const int64_t* indices, const int64_t* index,
const T* x, const T* x,
T* y, T* y,
CPUContext* ctx) { CPUContext* ctx) {
int index; int pos;
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < select_dim; ++j) { for (int j = 0; j < select_dim; ++j) {
index = indices[j]; pos = index[j];
index = index >= 0 ? index : index + axis_dim; pos = pos >= 0 ? pos : pos + axis_dim;
const T* offset_x = x + (i * axis_dim + index) * inner_dim; const T* offset_x = x + (i * axis_dim + pos) * inner_dim;
math::Copy(inner_dim, offset_x, y, ctx); math::Copy(inner_dim, offset_x, y, ctx);
y += inner_dim; y += inner_dim;
} }
...@@ -35,16 +35,16 @@ void _IndexSelectGrad( ...@@ -35,16 +35,16 @@ void _IndexSelectGrad(
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int select_dim, const int select_dim,
const int64_t* indices, const int64_t* index,
const T* dy, const T* dy,
T* dx, T* dx,
CPUContext* ctx) { CPUContext* ctx) {
int index; int pos;
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < select_dim; ++j) { for (int j = 0; j < select_dim; ++j) {
index = indices[j]; pos = index[j];
index = index >= 0 ? index : index + axis_dim; pos = pos >= 0 ? pos : pos + axis_dim;
T* offset_dx = dx + (i * axis_dim + index) * inner_dim; T* offset_dx = dx + (i * axis_dim + pos) * inner_dim;
math::Add(inner_dim, dy, offset_dx, offset_dx, ctx); math::Add(inner_dim, dy, offset_dx, offset_dx, ctx);
dy += inner_dim; dy += inner_dim;
} }
...@@ -55,18 +55,18 @@ void _IndexSelectGrad( ...@@ -55,18 +55,18 @@ void _IndexSelectGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \ #define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void name<T, CPUContext>( \ void name<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int select_dim, \ const int select_dim, \
const int64_t* indices, \ const int64_t* index, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name(outer_dim, inner_dim, axis_dim, select_dim, indices, x, y, ctx); \ _##name(outer_dim, inner_dim, axis_dim, select_dim, index, x, y, ctx); \
} }
DEFINE_KERNEL_LAUNCHER(IndexSelect, bool); DEFINE_KERNEL_LAUNCHER(IndexSelect, bool);
......
...@@ -15,19 +15,19 @@ __global__ void _IndexSelect( ...@@ -15,19 +15,19 @@ __global__ void _IndexSelect(
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int select_dim, const int select_dim,
const int64_t* indices, const int64_t* index,
const T* x, const T* x,
T* y) { T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int i = yi / inner_dim / select_dim; const int i = yi / inner_dim / select_dim;
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
int index = __ldg(indices + ((yi / inner_dim) % select_dim)); int pos = __ldg(index + ((yi / inner_dim) % select_dim));
#else #else
int index = indices[(yi / inner_dim) % select_dim]; int pos = index[(yi / inner_dim) % select_dim];
#endif #endif
index = index >= 0 ? index : index + axis_dim; pos = pos >= 0 ? pos : pos + axis_dim;
y[yi] = x[(i * axis_dim + index) * inner_dim + j]; y[yi] = x[(i * axis_dim + pos) * inner_dim + j];
} }
} }
...@@ -37,22 +37,22 @@ __global__ void _IndexSelectGrad( ...@@ -37,22 +37,22 @@ __global__ void _IndexSelectGrad(
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int select_dim, const int select_dim,
const int64_t* indices, const int64_t* index,
const T* dy, const T* dy,
T* dx) { T* dx) {
CUDA_1D_KERNEL_LOOP(ti, nthreads) { CUDA_1D_KERNEL_LOOP(ti, nthreads) {
const int i = ti / inner_dim; const int i = ti / inner_dim;
const int j = ti % inner_dim; const int j = ti % inner_dim;
const int c = i * axis_dim * inner_dim + j; const int x_offset = i * axis_dim * inner_dim + j;
const T* offset_dy = dy + i * select_dim * inner_dim + j; const T* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < select_dim; ++k) { for (int k = 0; k < select_dim; ++k) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
int index = __ldg(indices + k); int pos = __ldg(index + k);
#else #else
int index = indices[k]; int pos = index[k];
#endif #endif
index = index >= 0 ? index : index + axis_dim; pos = pos >= 0 ? pos : pos + axis_dim;
dx[c + index * inner_dim] += (*offset_dy); dx[x_offset + pos * inner_dim] += (*offset_dy);
offset_dy += inner_dim; offset_dy += inner_dim;
} }
} }
...@@ -64,23 +64,30 @@ __global__ void _IndexSelectGrad<half>( ...@@ -64,23 +64,30 @@ __global__ void _IndexSelectGrad<half>(
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int select_dim, const int select_dim,
const int64_t* indices, const int64_t* index,
const half* dy, const half* dy,
half* dx) { half* dx) {
CUDA_1D_KERNEL_LOOP(ti, nthreads) { CUDA_1D_KERNEL_LOOP(ti, nthreads) {
#if __CUDA_ARCH__ >= 530
const int i = ti / inner_dim; const int i = ti / inner_dim;
const int j = ti % inner_dim; const int j = ti % inner_dim;
const int c = i * axis_dim * inner_dim + j; const int x_offset = i * axis_dim * inner_dim + j;
const half* offset_dy = dy + i * select_dim * inner_dim + j; const half* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < select_dim; ++k) { for (int k = 0; k < select_dim; ++k) {
int index = __ldg(indices + j); #if __CUDA_ARCH__ >= 350
index = index >= 0 ? index : index + axis_dim; int pos = __ldg(index + k);
index = c + index * inner_dim; #else
dx[index] = __hadd(dx[index], *(offset_dy)); int pos = index[k];
#endif
pos = pos >= 0 ? pos : pos + axis_dim;
pos = x_offset + pos * inner_dim;
#if __CUDA_ARCH__ >= 530
dx[pos] = __hadd(dx[pos], *(offset_dy));
#else
dx[pos] =
__float2half(__half2float(dx[pos]) + __half2float(*(offset_dy)));
#endif
offset_dy += inner_dim; offset_dy += inner_dim;
} }
#endif
} }
} }
...@@ -94,7 +101,7 @@ void IndexSelectGrad<float16, CUDAContext>( ...@@ -94,7 +101,7 @@ void IndexSelectGrad<float16, CUDAContext>(
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int select_dim, const int select_dim,
const int64_t* indices, const int64_t* index,
const float16* dy, const float16* dy,
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
...@@ -108,49 +115,49 @@ void IndexSelectGrad<float16, CUDAContext>( ...@@ -108,49 +115,49 @@ void IndexSelectGrad<float16, CUDAContext>(
inner_dim, inner_dim,
axis_dim, axis_dim,
select_dim, select_dim,
indices, index,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx)); reinterpret_cast<half*>(dx));
} // IndexSelectGrad } // IndexSelectGrad
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void IndexSelect<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int64_t* index, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * select_dim * inner_dim; \
_IndexSelect<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, select_dim, index, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void IndexSelect<T, CUDAContext>( \ void IndexSelectGrad<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int select_dim, \ const int select_dim, \
const int64_t* indices, \ const int64_t* index, \
const T* x, \ const T* dy, \
T* y, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * select_dim * inner_dim; \ const int nthreads = outer_dim * inner_dim; \
_IndexSelect<<< \ _IndexSelectGrad<<< \
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>( \ ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, select_dim, indices, x, y); \ nthreads, inner_dim, axis_dim, select_dim, index, dy, dx); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void IndexSelectGrad<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int64_t* indices, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * inner_dim; \
_IndexSelectGrad<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, select_dim, indices, dy, dx); \
} }
DEFINE_KERNEL_LAUNCHER(bool); DEFINE_KERNEL_LAUNCHER(bool);
......
...@@ -10,8 +10,8 @@ namespace { ...@@ -10,8 +10,8 @@ namespace {
template <typename T> template <typename T>
void _BroadcastLossGrad( void _BroadcastLossGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* dy, const T* dy,
T* dx) { T* dx) {
std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim}; std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim};
...@@ -52,8 +52,8 @@ void ReduceLossGrad<float16, CPUContext>( ...@@ -52,8 +52,8 @@ void ReduceLossGrad<float16, CPUContext>(
template <> template <>
void BroadcastLossGrad<float16, CPUContext>( void BroadcastLossGrad<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* dy, const float16* dy,
float16* dx, float16* dx,
CPUContext* ctx) { CPUContext* ctx) {
...@@ -98,12 +98,12 @@ void BroadcastLossGrad<float16, CPUContext>( ...@@ -98,12 +98,12 @@ void BroadcastLossGrad<float16, CPUContext>(
template <> \ template <> \
void BroadcastLossGrad<T, CPUContext>( \ void BroadcastLossGrad<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* dy, \ const T* dy, \
T* dx, \ T* dx, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_BroadcastLossGrad(outer_dim, axis_dim, inner_dim, dy, dx); \ _BroadcastLossGrad(outer_dim, inner_dim, axis_dim, dy, dx); \
} }
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
......
...@@ -146,12 +146,12 @@ void ReduceLossGrad<float16, CUDAContext>( ...@@ -146,12 +146,12 @@ void ReduceLossGrad<float16, CUDAContext>(
template <> template <>
void BroadcastLossGrad<float16, CUDAContext>( void BroadcastLossGrad<float16, CUDAContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* dy, const float16* dy,
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim; const auto nthreads = outer_dim * axis_dim * inner_dim;
_BroadcastLossGrad<<< _BroadcastLossGrad<<<
CUDA_BLOCKS(nthreads), CUDA_BLOCKS(nthreads),
CUDA_THREADS, CUDA_THREADS,
...@@ -214,12 +214,12 @@ void BroadcastLossGrad<float16, CUDAContext>( ...@@ -214,12 +214,12 @@ void BroadcastLossGrad<float16, CUDAContext>(
template <> \ template <> \
void BroadcastLossGrad<T, CUDAContext>( \ void BroadcastLossGrad<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* dy, \ const T* dy, \
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
auto nthreads = outer_dim * axis_dim * inner_dim; \ const auto nthreads = outer_dim * axis_dim * inner_dim; \
_BroadcastLossGrad<<< \ _BroadcastLossGrad<<< \
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \ CUDA_THREADS, \
......
...@@ -10,10 +10,10 @@ namespace { ...@@ -10,10 +10,10 @@ namespace {
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
void _NLLLoss( void _NLLLoss(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* log_prob, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
LogitType* mask) { LogitType* mask) {
...@@ -26,7 +26,7 @@ void _NLLLoss( ...@@ -26,7 +26,7 @@ void _NLLLoss(
loss[i] = mask[i] = LogitType(0); loss[i] = mask[i] = LogitType(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
loss[i] = -log_prob[k], mask[i] = LogitType(1); loss[i] = -logit[k], mask[i] = LogitType(1);
} }
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data()); utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
...@@ -35,12 +35,12 @@ void _NLLLoss( ...@@ -35,12 +35,12 @@ void _NLLLoss(
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
void _NLLLossGrad( void _NLLLossGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* log_prob, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dlogit,
LogitType* mask) { LogitType* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
...@@ -51,7 +51,7 @@ void _NLLLossGrad( ...@@ -51,7 +51,7 @@ void _NLLLossGrad(
mask[i] = LogitType(0); mask[i] = LogitType(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
dx[k] = LogitType(-1), mask[i] = LogitType(1); dlogit[k] = LogitType(-1), mask[i] = LogitType(1);
} }
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data()); utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
...@@ -65,20 +65,20 @@ void _NLLLossGrad( ...@@ -65,20 +65,20 @@ void _NLLLossGrad(
template <> \ template <> \
void name<LogitType, TargetType, CPUContext>( \ void name<LogitType, TargetType, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const int ignore_index, \ const int ignore_index, \
const LogitType* log_prob, \ const LogitType* logit, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
LogitType* mask, \ LogitType* mask, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
axis_dim, \
inner_dim, \ inner_dim, \
axis_dim, \
ignore_index, \ ignore_index, \
log_prob, \ logit, \
target, \ target, \
loss, \ loss, \
mask); \ mask); \
......
...@@ -12,10 +12,10 @@ namespace { ...@@ -12,10 +12,10 @@ namespace {
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
__global__ void _NLLLoss( __global__ void _NLLLoss(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* log_prob, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
LogitType* mask) { LogitType* mask) {
...@@ -26,7 +26,7 @@ __global__ void _NLLLoss( ...@@ -26,7 +26,7 @@ __global__ void _NLLLoss(
if (label == ignore_index) { if (label == ignore_index) {
loss[yi] = mask[yi] = LogitType(0); loss[yi] = mask[yi] = LogitType(0);
} else { } else {
loss[yi] = -log_prob[(i * axis_dim + label) * inner_dim + j]; loss[yi] = -logit[(i * axis_dim + label) * inner_dim + j];
mask[yi] = LogitType(1); mask[yi] = LogitType(1);
} }
} }
...@@ -35,12 +35,12 @@ __global__ void _NLLLoss( ...@@ -35,12 +35,12 @@ __global__ void _NLLLoss(
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
__global__ void _NLLLossGrad( __global__ void _NLLLossGrad(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* log_prob, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dlogit,
LogitType* mask) { LogitType* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
...@@ -49,7 +49,7 @@ __global__ void _NLLLossGrad( ...@@ -49,7 +49,7 @@ __global__ void _NLLLossGrad(
if (label == ignore_index) { if (label == ignore_index) {
mask[yi] = LogitType(0); mask[yi] = LogitType(0);
} else { } else {
dx[(i * axis_dim + label) * inner_dim + j] = LogitType(-1); dlogit[(i * axis_dim + label) * inner_dim + j] = LogitType(-1);
mask[yi] = LogitType(1); mask[yi] = LogitType(1);
} }
} }
...@@ -63,21 +63,21 @@ __global__ void _NLLLossGrad( ...@@ -63,21 +63,21 @@ __global__ void _NLLLossGrad(
template <> \ template <> \
void name<LogitType, TargetType, CUDAContext>( \ void name<LogitType, TargetType, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const int ignore_index, \ const int ignore_index, \
const LogitType* log_prob, \ const LogitType* logit, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
LogitType* mask, \ LogitType* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
auto nthreads = outer_dim * inner_dim; \ const auto nthreads = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \ nthreads, \
axis_dim, \
inner_dim, \ inner_dim, \
axis_dim, \
ignore_index, \ ignore_index, \
log_prob, \ logit, \
target, \ target, \
loss, \ loss, \
mask); \ mask); \
......
...@@ -10,8 +10,8 @@ namespace { ...@@ -10,8 +10,8 @@ namespace {
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
void _SigmoidFocalLoss( void _SigmoidFocalLoss(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const LogitType pos_alpha, const LogitType pos_alpha,
const LogitType neg_alpha, const LogitType neg_alpha,
const LogitType gamma, const LogitType gamma,
...@@ -55,8 +55,8 @@ void _SigmoidFocalLoss( ...@@ -55,8 +55,8 @@ void _SigmoidFocalLoss(
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
void _SigmoidFocalLossGrad( void _SigmoidFocalLossGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const LogitType pos_alpha, const LogitType pos_alpha,
const LogitType neg_alpha, const LogitType neg_alpha,
const LogitType gamma, const LogitType gamma,
...@@ -108,8 +108,8 @@ void _SigmoidFocalLossGrad( ...@@ -108,8 +108,8 @@ void _SigmoidFocalLossGrad(
template <> \ template <> \
void name<LogitType, TargetType, CPUContext>( \ void name<LogitType, TargetType, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const float pos_alpha, \ const float pos_alpha, \
const float neg_alpha, \ const float neg_alpha, \
const float gamma, \ const float gamma, \
...@@ -121,8 +121,8 @@ void _SigmoidFocalLossGrad( ...@@ -121,8 +121,8 @@ void _SigmoidFocalLossGrad(
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
axis_dim, \
inner_dim, \ inner_dim, \
axis_dim, \
(LogitType)pos_alpha, \ (LogitType)pos_alpha, \
(LogitType)neg_alpha, \ (LogitType)neg_alpha, \
(LogitType)gamma, \ (LogitType)gamma, \
......
...@@ -12,8 +12,8 @@ namespace { ...@@ -12,8 +12,8 @@ namespace {
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
__global__ void _SigmoidFocalLoss( __global__ void _SigmoidFocalLoss(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const LogitType pos_alpha, const LogitType pos_alpha,
const LogitType neg_alpha, const LogitType neg_alpha,
const LogitType gamma, const LogitType gamma,
...@@ -53,8 +53,8 @@ __global__ void _SigmoidFocalLoss( ...@@ -53,8 +53,8 @@ __global__ void _SigmoidFocalLoss(
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
__global__ void _SigmoidFocalLossGrad( __global__ void _SigmoidFocalLossGrad(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const LogitType pos_alpha, const LogitType pos_alpha,
const LogitType neg_alpha, const LogitType neg_alpha,
const LogitType gamma, const LogitType gamma,
...@@ -102,8 +102,8 @@ __global__ void _SigmoidFocalLossGrad( ...@@ -102,8 +102,8 @@ __global__ void _SigmoidFocalLossGrad(
template <> \ template <> \
void name<LogitType, TargetType, CUDAContext>( \ void name<LogitType, TargetType, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const float pos_alpha, \ const float pos_alpha, \
const float neg_alpha, \ const float neg_alpha, \
const float gamma, \ const float gamma, \
...@@ -113,11 +113,11 @@ __global__ void _SigmoidFocalLossGrad( ...@@ -113,11 +113,11 @@ __global__ void _SigmoidFocalLossGrad(
LogitType* loss, \ LogitType* loss, \
LogitType* mask, \ LogitType* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \ const auto nthreads = outer_dim * axis_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \ nthreads, \
axis_dim, \
inner_dim, \ inner_dim, \
axis_dim, \
(LogitType)pos_alpha, \ (LogitType)pos_alpha, \
(LogitType)neg_alpha, \ (LogitType)neg_alpha, \
(LogitType)gamma, \ (LogitType)gamma, \
......
...@@ -10,8 +10,8 @@ namespace { ...@@ -10,8 +10,8 @@ namespace {
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
void _SparseSoftmaxCrossEntropy( void _SparseSoftmaxCrossEntropy(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
...@@ -36,8 +36,8 @@ void _SparseSoftmaxCrossEntropy( ...@@ -36,8 +36,8 @@ void _SparseSoftmaxCrossEntropy(
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
void _SparseSoftmaxCrossEntropyGrad( void _SparseSoftmaxCrossEntropyGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
...@@ -72,8 +72,8 @@ void _SparseSoftmaxCrossEntropyGrad( ...@@ -72,8 +72,8 @@ void _SparseSoftmaxCrossEntropyGrad(
template <> \ template <> \
void name<LogitType, TargetType, CPUContext>( \ void name<LogitType, TargetType, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const int ignore_index, \ const int ignore_index, \
const LogitType* prob, \ const LogitType* prob, \
const TargetType* target, \ const TargetType* target, \
...@@ -82,8 +82,8 @@ void _SparseSoftmaxCrossEntropyGrad( ...@@ -82,8 +82,8 @@ void _SparseSoftmaxCrossEntropyGrad(
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
axis_dim, \
inner_dim, \ inner_dim, \
axis_dim, \
ignore_index, \ ignore_index, \
prob, \ prob, \
target, \ target, \
......
...@@ -12,8 +12,8 @@ namespace { ...@@ -12,8 +12,8 @@ namespace {
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
__global__ void _SparseSoftmaxCrossEntropy( __global__ void _SparseSoftmaxCrossEntropy(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
...@@ -36,8 +36,8 @@ __global__ void _SparseSoftmaxCrossEntropy( ...@@ -36,8 +36,8 @@ __global__ void _SparseSoftmaxCrossEntropy(
template <typename LogitType, typename TargetType> template <typename LogitType, typename TargetType>
__global__ void _SparseSoftmaxCrossEntropyGrad( __global__ void _SparseSoftmaxCrossEntropyGrad(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
...@@ -69,19 +69,19 @@ __global__ void _SparseSoftmaxCrossEntropyGrad( ...@@ -69,19 +69,19 @@ __global__ void _SparseSoftmaxCrossEntropyGrad(
template <> \ template <> \
void name<LogitType, TargetType, CUDAContext>( \ void name<LogitType, TargetType, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const int ignore_index, \ const int ignore_index, \
const LogitType* prob, \ const LogitType* prob, \
const TargetType* target, \ const TargetType* target, \
LogitType* loss, \ LogitType* loss, \
LogitType* mask, \ LogitType* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * inner_dim; \ const auto nthreads = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \ nthreads, \
axis_dim, \
inner_dim, \ inner_dim, \
axis_dim, \
ignore_index, \ ignore_index, \
prob, \ prob, \
target, \ target, \
......
...@@ -10,8 +10,8 @@ namespace { ...@@ -10,8 +10,8 @@ namespace {
template <typename T> template <typename T>
void _L1Normalize( void _L1Normalize(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const T scale, const T scale,
const T eps, const T eps,
const T* x, const T* x,
...@@ -32,8 +32,8 @@ void _L1Normalize( ...@@ -32,8 +32,8 @@ void _L1Normalize(
template <typename T> template <typename T>
void _L2Normalize( void _L2Normalize(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const T scale, const T scale,
const T eps, const T eps,
const T* x, const T* x,
...@@ -54,8 +54,8 @@ void _L2Normalize( ...@@ -54,8 +54,8 @@ void _L2Normalize(
template <typename T> template <typename T>
void _L1NormalizeGrad( void _L1NormalizeGrad(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const T scale, const T scale,
const T eps, const T eps,
const T* dy, const T* dy,
...@@ -81,8 +81,8 @@ void _L1NormalizeGrad( ...@@ -81,8 +81,8 @@ void _L1NormalizeGrad(
template <typename T> template <typename T>
void _L2NormalizeGrad( void _L2NormalizeGrad(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const T scale, const T scale,
const T eps, const T eps,
const T* dy, const T* dy,
...@@ -112,8 +112,8 @@ void _L2NormalizeGrad( ...@@ -112,8 +112,8 @@ void _L2NormalizeGrad(
template <> template <>
void L1Normalize<float16, CPUContext>( void L1Normalize<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const float16* x, const float16* x,
...@@ -125,8 +125,8 @@ void L1Normalize<float16, CPUContext>( ...@@ -125,8 +125,8 @@ void L1Normalize<float16, CPUContext>(
template <> template <>
void L2Normalize<float16, CPUContext>( void L2Normalize<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const float16* x, const float16* x,
...@@ -138,8 +138,8 @@ void L2Normalize<float16, CPUContext>( ...@@ -138,8 +138,8 @@ void L2Normalize<float16, CPUContext>(
template <> template <>
void L1NormalizeGrad<float16, CPUContext>( void L1NormalizeGrad<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const float16* dy, const float16* dy,
...@@ -152,8 +152,8 @@ void L1NormalizeGrad<float16, CPUContext>( ...@@ -152,8 +152,8 @@ void L1NormalizeGrad<float16, CPUContext>(
template <> template <>
void L2NormalizeGrad<float16, CPUContext>( void L2NormalizeGrad<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const float16* dy, const float16* dy,
...@@ -167,14 +167,14 @@ void L2NormalizeGrad<float16, CPUContext>( ...@@ -167,14 +167,14 @@ void L2NormalizeGrad<float16, CPUContext>(
template <> \ template <> \
void name<T, CPUContext>( \ void name<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \
const float scale, \ const float scale, \
const float eps, \ const float eps, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name(outer_dim, reduce_dim, inner_dim, (T)scale, (T)eps, x, y); \ _##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, x, y); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float); DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
...@@ -187,15 +187,15 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double); ...@@ -187,15 +187,15 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
template <> \ template <> \
void name<T, CPUContext>( \ void name<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \
const float scale, \ const float scale, \
const float eps, \ const float eps, \
const T* dy, \ const T* dy, \
const T* x, \ const T* x, \
T* dx, \ T* dx, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name(outer_dim, reduce_dim, inner_dim, (T)scale, (T)eps, dy, x, dx); \ _##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, dy, x, dx); \
} }
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float);
......
...@@ -14,8 +14,8 @@ namespace { ...@@ -14,8 +14,8 @@ namespace {
template <typename T> template <typename T>
__global__ void _L1Normalize( __global__ void _L1Normalize(
const int nblocks, const int nblocks,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const T scale, const T scale,
const T eps, const T eps,
const T* x, const T* x,
...@@ -42,8 +42,8 @@ __global__ void _L1Normalize( ...@@ -42,8 +42,8 @@ __global__ void _L1Normalize(
__global__ void _L1Normalize( __global__ void _L1Normalize(
const int nblocks, const int nblocks,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const half* x, const half* x,
...@@ -71,8 +71,8 @@ __global__ void _L1Normalize( ...@@ -71,8 +71,8 @@ __global__ void _L1Normalize(
template <typename T> template <typename T>
__global__ void _L2Normalize( __global__ void _L2Normalize(
const int nblocks, const int nblocks,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const T scale, const T scale,
const T eps, const T eps,
const T* x, const T* x,
...@@ -99,8 +99,8 @@ __global__ void _L2Normalize( ...@@ -99,8 +99,8 @@ __global__ void _L2Normalize(
__global__ void _L2Normalize( __global__ void _L2Normalize(
const int nblocks, const int nblocks,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const half* x, const half* x,
...@@ -128,8 +128,8 @@ __global__ void _L2Normalize( ...@@ -128,8 +128,8 @@ __global__ void _L2Normalize(
template <typename T> template <typename T>
__global__ void _L1NormalizeGrad( __global__ void _L1NormalizeGrad(
const int nblocks, const int nblocks,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const T scale, const T scale,
const T eps, const T eps,
const T* dy, const T* dy,
...@@ -162,8 +162,8 @@ __global__ void _L1NormalizeGrad( ...@@ -162,8 +162,8 @@ __global__ void _L1NormalizeGrad(
__global__ void _L1NormalizeGrad( __global__ void _L1NormalizeGrad(
const int nblocks, const int nblocks,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const half* dy, const half* dy,
...@@ -199,8 +199,8 @@ __global__ void _L1NormalizeGrad( ...@@ -199,8 +199,8 @@ __global__ void _L1NormalizeGrad(
template <typename T> template <typename T>
__global__ void _L2NormalizeGrad( __global__ void _L2NormalizeGrad(
const int nblocks, const int nblocks,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const T scale, const T scale,
const T eps, const T eps,
const T* dy, const T* dy,
...@@ -233,8 +233,8 @@ __global__ void _L2NormalizeGrad( ...@@ -233,8 +233,8 @@ __global__ void _L2NormalizeGrad(
__global__ void _L2NormalizeGrad( __global__ void _L2NormalizeGrad(
const int nblocks, const int nblocks,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const half* dy, const half* dy,
...@@ -275,8 +275,8 @@ __global__ void _L2NormalizeGrad( ...@@ -275,8 +275,8 @@ __global__ void _L2NormalizeGrad(
template <> \ template <> \
void name<T, CUDAContext>( \ void name<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \
const float scale, \ const float scale, \
const float eps, \ const float eps, \
const float16* x, \ const float16* x, \
...@@ -285,8 +285,8 @@ __global__ void _L2NormalizeGrad( ...@@ -285,8 +285,8 @@ __global__ void _L2NormalizeGrad(
const auto nblocks = outer_dim * inner_dim; \ const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, \ nblocks, \
reduce_dim, \
inner_dim, \ inner_dim, \
reduce_dim, \
scale, \ scale, \
eps, \ eps, \
reinterpret_cast<const half*>(x), \ reinterpret_cast<const half*>(x), \
...@@ -301,8 +301,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, float16); ...@@ -301,8 +301,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, float16);
template <> \ template <> \
void name<T, CUDAContext>( \ void name<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \
const float scale, \ const float scale, \
const float eps, \ const float eps, \
const T* x, \ const T* x, \
...@@ -310,7 +310,7 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, float16); ...@@ -310,7 +310,7 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, float16);
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const auto nblocks = outer_dim * inner_dim; \ const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, reduce_dim, inner_dim, (T)scale, (T)eps, x, y); \ nblocks, inner_dim, reduce_dim, (T)scale, (T)eps, x, y); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float); DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
...@@ -323,8 +323,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double); ...@@ -323,8 +323,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
template <> \ template <> \
void name<T, CUDAContext>( \ void name<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \
const float scale, \ const float scale, \
const float eps, \ const float eps, \
const float16* dy, \ const float16* dy, \
...@@ -334,8 +334,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double); ...@@ -334,8 +334,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
const auto nblocks = outer_dim * inner_dim; \ const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, \ nblocks, \
reduce_dim, \
inner_dim, \ inner_dim, \
reduce_dim, \
scale, \ scale, \
eps, \ eps, \
reinterpret_cast<const half*>(dy), \ reinterpret_cast<const half*>(dy), \
...@@ -351,8 +351,8 @@ DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16); ...@@ -351,8 +351,8 @@ DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16);
template <> \ template <> \
void name<T, CUDAContext>( \ void name<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \
const float scale, \ const float scale, \
const float eps, \ const float eps, \
const T* dy, \ const T* dy, \
...@@ -361,7 +361,7 @@ DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16); ...@@ -361,7 +361,7 @@ DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16);
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const auto nblocks = outer_dim * inner_dim; \ const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, reduce_dim, inner_dim, (T)scale, (T)eps, dy, x, dx); \ nblocks, inner_dim, reduce_dim, (T)scale, (T)eps, dy, x, dx); \
} }
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float);
......
...@@ -22,8 +22,8 @@ void _BiasAdd( ...@@ -22,8 +22,8 @@ void _BiasAdd(
template <typename T> template <typename T>
void _BiasAdd( void _BiasAdd(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
const T* b, const T* b,
T* y) { T* y) {
...@@ -44,8 +44,8 @@ void _BiasAdd( ...@@ -44,8 +44,8 @@ void _BiasAdd(
template <> template <>
void BiasAdd<float16, CPUContext>( void BiasAdd<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* x, const float16* x,
const float16* b, const float16* b,
float16* y, float16* y,
...@@ -57,8 +57,8 @@ void BiasAdd<float16, CPUContext>( ...@@ -57,8 +57,8 @@ void BiasAdd<float16, CPUContext>(
template <> \ template <> \
void BiasAdd<T, CPUContext>( \ void BiasAdd<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* x, \ const T* x, \
const T* b, \ const T* b, \
T* y, \ T* y, \
...@@ -66,7 +66,7 @@ void BiasAdd<float16, CPUContext>( ...@@ -66,7 +66,7 @@ void BiasAdd<float16, CPUContext>(
if (inner_dim == 1) { \ if (inner_dim == 1) { \
_BiasAdd(outer_dim, axis_dim, x, b, y); \ _BiasAdd(outer_dim, axis_dim, x, b, y); \
} else { \ } else { \
_BiasAdd(outer_dim, axis_dim, inner_dim, x, b, y); \ _BiasAdd(outer_dim, inner_dim, axis_dim, x, b, y); \
} \ } \
} }
......
...@@ -38,8 +38,8 @@ __global__ void _BiasAdd<half>( ...@@ -38,8 +38,8 @@ __global__ void _BiasAdd<half>(
template <typename T> template <typename T>
__global__ void _BiasAdd( __global__ void _BiasAdd(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
const T* b, const T* b,
T* y) { T* y) {
...@@ -55,8 +55,8 @@ __global__ void _BiasAdd( ...@@ -55,8 +55,8 @@ __global__ void _BiasAdd(
template <> template <>
__global__ void _BiasAdd<half>( __global__ void _BiasAdd<half>(
const int nthreads, const int nthreads,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const half* x, const half* x,
const half* b, const half* b,
half* y) { half* y) {
...@@ -74,13 +74,13 @@ __global__ void _BiasAdd<half>( ...@@ -74,13 +74,13 @@ __global__ void _BiasAdd<half>(
template <> template <>
void BiasAdd<float16, CUDAContext>( void BiasAdd<float16, CUDAContext>(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float16* x, const float16* x,
const float16* b, const float16* b,
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
const int nthreads = outer_dim * axis_dim * inner_dim; const auto nthreads = outer_dim * axis_dim * inner_dim;
if (inner_dim == 1) { if (inner_dim == 1) {
_BiasAdd<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _BiasAdd<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads, nthreads,
...@@ -91,8 +91,8 @@ void BiasAdd<float16, CUDAContext>( ...@@ -91,8 +91,8 @@ void BiasAdd<float16, CUDAContext>(
} else { } else {
_BiasAdd<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _BiasAdd<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads, nthreads,
axis_dim,
inner_dim, inner_dim,
axis_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(b), reinterpret_cast<const half*>(b),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -103,13 +103,13 @@ void BiasAdd<float16, CUDAContext>( ...@@ -103,13 +103,13 @@ void BiasAdd<float16, CUDAContext>(
template <> \ template <> \
void BiasAdd<T, CUDAContext>( \ void BiasAdd<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int axis_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \
const T* x, \ const T* x, \
const T* b, \ const T* b, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \ const auto nthreads = outer_dim * axis_dim * inner_dim; \
if (inner_dim == 1) { \ if (inner_dim == 1) { \
_BiasAdd<<< \ _BiasAdd<<< \
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
...@@ -121,7 +121,7 @@ void BiasAdd<float16, CUDAContext>( ...@@ -121,7 +121,7 @@ void BiasAdd<float16, CUDAContext>(
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(nthreads, axis_dim, inner_dim, x, b, y); \ ctx->cuda_stream()>>>(nthreads, inner_dim, axis_dim, x, b, y); \
} \ } \
} }
......
#include "dragon/operators/activation/hardsigmoid_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void HardSigmoidOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
kernel::HardSigmoid(
X.count(),
alpha_,
beta_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void HardSigmoidOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void HardSigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
kernel::HardSigmoidGrad(
Y.count(),
alpha_,
dY.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(Y)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void HardSigmoidGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(HardSigmoid);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSigmoid);
#endif
DEPLOY_CPU_OPERATOR(HardSigmoidGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSigmoidGradient);
#endif
OPERATOR_SCHEMA(HardSigmoid)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(HardSigmoidGradient)
/* Y, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1)
/* dY => dX */
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(HardSigmoid, InplaceGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_HARDSIGMOID_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_HARDSIGMOID_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class HardSigmoidOp : public Operator<Context> {
public:
HardSigmoidOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)),
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
float alpha_, beta_;
};
template <class Context>
class HardSigmoidGradientOp : public Operator<Context> {
public:
HardSigmoidGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
float alpha_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_HARDSIGMOID_OP_H_
#include "dragon/operators/activation/hardswish_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void HardSwishOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
kernel::HardSwish(
X.count(),
alpha_,
beta_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void HardSwishOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void HardSwishGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(1), *dX = Output(0);
kernel::HardSwishGrad(
X.count(),
alpha_,
beta_,
dY.template data<T, Context>(),
X.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void HardSwishGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(HardSwish);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSwish);
#endif
DEPLOY_CPU_OPERATOR(HardSwishGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSwishGradient);
#endif
OPERATOR_SCHEMA(HardSwish)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(HardSwishGradient)
/* X, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(HardSwish, GenericGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_HARDSWISH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_HARDSWISH_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class HardSwishOp : public Operator<Context> {
public:
HardSwishOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)),
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
float alpha_, beta_;
};
template <class Context>
class HardSwishGradientOp : public Operator<Context> {
public:
HardSwishGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)),
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
float alpha_, beta_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_HARDSWISH_OP_H_
...@@ -12,8 +12,8 @@ void SoftmaxOp<Context>::DoRunWithType() { ...@@ -12,8 +12,8 @@ void SoftmaxOp<Context>::DoRunWithType() {
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
kernel::Softmax( kernel::Softmax(
X.count(0, axis), X.count(0, axis),
X.dim(axis),
X.count(axis + 1), X.count(axis + 1),
X.dim(axis),
X.template data<T, Context>(), X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
...@@ -31,8 +31,8 @@ void SoftmaxGradientOp<Context>::DoRunWithType() { ...@@ -31,8 +31,8 @@ void SoftmaxGradientOp<Context>::DoRunWithType() {
CANONICALIZE_AXIS_WITH_TENSOR(Y); CANONICALIZE_AXIS_WITH_TENSOR(Y);
kernel::SoftmaxGrad( kernel::SoftmaxGrad(
Y.count(0, axis), Y.count(0, axis),
Y.dim(axis),
Y.count(axis + 1), Y.count(axis + 1),
Y.dim(axis),
dY.template data<T, Context>(), dY.template data<T, Context>(),
Y.template data<T, Context>(), Y.template data<T, Context>(),
dX->ReshapeLike(Y)->template mutable_data<T, Context>(), dX->ReshapeLike(Y)->template mutable_data<T, Context>(),
......
#include "dragon/operators/activation/swish_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void SwishOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
kernel::Swish(
X.count(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void SwishOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void SwishGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &Y = Input(1);
auto &dY = Input(2), *dX = Output(0);
kernel::SwishGrad(
X.count(),
dY.template data<T, Context>(),
X.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void SwishGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Swish);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Swish);
#endif
DEPLOY_CPU_OPERATOR(SwishGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SwishGradient);
#endif
OPERATOR_SCHEMA(Swish)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(SwishGradient)
/* X, Y, dY */
.NumInputs(3)
/* dX */
.NumOutputs(1);
namespace {
class GradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
"",
vector<string>({I(0), O(0), GO(0)}),
vector<string>({GI(0)}));
}
};
} // namespace
REGISTER_GRADIENT(Swish, GradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class SwishOp : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(SwishOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
};
template <class Context>
class SwishGradientOp : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(SwishGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_
...@@ -39,8 +39,8 @@ void ChannelAffineOp<Context>::DoRunWithType() { ...@@ -39,8 +39,8 @@ void ChannelAffineOp<Context>::DoRunWithType() {
kernel::ChannelAffine( kernel::ChannelAffine(
X.count(0, axis), X.count(0, axis),
X.count(axis, axis + num_axes),
X.count(axis + num_axes), X.count(axis + num_axes),
X.count(axis, axis + num_axes),
X.template data<T, Context>(), X.template data<T, Context>(),
W.template data<T, Context>(), W.template data<T, Context>(),
InputSize() <= 2 ? nullptr : Input(2).template data<T, Context>(), InputSize() <= 2 ? nullptr : Input(2).template data<T, Context>(),
...@@ -121,8 +121,8 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() { ...@@ -121,8 +121,8 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
Output(0)->ReshapeLike(Input(-1)); Output(0)->ReshapeLike(Input(-1));
kernel::ChannelAffine( kernel::ChannelAffine(
X.count(0, axis), X.count(0, axis),
X.count(axis, axis + num_axes),
X.count(axis + num_axes), X.count(axis + num_axes),
X.count(axis, axis + num_axes),
dY.template data<T, Context>(), dY.template data<T, Context>(),
W.template data<T, Context>(), W.template data<T, Context>(),
(const T*)nullptr, (const T*)nullptr,
......
...@@ -12,8 +12,8 @@ void CumSumOp<Context>::DoRunWithType() { ...@@ -12,8 +12,8 @@ void CumSumOp<Context>::DoRunWithType() {
kernel::CumSum( kernel::CumSum(
X.count(0, axis), X.count(0, axis),
X.dim(axis),
X.count(axis + 1), X.count(axis + 1),
X.dim(axis),
exclusive_, exclusive_,
reverse_, reverse_,
X.template data<T, Context>(), X.template data<T, Context>(),
...@@ -34,8 +34,8 @@ void CumSumGradientOp<Context>::DoRunWithType() { ...@@ -34,8 +34,8 @@ void CumSumGradientOp<Context>::DoRunWithType() {
kernel::CumSum( kernel::CumSum(
dY.count(0, axis), dY.count(0, axis),
dY.dim(axis),
dY.count(axis + 1), dY.count(axis + 1),
dY.dim(axis),
exclusive_, exclusive_,
!reverse_, !reverse_,
dY.template data<T, Context>(), dY.template data<T, Context>(),
......
...@@ -26,7 +26,7 @@ void MultinomialOp<Context>::DoRunWithType() { ...@@ -26,7 +26,7 @@ void MultinomialOp<Context>::DoRunWithType() {
CPUContext cpu_ctx; CPUContext cpu_ctx;
auto* prob = Buffer("prob")->template mutable_data<T, CPUContext>(); auto* prob = Buffer("prob")->template mutable_data<T, CPUContext>();
kernel::Softmax( kernel::Softmax(
X.count(0, axis), X.dim(axis), X.count(axis + 1), x, prob, &cpu_ctx); X.count(0, axis), X.count(axis + 1), X.dim(axis), x, prob, &cpu_ctx);
x = prob; x = prob;
} }
......
...@@ -27,8 +27,8 @@ void NLLLossOp<Context>::DoRunWithType() { ...@@ -27,8 +27,8 @@ void NLLLossOp<Context>::DoRunWithType() {
kernel::NLLLoss( kernel::NLLLoss(
outer_dim, outer_dim,
X.dim(axis),
inner_dim, inner_dim,
X.dim(axis),
ignore_index_, ignore_index_,
X.template data<LogitType, Context>(), X.template data<LogitType, Context>(),
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetType, Context>(),
...@@ -109,8 +109,8 @@ void NLLLossGradientOp<Context>::DoRunWithType() { ...@@ -109,8 +109,8 @@ void NLLLossGradientOp<Context>::DoRunWithType() {
kernel::NLLLossGrad( kernel::NLLLossGrad(
outer_dim, outer_dim,
dX->dim(axis),
inner_dim, inner_dim,
dX->dim(axis),
ignore_index_, ignore_index_,
X.template data<LogitType, Context>(), X.template data<LogitType, Context>(),
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetType, Context>(),
...@@ -120,7 +120,7 @@ void NLLLossGradientOp<Context>::DoRunWithType() { ...@@ -120,7 +120,7 @@ void NLLLossGradientOp<Context>::DoRunWithType() {
if (reduction_ == "NONE") { if (reduction_ == "NONE") {
kernel::BroadcastLossGrad( kernel::BroadcastLossGrad(
outer_dim, dX->dim(axis), inner_dim, dy, dx, ctx()); outer_dim, inner_dim, dX->dim(axis), dy, dx, ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
......
...@@ -26,8 +26,8 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() { ...@@ -26,8 +26,8 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() {
kernel::SigmoidFocalLoss( kernel::SigmoidFocalLoss(
outer_dim, outer_dim,
X.dim(axis),
inner_dim, inner_dim,
X.dim(axis),
pos_alpha_, pos_alpha_,
neg_alpha_, neg_alpha_,
gamma_, gamma_,
...@@ -107,8 +107,8 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() { ...@@ -107,8 +107,8 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() {
kernel::SigmoidFocalLossGrad( kernel::SigmoidFocalLossGrad(
outer_dim, outer_dim,
dX->dim(axis),
inner_dim, inner_dim,
dX->dim(axis),
pos_alpha_, pos_alpha_,
neg_alpha_, neg_alpha_,
gamma_, gamma_,
......
...@@ -24,8 +24,8 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -24,8 +24,8 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() {
kernel::Softmax( kernel::Softmax(
outer_dim, outer_dim,
X.dim(axis),
inner_dim, inner_dim,
X.dim(axis),
X.template data<T, Context>(), X.template data<T, Context>(),
prob, prob,
ctx()); ctx());
...@@ -90,7 +90,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -90,7 +90,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
if (reduction_ == "NONE") { if (reduction_ == "NONE") {
kernel::BroadcastLossGrad( kernel::BroadcastLossGrad(
outer_dim, dX->dim(axis), inner_dim, dy, dx, ctx()); outer_dim, inner_dim, dX->dim(axis), dy, dx, ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_MEAN") { if (reduction_ == "BATCH_MEAN") {
......
...@@ -29,16 +29,16 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -29,16 +29,16 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
kernel::Softmax( kernel::Softmax(
outer_dim, outer_dim,
X.dim(axis),
inner_dim, inner_dim,
X.dim(axis),
X.template data<LogitType, Context>(), X.template data<LogitType, Context>(),
prob, prob,
ctx()); ctx());
kernel::SparseSoftmaxCrossEntropy( kernel::SparseSoftmaxCrossEntropy(
outer_dim, outer_dim,
X.dim(axis),
inner_dim, inner_dim,
X.dim(axis),
ignore_index_, ignore_index_,
prob, prob,
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetType, Context>(),
...@@ -120,8 +120,8 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -120,8 +120,8 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
kernel::SparseSoftmaxCrossEntropyGrad( kernel::SparseSoftmaxCrossEntropyGrad(
outer_dim, outer_dim,
dX->dim(axis),
inner_dim, inner_dim,
dX->dim(axis),
ignore_index_, ignore_index_,
prob, prob,
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetType, Context>(),
...@@ -131,7 +131,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -131,7 +131,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
if (reduction_ == "NONE") { if (reduction_ == "NONE") {
kernel::BroadcastLossGrad( kernel::BroadcastLossGrad(
outer_dim, dX->dim(axis), inner_dim, dy, dx, ctx()); outer_dim, inner_dim, dX->dim(axis), dy, dx, ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
......
...@@ -59,8 +59,8 @@ void FullyConnectedOp<Context>::DoRunWithType() { ...@@ -59,8 +59,8 @@ void FullyConnectedOp<Context>::DoRunWithType() {
TENSOR_FILL(Input(2), vec64_t({N})); TENSOR_FILL(Input(2), vec64_t({N}));
kernel::BiasAdd( kernel::BiasAdd(
M, M,
N,
1, 1,
N,
Y->template data<T, Context>(), Y->template data<T, Context>(),
Input(2).template data<T, Context>(), Input(2).template data<T, Context>(),
Y->template mutable_data<T, Context>(), Y->template mutable_data<T, Context>(),
......
...@@ -56,9 +56,9 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -56,9 +56,9 @@ void BatchNormOp<Context>::TrainingImpl() {
// Compute affine transformation // Compute affine transformation
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
} }
} }
...@@ -91,9 +91,9 @@ void BatchNormOp<Context>::InferenceImpl() { ...@@ -91,9 +91,9 @@ void BatchNormOp<Context>::InferenceImpl() {
// Compute affine transformation // Compute affine transformation
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
} }
} }
......
...@@ -89,9 +89,9 @@ void SyncBatchNormOp<Context>::TrainingImpl() { ...@@ -89,9 +89,9 @@ void SyncBatchNormOp<Context>::TrainingImpl() {
// Compute affine transformation // Compute affine transformation
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx()); kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
} }
} }
......
...@@ -28,8 +28,8 @@ void LpNormalizeOp<Context>::DoRunWithType() { ...@@ -28,8 +28,8 @@ void LpNormalizeOp<Context>::DoRunWithType() {
if (p_ == 1) { if (p_ == 1) {
kernel::L1Normalize( kernel::L1Normalize(
X.count(0, axis), X.count(0, axis),
reduce_dim,
X.count(axis + num_axes), X.count(axis + num_axes),
reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f, reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
epsilon_, epsilon_,
X.template data<T, Context>(), X.template data<T, Context>(),
...@@ -38,8 +38,8 @@ void LpNormalizeOp<Context>::DoRunWithType() { ...@@ -38,8 +38,8 @@ void LpNormalizeOp<Context>::DoRunWithType() {
} else if (p_ == 2) { } else if (p_ == 2) {
kernel::L2Normalize( kernel::L2Normalize(
X.count(0, axis), X.count(0, axis),
reduce_dim,
X.count(axis + num_axes), X.count(axis + num_axes),
reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f, reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
epsilon_, epsilon_,
X.template data<T, Context>(), X.template data<T, Context>(),
...@@ -65,8 +65,8 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() { ...@@ -65,8 +65,8 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
if (p_ == 1) { if (p_ == 1) {
kernel::L1NormalizeGrad( kernel::L1NormalizeGrad(
X.count(0, axis), X.count(0, axis),
reduce_dim,
X.count(axis + num_axes), X.count(axis + num_axes),
reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f, reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
epsilon_, epsilon_,
dY.template data<T, Context>(), dY.template data<T, Context>(),
...@@ -76,8 +76,8 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() { ...@@ -76,8 +76,8 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
} else if (p_ == 2) { } else if (p_ == 2) {
kernel::L2NormalizeGrad( kernel::L2NormalizeGrad(
X.count(0, axis), X.count(0, axis),
reduce_dim,
X.count(axis + num_axes), X.count(axis + num_axes),
reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f, reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
epsilon_, epsilon_,
dY.template data<T, Context>(), dY.template data<T, Context>(),
......
...@@ -23,8 +23,8 @@ void BiasAddOp<Context>::DoRunWithType() { ...@@ -23,8 +23,8 @@ void BiasAddOp<Context>::DoRunWithType() {
TENSOR_FILL(B, vec64_t({C})); TENSOR_FILL(B, vec64_t({C}));
kernel::BiasAdd( kernel::BiasAdd(
N, N,
C,
S, S,
C,
X.template data<T, Context>(), X.template data<T, Context>(),
B.template data<T, Context>(), B.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
......
...@@ -118,10 +118,10 @@ template <typename T> ...@@ -118,10 +118,10 @@ template <typename T>
void ConvOpBase<Context>::Pb(const T* bias, T* y) { void ConvOpBase<Context>::Pb(const T* bias, T* y) {
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
kernel::BiasAdd( kernel::BiasAdd(
Input(0).dim(0), out_channels_, out_dim_, y, bias, y, ctx()); Input(0).dim(0), out_dim_, out_channels_, y, bias, y, ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
kernel::BiasAdd( kernel::BiasAdd(
Input(0).dim(0) * out_dim_, out_channels_, 1, y, bias, y, ctx()); Input(0).dim(0) * out_dim_, 1, out_channels_, y, bias, y, ctx());
} }
} }
......
...@@ -23,6 +23,8 @@ from dragon.core.ops.activation_ops import dropout ...@@ -23,6 +23,8 @@ from dragon.core.ops.activation_ops import dropout
from dragon.core.ops.activation_ops import drop_block2d from dragon.core.ops.activation_ops import drop_block2d
from dragon.core.ops.activation_ops import drop_path from dragon.core.ops.activation_ops import drop_path
from dragon.core.ops.activation_ops import elu from dragon.core.ops.activation_ops import elu
from dragon.core.ops.activation_ops import hardsigmoid
from dragon.core.ops.activation_ops import hardswish
from dragon.core.ops.activation_ops import leaky_relu from dragon.core.ops.activation_ops import leaky_relu
from dragon.core.ops.activation_ops import log_softmax from dragon.core.ops.activation_ops import log_softmax
from dragon.core.ops.activation_ops import prelu from dragon.core.ops.activation_ops import prelu
...@@ -30,6 +32,7 @@ from dragon.core.ops.activation_ops import relu ...@@ -30,6 +32,7 @@ from dragon.core.ops.activation_ops import relu
from dragon.core.ops.activation_ops import relu6 from dragon.core.ops.activation_ops import relu6
from dragon.core.ops.activation_ops import selu from dragon.core.ops.activation_ops import selu
from dragon.core.ops.activation_ops import softmax from dragon.core.ops.activation_ops import softmax
from dragon.core.ops.activation_ops import swish
from dragon.core.ops.math_ops import fully_connected from dragon.core.ops.math_ops import fully_connected
from dragon.core.ops.normalization_ops import batch_norm from dragon.core.ops.normalization_ops import batch_norm
from dragon.core.ops.normalization_ops import group_norm from dragon.core.ops.normalization_ops import group_norm
......
...@@ -20,7 +20,7 @@ from dragon.core.util import tls ...@@ -20,7 +20,7 @@ from dragon.core.util import tls
def device(device_type, device_index=0): def device(device_type, device_index=0):
"""Context-manager to nest the the device spec. """Context-manager to nest the device spec.
Examples: Examples:
......
...@@ -223,6 +223,96 @@ def elu(inputs, alpha=1., **kwargs): ...@@ -223,6 +223,96 @@ def elu(inputs, alpha=1., **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def hardsigmoid(inputs, alpha=0.2, beta=0.5, **kwargs):
r"""Apply the hard sigmoid function.
The **HardSigmoid** function is defined as:
.. math:: \text{HardSigmoid}(x) = \max(0, \min(1, \alpha * x + \beta))
Examples:
```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.hardsigmoid(x, inplace=False))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
alpha : float, optional, default=0.2
The value to :math:`\alpha`.
beta : float, optional, default=0.5
The value to :math:`\beta`.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
args['alpha'] = float(alpha)
args['beta'] = float(beta)
inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.HardSigmoid
if context.executing_eagerly():
return op_lib \
.instantiate(
alpha=args['alpha'],
beta=args['beta'],
).apply([inputs], inplace=inplace)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs):
r"""Apply the hard swish function.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
The **HardSwish** function is defined as:
.. math:: \text{HardSwish}(x) = x \cdot \max(0, \min(1, \alpha * x + \beta))
Examples:
```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.hardswish(x))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
alpha : float, optional, default=0.2
The value to :math:`\alpha`.
beta : float, optional, default=0.5
The value to :math:`\beta`.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
args['alpha'] = float(alpha)
args['beta'] = float(beta)
op_lib = activation_ops_lib.HardSwish
if context.executing_eagerly():
return op_lib \
.instantiate(
alpha=args['alpha'],
beta=args['beta'],
).apply([inputs])
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def leaky_relu(inputs, alpha=0.2, **kwargs): def leaky_relu(inputs, alpha=0.2, **kwargs):
r"""Apply the leaky rectified linear unit. r"""Apply the leaky rectified linear unit.
...@@ -269,7 +359,7 @@ def leaky_relu(inputs, alpha=0.2, **kwargs): ...@@ -269,7 +359,7 @@ def leaky_relu(inputs, alpha=0.2, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def log_softmax(inputs, axis=-1, **kwargs): def log_softmax(inputs, axis=-1, **kwargs):
r"""Apply the composite of logarithm and softmax. r"""Compute the composite of logarithm and softmax.
The **LogSoftmax** function is defined as: The **LogSoftmax** function is defined as:
...@@ -492,7 +582,7 @@ def selu(inputs, alpha=1.67326, gamma=1.0507, **kwargs): ...@@ -492,7 +582,7 @@ def selu(inputs, alpha=1.67326, gamma=1.0507, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def sigmoid(inputs, **kwargs): def sigmoid(inputs, **kwargs):
r"""Apply the sigmoid function. r"""Compute the sigmoid result of input.
The **Sigmoid** function is defined as: The **Sigmoid** function is defined as:
...@@ -529,7 +619,7 @@ def sigmoid(inputs, **kwargs): ...@@ -529,7 +619,7 @@ def sigmoid(inputs, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def softmax(inputs, axis=-1, **kwargs): def softmax(inputs, axis=-1, **kwargs):
r"""Apply the softmax function. r"""Compute the softmax result.
The **Softmax** function is defined as: The **Softmax** function is defined as:
...@@ -602,3 +692,40 @@ def tanh(inputs, **kwargs): ...@@ -602,3 +692,40 @@ def tanh(inputs, **kwargs):
.apply([inputs], inplace=inplace) .apply([inputs], inplace=inplace)
else: else:
return op_lib.blend('Tanh', **args) return op_lib.blend('Tanh', **args)
@OpSchema.num_inputs(1)
def swish(inputs, **kwargs):
r"""Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
The **Swish** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Examples:
```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.swish(x))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
op_lib = activation_ops_lib.Activation
if context.executing_eagerly():
return op_lib \
.instantiate(op_type='Swish') \
.apply([inputs])
else:
return op_lib.blend('Swish', **args)
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Activation(Operator): class Activation(Operator):
"""Base activation operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Activation, self).__init__(key, dev, **kwargs) super(Activation, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '') self.op_type = kwargs.get('op_type', '')
...@@ -31,6 +33,8 @@ class Activation(Operator): ...@@ -31,6 +33,8 @@ class Activation(Operator):
class Dropout(Activation): class Dropout(Activation):
"""Dropout operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs) super(Dropout, self).__init__(key, dev, **kwargs)
self.prob = kwargs.get('prob', 0.5) self.prob = kwargs.get('prob', 0.5)
...@@ -47,6 +51,8 @@ class Dropout(Activation): ...@@ -47,6 +51,8 @@ class Dropout(Activation):
class DropBlock2d(Activation): class DropBlock2d(Activation):
"""DropBlock2d operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(DropBlock2d, self).__init__(key, dev, **kwargs) super(DropBlock2d, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', 7) self.block_size = kwargs.get('block_size', 7)
...@@ -69,6 +75,8 @@ class DropBlock2d(Activation): ...@@ -69,6 +75,8 @@ class DropBlock2d(Activation):
class DropPath(Activation): class DropPath(Activation):
"""DropPath operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs) super(DropPath, self).__init__(key, dev, **kwargs)
self.prob = kwargs.get('prob', 0.2) self.prob = kwargs.get('prob', 0.2)
...@@ -85,6 +93,8 @@ class DropPath(Activation): ...@@ -85,6 +93,8 @@ class DropPath(Activation):
class Elu(Activation): class Elu(Activation):
"""Elu operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Elu, self).__init__(key, dev, **kwargs) super(Elu, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.) self.alpha = kwargs.get('alpha', 1.)
...@@ -96,7 +106,45 @@ class Elu(Activation): ...@@ -96,7 +106,45 @@ class Elu(Activation):
} }
class HardSigmoid(Activation):
"""HardSigmoid operator."""
def __init__(self, key, dev, **kwargs):
super(HardSigmoid, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.2)
self.beta = kwargs.get('beta', 0.5)
def attributes(self):
return {
'op_type': 'HardSigmoid',
'arguments': {
'alpha': float(self.alpha),
'beta': float(self.beta),
},
}
class HardSwish(Activation):
"""HardSwish operator."""
def __init__(self, key, dev, **kwargs):
super(HardSwish, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.2)
self.beta = kwargs.get('beta', 0.5)
def attributes(self):
return {
'op_type': 'HardSwish',
'arguments': {
'alpha': float(self.alpha),
'beta': float(self.beta),
},
}
class PRelu(Operator): class PRelu(Operator):
"""PRelu operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(PRelu, self).__init__(key, dev, **kwargs) super(PRelu, self).__init__(key, dev, **kwargs)
self.data_format = kwargs.get('data_format', 'NCHW') self.data_format = kwargs.get('data_format', 'NCHW')
...@@ -112,6 +160,8 @@ class PRelu(Operator): ...@@ -112,6 +160,8 @@ class PRelu(Operator):
class Relu(Activation): class Relu(Activation):
"""Relu operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Relu, self).__init__(key, dev, **kwargs) super(Relu, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.) self.alpha = kwargs.get('alpha', 0.)
...@@ -124,6 +174,8 @@ class Relu(Activation): ...@@ -124,6 +174,8 @@ class Relu(Activation):
class Relu6(Activation): class Relu6(Activation):
"""Relu6 operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Relu6, self).__init__(key, dev, **kwargs) super(Relu6, self).__init__(key, dev, **kwargs)
...@@ -135,6 +187,8 @@ class Relu6(Activation): ...@@ -135,6 +187,8 @@ class Relu6(Activation):
class Selu(Activation): class Selu(Activation):
"""Selu operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Selu, self).__init__(key, dev, **kwargs) super(Selu, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.67326) self.alpha = kwargs.get('alpha', 1.67326)
...@@ -151,6 +205,8 @@ class Selu(Activation): ...@@ -151,6 +205,8 @@ class Selu(Activation):
class Softmax(Activation): class Softmax(Activation):
"""Softmax operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Softmax, self).__init__(key, dev, **kwargs) super(Softmax, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1) self.axis = kwargs.get('axis', 1)
......
...@@ -19,6 +19,8 @@ from dragon.core.framework.ops import Operator ...@@ -19,6 +19,8 @@ from dragon.core.framework.ops import Operator
class ArgReduce(Operator): class ArgReduce(Operator):
"""ArgReduce operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ArgReduce, self).__init__(key, dev, **kwargs) super(ArgReduce, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', 'ArgMax') self.op_type = kwargs.get('op_type', 'ArgMax')
...@@ -39,6 +41,8 @@ class ArgReduce(Operator): ...@@ -39,6 +41,8 @@ class ArgReduce(Operator):
class Cast(Operator): class Cast(Operator):
"""Cast operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Cast, self).__init__(key, dev, **kwargs) super(Cast, self).__init__(key, dev, **kwargs)
self.dtype = kwargs.get('dtype', 'float32') self.dtype = kwargs.get('dtype', 'float32')
...@@ -58,6 +62,8 @@ class Cast(Operator): ...@@ -58,6 +62,8 @@ class Cast(Operator):
class ChannelAffine(Operator): class ChannelAffine(Operator):
"""ChannelAffine operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ChannelAffine, self).__init__(key, dev, **kwargs) super(ChannelAffine, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1) self.axis = kwargs.get('axis', 1)
...@@ -78,6 +84,8 @@ class ChannelAffine(Operator): ...@@ -78,6 +84,8 @@ class ChannelAffine(Operator):
class ChannelNormalize(Operator): class ChannelNormalize(Operator):
"""ChannelNormalize operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ChannelNormalize, self).__init__(key, dev, **kwargs) super(ChannelNormalize, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
...@@ -115,6 +123,8 @@ class ChannelNormalize(Operator): ...@@ -115,6 +123,8 @@ class ChannelNormalize(Operator):
class ChannelShuffle(Operator): class ChannelShuffle(Operator):
"""ChannelShuffle operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ChannelShuffle, self).__init__(key, dev, **kwargs) super(ChannelShuffle, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0) self.axis = kwargs.get('axis', 0)
...@@ -134,6 +144,8 @@ class ChannelShuffle(Operator): ...@@ -134,6 +144,8 @@ class ChannelShuffle(Operator):
class Concat(Operator): class Concat(Operator):
"""Concat operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Concat, self).__init__(key, dev, **kwargs) super(Concat, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0) self.axis = kwargs.get('axis', 0)
...@@ -149,6 +161,8 @@ class Concat(Operator): ...@@ -149,6 +161,8 @@ class Concat(Operator):
class Cumulative(Operator): class Cumulative(Operator):
"""Cumulative operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Cumulative, self).__init__(key, dev, **kwargs) super(Cumulative, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0) self.axis = kwargs.get('axis', 0)
...@@ -171,6 +185,8 @@ class Cumulative(Operator): ...@@ -171,6 +185,8 @@ class Cumulative(Operator):
class Expand(Operator): class Expand(Operator):
"""Expand operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Expand, self).__init__(key, dev, **kwargs) super(Expand, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -200,6 +216,8 @@ class Expand(Operator): ...@@ -200,6 +216,8 @@ class Expand(Operator):
class ExpandDims(Operator): class ExpandDims(Operator):
"""ExpandDims operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ExpandDims, self).__init__(key, dev, **kwargs) super(ExpandDims, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', [0]) self.axes = kwargs.get('axes', [0])
...@@ -218,6 +236,8 @@ class ExpandDims(Operator): ...@@ -218,6 +236,8 @@ class ExpandDims(Operator):
class Flatten(Operator): class Flatten(Operator):
"""Flatten operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Flatten, self).__init__(key, dev, **kwargs) super(Flatten, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0) self.axis = kwargs.get('axis', 0)
...@@ -240,6 +260,8 @@ class Flatten(Operator): ...@@ -240,6 +260,8 @@ class Flatten(Operator):
class IndexSelect(Operator): class IndexSelect(Operator):
"""IndexSelect operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(IndexSelect, self).__init__(key, dev, **kwargs) super(IndexSelect, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0) self.axis = kwargs.get('axis', 0)
...@@ -259,6 +281,8 @@ class IndexSelect(Operator): ...@@ -259,6 +281,8 @@ class IndexSelect(Operator):
class LinSpace(Operator): class LinSpace(Operator):
"""LinSpace operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(LinSpace, self).__init__(key, dev, **kwargs) super(LinSpace, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -309,6 +333,8 @@ class LinSpace(Operator): ...@@ -309,6 +333,8 @@ class LinSpace(Operator):
class MaskedSelect(Operator): class MaskedSelect(Operator):
"""MaskedSelect operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(MaskedSelect, self).__init__(key, dev, **kwargs) super(MaskedSelect, self).__init__(key, dev, **kwargs)
...@@ -320,6 +346,8 @@ class MaskedSelect(Operator): ...@@ -320,6 +346,8 @@ class MaskedSelect(Operator):
class Moments(Operator): class Moments(Operator):
"""Moments operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Moments, self).__init__(key, dev, **kwargs) super(Moments, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None) self.axes = kwargs.get('axes', None)
...@@ -339,6 +367,8 @@ class Moments(Operator): ...@@ -339,6 +367,8 @@ class Moments(Operator):
class Multinomial(Operator): class Multinomial(Operator):
"""Multinomial operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Multinomial, self).__init__(key, dev, **kwargs) super(Multinomial, self).__init__(key, dev, **kwargs)
self.epsilon = kwargs.get('epsilon', 0.) self.epsilon = kwargs.get('epsilon', 0.)
...@@ -360,6 +390,8 @@ class Multinomial(Operator): ...@@ -360,6 +390,8 @@ class Multinomial(Operator):
class NonZero(Operator): class NonZero(Operator):
"""NonZero operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(NonZero, self).__init__(key, dev, **kwargs) super(NonZero, self).__init__(key, dev, **kwargs)
...@@ -371,6 +403,8 @@ class NonZero(Operator): ...@@ -371,6 +403,8 @@ class NonZero(Operator):
class OneHot(Operator): class OneHot(Operator):
"""OneHot operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(OneHot, self).__init__(key, dev, **kwargs) super(OneHot, self).__init__(key, dev, **kwargs)
self.depth = kwargs.get('depth', 1) self.depth = kwargs.get('depth', 1)
...@@ -392,6 +426,8 @@ class OneHot(Operator): ...@@ -392,6 +426,8 @@ class OneHot(Operator):
class Pad(Operator): class Pad(Operator):
"""Pad operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Pad, self).__init__(key, dev, **kwargs) super(Pad, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -425,6 +461,8 @@ class Pad(Operator): ...@@ -425,6 +461,8 @@ class Pad(Operator):
class Permutation(Operator): class Permutation(Operator):
"""Permutation operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Permutation, self).__init__(key, dev, **kwargs) super(Permutation, self).__init__(key, dev, **kwargs)
self.dtype = kwargs.get('dtype', 'int64') self.dtype = kwargs.get('dtype', 'int64')
...@@ -453,6 +491,8 @@ class Permutation(Operator): ...@@ -453,6 +491,8 @@ class Permutation(Operator):
class Range(Operator): class Range(Operator):
"""Range operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Range, self).__init__(key, dev, **kwargs) super(Range, self).__init__(key, dev, **kwargs)
self.num_args = kwargs.get('num_args', 3) self.num_args = kwargs.get('num_args', 3)
...@@ -487,6 +527,8 @@ class Range(Operator): ...@@ -487,6 +527,8 @@ class Range(Operator):
class Reduce(Operator): class Reduce(Operator):
"""Reduce operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Reduce, self).__init__(key, dev, **kwargs) super(Reduce, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None) self.axes = kwargs.get('axes', None)
...@@ -507,6 +549,8 @@ class Reduce(Operator): ...@@ -507,6 +549,8 @@ class Reduce(Operator):
class Repeat(Operator): class Repeat(Operator):
"""Repeat operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Repeat, self).__init__(key, dev, **kwargs) super(Repeat, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 2147483647) self.axis = kwargs.get('axis', 2147483647)
...@@ -526,6 +570,8 @@ class Repeat(Operator): ...@@ -526,6 +570,8 @@ class Repeat(Operator):
class Reshape(Operator): class Reshape(Operator):
"""Reshape operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Reshape, self).__init__(key, dev, **kwargs) super(Reshape, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -555,6 +601,8 @@ class Reshape(Operator): ...@@ -555,6 +601,8 @@ class Reshape(Operator):
class Slice(Operator): class Slice(Operator):
"""Slice operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Slice, self).__init__(key, dev, **kwargs) super(Slice, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -590,6 +638,8 @@ class Slice(Operator): ...@@ -590,6 +638,8 @@ class Slice(Operator):
class Shape(Operator): class Shape(Operator):
"""Shape operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Shape, self).__init__(key, dev, **kwargs) super(Shape, self).__init__(key, dev, **kwargs)
self._device = device_spec.DeviceSpec() self._device = device_spec.DeviceSpec()
...@@ -602,6 +652,8 @@ class Shape(Operator): ...@@ -602,6 +652,8 @@ class Shape(Operator):
class Sort(Operator): class Sort(Operator):
"""Sort operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Sort, self).__init__(key, dev, **kwargs) super(Sort, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
...@@ -621,6 +673,8 @@ class Sort(Operator): ...@@ -621,6 +673,8 @@ class Sort(Operator):
class Split(Operator): class Split(Operator):
"""Split operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Split, self).__init__(key, dev, **kwargs) super(Split, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0) self.axis = kwargs.get('axis', 0)
...@@ -643,6 +697,8 @@ class Split(Operator): ...@@ -643,6 +697,8 @@ class Split(Operator):
class Squeeze(Operator): class Squeeze(Operator):
"""Squeeze operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Squeeze, self).__init__(key, dev, **kwargs) super(Squeeze, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None) self.axes = kwargs.get('axes', None)
...@@ -674,6 +730,8 @@ class Stack(Operator): ...@@ -674,6 +730,8 @@ class Stack(Operator):
class Tile(Operator): class Tile(Operator):
"""Tile operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Tile, self).__init__(key, dev, **kwargs) super(Tile, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -703,6 +761,8 @@ class Tile(Operator): ...@@ -703,6 +761,8 @@ class Tile(Operator):
class Transpose(Operator): class Transpose(Operator):
"""Transpose operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Transpose, self).__init__(key, dev, **kwargs) super(Transpose, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -732,6 +792,8 @@ class Transpose(Operator): ...@@ -732,6 +792,8 @@ class Transpose(Operator):
class TopK(Operator): class TopK(Operator):
"""TopK operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(TopK, self).__init__(key, dev, **kwargs) super(TopK, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 1) self.k = kwargs.get('k', 1)
...@@ -755,6 +817,8 @@ class TopK(Operator): ...@@ -755,6 +817,8 @@ class TopK(Operator):
class Unique(Operator): class Unique(Operator):
"""Unique operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Unique, self).__init__(key, dev, **kwargs) super(Unique, self).__init__(key, dev, **kwargs)
self.return_inverse = kwargs.get('return_inverse', False) self.return_inverse = kwargs.get('return_inverse', False)
...@@ -776,6 +840,8 @@ class Unique(Operator): ...@@ -776,6 +840,8 @@ class Unique(Operator):
class Where(Operator): class Where(Operator):
"""Where operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Where, self).__init__(key, dev, **kwargs) super(Where, self).__init__(key, dev, **kwargs)
......
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Assign(Operator): class Assign(Operator):
"""Assign operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs) super(Assign, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -54,6 +56,8 @@ class Assign(Operator): ...@@ -54,6 +56,8 @@ class Assign(Operator):
class Copy(Operator): class Copy(Operator):
"""Copy operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Copy, self).__init__(key, dev, **kwargs) super(Copy, self).__init__(key, dev, **kwargs)
...@@ -66,6 +70,8 @@ class Copy(Operator): ...@@ -66,6 +70,8 @@ class Copy(Operator):
class MaskedAssign(Operator): class MaskedAssign(Operator):
"""MaskedAssign operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(MaskedAssign, self).__init__(key, dev, **kwargs) super(MaskedAssign, self).__init__(key, dev, **kwargs)
......
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Collective(Operator): class Collective(Operator):
"""Collective operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Collective, self).__init__(key, dev, **kwargs) super(Collective, self).__init__(key, dev, **kwargs)
self.root = kwargs.get('root', 0) self.root = kwargs.get('root', 0)
......
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Initializer(Operator): class Initializer(Operator):
"""Initializer operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Initializer, self).__init__(key, dev, **kwargs) super(Initializer, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
...@@ -43,6 +45,8 @@ class Initializer(Operator): ...@@ -43,6 +45,8 @@ class Initializer(Operator):
class Eye(Initializer): class Eye(Initializer):
"""Eye operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Eye, self).__init__(key, dev, **kwargs) super(Eye, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 0) self.k = kwargs.get('k', 0)
...@@ -79,6 +83,8 @@ class Fill(Initializer): ...@@ -79,6 +83,8 @@ class Fill(Initializer):
class GlorotNormal(Initializer): class GlorotNormal(Initializer):
"""GlorotNormal operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(GlorotNormal, self).__init__(key, dev, **kwargs) super(GlorotNormal, self).__init__(key, dev, **kwargs)
self.scale = kwargs.get('scale', 2.) self.scale = kwargs.get('scale', 2.)
...@@ -99,6 +105,8 @@ class GlorotNormal(Initializer): ...@@ -99,6 +105,8 @@ class GlorotNormal(Initializer):
class GlorotUniform(Initializer): class GlorotUniform(Initializer):
"""GlorotUniform operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(GlorotUniform, self).__init__(key, dev, **kwargs) super(GlorotUniform, self).__init__(key, dev, **kwargs)
self.scale = kwargs.get('scale', 3.) self.scale = kwargs.get('scale', 3.)
...@@ -119,6 +127,8 @@ class GlorotUniform(Initializer): ...@@ -119,6 +127,8 @@ class GlorotUniform(Initializer):
class RandomNormal(Initializer): class RandomNormal(Initializer):
"""RandomNormal operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(RandomNormal, self).__init__(key, dev, **kwargs) super(RandomNormal, self).__init__(key, dev, **kwargs)
self.mean = kwargs.get('mean', 0.) self.mean = kwargs.get('mean', 0.)
...@@ -139,6 +149,8 @@ class RandomNormal(Initializer): ...@@ -139,6 +149,8 @@ class RandomNormal(Initializer):
class RandomUniform(Initializer): class RandomUniform(Initializer):
"""RandomUniform operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(RandomUniform, self).__init__(key, dev, **kwargs) super(RandomUniform, self).__init__(key, dev, **kwargs)
self.low = kwargs.get('low', 0.) self.low = kwargs.get('low', 0.)
...@@ -159,6 +171,8 @@ class RandomUniform(Initializer): ...@@ -159,6 +171,8 @@ class RandomUniform(Initializer):
class TruncatedNormal(Initializer): class TruncatedNormal(Initializer):
"""TruncatedNormal operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(TruncatedNormal, self).__init__(key, dev, **kwargs) super(TruncatedNormal, self).__init__(key, dev, **kwargs)
self.mean = kwargs.get('mean', 0.) self.mean = kwargs.get('mean', 0.)
......
...@@ -17,9 +17,11 @@ from __future__ import print_function ...@@ -17,9 +17,11 @@ from __future__ import print_function
from dragon.core.framework.ops import Operator from dragon.core.framework.ops import Operator
class _Loss(Operator): class Loss(Operator):
"""Loss operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(_Loss, self).__init__(key, dev, **kwargs) super(Loss, self).__init__(key, dev, **kwargs)
self.reduction = kwargs.get('reduction', 'MEAN') self.reduction = kwargs.get('reduction', 'MEAN')
def attributes(self): def attributes(self):
...@@ -34,17 +36,23 @@ class _Loss(Operator): ...@@ -34,17 +36,23 @@ class _Loss(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class L1Loss(_Loss): class L1Loss(Loss):
"""L1Loss operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(L1Loss, self).__init__(key, dev, **kwargs) super(L1Loss, self).__init__(key, dev, **kwargs)
class L2Loss(_Loss): class L2Loss(Loss):
"""L2Loss operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(L2Loss, self).__init__(key, dev, **kwargs) super(L2Loss, self).__init__(key, dev, **kwargs)
class NLLLoss(_Loss): class NLLLoss(Loss):
"""NLLLoss operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(NLLLoss, self).__init__(key, dev, **kwargs) super(NLLLoss, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
...@@ -61,12 +69,16 @@ class NLLLoss(_Loss): ...@@ -61,12 +69,16 @@ class NLLLoss(_Loss):
} }
class SigmoidCrossEntropy(_Loss): class SigmoidCrossEntropy(Loss):
"""SigmoidCrossEntropy operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(SigmoidCrossEntropy, self).__init__(key, dev, **kwargs) super(SigmoidCrossEntropy, self).__init__(key, dev, **kwargs)
class SmoothL1Loss(_Loss): class SmoothL1Loss(Loss):
"""SmoothL1Loss operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(SmoothL1Loss, self).__init__(key, dev, **kwargs) super(SmoothL1Loss, self).__init__(key, dev, **kwargs)
self.beta = kwargs.get('beta', 1.) self.beta = kwargs.get('beta', 1.)
...@@ -81,7 +93,9 @@ class SmoothL1Loss(_Loss): ...@@ -81,7 +93,9 @@ class SmoothL1Loss(_Loss):
} }
class SoftmaxCrossEntropy(_Loss): class SoftmaxCrossEntropy(Loss):
"""SoftmaxCrossEntropy operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(SoftmaxCrossEntropy, self).__init__(key, dev, **kwargs) super(SoftmaxCrossEntropy, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
...@@ -96,7 +110,9 @@ class SoftmaxCrossEntropy(_Loss): ...@@ -96,7 +110,9 @@ class SoftmaxCrossEntropy(_Loss):
} }
class SparseSoftmaxCrossEntropy(_Loss): class SparseSoftmaxCrossEntropy(Loss):
"""SparseSoftmaxCrossEntropy operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(SparseSoftmaxCrossEntropy, self).__init__(key, dev, **kwargs) super(SparseSoftmaxCrossEntropy, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
...@@ -113,7 +129,9 @@ class SparseSoftmaxCrossEntropy(_Loss): ...@@ -113,7 +129,9 @@ class SparseSoftmaxCrossEntropy(_Loss):
} }
class SigmoidFocalLoss(_Loss): class SigmoidFocalLoss(Loss):
"""SigmoidFocalLoss operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(SigmoidFocalLoss, self).__init__(key, dev, **kwargs) super(SigmoidFocalLoss, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
......
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Axpby(Operator): class Axpby(Operator):
"""Axpby operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Axpby, self).__init__(key, dev, **kwargs) super(Axpby, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.) self.alpha = kwargs.get('alpha', 1.)
...@@ -40,6 +42,8 @@ class Axpby(Operator): ...@@ -40,6 +42,8 @@ class Axpby(Operator):
class BinaryOp(Operator): class BinaryOp(Operator):
"""Binary operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(BinaryOp, self).__init__(key, dev, **kwargs) super(BinaryOp, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '') self.op_type = kwargs.get('op_type', '')
...@@ -52,6 +56,8 @@ class BinaryOp(Operator): ...@@ -52,6 +56,8 @@ class BinaryOp(Operator):
class Clip(Operator): class Clip(Operator):
"""Clip operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Clip, self).__init__(key, dev, **kwargs) super(Clip, self).__init__(key, dev, **kwargs)
self.low = kwargs.get('low', None) self.low = kwargs.get('low', None)
...@@ -75,6 +81,8 @@ class Clip(Operator): ...@@ -75,6 +81,8 @@ class Clip(Operator):
class FullyConnected(Operator): class FullyConnected(Operator):
"""FullyConnected operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(FullyConnected, self).__init__(key, dev, **kwargs) super(FullyConnected, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1) self.axis = kwargs.get('axis', 1)
...@@ -94,6 +102,8 @@ class FullyConnected(Operator): ...@@ -94,6 +102,8 @@ class FullyConnected(Operator):
class MatMul(Operator): class MatMul(Operator):
"""MatMul operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs) super(MatMul, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False) self.transpose_a = kwargs.get('transpose_a', False)
...@@ -113,6 +123,8 @@ class MatMul(Operator): ...@@ -113,6 +123,8 @@ class MatMul(Operator):
class UnaryOp(Operator): class UnaryOp(Operator):
"""Unary operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(UnaryOp, self).__init__(key, dev, **kwargs) super(UnaryOp, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '') self.op_type = kwargs.get('op_type', '')
......
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Metric(Operator): class Metric(Operator):
"""Metric operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Metric, self).__init__(key, dev, **kwargs) super(Metric, self).__init__(key, dev, **kwargs)
self.reduction = kwargs.get('reduction', 'MEAN') self.reduction = kwargs.get('reduction', 'MEAN')
...@@ -27,6 +29,8 @@ class Metric(Operator): ...@@ -27,6 +29,8 @@ class Metric(Operator):
class Accuracy(Metric): class Accuracy(Metric):
"""Accuracy operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Accuracy, self).__init__(key, dev, **kwargs) super(Accuracy, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1) self.axis = kwargs.get('axis', 1)
......
...@@ -36,19 +36,11 @@ def batch_norm( ...@@ -36,19 +36,11 @@ def batch_norm(
The normalization is defined as: The normalization is defined as:
.. math:: .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The running average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: .. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
Note that the number of inputs should be **5**, i.e.,
this operators is implemented into the fused version.
However, you can still fix the ``gamma`` and ``beta``,
by disabling the their gradients directly.
Parameters Parameters
---------- ----------
...@@ -91,18 +83,11 @@ def group_norm(inputs, axis=-1, group=32, epsilon=1e-5, **kwargs): ...@@ -91,18 +83,11 @@ def group_norm(inputs, axis=-1, group=32, epsilon=1e-5, **kwargs):
The normalization is defined as: The normalization is defined as:
.. math:: .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
It turns out to be **InstanceNorm**, if ``group`` is **0**, Note that it turns out to be **InstanceNorm**, if ``group`` is **0**,
or **LayerNorm**, if ``group`` is **1**. or **LayerNorm**, if ``group`` is **1**.
Note that the number of inputs should be **3**, i.e.,
this operators is implemented into the fused version.
However, you can still fix the ``gamma`` and ``beta``,
by disabling the their gradients directly.
Parameters Parameters
---------- ----------
inputs : Sequence[dragon.Tensor] inputs : Sequence[dragon.Tensor]
...@@ -141,14 +126,34 @@ def instance_norm(inputs, axis=-1, epsilon=1e-5, **kwargs): ...@@ -141,14 +126,34 @@ def instance_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
The normalization is defined as: The normalization is defined as:
.. math:: .. math:: \text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
\text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Parameters
----------
inputs : Sequence[dragon.Tensor]
The tensor ``x``, ``gamma`` and ``beta``.
axis : int, optional, default=-1
The channel axis.
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
Note that the number of inputs should be **3**, i.e., Returns
this operators is implemented into the fused version. -------
dragon.Tensor
The output tensor.
"""
return group_norm(inputs, axis=axis, group=0, epsilon=epsilon, **kwargs)
However, you can still fix the **gamma** and **beta**,
by disabling the their gradients directly. @OpSchema.num_inputs(3)
def layer_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
r"""Apply the layer normalization.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
The normalization is defined as:
.. math:: \text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Parameters Parameters
---------- ----------
...@@ -165,7 +170,7 @@ def instance_norm(inputs, axis=-1, epsilon=1e-5, **kwargs): ...@@ -165,7 +170,7 @@ def instance_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
The output tensor. The output tensor.
""" """
return group_norm(inputs, axis=axis, group=0, epsilon=epsilon, **kwargs) return group_norm(inputs, axis=axis, group=1, epsilon=epsilon, **kwargs)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
...@@ -238,40 +243,6 @@ def lp_normalize(inputs, axis=None, p=2, epsilon=1e-12, reduction='sum', **kwarg ...@@ -238,40 +243,6 @@ def lp_normalize(inputs, axis=None, p=2, epsilon=1e-12, reduction='sum', **kwarg
return op_lib.blend(**args) return op_lib.blend(**args)
@OpSchema.num_inputs(3)
def layer_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
r"""Apply the layer normalization.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
The normalization is defined as:
.. math::
\text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Note that the number of inputs should be *3*, i.e.,
this operators is implemented into the fused version.
However, you can still fix the *gamma* and *beta*,
by disabling the their gradients directly.
Parameters
----------
inputs : Sequence[dragon.Tensor]
The tensor ``x``, ``gamma`` and ``beta``.
axis : int, optional, default=-1
The channel axis.
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
Returns
-------
dragon.Tensor
The output tensor.
"""
return group_norm(inputs, axis=axis, group=1, epsilon=epsilon, **kwargs)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def local_response_norm( def local_response_norm(
inputs, inputs,
...@@ -289,8 +260,8 @@ def local_response_norm( ...@@ -289,8 +260,8 @@ def local_response_norm(
.. math:: .. math::
out_{i} = x_{i}\left(k + \frac{\alpha}{n} out_{i} = x_{i}\left(k + \frac{\alpha}{n}
\sum_{j=\max(0, i-n/2)}^{\min(N-1,i+n/2)}x_{j}^2 \sum_{j=\max(0, i-n/2)}^{\min(N-1,i+n/2)}x_{j}^2
\right)^{-\beta} \right)^{-\beta}
Parameters Parameters
---------- ----------
...@@ -347,19 +318,11 @@ def sync_batch_norm( ...@@ -347,19 +318,11 @@ def sync_batch_norm(
The normalization is defined as: The normalization is defined as:
.. math:: .. math:: \text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
\text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The running average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: .. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
Note that the number of inputs should be **5**, i.e.,
this operators is implemented into the fused version.
However, you can still fix the ``gamma`` and ``beta``,
by disabling the their gradients directly.
Parameters Parameters
---------- ----------
......
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class BatchNorm(Operator): class BatchNorm(Operator):
"""BatchNorm operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(BatchNorm, self).__init__(key, dev, **kwargs) super(BatchNorm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
...@@ -43,6 +45,8 @@ class BatchNorm(Operator): ...@@ -43,6 +45,8 @@ class BatchNorm(Operator):
class GroupNorm(Operator): class GroupNorm(Operator):
"""GroupNorm operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(GroupNorm, self).__init__(key, dev, **kwargs) super(GroupNorm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
...@@ -64,6 +68,8 @@ class GroupNorm(Operator): ...@@ -64,6 +68,8 @@ class GroupNorm(Operator):
class LpNormalize(Operator): class LpNormalize(Operator):
"""LpNormalize operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(LpNormalize, self).__init__(key, dev, **kwargs) super(LpNormalize, self).__init__(key, dev, **kwargs)
self.p = kwargs.get('p', 2) self.p = kwargs.get('p', 2)
...@@ -89,6 +95,8 @@ class LpNormalize(Operator): ...@@ -89,6 +95,8 @@ class LpNormalize(Operator):
class LocalResponseNorm(Operator): class LocalResponseNorm(Operator):
"""LocalResponseNorm operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(LocalResponseNorm, self).__init__(key, dev, **kwargs) super(LocalResponseNorm, self).__init__(key, dev, **kwargs)
self.size = kwargs.get('size', 5) self.size = kwargs.get('size', 5)
...@@ -114,6 +122,8 @@ class LocalResponseNorm(Operator): ...@@ -114,6 +122,8 @@ class LocalResponseNorm(Operator):
class SyncBatchNorm(BatchNorm): class SyncBatchNorm(BatchNorm):
"""SyncBatchNorm operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(SyncBatchNorm, self).__init__(key, dev, **kwargs) super(SyncBatchNorm, self).__init__(key, dev, **kwargs)
self.process_group = kwargs.get('process_group', None) self.process_group = kwargs.get('process_group', None)
......
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class LSTMCell(Operator): class LSTMCell(Operator):
"""LSTMCell operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(LSTMCell, self).__init__(key, dev, **kwargs) super(LSTMCell, self).__init__(key, dev, **kwargs)
...@@ -30,6 +32,8 @@ class LSTMCell(Operator): ...@@ -30,6 +32,8 @@ class LSTMCell(Operator):
class Recurrent(Operator): class Recurrent(Operator):
"""Recurrent operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Recurrent, self).__init__(key, dev, **kwargs) super(Recurrent, self).__init__(key, dev, **kwargs)
self.mode = kwargs.get('mode', 'rnn_tanh') self.mode = kwargs.get('mode', 'rnn_tanh')
...@@ -58,6 +62,8 @@ class Recurrent(Operator): ...@@ -58,6 +62,8 @@ class Recurrent(Operator):
class RNNParamSet(Operator): class RNNParamSet(Operator):
"""RNNParamSet operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(RNNParamSet, self).__init__(key, dev, **kwargs) super(RNNParamSet, self).__init__(key, dev, **kwargs)
self.param_type = kwargs.get('param_type', 'matrix') self.param_type = kwargs.get('param_type', 'matrix')
......
...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator ...@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class ParamUpdate(Operator): class ParamUpdate(Operator):
"""ParamUpdate operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ParamUpdate, self).__init__(key, dev, **kwargs) super(ParamUpdate, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '') self.op_type = kwargs.get('op_type', '')
......
...@@ -17,9 +17,11 @@ from __future__ import print_function ...@@ -17,9 +17,11 @@ from __future__ import print_function
from dragon.core.framework.ops import Operator from dragon.core.framework.ops import Operator
class _ConvNd(Operator): class ConvNd(Operator):
"""ConvNd operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(_ConvNd, self).__init__(key, dev, **kwargs) super(ConvNd, self).__init__(key, dev, **kwargs)
self.num_output = kwargs.get('dim_out', 1) self.num_output = kwargs.get('dim_out', 1)
self.kernel_shape = kwargs.get('kernel_shape', 1) self.kernel_shape = kwargs.get('kernel_shape', 1)
self.strides = kwargs.get('strides', 1) self.strides = kwargs.get('strides', 1)
...@@ -46,9 +48,11 @@ class _ConvNd(Operator): ...@@ -46,9 +48,11 @@ class _ConvNd(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class _PoolNd(Operator): class PoolNd(Operator):
"""PoolNd operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(_PoolNd, self).__init__(key, dev, **kwargs) super(PoolNd, self).__init__(key, dev, **kwargs)
self.kernel_shape = kwargs.get('kernel_shape', 1) self.kernel_shape = kwargs.get('kernel_shape', 1)
self.strides = kwargs.get('strides', 1) self.strides = kwargs.get('strides', 1)
self.pads = kwargs.get('pads', 0) self.pads = kwargs.get('pads', 0)
...@@ -78,6 +82,8 @@ class _PoolNd(Operator): ...@@ -78,6 +82,8 @@ class _PoolNd(Operator):
class BiasAdd(Operator): class BiasAdd(Operator):
"""BiasAdd operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(BiasAdd, self).__init__(key, dev, **kwargs) super(BiasAdd, self).__init__(key, dev, **kwargs)
self.data_format = kwargs.get('data_format', 'NCHW') self.data_format = kwargs.get('data_format', 'NCHW')
...@@ -93,12 +99,16 @@ class BiasAdd(Operator): ...@@ -93,12 +99,16 @@ class BiasAdd(Operator):
return self.dispatch(inputs, outputs) return self.dispatch(inputs, outputs)
class Conv2d(_ConvNd): class Conv2d(ConvNd):
"""Conv2d operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Conv2d, self).__init__(key, dev, **kwargs) super(Conv2d, self).__init__(key, dev, **kwargs)
class ConvTranspose2d(_ConvNd): class ConvTranspose2d(ConvNd):
"""ConvTranspose2d operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ConvTranspose2d, self).__init__(key, dev, **kwargs) super(ConvTranspose2d, self).__init__(key, dev, **kwargs)
self.output_padding = kwargs.get('output_padding', None) self.output_padding = kwargs.get('output_padding', None)
...@@ -121,6 +131,8 @@ class ConvTranspose2d(_ConvNd): ...@@ -121,6 +131,8 @@ class ConvTranspose2d(_ConvNd):
class DepthToSpace(Operator): class DepthToSpace(Operator):
"""DepthToSpace operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(DepthToSpace, self).__init__(key, dev, **kwargs) super(DepthToSpace, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', '2') self.block_size = kwargs.get('block_size', '2')
...@@ -139,17 +151,23 @@ class DepthToSpace(Operator): ...@@ -139,17 +151,23 @@ class DepthToSpace(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class DepthwiseConv2d(_ConvNd): class DepthwiseConv2d(ConvNd):
"""DepthwiseConv2d operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(DepthwiseConv2d, self).__init__(key, dev, **kwargs) super(DepthwiseConv2d, self).__init__(key, dev, **kwargs)
class Pool2d(_PoolNd): class Pool2d(PoolNd):
"""Pool2d operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Pool2d, self).__init__(key, dev, **kwargs) super(Pool2d, self).__init__(key, dev, **kwargs)
class Resize(Operator): class Resize(Operator):
"""Resize operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Resize, self).__init__(key, dev, **kwargs) super(Resize, self).__init__(key, dev, **kwargs)
self.num_sizes = kwargs.get('num_sizes', 0) self.num_sizes = kwargs.get('num_sizes', 0)
...@@ -193,6 +211,8 @@ class Resize(Operator): ...@@ -193,6 +211,8 @@ class Resize(Operator):
class RoiAlign(Operator): class RoiAlign(Operator):
"""RoiAlign operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(RoiAlign, self).__init__(key, dev, **kwargs) super(RoiAlign, self).__init__(key, dev, **kwargs)
self.pooled_h = kwargs.get('pooled_h', 0) self.pooled_h = kwargs.get('pooled_h', 0)
...@@ -216,6 +236,8 @@ class RoiAlign(Operator): ...@@ -216,6 +236,8 @@ class RoiAlign(Operator):
class RoiPool(Operator): class RoiPool(Operator):
"""RoiPool operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(RoiPool, self).__init__(key, dev, **kwargs) super(RoiPool, self).__init__(key, dev, **kwargs)
self.pooled_h = kwargs.get('pooled_h', 7) self.pooled_h = kwargs.get('pooled_h', 7)
...@@ -237,6 +259,8 @@ class RoiPool(Operator): ...@@ -237,6 +259,8 @@ class RoiPool(Operator):
class SpaceToDepth(Operator): class SpaceToDepth(Operator):
"""SpaceToDepth operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(SpaceToDepth, self).__init__(key, dev, **kwargs) super(SpaceToDepth, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', '2') self.block_size = kwargs.get('block_size', '2')
......
...@@ -30,6 +30,20 @@ def dropout_exporter(op_def, shape_dict, ws): ...@@ -30,6 +30,20 @@ def dropout_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('HardSigmoid')
def hardsigmoid_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
alpha, beta = 0.2, 0.5
for arg in op_def.arg:
if arg.name == 'alpha':
alpha = arg.f
elif arg.name == 'beta':
beta = arg.f
helper.add_attribute(node, 'alpha', alpha)
helper.add_attribute(node, 'beta', beta)
return node, const_tensors
@exporter.register('PRelu') @exporter.register('PRelu')
def prelu_exporter(op_def, shape_dict, ws): def prelu_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = exporter.translate(**locals())
......
...@@ -84,6 +84,47 @@ void EluGrad( ...@@ -84,6 +84,47 @@ void EluGrad(
T* dx, T* dx,
Context* ctx); Context* ctx);
/* activation.hardsigmoid */
template <typename T, class Context>
void HardSigmoid(
const int count,
const float alpha,
const float beta,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void HardSigmoidGrad(
const int count,
const float alpha,
const T* dy,
const T* y,
T* dx,
Context* ctx);
/* activation.hardswish */
template <typename T, class Context>
void HardSwish(
const int count,
const float alpha,
const float beta,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void HardSwishGrad(
const int count,
const float alpha,
const float beta,
const T* dy,
const T* x,
T* dx,
Context* ctx);
/* activation.prelu */ /* activation.prelu */
template <typename T, class Context> template <typename T, class Context>
...@@ -185,8 +226,8 @@ void SigmoidGrad(const int count, const T* dy, const T* y, T* dx, Context* ctx); ...@@ -185,8 +226,8 @@ void SigmoidGrad(const int count, const T* dy, const T* y, T* dx, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Softmax( void Softmax(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
T* y, T* y,
Context* ctx); Context* ctx);
...@@ -194,13 +235,27 @@ void Softmax( ...@@ -194,13 +235,27 @@ void Softmax(
template <typename T, class Context> template <typename T, class Context>
void SoftmaxGrad( void SoftmaxGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* dy, const T* dy,
const T* y, const T* y,
T* dx, T* dx,
Context* ctx); Context* ctx);
/* activation.hardswish */
template <typename T, class Context>
void Swish(const int count, const T* x, T* y, Context* ctx);
template <typename T, class Context>
void SwishGrad(
const int count,
const T* dy,
const T* x,
const T* y,
T* dx,
Context* ctx);
/* activation.tanh */ /* activation.tanh */
template <typename T, class Context> template <typename T, class Context>
...@@ -236,8 +291,8 @@ void ArgMin( ...@@ -236,8 +291,8 @@ void ArgMin(
template <typename T, class Context> template <typename T, class Context>
void ChannelAffine( void ChannelAffine(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
const T* w, const T* w,
const T* b, const T* b,
...@@ -275,8 +330,8 @@ void ChannelShuffle( ...@@ -275,8 +330,8 @@ void ChannelShuffle(
template <typename T, class Context> template <typename T, class Context>
void CumSum( void CumSum(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const bool exclusive, const bool exclusive,
const bool reverse, const bool reverse,
const T* x, const T* x,
...@@ -296,7 +351,7 @@ void IndexSelect( ...@@ -296,7 +351,7 @@ void IndexSelect(
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int select_dim, const int select_dim,
const int64_t* indices, const int64_t* index,
const T* x, const T* x,
T* y, T* y,
Context* ctx); Context* ctx);
...@@ -529,7 +584,7 @@ void TopSelect( ...@@ -529,7 +584,7 @@ void TopSelect(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int topk, const int select_dim,
const int largest, const int largest,
const T* x, const T* x,
T* value, T* value,
...@@ -585,8 +640,8 @@ void ReduceLossGrad( ...@@ -585,8 +640,8 @@ void ReduceLossGrad(
template <typename T, class Context> template <typename T, class Context>
void BroadcastLossGrad( void BroadcastLossGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* dy, const T* dy,
T* dx, T* dx,
Context* ctx); Context* ctx);
...@@ -596,10 +651,10 @@ void BroadcastLossGrad( ...@@ -596,10 +651,10 @@ void BroadcastLossGrad(
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
void NLLLoss( void NLLLoss(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* log_prob, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
LogitType* mask, LogitType* mask,
...@@ -608,12 +663,12 @@ void NLLLoss( ...@@ -608,12 +663,12 @@ void NLLLoss(
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
void NLLLossGrad( void NLLLossGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* log_prob, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* dx, LogitType* dlogit,
LogitType* mask, LogitType* mask,
Context* ctx); Context* ctx);
...@@ -642,12 +697,12 @@ void SigmoidCrossEntropyGrad( ...@@ -642,12 +697,12 @@ void SigmoidCrossEntropyGrad(
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
void SigmoidFocalLoss( void SigmoidFocalLoss(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float pos_alpha, const float pos_alpha,
const float neg_alpha, const float neg_alpha,
const float gamma, const float gamma,
const int neg_id, const int negative_index,
const LogitType* logit, const LogitType* logit,
const TargetType* target, const TargetType* target,
LogitType* loss, LogitType* loss,
...@@ -657,8 +712,8 @@ void SigmoidFocalLoss( ...@@ -657,8 +712,8 @@ void SigmoidFocalLoss(
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
void SigmoidFocalLossGrad( void SigmoidFocalLossGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const float pos_alpha, const float pos_alpha,
const float neg_alpha, const float neg_alpha,
const float gamma, const float gamma,
...@@ -693,8 +748,8 @@ template <typename T, class Context> ...@@ -693,8 +748,8 @@ template <typename T, class Context>
void SoftmaxCrossEntropy( void SoftmaxCrossEntropy(
const int count, const int count,
const T* prob, const T* prob,
const T* targets, const T* target,
T* losses, T* loss,
Context* ctx); Context* ctx);
/* loss.sparse_softmax_cross_entropy */ /* loss.sparse_softmax_cross_entropy */
...@@ -702,8 +757,8 @@ void SoftmaxCrossEntropy( ...@@ -702,8 +757,8 @@ void SoftmaxCrossEntropy(
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
void SparseSoftmaxCrossEntropy( void SparseSoftmaxCrossEntropy(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
...@@ -714,8 +769,8 @@ void SparseSoftmaxCrossEntropy( ...@@ -714,8 +769,8 @@ void SparseSoftmaxCrossEntropy(
template <typename LogitType, typename TargetType, class Context> template <typename LogitType, typename TargetType, class Context>
void SparseSoftmaxCrossEntropyGrad( void SparseSoftmaxCrossEntropyGrad(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitType* prob,
const TargetType* target, const TargetType* target,
...@@ -907,8 +962,8 @@ void GroupNormBackward( ...@@ -907,8 +962,8 @@ void GroupNormBackward(
template <typename T, class Context> template <typename T, class Context>
void L1Normalize( void L1Normalize(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const T* x, const T* x,
...@@ -918,8 +973,8 @@ void L1Normalize( ...@@ -918,8 +973,8 @@ void L1Normalize(
template <typename T, class Context> template <typename T, class Context>
void L1NormalizeGrad( void L1NormalizeGrad(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const T* dy, const T* dy,
...@@ -930,8 +985,8 @@ void L1NormalizeGrad( ...@@ -930,8 +985,8 @@ void L1NormalizeGrad(
template <typename T, class Context> template <typename T, class Context>
void L2Normalize( void L2Normalize(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const T* x, const T* x,
...@@ -941,8 +996,8 @@ void L2Normalize( ...@@ -941,8 +996,8 @@ void L2Normalize(
template <typename T, class Context> template <typename T, class Context>
void L2NormalizeGrad( void L2NormalizeGrad(
const int outer_dim, const int outer_dim,
const int reduce_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim,
const float scale, const float scale,
const float eps, const float eps,
const T* dy, const T* dy,
...@@ -1030,8 +1085,8 @@ void SGDUpdate( ...@@ -1030,8 +1085,8 @@ void SGDUpdate(
template <typename T, class Context> template <typename T, class Context>
void BiasAdd( void BiasAdd(
const int outer_dim, const int outer_dim,
const int axis_dim,
const int inner_dim, const int inner_dim,
const int axis_dim,
const T* x, const T* x,
const T* b, const T* b,
T* y, T* y,
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!