Commit f76c693e by Ting PAN

Optimize Concat && Split Operator

Summary:
This commit uses CopyMatrix to implement concat and split generically
instead of specialized kernels.
1 parent 77dcd71d
...@@ -44,8 +44,8 @@ class EuclideanLoss(Layer): ...@@ -44,8 +44,8 @@ class EuclideanLoss(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(EuclideanLoss, self).__init__(layer_param) super(EuclideanLoss, self).__init__(layer_param)
param = layer_param.loss_param param = layer_param.loss_param
norm_dict = {0: 'mean', 1: 'mean', 2: 'batch_size', 3: 'sum'} norm_dict = {0: 'mean', 1: 'mean', 2: 'batch_mean', 3: 'sum'}
reduction = 'batch_size' reduction = 'batch_mean'
if param.HasField('normalize'): if param.HasField('normalize'):
if param.normalize: if param.normalize:
reduction = 'mean' reduction = 'mean'
...@@ -81,11 +81,11 @@ class SigmoidCrossEntropyLoss(Layer): ...@@ -81,11 +81,11 @@ class SigmoidCrossEntropyLoss(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(SigmoidCrossEntropyLoss, self).__init__(layer_param) super(SigmoidCrossEntropyLoss, self).__init__(layer_param)
param = layer_param.loss_param param = layer_param.loss_param
norm_dict = {0: 'mean', 1: 'valid', 2: 'batch_size', 3: 'sum'} norm_dict = {0: 'mean', 1: 'valid', 2: 'batch_mean', 3: 'sum'}
reduction = 'valid' reduction = 'valid'
if param.HasField('normalize'): if param.HasField('normalize'):
if not param.normalize: if not param.normalize:
reduction = 'batch_size' reduction = 'batch_mean'
else: else:
reduction = norm_dict[param.normalization] reduction = norm_dict[param.normalization]
self.arguments = {'reduction': reduction} self.arguments = {'reduction': reduction}
...@@ -123,8 +123,8 @@ class SmoothL1Loss(Layer): ...@@ -123,8 +123,8 @@ class SmoothL1Loss(Layer):
super(SmoothL1Loss, self).__init__(layer_param) super(SmoothL1Loss, self).__init__(layer_param)
param = layer_param.loss_param param = layer_param.loss_param
smooth_l1_param = layer_param.smooth_l1_loss_param smooth_l1_param = layer_param.smooth_l1_loss_param
norm_dict = {0: 'mean', 1: 'mean', 2: 'batch_size', 3: 'sum'} norm_dict = {0: 'mean', 1: 'mean', 2: 'batch_mean', 3: 'sum'}
reduction = 'batch_size' reduction = 'batch_mean'
if param.HasField('normalize'): if param.HasField('normalize'):
if param.normalize: if param.normalize:
reduction = 'mean' reduction = 'mean'
...@@ -174,11 +174,11 @@ class SoftmaxWithLoss(Layer): ...@@ -174,11 +174,11 @@ class SoftmaxWithLoss(Layer):
super(SoftmaxWithLoss, self).__init__(layer_param) super(SoftmaxWithLoss, self).__init__(layer_param)
param = layer_param.loss_param param = layer_param.loss_param
softmax_param = layer_param.softmax_param softmax_param = layer_param.softmax_param
norm_dict = {0: 'mean', 1: 'valid', 2: 'batch_size', 3: 'sum'} norm_dict = {0: 'mean', 1: 'valid', 2: 'batch_mean', 3: 'sum'}
reduction = 'valid' reduction = 'valid'
if param.HasField('normalize'): if param.HasField('normalize'):
if not param.normalize: if not param.normalize:
reduction = 'batch_size' reduction = 'batch_mean'
else: else:
reduction = norm_dict[param.normalization] reduction = norm_dict[param.normalization]
self.arguments = { self.arguments = {
......
...@@ -84,6 +84,9 @@ vm.torch.nn ...@@ -84,6 +84,9 @@ vm.torch.nn
: Apply the gumbel softmax with a temperature. : Apply the gumbel softmax with a temperature.
`[Jang et.al, 2016] <https://arxiv.org/abs/1611.01144>`_. `[Jang et.al, 2016] <https://arxiv.org/abs/1611.01144>`_.
`class KLDivLoss <nn/KLDivLoss.html>`_
: Compute the Kullback-Leibler divergence.
`class L1Loss <nn/L1Loss.html>`_ `class L1Loss <nn/L1Loss.html>`_
: Compute the element-wise absolute value difference. : Compute the element-wise absolute value difference.
...@@ -219,6 +222,7 @@ vm.torch.nn ...@@ -219,6 +222,7 @@ vm.torch.nn
nn/GroupNorm nn/GroupNorm
nn/GRU nn/GRU
nn/GumbelSoftmax nn/GumbelSoftmax
nn/KLDivLoss
nn/L1Loss nn/L1Loss
nn/LeakyReLU nn/LeakyReLU
nn/Linear nn/Linear
......
KLDivLoss
=========
.. autoclass:: dragon.vm.torch.nn.KLDivLoss
__init__
--------
.. automethod:: dragon.vm.torch.nn.KLDivLoss.__init__
.. _torch.nn.functional.kl_div(...): functional/kl_div.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
...@@ -53,6 +53,9 @@ vm.torch.nn.functional ...@@ -53,6 +53,9 @@ vm.torch.nn.functional
: Apply the group normalization to input. : Apply the group normalization to input.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_. `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`kl_div(...) <functional/kl_div.html>`_
: Compute the Kullback-Leibler divergence.
`l1_loss(...) <functional/l1_loss.html>`_ `l1_loss(...) <functional/l1_loss.html>`_
: Compute the element-wise absolute value difference. : Compute the element-wise absolute value difference.
...@@ -142,6 +145,7 @@ vm.torch.nn.functional ...@@ -142,6 +145,7 @@ vm.torch.nn.functional
functional/dropout functional/dropout
functional/elu functional/elu
functional/group_norm functional/group_norm
functional/kl_div
functional/l1_loss functional/l1_loss
functional/leaky_relu functional/leaky_relu
functional/linear functional/linear
......
kl_div
======
.. autofunction:: dragon.vm.torch.nn.functional.kl_div
.. _torch.nn.KLDivLoss(...): ../KLDivLoss.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
...@@ -148,13 +148,13 @@ class CUDAObjects { ...@@ -148,13 +148,13 @@ class CUDAObjects {
Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES]; Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES];
#endif #endif
/*! \brief The flag that alllows cuDNN or not */ /*! \brief The flag that allows cuDNN or not */
bool cudnn_enabled_ = true; bool cudnn_enabled_ = true;
/*! \brief The flag that allows cuDNN benchmark or not */ /*! \brief The flag that allows cuDNN benchmark or not */
bool cudnn_benchmark_ = false; bool cudnn_benchmark_ = false;
/*! \brief The flag thats allow cuDNN TF32 math type or not */ /*! \brief The flag that allows cuDNN TF32 math type or not */
bool cudnn_allow_tf32_ = false; bool cudnn_allow_tf32_ = false;
private: private:
......
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _Concat(
const int outer_dim,
const int inner_dim,
const int x_axis_dim,
const int y_axis_dim,
const int index,
const T* x,
T* y) {
const int offset = index * inner_dim;
const int x_cols = x_axis_dim * inner_dim;
const int y_cols = y_axis_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) {
std::memcpy(y + i * y_cols + offset, x + i * x_cols, x_cols * sizeof(T));
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Concat<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int x_axis_dim, \
const int y_axis_dim, \
const int index, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_Concat(outer_dim, inner_dim, x_axis_dim, y_axis_dim, index, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void _Concat(
const int nthreads,
const int inner_dim,
const int x_cols,
const int y_cols,
const int offset,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int i = xi / x_cols;
const int j = xi % x_cols;
y[i * y_cols + offset + j] = x[xi];
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Concat<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int x_axis_dim, \
const int y_axis_dim, \
const int index, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const int offset = index * inner_dim; \
const int x_cols = x_axis_dim * inner_dim; \
const int y_cols = y_axis_dim * inner_dim; \
const int nthreads = outer_dim * x_cols; \
_Concat<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, inner_dim, x_cols, y_cols, offset, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _Split(
const int outer_dim,
const int inner_dim,
const int x_axis_dim,
const int y_axis_dim,
const int index,
const T* x,
T* y) {
const int offset = index * inner_dim;
const int x_cols = x_axis_dim * inner_dim;
const int y_cols = y_axis_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) {
std::memcpy(y + i * y_cols, x + i * x_cols + offset, y_cols * sizeof(T));
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Split<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int x_axis_dim, \
const int y_axis_dim, \
const int index, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_Split(outer_dim, inner_dim, x_axis_dim, y_axis_dim, index, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void _Split(
const int nthreads,
const int inner_dim,
const int x_cols,
const int y_cols,
const int offset,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / y_cols;
const int j = yi % y_cols;
y[yi] = x[i * x_cols + offset + j];
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Split<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int x_axis_dim, \
const int y_axis_dim, \
const int index, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const int offset = index * inner_dim; \
const int x_cols = x_axis_dim * inner_dim; \
const int y_cols = y_axis_dim * inner_dim; \
const int nthreads = outer_dim * y_cols; \
_Split<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, inner_dim, x_cols, y_cols, offset, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
#include "dragon/operators/array/concat_op.h" #include "dragon/operators/array/concat_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/math_functions.h"
namespace dragon { namespace dragon {
...@@ -26,20 +26,20 @@ void ConcatOp<Context>::DoRunWithType() { ...@@ -26,20 +26,20 @@ void ConcatOp<Context>::DoRunWithType() {
Y_dims[axis] += Input(i).dim(axis); Y_dims[axis] += Input(i).dim(axis);
} }
int64_t index = 0; Y->Reshape(Y_dims);
auto* y = Y->Reshape(Y_dims)->template mutable_data<T, Context>(); int64_t output_offset = 0;
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); ++i) {
kernel::Concat( const auto& Xi = Input(i);
X.count(0, axis), math::CopyMatrix(
X.count(axis + 1), Xi.count(0, axis),
Input(i).dim(axis), Xi.count(axis),
Y_dims[axis], Xi.count(axis),
index, Y->count(axis),
Input(i).template data<T, Context>(), Xi.template data<T, Context>(),
y, Y->template mutable_data<T, Context>() + output_offset,
ctx()); ctx());
index += Input(i).dim(axis); output_offset += Xi.count(axis);
} }
} }
...@@ -54,21 +54,21 @@ void ConcatGradientOp<Context>::DoRunWithType() { ...@@ -54,21 +54,21 @@ void ConcatGradientOp<Context>::DoRunWithType() {
auto& dY = Input(0); auto& dY = Input(0);
CANONICALIZE_AXIS_WITH_TENSOR(dY); CANONICALIZE_AXIS_WITH_TENSOR(dY);
int64_t index = 0; int64_t input_offset = 0;
for (int i = 0; i < OutputSize(); i++) {
for (int i = 0; i < OutputSize(); ++i) {
auto &X = RESTORE_INPUT_SPEC(i), *dX = Output(i); auto &X = RESTORE_INPUT_SPEC(i), *dX = Output(i);
if (dX->has_name()) { if (dX->has_name()) {
kernel::Split( math::CopyMatrix(
dY.count(0, axis), dY.count(0, axis),
dY.count(axis + 1), X.count(axis),
dY.dim(axis), dY.count(axis),
X.dim(axis), X.count(axis),
index, dY.template data<T, Context>() + input_offset,
dY.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(), dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
index += X.dim(axis); input_offset += X.count(axis);
} }
} }
......
#include "dragon/operators/array/split_op.h" #include "dragon/operators/array/split_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -34,27 +33,27 @@ void SplitOp<Context>::DoRunWithType() { ...@@ -34,27 +33,27 @@ void SplitOp<Context>::DoRunWithType() {
// Store for the gradient calculation // Store for the gradient calculation
STORE_INPUT_SPEC(0); STORE_INPUT_SPEC(0);
int64_t index = 0, next_index;
vec64_t Y_dims(X.dims()); vec64_t Y_dims(X.dims());
int64_t input_offset = 0, total_size = 0;
for (int i = 0; i < num_splits; ++i) { for (int i = 0; i < num_splits; ++i) {
next_index = index + size_splits[i]; total_size += size_splits[i];
CHECK(size_splits[i] > 0 && next_index <= X.dim(axis)) CHECK(size_splits[i] > 0 && total_size <= X.dim(axis))
<< "\nIllegal size of splits: " << Tensor::DimString(size_splits) << "\nIllegal size of splits: " << Tensor::DimString(size_splits)
<< " for dimension: " << X.dim(axis); << " for dimension: " << X.dim(axis);
auto* Y = Output(i); auto* Y = Output(i);
if (Y->has_name()) { if (Y->has_name()) {
Y_dims[axis] = size_splits[i]; Y_dims[axis] = size_splits[i];
kernel::Split( math::CopyMatrix(
X.count(0, axis), X.count(0, axis),
X.count(axis + 1), size_splits[i] * X.count(axis + 1),
X.dim(axis), X.count(axis),
size_splits[i], size_splits[i] * X.count(axis + 1),
index, X.template data<T, Context>() + input_offset,
X.template data<T, Context>(),
Y->Reshape(Y_dims)->template mutable_data<T, Context>(), Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
index = next_index; input_offset += size_splits[i] * X.count(axis + 1);
} }
} }
...@@ -67,7 +66,6 @@ template <class Context> ...@@ -67,7 +66,6 @@ template <class Context>
template <typename T> template <typename T>
void SplitGradientOp<Context>::DoRunWithType() { void SplitGradientOp<Context>::DoRunWithType() {
auto* dX = Output(0); auto* dX = Output(0);
int num_splits = InputSize(); int num_splits = InputSize();
CANONICALIZE_AXIS_WITH_TENSOR((*dX)); CANONICALIZE_AXIS_WITH_TENSOR((*dX));
DETERMINE_RUNTIME_ARGS((*dX)); DETERMINE_RUNTIME_ARGS((*dX));
...@@ -84,21 +82,21 @@ void SplitGradientOp<Context>::DoRunWithType() { ...@@ -84,21 +82,21 @@ void SplitGradientOp<Context>::DoRunWithType() {
} }
} }
int64_t index = 0; int64_t output_offset = 0;
for (int i = 0; i < num_splits; i++) { for (int i = 0; i < num_splits; i++) {
auto& dY = Input(i); auto& dY = Input(i);
if (dY.has_name()) { if (dY.has_name()) {
kernel::Concat( math::CopyMatrix(
dX->count(0, axis), dY.count(0, axis),
dX->count(axis + 1), dY.count(axis),
size_splits[i], dY.count(axis),
dX->dim(axis), dX->count(axis),
index,
dY.template data<T, Context>(), dY.template data<T, Context>(),
dX->template mutable_data<T, Context>(), dX->template mutable_data<T, Context>() + output_offset,
ctx()); ctx());
} }
index += size_splits[i]; output_offset += size_splits[i] * dX->count(axis + 1);
} }
} }
......
#include "dragon/operators/array/stack_op.h" #include "dragon/operators/array/stack_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/math_functions.h"
namespace dragon { namespace dragon {
...@@ -26,18 +26,20 @@ void StackOp<Context>::DoRunWithType() { ...@@ -26,18 +26,20 @@ void StackOp<Context>::DoRunWithType() {
STORE_INPUT_SPEC(i); STORE_INPUT_SPEC(i);
} }
auto* y = Y->Reshape(Y_dims)->template mutable_data<T, Context>(); Y->Reshape(Y_dims);
int64_t output_offset = 0;
for (int i = 0; i < num_stacks; i++) { for (int i = 0; i < num_stacks; i++) {
kernel::Concat( const auto& Xi = Input(i);
X.count(0, axis), math::CopyMatrix(
X.count(axis), Xi.count(0, axis),
1, Xi.count(axis),
num_stacks, Xi.count(axis),
i, Y->count(axis),
Input(i).template data<T, Context>(), Xi.template data<T, Context>(),
y, Y->template mutable_data<T, Context>() + output_offset,
ctx()); ctx());
output_offset += Xi.count(axis);
} }
} }
...@@ -52,20 +54,21 @@ void StackGradientOp<Context>::DoRunWithType() { ...@@ -52,20 +54,21 @@ void StackGradientOp<Context>::DoRunWithType() {
auto &X_ref = RESTORE_INPUT_SPEC(0), &dY = Input(0); auto &X_ref = RESTORE_INPUT_SPEC(0), &dY = Input(0);
CANONICALIZE_AXIS_WITH_TENSOR_AND_OFFSET(X_ref, 1) CANONICALIZE_AXIS_WITH_TENSOR_AND_OFFSET(X_ref, 1)
int num_stacks = OutputSize(); int64_t input_offset = 0;
for (int i = 0; i < num_stacks; ++i) {
for (int i = 0; i < OutputSize(); ++i) {
auto &X = RESTORE_INPUT_SPEC(i), *dX = Output(i); auto &X = RESTORE_INPUT_SPEC(i), *dX = Output(i);
if (dX->has_name()) { if (dX->has_name()) {
kernel::Split( math::CopyMatrix(
X.count(0, axis), dY.count(0, axis),
X.count(axis),
dY.count(axis),
X.count(axis), X.count(axis),
num_stacks, dY.template data<T, Context>() + input_offset,
1,
i,
dY.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(), dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
input_offset += X.count(axis);
} }
} }
......
...@@ -45,7 +45,7 @@ void L1LossOp<Context>::DoRunWithType() { ...@@ -45,7 +45,7 @@ void L1LossOp<Context>::DoRunWithType() {
ctx()); ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_SIZE") { if (reduction_ == "BATCH_MEAN") {
normalizer *= X.dim(0); normalizer *= X.dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer *= X.count(); normalizer *= X.count();
...@@ -94,7 +94,7 @@ void L1LossGradientOp<Context>::DoRunWithType() { ...@@ -94,7 +94,7 @@ void L1LossGradientOp<Context>::DoRunWithType() {
math::Mul(dX->count(), dy, dx, dx, ctx()); math::Mul(dX->count(), dy, dx, dx, ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_SIZE") { if (reduction_ == "BATCH_MEAN") {
normalizer *= dX->dim(0); normalizer *= dX->dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer *= dX->count(); normalizer *= dX->count();
......
...@@ -45,7 +45,7 @@ void L2LossOp<Context>::DoRunWithType() { ...@@ -45,7 +45,7 @@ void L2LossOp<Context>::DoRunWithType() {
ctx()); ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_SIZE") { if (reduction_ == "BATCH_MEAN") {
normalizer *= X.dim(0); normalizer *= X.dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer *= X.count(); normalizer *= X.count();
...@@ -92,7 +92,7 @@ void L2LossGradientOp<Context>::DoRunWithType() { ...@@ -92,7 +92,7 @@ void L2LossGradientOp<Context>::DoRunWithType() {
math::Scale(dX->count(), 2.f, dx, dx, ctx()); math::Scale(dX->count(), 2.f, dx, dx, ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_SIZE") { if (reduction_ == "BATCH_MEAN") {
normalizer *= dX->dim(0); normalizer *= dX->dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer *= dX->count(); normalizer *= dX->count();
......
...@@ -48,7 +48,7 @@ void NLLLossOp<Context>::DoRunWithType() { ...@@ -48,7 +48,7 @@ void NLLLossOp<Context>::DoRunWithType() {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
normalizer = -1; // Select from mask normalizer = -1; // Select from mask
} else if (reduction_ == "BATCH_SIZE") { } else if (reduction_ == "BATCH_MEAN") {
normalizer = X.dim(0); normalizer = X.dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = num_preds; normalizer = num_preds;
...@@ -125,7 +125,7 @@ void NLLLossGradientOp<Context>::DoRunWithType() { ...@@ -125,7 +125,7 @@ void NLLLossGradientOp<Context>::DoRunWithType() {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
normalizer = -1; // Select from mask normalizer = -1; // Select from mask
} else if (reduction_ == "BATCH_SIZE") { } else if (reduction_ == "BATCH_MEAN") {
normalizer = dX->dim(0); normalizer = dX->dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = num_preds; normalizer = num_preds;
......
...@@ -38,7 +38,7 @@ void SigmoidCrossEntropyOp<Context>::DoRunWithType() { ...@@ -38,7 +38,7 @@ void SigmoidCrossEntropyOp<Context>::DoRunWithType() {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
normalizer = -1; // Select from mask normalizer = -1; // Select from mask
} else if (reduction_ == "BATCH_SIZE") { } else if (reduction_ == "BATCH_MEAN") {
normalizer = X.dim(0); normalizer = X.dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = X.count(); normalizer = X.count();
...@@ -83,7 +83,7 @@ void SigmoidCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -83,7 +83,7 @@ void SigmoidCrossEntropyGradientOp<Context>::DoRunWithType() {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
normalizer = -1; // Select from mask normalizer = -1; // Select from mask
} else if (reduction_ == "BATCH_SIZE") { } else if (reduction_ == "BATCH_MEAN") {
normalizer = dX->dim(0); normalizer = dX->dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = dX->count(); normalizer = dX->count();
......
...@@ -48,7 +48,7 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() { ...@@ -48,7 +48,7 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
normalizer = -1; // Select from mask normalizer = -1; // Select from mask
} else if (reduction_ == "BATCH_SIZE") { } else if (reduction_ == "BATCH_MEAN") {
normalizer = X.dim(0); normalizer = X.dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = X.count(); normalizer = X.count();
...@@ -125,7 +125,7 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() { ...@@ -125,7 +125,7 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
normalizer = -1; // Select from mask normalizer = -1; // Select from mask
} else if (reduction_ == "BATCH_SIZE") { } else if (reduction_ == "BATCH_MEAN") {
normalizer = dX->dim(0); normalizer = dX->dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = dX->count(); normalizer = dX->count();
......
...@@ -45,7 +45,7 @@ void SmoothL1LossOp<Context>::DoRunWithType() { ...@@ -45,7 +45,7 @@ void SmoothL1LossOp<Context>::DoRunWithType() {
ctx()); ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_SIZE") { if (reduction_ == "BATCH_MEAN") {
normalizer *= X.dim(0); normalizer *= X.dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer *= X.count(); normalizer *= X.count();
...@@ -94,7 +94,7 @@ void SmoothL1LossGradientOp<Context>::DoRunWithType() { ...@@ -94,7 +94,7 @@ void SmoothL1LossGradientOp<Context>::DoRunWithType() {
math::Mul(dX->count(), dy, dx, dx, ctx()); math::Mul(dX->count(), dy, dx, dx, ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_SIZE") { if (reduction_ == "BATCH_MEAN") {
normalizer *= dX->dim(0); normalizer *= dX->dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer *= dX->count(); normalizer *= dX->count();
......
...@@ -49,7 +49,7 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -49,7 +49,7 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() {
ctx()); ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_SIZE") { if (reduction_ == "BATCH_MEAN") {
normalizer = X.dim(0); normalizer = X.dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = num_preds; normalizer = num_preds;
...@@ -93,7 +93,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -93,7 +93,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
outer_dim, dX->dim(axis), inner_dim, dy, dx, ctx()); outer_dim, dX->dim(axis), inner_dim, dy, dx, ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "BATCH_SIZE") { if (reduction_ == "BATCH_MEAN") {
normalizer = dX->dim(0); normalizer = dX->dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = num_preds; normalizer = num_preds;
......
...@@ -58,7 +58,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -58,7 +58,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
normalizer = -1; // Select from mask normalizer = -1; // Select from mask
} else if (reduction_ == "BATCH_SIZE") { } else if (reduction_ == "BATCH_MEAN") {
normalizer = X.dim(0); normalizer = X.dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = num_preds; normalizer = num_preds;
...@@ -136,7 +136,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -136,7 +136,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
int64_t normalizer = 1; int64_t normalizer = 1;
if (reduction_ == "VALID") { if (reduction_ == "VALID") {
normalizer = -1; // Select from mask normalizer = -1; // Select from mask
} else if (reduction_ == "BATCH_SIZE") { } else if (reduction_ == "BATCH_MEAN") {
normalizer = dX->dim(0); normalizer = dX->dim(0);
} else if (reduction_ == "MEAN") { } else if (reduction_ == "MEAN") {
normalizer = num_preds; normalizer = num_preds;
......
...@@ -89,7 +89,8 @@ def l1_loss(inputs, reduction='mean', **kwargs): ...@@ -89,7 +89,8 @@ def l1_loss(inputs, reduction='mean', **kwargs):
op_lib = loss_ops_lib.L1Loss op_lib = loss_ops_lib.L1Loss
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate(reduction=args['reduction']).apply(inputs) .instantiate(reduction=args['reduction']) \
.apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -128,9 +129,8 @@ def l2_loss(inputs, reduction='mean', **kwargs): ...@@ -128,9 +129,8 @@ def l2_loss(inputs, reduction='mean', **kwargs):
op_lib = loss_ops_lib.L2Loss op_lib = loss_ops_lib.L2Loss
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(reduction=args['reduction']) \
reduction=args['reduction'], .apply(inputs)
).apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -213,9 +213,9 @@ def sigmoid_cross_entropy(inputs, reduction='valid', **kwargs): ...@@ -213,9 +213,9 @@ def sigmoid_cross_entropy(inputs, reduction='valid', **kwargs):
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
op_lib = loss_ops_lib.SigmoidCrossEntropy op_lib = loss_ops_lib.SigmoidCrossEntropy
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate( return op_lib \
reduction=args['reduction'], .instantiate(reduction=args['reduction']) \
).apply(inputs) .apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -36,7 +36,7 @@ DEFINE_SCALE_FUNC(double); ...@@ -36,7 +36,7 @@ DEFINE_SCALE_FUNC(double);
DRAGON_API void Copy<T, CPUContext>( \ DRAGON_API void Copy<T, CPUContext>( \
const int n, const T* x, T* y, CPUContext* ctx) { \ const int n, const T* x, T* y, CPUContext* ctx) { \
if (x != y && n > 0) { \ if (x != y && n > 0) { \
memcpy(y, x, n * sizeof(T)); \ memcpy(y, x, sizeof(T) * n); \
} \ } \
} }
...@@ -75,6 +75,36 @@ DEFINE_COPY_FUNC(float); ...@@ -75,6 +75,36 @@ DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double); DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC #undef DEFINE_COPY_FUNC
#define DEFINE_COPY_FUNC(T) \
template <> \
DRAGON_API void CopyMatrix<T, CPUContext>( \
const int m, \
const int n, \
const int ldx, \
const int ldy, \
const T* x, \
T* y, \
CPUContext* ctx) { \
if (m <= 0 || n <= 0) return; \
if (ldx == n && ldy == n) { \
if (x != y) memcpy(y, x, sizeof(T) * m * n); \
return; \
} \
for (int i = 0; i < m; ++i) { \
memcpy(y + ldy * i, x + ldx * i, sizeof(T) * n); \
} \
}
DEFINE_COPY_FUNC(bool);
DEFINE_COPY_FUNC(int8_t);
DEFINE_COPY_FUNC(uint8_t);
DEFINE_COPY_FUNC(int);
DEFINE_COPY_FUNC(int64_t);
DEFINE_COPY_FUNC(float16);
DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC
template <> template <>
DRAGON_API void Axpy<float16, CPUContext>( DRAGON_API void Axpy<float16, CPUContext>(
const int n, const int n,
......
...@@ -241,7 +241,7 @@ DEFINE_COPY_FUNC(float16); ...@@ -241,7 +241,7 @@ DEFINE_COPY_FUNC(float16);
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
if (x != y && n > 0) { \ if (x != y && n > 0) { \
cublas_func(ctx->cublas_handle(), n, x, incx, y, incy); \ CUBLAS_CHECK(cublas_func(ctx->cublas_handle(), n, x, incx, y, incy)); \
} \ } \
} }
...@@ -249,6 +249,38 @@ DEFINE_COPY_FUNC(float, cublasScopy); ...@@ -249,6 +249,38 @@ DEFINE_COPY_FUNC(float, cublasScopy);
DEFINE_COPY_FUNC(double, cublasDcopy); DEFINE_COPY_FUNC(double, cublasDcopy);
#undef DEFINE_COPY_FUNC #undef DEFINE_COPY_FUNC
#define DEFINE_COPY_FUNC(T) \
template <> \
DRAGON_API void CopyMatrix<T, CUDAContext>( \
const int m, \
const int n, \
const int ldx, \
const int ldy, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
if (m <= 0 || n <= 0) return; \
CUDA_CHECK(cudaMemcpy2DAsync( \
y, \
sizeof(T) * ldy, \
x, \
sizeof(T) * ldx, \
sizeof(T) * n, \
m, \
cudaMemcpyDeviceToDevice, \
ctx->cuda_stream())); \
}
DEFINE_COPY_FUNC(bool);
DEFINE_COPY_FUNC(int8_t);
DEFINE_COPY_FUNC(uint8_t);
DEFINE_COPY_FUNC(int);
DEFINE_COPY_FUNC(int64_t);
DEFINE_COPY_FUNC(float16);
DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC
#define DEFINE_AXPY_FUNC(T) \ #define DEFINE_AXPY_FUNC(T) \
template <> \ template <> \
DRAGON_API void Axpy<T, CUDAContext>( \ DRAGON_API void Axpy<T, CUDAContext>( \
...@@ -398,7 +430,7 @@ DEFINE_DOT_FUNC(double, cublasDdot); ...@@ -398,7 +430,7 @@ DEFINE_DOT_FUNC(double, cublasDdot);
const int n, const T* x, T* y, CUDAContext* ctx) { \ const int n, const T* x, T* y, CUDAContext* ctx) { \
CUBLAS_CHECK(cublasSetPointerMode( \ CUBLAS_CHECK(cublasSetPointerMode( \
ctx->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ ctx->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \
cublas_func(ctx->cublas_handle(), n, x, 1, y); \ CUBLAS_CHECK(cublas_func(ctx->cublas_handle(), n, x, 1, y)); \
} \ } \
template <> \ template <> \
DRAGON_API T ASum<T, CUDAContext>( \ DRAGON_API T ASum<T, CUDAContext>( \
...@@ -406,7 +438,7 @@ DEFINE_DOT_FUNC(double, cublasDdot); ...@@ -406,7 +438,7 @@ DEFINE_DOT_FUNC(double, cublasDdot);
T ret; \ T ret; \
CUBLAS_CHECK( \ CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
cublas_func(ctx->cublas_handle(), n, x, 1, &ret); \ CUBLAS_CHECK(cublas_func(ctx->cublas_handle(), n, x, 1, &ret)); \
return ret; \ return ret; \
} }
......
...@@ -41,6 +41,16 @@ DRAGON_API void Copy( ...@@ -41,6 +41,16 @@ DRAGON_API void Copy(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void CopyMatrix(
const int m,
const int n,
const int ldx,
const int ldy,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
DRAGON_API void DRAGON_API void
Axpy(const int n, const float alpha, const T* x, T* y, Context* ctx); Axpy(const int n, const float alpha, const T* x, T* y, Context* ctx);
......
...@@ -270,19 +270,6 @@ void ChannelShuffle( ...@@ -270,19 +270,6 @@ void ChannelShuffle(
T* y, T* y,
Context* ctx); Context* ctx);
/* array.concat */
template <typename T, class Context>
void Concat(
const int outer_dim,
const int inner_dim,
const int x_axis_dim,
const int y_cat_dim,
const int index,
const T* x,
T* y,
Context* ctx);
/* array.cumsum */ /* array.cumsum */
template <typename T, class Context> template <typename T, class Context>
...@@ -494,19 +481,6 @@ void SliceGrad( ...@@ -494,19 +481,6 @@ void SliceGrad(
T* dx, T* dx,
Context* ctx); Context* ctx);
/* array.split */
template <typename T, class Context>
void Split(
const int outer_dim,
const int inner_dim,
const int x_axis_dim,
const int y_axis_dim,
const int index,
const T* x,
T* y,
Context* ctx);
/* array.tile */ /* array.tile */
template <typename T, class Context> template <typename T, class Context>
......
...@@ -48,6 +48,7 @@ from dragon.vm.torch.core.nn.modules.linear import Linear ...@@ -48,6 +48,7 @@ from dragon.vm.torch.core.nn.modules.linear import Linear
from dragon.vm.torch.core.nn.modules.loss import CTCLoss from dragon.vm.torch.core.nn.modules.loss import CTCLoss
from dragon.vm.torch.core.nn.modules.loss import BCEWithLogitsLoss from dragon.vm.torch.core.nn.modules.loss import BCEWithLogitsLoss
from dragon.vm.torch.core.nn.modules.loss import CrossEntropyLoss from dragon.vm.torch.core.nn.modules.loss import CrossEntropyLoss
from dragon.vm.torch.core.nn.modules.loss import KLDivLoss
from dragon.vm.torch.core.nn.modules.loss import L1Loss from dragon.vm.torch.core.nn.modules.loss import L1Loss
from dragon.vm.torch.core.nn.modules.loss import MSELoss from dragon.vm.torch.core.nn.modules.loss import MSELoss
from dragon.vm.torch.core.nn.modules.loss import NLLLoss from dragon.vm.torch.core.nn.modules.loss import NLLLoss
......
...@@ -27,6 +27,7 @@ from dragon.vm.torch.core.nn.functional import drop_path ...@@ -27,6 +27,7 @@ from dragon.vm.torch.core.nn.functional import drop_path
from dragon.vm.torch.core.nn.functional import dropout from dragon.vm.torch.core.nn.functional import dropout
from dragon.vm.torch.core.nn.functional import elu from dragon.vm.torch.core.nn.functional import elu
from dragon.vm.torch.core.nn.functional import group_norm from dragon.vm.torch.core.nn.functional import group_norm
from dragon.vm.torch.core.nn.functional import kl_div
from dragon.vm.torch.core.nn.functional import l1_loss from dragon.vm.torch.core.nn.functional import l1_loss
from dragon.vm.torch.core.nn.functional import leaky_relu from dragon.vm.torch.core.nn.functional import leaky_relu
from dragon.vm.torch.core.nn.functional import linear from dragon.vm.torch.core.nn.functional import linear
......
...@@ -18,6 +18,7 @@ from dragon.core.util import nest ...@@ -18,6 +18,7 @@ from dragon.core.util import nest
from dragon.vm.torch.core.nn.modules import _functions from dragon.vm.torch.core.nn.modules import _functions
from dragon.vm.torch.core.nn import _reduction from dragon.vm.torch.core.nn import _reduction
from dragon.vm.torch.core.nn.modules import utils from dragon.vm.torch.core.nn.modules import utils
from dragon.vm.torch.core.ops.math import functional as math_funcs
def avg_pool2d( def avg_pool2d(
...@@ -715,6 +716,59 @@ def interpolate( ...@@ -715,6 +716,59 @@ def interpolate(
).apply(input, size, scale_factor) ).apply(input, size, scale_factor)
def kl_div(
input,
target,
size_average=None,
reduce=None,
reduction='mean',
log_target=False,
):
"""Compute the Kullback-Leibler divergence.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
target : dragon.vm.torch.Tensor
The target tensor.
size_average : bool, optional
Whether to average the loss.
reduce : bool, optional
Whether to reduce the loss.
reduction : {'none', 'batchmean', 'mean', 'sum'}, optional
The reduce method.
log_target : bool, optional, default=False
The flag indicating whether ``target`` is passed in log space.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.KLDivLoss(...)`_
"""
if size_average is not None or reduce is not None:
reduction = _reduction.legacy_get_string(size_average, reduce)
else:
reduction = reduction
if not log_target:
out = target * (math_funcs.log(target) - input)
else:
out = math_funcs.exp(target) * (target - input)
if reduction == 'none':
return out
elif reduction == 'batchmean':
return out.sum() / input.size()[0]
elif reduction == 'mean':
return out.mean()
else:
return out.sum()
def l1_loss( def l1_loss(
input, input,
target, target,
......
...@@ -256,6 +256,58 @@ class CrossEntropyLoss(_WeightedLoss): ...@@ -256,6 +256,58 @@ class CrossEntropyLoss(_WeightedLoss):
) )
class KLDivLoss(_Loss):
"""Compute the Kullback-Leibler divergence.
Examples:
```python
m = torch.nn.KLDivLoss()
eps = 1e-12 # Epsilon to avoid log(0)
# Compute KL(P || Q)
q = torch.tensor([0.0, 0.1, 0.2, 0.3, 1.0])
p = torch.tensor([0.0, 0.3, 0.2, 0.1, 0.9])
loss = m(torch.log(torch.clamp(q, eps)), torch.clamp(p, eps))
```
See Also
--------
`torch.nn.functional.kl_div(...)`_
"""
def __init__(
self,
size_average=None,
reduce=None,
reduction='mean',
log_target=False,
):
"""Create a ``KDivLoss`` module.
Parameters
----------
size_average : bool, optional
**True** to set the ``reduction`` to *'mean'*.
reduce : bool, optional
**True** to set the ``reduction`` to *'sum'* or *'mean'*.
reduction : {'none', 'batchmean', 'mean', 'sum'}, optional
The reduce method.
log_target : bool, optional, default=False
The flag indicating whether ``target`` is passed in log space.
"""
super(KDivLoss, self).__init__(size_average, reduce, reduction)
self.log_target = log_target
def forward(self, input, target):
return F.kl_div(
input, target,
reduction=self.reduction,
log_target=self.log_target,
)
class L1Loss(_Loss): class L1Loss(_Loss):
r"""Compute the element-wise absolute value difference. r"""Compute the element-wise absolute value difference.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!