Commit 2f5edb5c by Ting PAN

soft-target support for softmax crossentropy

1 parent 2356c658
......@@ -303,7 +303,7 @@ void AbsGrad(const int count, const T* dy, T* dx);
/******************** loss.sigmoid_cross_entropy_loss ********************/
template <typename T, class Context>
void SigmoidCrossEntropy(const int count, const T* x, const T* targets, T* loss);
void SigmoidCrossEntropy(const int count, const T* x, const T* target, T* loss);
/******************** loss.smooth_l1_loss ********************/
......@@ -316,10 +316,7 @@ void SmoothL1Grad(const int count, const float sigma2, const T* dy, T* dx);
/******************** loss.softmax_cross_entropy_loss ********************/
template <typename T, class Context>
void SoftmaxCrossEntropy(const int count, const T* prob, const T* labels, T* loss);
template <typename T, class Context>
void SoftmaxCrossEntropyGrad(const int count, const T* prob, const T* labels, T* dx);
void SoftmaxCrossEntropy(const int count, const T* prob, const T* target, T* loss);
/******************** loss.softmax_loss ********************/
......
......@@ -8,13 +8,12 @@ namespace dragon {
template <class Context> template <typename T>
void SigmoidCrossEntropyLossOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* prob_data = prob->template mutable_data<T, Context>();
kernel::Sigmoid<T, Context>(prob->count(), Xdata, prob_data);
auto* Pdata = prob->template mutable_data<T, Context>();
kernel::Sigmoid<T, Context>(prob->count(), Xdata, Pdata);
auto* label_data = input(1).template data<T, Context>();
auto* loss_data = losses.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropy<T, Context>(input(0).count(),
Xdata, label_data, loss_data);
auto* Tdata = input(1).template data<T, Context>();
auto* Ldata = losses.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropy<T, Context>(input(0).count(), Xdata, Tdata, Ldata);
if (normalization == "UNIT") {
output(0)->ReshapeLike(losses);
......@@ -26,7 +25,7 @@ void SigmoidCrossEntropyLossOp<Context>::RunWithType() {
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = input(0).count();
else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), loss_data);
T loss = math::ASum<T, Context>(losses.count(), Ldata);
output(0)->Reshape(vector<TIndex>(1, 1));
auto* Ydata = output(0)->template mutable_data<T, CPUContext>();
Ydata[0] = loss / normalizer;
......@@ -52,11 +51,11 @@ OPERATOR_SCHEMA(SigmoidCrossEntropyLoss).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void SigmoidCrossEntropyLossGradientOp<Context>::RunWithType() {
auto* prob_data = prob->template data<T, Context>();
auto* label_data = input(1).template data<T, Context>();
auto* Pdata = prob->template data<T, Context>();
auto* Tdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(prob->count(), dXdata, prob_data);
math::Axpy<T, Context>(output(0)->count(), -1.0, label_data, dXdata);
ctx().template Copy<T, Context, Context>(prob->count(), dXdata, Pdata);
math::Axpy<T, Context>(output(0)->count(), -1.0, Tdata, dXdata);
if (normalization == "UNIT") {
auto* dYdata = input(-1).template data<T, Context>();
......
......@@ -9,17 +9,19 @@ namespace dragon {
template <class Context> template <typename T>
void SoftmaxCrossEntropyLossOp<Context>::RunWithType() {
auto* prob_data = prob->template data<T, Context>();
auto* label_data = input(1).template data<T, Context>();
auto* loss_data = losses.template mutable_data<T, Context>();
kernel::SoftmaxCrossEntropy<T, Context>(input(0).count(),
prob_data, label_data, loss_data);
auto* Pdata = prob->template data<T, Context>();
auto* Tdata = input(1).template data<T, Context>();
auto* Ldata = losses.template mutable_data<T, Context>();
kernel::SoftmaxCrossEntropy<T, Context>(input(0).count(), Pdata, Tdata, Ldata);
if (normalization == "UNIT") {
output(0)->Reshape(vector<TIndex>(1, outer_dim * inner_dim));
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::Sum<T, Context>(losses.count(),
input(0).dim(axis), inner_dim, loss_data, Ydata);
input(0).dim(axis),
inner_dim,
Ldata,
Ydata);
return;
}
......@@ -27,7 +29,7 @@ void SoftmaxCrossEntropyLossOp<Context>::RunWithType() {
if (normalization == "BATCH_SIZE") normalizer = outer_dim;
else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), loss_data);
T loss = math::ASum<T, Context>(losses.count(), Ldata);
output(0)->Reshape(vector<TIndex>(1, 1));
auto* Ydata = output(0)->template mutable_data<T, Context>();
Ydata[0] = loss / normalizer;
......@@ -55,17 +57,21 @@ OPERATOR_SCHEMA(SoftmaxCrossEntropyLoss).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void SoftmaxCrossEntropyLossGradientOp<Context>::RunWithType() {
auto* label_data = input(1).template data<T, Context>();
auto* prob_data = prob->template mutable_data<T, Context>();
auto* Tdata = input(1).template data<T, Context>();
auto* Pdata = prob->template mutable_data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::SoftmaxCrossEntropyGrad<T, Context>(output(0)->count(),
prob_data, label_data, dXdata);
ctx().template Copy<T, Context, Context>(prob->count(), dXdata, Pdata);
math::Axpy<T, Context>(output(0)->count(), -1.0, Tdata, dXdata);
if (normalization == "UNIT") {
auto* dYdata = input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>(input(0).count() / input(0).dim(axis),
input(0).dim(axis), inner_dim, 1.0, dYdata, prob_data);
math::Mul<T, Context>(output(0)->count(), prob_data, dXdata, dXdata);
input(0).dim(axis),
inner_dim,
1.0,
dYdata,
Pdata);
math::Mul<T, Context>(output(0)->count(), Pdata, dXdata, dXdata);
return;
}
......
......@@ -677,11 +677,11 @@ template<> void AbsGrad<float, CPUContext>(const int count, const float* dy, flo
template <> void SigmoidCrossEntropy<float, CPUContext>(const int count,
const float* x,
const float* targets,
const float* target,
float* loss) {
for (int i = 0; i < count; ++i) {
loss[i] = std::log(1 + std::exp(x[i] - 2 * x[i] * (x[i] >= 0)))
+ x[i] * ((x[i] >= 0) - targets[i]);
+ x[i] * ((x[i] >= 0) - target[i]);
}
}
......@@ -716,19 +716,10 @@ template<> void SmoothL1Grad<float, CPUContext>(const int count,
template <> void SoftmaxCrossEntropy<float, CPUContext>(const int count,
const float* prob,
const float* labels,
const float* target,
float* loss) {
for (int i = 0; i < count; ++i) {
loss[i] = - labels[i] * std::log(std::max(prob[i], FLT_MIN));
}
}
template <> void SoftmaxCrossEntropyGrad<float, CPUContext>(const int count,
const float* prob,
const float* labels,
float* dx) {
for (int i = 0; i < count; ++i) {
dx[i] = prob[i] - (labels[i] > 0);
loss[i] = - target[i] * std::log(std::max(prob[i], FLT_MIN));
}
}
......
......@@ -1316,45 +1316,24 @@ template<> void SmoothL1Grad<float, CUDAContext>(const int count,
template <typename T>
__global__ void _SoftmaxCrossEntropy(const int count,
const T* prob,
const T* labels,
const T* target,
T* loss) {
CUDA_KERNEL_LOOP(idx, count) {
loss[idx] = - labels[idx] * log(max(prob[idx], FLT_MIN));
loss[idx] = -target[idx] * log(max(prob[idx], FLT_MIN));
}
}
template <> void SoftmaxCrossEntropy<float, CUDAContext>(const int count,
const float* prob,
const float* labels,
const float* target,
float* loss) {
_SoftmaxCrossEntropy<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
prob,
labels,
target,
loss);
CUDA_POST_KERNEL_CHECK;
}
template <typename T>
__global__ void _SoftmaxCrossEntropyGrad(const int count,
const T* prob,
const T* labels,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
dx[idx] = prob[idx] - (labels[idx] > 0);
}
}
template <> void SoftmaxCrossEntropyGrad<float, CUDAContext>(const int count,
const float* prob,
const float* labels,
float* dx) {
_SoftmaxCrossEntropyGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
prob,
labels,
dx);
CUDA_POST_KERNEL_CHECK;
}
/******************** loss.softmax_loss ********************/
template <typename T>
......
......@@ -66,7 +66,7 @@
7. Setup MPI [Optional]
#### Linux:
- We use OpenMPI which support "cuda-aware-mpi"
- We use OpenMPI which supports "cuda-aware-mpi"
- See more:
- https://devblogs.nvidia.com/parallelforall/introduction-cuda-aware-mpi/
- https://www.open-mpi.org/faq/?category=buildcuda
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!