Commit 2f5edb5c by Ting PAN

soft-target support for softmax crossentropy

1 parent 2356c658
...@@ -303,7 +303,7 @@ void AbsGrad(const int count, const T* dy, T* dx); ...@@ -303,7 +303,7 @@ void AbsGrad(const int count, const T* dy, T* dx);
/******************** loss.sigmoid_cross_entropy_loss ********************/ /******************** loss.sigmoid_cross_entropy_loss ********************/
template <typename T, class Context> template <typename T, class Context>
void SigmoidCrossEntropy(const int count, const T* x, const T* targets, T* loss); void SigmoidCrossEntropy(const int count, const T* x, const T* target, T* loss);
/******************** loss.smooth_l1_loss ********************/ /******************** loss.smooth_l1_loss ********************/
...@@ -316,10 +316,7 @@ void SmoothL1Grad(const int count, const float sigma2, const T* dy, T* dx); ...@@ -316,10 +316,7 @@ void SmoothL1Grad(const int count, const float sigma2, const T* dy, T* dx);
/******************** loss.softmax_cross_entropy_loss ********************/ /******************** loss.softmax_cross_entropy_loss ********************/
template <typename T, class Context> template <typename T, class Context>
void SoftmaxCrossEntropy(const int count, const T* prob, const T* labels, T* loss); void SoftmaxCrossEntropy(const int count, const T* prob, const T* target, T* loss);
template <typename T, class Context>
void SoftmaxCrossEntropyGrad(const int count, const T* prob, const T* labels, T* dx);
/******************** loss.softmax_loss ********************/ /******************** loss.softmax_loss ********************/
......
...@@ -8,13 +8,12 @@ namespace dragon { ...@@ -8,13 +8,12 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void SigmoidCrossEntropyLossOp<Context>::RunWithType() { void SigmoidCrossEntropyLossOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* prob_data = prob->template mutable_data<T, Context>(); auto* Pdata = prob->template mutable_data<T, Context>();
kernel::Sigmoid<T, Context>(prob->count(), Xdata, prob_data); kernel::Sigmoid<T, Context>(prob->count(), Xdata, Pdata);
auto* label_data = input(1).template data<T, Context>(); auto* Tdata = input(1).template data<T, Context>();
auto* loss_data = losses.template mutable_data<T, Context>(); auto* Ldata = losses.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropy<T, Context>(input(0).count(), kernel::SigmoidCrossEntropy<T, Context>(input(0).count(), Xdata, Tdata, Ldata);
Xdata, label_data, loss_data);
if (normalization == "UNIT") { if (normalization == "UNIT") {
output(0)->ReshapeLike(losses); output(0)->ReshapeLike(losses);
...@@ -26,7 +25,7 @@ void SigmoidCrossEntropyLossOp<Context>::RunWithType() { ...@@ -26,7 +25,7 @@ void SigmoidCrossEntropyLossOp<Context>::RunWithType() {
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0); if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = input(0).count(); else if (normalization == "FULL") normalizer = input(0).count();
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), loss_data); T loss = math::ASum<T, Context>(losses.count(), Ldata);
output(0)->Reshape(vector<TIndex>(1, 1)); output(0)->Reshape(vector<TIndex>(1, 1));
auto* Ydata = output(0)->template mutable_data<T, CPUContext>(); auto* Ydata = output(0)->template mutable_data<T, CPUContext>();
Ydata[0] = loss / normalizer; Ydata[0] = loss / normalizer;
...@@ -52,11 +51,11 @@ OPERATOR_SCHEMA(SigmoidCrossEntropyLoss).NumInputs(2).NumOutputs(1); ...@@ -52,11 +51,11 @@ OPERATOR_SCHEMA(SigmoidCrossEntropyLoss).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void SigmoidCrossEntropyLossGradientOp<Context>::RunWithType() { void SigmoidCrossEntropyLossGradientOp<Context>::RunWithType() {
auto* prob_data = prob->template data<T, Context>(); auto* Pdata = prob->template data<T, Context>();
auto* label_data = input(1).template data<T, Context>(); auto* Tdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(prob->count(), dXdata, prob_data); ctx().template Copy<T, Context, Context>(prob->count(), dXdata, Pdata);
math::Axpy<T, Context>(output(0)->count(), -1.0, label_data, dXdata); math::Axpy<T, Context>(output(0)->count(), -1.0, Tdata, dXdata);
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
......
...@@ -9,17 +9,19 @@ namespace dragon { ...@@ -9,17 +9,19 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void SoftmaxCrossEntropyLossOp<Context>::RunWithType() { void SoftmaxCrossEntropyLossOp<Context>::RunWithType() {
auto* prob_data = prob->template data<T, Context>(); auto* Pdata = prob->template data<T, Context>();
auto* label_data = input(1).template data<T, Context>(); auto* Tdata = input(1).template data<T, Context>();
auto* loss_data = losses.template mutable_data<T, Context>(); auto* Ldata = losses.template mutable_data<T, Context>();
kernel::SoftmaxCrossEntropy<T, Context>(input(0).count(), kernel::SoftmaxCrossEntropy<T, Context>(input(0).count(), Pdata, Tdata, Ldata);
prob_data, label_data, loss_data);
if (normalization == "UNIT") { if (normalization == "UNIT") {
output(0)->Reshape(vector<TIndex>(1, outer_dim * inner_dim)); output(0)->Reshape(vector<TIndex>(1, outer_dim * inner_dim));
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::Sum<T, Context>(losses.count(), kernel::Sum<T, Context>(losses.count(),
input(0).dim(axis), inner_dim, loss_data, Ydata); input(0).dim(axis),
inner_dim,
Ldata,
Ydata);
return; return;
} }
...@@ -27,7 +29,7 @@ void SoftmaxCrossEntropyLossOp<Context>::RunWithType() { ...@@ -27,7 +29,7 @@ void SoftmaxCrossEntropyLossOp<Context>::RunWithType() {
if (normalization == "BATCH_SIZE") normalizer = outer_dim; if (normalization == "BATCH_SIZE") normalizer = outer_dim;
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), loss_data); T loss = math::ASum<T, Context>(losses.count(), Ldata);
output(0)->Reshape(vector<TIndex>(1, 1)); output(0)->Reshape(vector<TIndex>(1, 1));
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
Ydata[0] = loss / normalizer; Ydata[0] = loss / normalizer;
...@@ -55,17 +57,21 @@ OPERATOR_SCHEMA(SoftmaxCrossEntropyLoss).NumInputs(2).NumOutputs(1); ...@@ -55,17 +57,21 @@ OPERATOR_SCHEMA(SoftmaxCrossEntropyLoss).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void SoftmaxCrossEntropyLossGradientOp<Context>::RunWithType() { void SoftmaxCrossEntropyLossGradientOp<Context>::RunWithType() {
auto* label_data = input(1).template data<T, Context>(); auto* Tdata = input(1).template data<T, Context>();
auto* prob_data = prob->template mutable_data<T, Context>(); auto* Pdata = prob->template mutable_data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::SoftmaxCrossEntropyGrad<T, Context>(output(0)->count(), ctx().template Copy<T, Context, Context>(prob->count(), dXdata, Pdata);
prob_data, label_data, dXdata); math::Axpy<T, Context>(output(0)->count(), -1.0, Tdata, dXdata);
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>(input(0).count() / input(0).dim(axis), kernel::SumGrad<T, Context>(input(0).count() / input(0).dim(axis),
input(0).dim(axis), inner_dim, 1.0, dYdata, prob_data); input(0).dim(axis),
math::Mul<T, Context>(output(0)->count(), prob_data, dXdata, dXdata); inner_dim,
1.0,
dYdata,
Pdata);
math::Mul<T, Context>(output(0)->count(), Pdata, dXdata, dXdata);
return; return;
} }
......
...@@ -677,11 +677,11 @@ template<> void AbsGrad<float, CPUContext>(const int count, const float* dy, flo ...@@ -677,11 +677,11 @@ template<> void AbsGrad<float, CPUContext>(const int count, const float* dy, flo
template <> void SigmoidCrossEntropy<float, CPUContext>(const int count, template <> void SigmoidCrossEntropy<float, CPUContext>(const int count,
const float* x, const float* x,
const float* targets, const float* target,
float* loss) { float* loss) {
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
loss[i] = std::log(1 + std::exp(x[i] - 2 * x[i] * (x[i] >= 0))) loss[i] = std::log(1 + std::exp(x[i] - 2 * x[i] * (x[i] >= 0)))
+ x[i] * ((x[i] >= 0) - targets[i]); + x[i] * ((x[i] >= 0) - target[i]);
} }
} }
...@@ -716,19 +716,10 @@ template<> void SmoothL1Grad<float, CPUContext>(const int count, ...@@ -716,19 +716,10 @@ template<> void SmoothL1Grad<float, CPUContext>(const int count,
template <> void SoftmaxCrossEntropy<float, CPUContext>(const int count, template <> void SoftmaxCrossEntropy<float, CPUContext>(const int count,
const float* prob, const float* prob,
const float* labels, const float* target,
float* loss) { float* loss) {
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
loss[i] = - labels[i] * std::log(std::max(prob[i], FLT_MIN)); loss[i] = - target[i] * std::log(std::max(prob[i], FLT_MIN));
}
}
template <> void SoftmaxCrossEntropyGrad<float, CPUContext>(const int count,
const float* prob,
const float* labels,
float* dx) {
for (int i = 0; i < count; ++i) {
dx[i] = prob[i] - (labels[i] > 0);
} }
} }
......
...@@ -1316,45 +1316,24 @@ template<> void SmoothL1Grad<float, CUDAContext>(const int count, ...@@ -1316,45 +1316,24 @@ template<> void SmoothL1Grad<float, CUDAContext>(const int count,
template <typename T> template <typename T>
__global__ void _SoftmaxCrossEntropy(const int count, __global__ void _SoftmaxCrossEntropy(const int count,
const T* prob, const T* prob,
const T* labels, const T* target,
T* loss) { T* loss) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_KERNEL_LOOP(idx, count) {
loss[idx] = - labels[idx] * log(max(prob[idx], FLT_MIN)); loss[idx] = -target[idx] * log(max(prob[idx], FLT_MIN));
} }
} }
template <> void SoftmaxCrossEntropy<float, CUDAContext>(const int count, template <> void SoftmaxCrossEntropy<float, CUDAContext>(const int count,
const float* prob, const float* prob,
const float* labels, const float* target,
float* loss) { float* loss) {
_SoftmaxCrossEntropy<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _SoftmaxCrossEntropy<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
prob, prob,
labels, target,
loss); loss);
CUDA_POST_KERNEL_CHECK; CUDA_POST_KERNEL_CHECK;
} }
template <typename T>
__global__ void _SoftmaxCrossEntropyGrad(const int count,
const T* prob,
const T* labels,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
dx[idx] = prob[idx] - (labels[idx] > 0);
}
}
template <> void SoftmaxCrossEntropyGrad<float, CUDAContext>(const int count,
const float* prob,
const float* labels,
float* dx) {
_SoftmaxCrossEntropyGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
prob,
labels,
dx);
CUDA_POST_KERNEL_CHECK;
}
/******************** loss.softmax_loss ********************/ /******************** loss.softmax_loss ********************/
template <typename T> template <typename T>
......
...@@ -66,7 +66,7 @@ ...@@ -66,7 +66,7 @@
7. Setup MPI [Optional] 7. Setup MPI [Optional]
#### Linux: #### Linux:
- We use OpenMPI which support "cuda-aware-mpi" - We use OpenMPI which supports "cuda-aware-mpi"
- See more: - See more:
- https://devblogs.nvidia.com/parallelforall/introduction-cuda-aware-mpi/ - https://devblogs.nvidia.com/parallelforall/introduction-cuda-aware-mpi/
- https://www.open-mpi.org/faq/?category=buildcuda - https://www.open-mpi.org/faq/?category=buildcuda
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!