Commit 2f685b88 by Ting PAN

fix alpha and normalization for SparseSoftmaxFocalLoss

1 parent ddb76e7b
...@@ -18,18 +18,20 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex ...@@ -18,18 +18,20 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
: SparseSoftmaxCrossEntropyOp<Context>(op_def, ws), : SparseSoftmaxCrossEntropyOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")), normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)), alpha(OperatorBase::GetSingleArg<float>("alpha", 0.5)),
gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)), gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)),
use_pseudo_metric(OperatorBase::GetSingleArg<bool>("use_pseudo_metric", true)) { neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {
if (alpha == 1.0) use_pseudo_metric = false; pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0;
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float alpha, gamma; float alpha, gamma;
bool use_pseudo_metric; int neg_id;
float pos_alpha, neg_alpha;
TIndex axis, outer_dim, inner_dim; TIndex axis, outer_dim, inner_dim;
Tensor* scale; Tensor* scale;
string normalization; string normalization;
...@@ -43,13 +45,15 @@ class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyG ...@@ -43,13 +45,15 @@ class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyG
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")), normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)), gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))) {} eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float gamma, eps; float gamma, eps;
int neg_id;
TIndex axis, outer_dim, inner_dim; TIndex axis, outer_dim, inner_dim;
Tensor* scale; Tensor* scale;
string normalization; string normalization;
......
...@@ -340,8 +340,10 @@ void SparseSoftmaxFocalLoss(const int count, ...@@ -340,8 +340,10 @@ void SparseSoftmaxFocalLoss(const int count,
const int classes, const int classes,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const float alpha, const float pos_alpha,
const float neg_alpha,
const float gamma, const float gamma,
const int neg_id,
const T* prob, const T* prob,
const T* labels, const T* labels,
T* scale, T* scale,
...@@ -355,6 +357,7 @@ void SparseSoftmaxFocalLossGrad(const int count, ...@@ -355,6 +357,7 @@ void SparseSoftmaxFocalLossGrad(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const float gamma, const float gamma,
const int neg_id,
const float eps, const float eps,
const T* scale, const T* scale,
const T* prob, const T* prob,
......
...@@ -90,6 +90,7 @@ def FeedTensor(tensor, ndarray, force_cpu=False, dtype=None): ...@@ -90,6 +90,7 @@ def FeedTensor(tensor, ndarray, force_cpu=False, dtype=None):
ndarray = np.array(ndarray, dtype=dtype) ndarray = np.array(ndarray, dtype=dtype)
FeedTensorCC(tensor, ndarray, StringfyProto(dev)) FeedTensorCC(tensor, ndarray, StringfyProto(dev))
stages = { stages = {
'forward': {'include': '', 'exclude': 'Gradient'}, 'forward': {'include': '', 'exclude': 'Gradient'},
'backward': {'include': 'Gradient', 'exclude': 'Generate'}, 'backward': {'include': 'Gradient', 'exclude': 'Generate'},
...@@ -119,6 +120,7 @@ def RunGraph(graph_name, inputs=(), outputs=[], stage=None, return_outputs=True) ...@@ -119,6 +120,7 @@ def RunGraph(graph_name, inputs=(), outputs=[], stage=None, return_outputs=True)
def PrintRawGraphDef(graph_def): def PrintRawGraphDef(graph_def):
logger.info(graph_def) logger.info(graph_def)
def PrintOptimizedGraph(graph_def): def PrintOptimizedGraph(graph_def):
graph_name = graph_def.name graph_name = graph_def.name
graph_tensor = 'GraphDef_' + graph_name graph_tensor = 'GraphDef_' + graph_name
...@@ -156,6 +158,7 @@ def Snapshot(tensors, filename, prefix='', suffix='.bin', format=0): ...@@ -156,6 +158,7 @@ def Snapshot(tensors, filename, prefix='', suffix='.bin', format=0):
names = [tensor.name for tensor in tensors] names = [tensor.name for tensor in tensors]
SnapshotCC(filepath, names, format) SnapshotCC(filepath, names, format)
def Restore(filename, format=0): def Restore(filename, format=0):
if mpi.is_init(): if mpi.is_init():
if not mpi.allow_snapshot(): if not mpi.allow_snapshot():
......
...@@ -123,7 +123,7 @@ def L2Loss(inputs, normalization='BATCH_SIZE', **kwargs): ...@@ -123,7 +123,7 @@ def L2Loss(inputs, normalization='BATCH_SIZE', **kwargs):
def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=(), def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=(),
alpha=0.25, gamma=2.0, eps=1e-10, use_pseudo_metric=True, **kwargs): alpha=0.5, gamma=2.0, eps=1e-10, neg_id=-1, **kwargs):
""" """
:param inputs: a list of Tensor contains [input, label] :param inputs: a list of Tensor contains [input, label]
:param axis a int of using which axis to compute softmax :param axis a int of using which axis to compute softmax
......
...@@ -21,6 +21,9 @@ class Layer(object): ...@@ -21,6 +21,9 @@ class Layer(object):
self._param = {} self._param = {}
self._common_param = {} self._common_param = {}
self._loss_weight = None if len(LayerParameter.loss_weight) == 0 \
else LayerParameter.loss_weight
for include in LayerParameter.include: for include in LayerParameter.include:
mpi_rank = [int(rank) for rank in include.mpi_rank] mpi_rank = [int(rank) for rank in include.mpi_rank]
if len(mpi_rank) > 0: self._common_param['mpi_rank'] = mpi_rank if len(mpi_rank) > 0: self._common_param['mpi_rank'] = mpi_rank
......
...@@ -24,7 +24,9 @@ class SoftmaxWithLossLayer(Layer): ...@@ -24,7 +24,9 @@ class SoftmaxWithLossLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(SoftmaxWithLossLayer, self).Setup(bottom) super(SoftmaxWithLossLayer, self).Setup(bottom)
return ops.SparseSoftmaxCrossEntropy(bottom, **self._param) loss = ops.SparseSoftmaxCrossEntropy(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
return loss
class SigmoidCrossEntropyLossLayer(Layer): class SigmoidCrossEntropyLossLayer(Layer):
...@@ -40,7 +42,8 @@ class SigmoidCrossEntropyLossLayer(Layer): ...@@ -40,7 +42,8 @@ class SigmoidCrossEntropyLossLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(SigmoidCrossEntropyLossLayer, self).Setup(bottom) super(SigmoidCrossEntropyLossLayer, self).Setup(bottom)
return ops.SigmoidCrossEntropy(bottom, **self._param) loss = ops.SigmoidCrossEntropy(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
class L2LossLayer(Layer): class L2LossLayer(Layer):
...@@ -52,7 +55,9 @@ class L2LossLayer(Layer): ...@@ -52,7 +55,9 @@ class L2LossLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(L2LossLayer, self).Setup(bottom) super(L2LossLayer, self).Setup(bottom)
return ops.L2Loss(bottom, **self._param) loss = ops.L2Loss(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
return loss
class SmoothL1LossLayer(Layer): class SmoothL1LossLayer(Layer):
...@@ -63,7 +68,9 @@ class SmoothL1LossLayer(Layer): ...@@ -63,7 +68,9 @@ class SmoothL1LossLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(SmoothL1LossLayer, self).Setup(bottom) super(SmoothL1LossLayer, self).Setup(bottom)
return ops.SmoothL1Loss(bottom, **self._param) loss = ops.SmoothL1Loss(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
return loss
class SoftmaxWithFocalLossLayer(Layer): class SoftmaxWithFocalLossLayer(Layer):
...@@ -83,8 +90,10 @@ class SoftmaxWithFocalLossLayer(Layer): ...@@ -83,8 +90,10 @@ class SoftmaxWithFocalLossLayer(Layer):
'alpha': float(focal_loss_param.alpha), 'alpha': float(focal_loss_param.alpha),
'gamma': float(focal_loss_param.gamma), 'gamma': float(focal_loss_param.gamma),
'eps': float(focal_loss_param.eps), 'eps': float(focal_loss_param.eps),
'use_pseudo_metric': focal_loss_param.use_pseudo_metric} 'neg_id': focal_loss_param.neg_id}
def Setup(self, bottom): def Setup(self, bottom):
super(SoftmaxWithFocalLossLayer, self).Setup(bottom) super(SoftmaxWithFocalLossLayer, self).Setup(bottom)
return ops.SparseSoftmaxFocalLoss(bottom, **self._param) loss = ops.SparseSoftmaxFocalLoss(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
return loss
...@@ -1509,6 +1509,6 @@ message FocalLossParameter { ...@@ -1509,6 +1509,6 @@ message FocalLossParameter {
optional float alpha = 1 [default = 1.0]; optional float alpha = 1 [default = 1.0];
optional float gamma = 2 [default = 0.25]; optional float gamma = 2 [default = 0.25];
optional float eps = 3 [default = 1e-10]; optional float eps = 3 [default = 1e-10];
optional bool use_pseudo_metric = 4 [default = true]; optional int32 neg_id = 4 [default = -1];
} }
...@@ -26,7 +26,7 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() { ...@@ -26,7 +26,7 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() {
} }
T normalizer; T normalizer;
if (normalization == "BATCH_SIZE") normalizer = outer_dim; if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), Ldata); T loss = math::ASum<T, Context>(losses.count(), Ldata);
...@@ -76,7 +76,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -76,7 +76,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
} }
T normalizer; T normalizer;
if (normalization == "BATCH_SIZE") normalizer = outer_dim; if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
auto* dYdata = input(-1).template data<T, CPUContext>(); auto* dYdata = input(-1).template data<T, CPUContext>();
......
...@@ -33,7 +33,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() { ...@@ -33,7 +33,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() {
T normalizer; T normalizer;
if (normalization == "VALID") if (normalization == "VALID")
normalizer = math::ASum<T, Context>(valid.count(), valid_data); normalizer = math::ASum<T, Context>(valid.count(), valid_data);
else if (normalization == "BATCH_SIZE") normalizer = outer_dim; else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), loss_data); T loss = math::ASum<T, Context>(losses.count(), loss_data);
...@@ -91,7 +91,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -91,7 +91,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
T normalizer; T normalizer;
if (normalization == "VALID") normalizer = math::ASum<T, Context>(valid.count(), valid_data); if (normalization == "VALID") normalizer = math::ASum<T, Context>(valid.count(), valid_data);
else if (normalization == "BATCH_SIZE") normalizer = outer_dim; else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
auto* dYdata = input(-1).template data<T, CPUContext>(); auto* dYdata = input(-1).template data<T, CPUContext>();
......
...@@ -19,8 +19,10 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() { ...@@ -19,8 +19,10 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() {
input(0).dim(axis), input(0).dim(axis),
outer_dim, outer_dim,
inner_dim, inner_dim,
alpha, pos_alpha,
neg_alpha,
gamma, gamma,
neg_id,
prob_data, prob_data,
label_data, label_data,
scale_data, scale_data,
...@@ -29,11 +31,6 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() { ...@@ -29,11 +31,6 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() {
&this->ignore); &this->ignore);
if (normalization == "UNIT") { if (normalization == "UNIT") {
if (use_pseudo_metric) {
math::MulScalar<T, Context>(this->losses.count(),
1.0 / alpha,
loss_data);
}
output(0)->ReshapeLike(this->losses); output(0)->ReshapeLike(this->losses);
output(0)->Share(this->losses); output(0)->Share(this->losses);
return; return;
...@@ -42,11 +39,10 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() { ...@@ -42,11 +39,10 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() {
T normalizer; T normalizer;
if (normalization == "VALID") if (normalization == "VALID")
normalizer = math::ASum<T, Context>(this->valid.count(), valid_data); normalizer = math::ASum<T, Context>(this->valid.count(), valid_data);
else if (normalization == "BATCH_SIZE") normalizer = outer_dim; else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(this->losses.count(), loss_data); T loss = math::ASum<T, Context>(this->losses.count(), loss_data);
loss = use_pseudo_metric ? loss / alpha : loss;
output(0)->Reshape(vector<TIndex>(1, 1)); output(0)->Reshape(vector<TIndex>(1, 1));
auto* Ydata = output(0)->template mutable_data<T, CPUContext>(); auto* Ydata = output(0)->template mutable_data<T, CPUContext>();
Ydata[0] = loss / normalizer; Ydata[0] = loss / normalizer;
...@@ -88,6 +84,7 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() { ...@@ -88,6 +84,7 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() {
outer_dim, outer_dim,
inner_dim, inner_dim,
gamma, gamma,
neg_id,
eps, eps,
scale_data, scale_data,
prob_data, prob_data,
...@@ -110,7 +107,7 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() { ...@@ -110,7 +107,7 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() {
T normalizer; T normalizer;
if (normalization == "VALID") normalizer = math::ASum<T, Context>(this->valid.count(), valid_data); if (normalization == "VALID") normalizer = math::ASum<T, Context>(this->valid.count(), valid_data);
else if (normalization == "BATCH_SIZE") normalizer = outer_dim; else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim; else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
auto* dYdata = input(-1).template data<T, CPUContext>(); auto* dYdata = input(-1).template data<T, CPUContext>();
......
...@@ -19,6 +19,8 @@ message NetParameter { ...@@ -19,6 +19,8 @@ message NetParameter {
} }
message LayerParameter { message LayerParameter {
optional string name = 1; // the layer name optional string name = 1;
repeated BlobProto blobs = 7; repeated BlobProto blobs = 7;
} }
\ No newline at end of file
...@@ -780,8 +780,10 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count, ...@@ -780,8 +780,10 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count,
const int classes, const int classes,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const float alpha, const float pos_alpha,
const float neg_alpha,
const float gamma, const float gamma,
const int neg_id,
const float* prob, const float* prob,
const float* labels, const float* labels,
float* scale, float* scale,
...@@ -793,7 +795,7 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count, ...@@ -793,7 +795,7 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count,
const int dim = count / outer_dim; const int dim = count / outer_dim;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
scale[i] = alpha * std::pow((1.0f - prob[i]), gamma); scale[i] = std::pow((1.0f - prob[i]), gamma);
} }
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
...@@ -809,9 +811,11 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count, ...@@ -809,9 +811,11 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count,
} }
if (k == ignore->count()) { if (k == ignore->count()) {
const int t_ = i * dim + label * inner_dim + j; const int t_ = i * dim + label * inner_dim + j;
float labeled_prob = prob[t_]; float labeled_prob = std::max(labeled_prob, FLT_MIN);
loss[idx] = -scale[t_] * std::log(std::max(labeled_prob, FLT_MIN)); scale[t_] = label > neg_id ? pos_alpha * scale[t_] :
valid[idx] = 1; neg_alpha * scale[t_];
loss[idx] = -scale[t_] * std::log(labeled_prob);
valid[idx] = label > neg_id ? 1 : 0;
} }
} }
} }
...@@ -822,6 +826,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CPUContext>(const int count, ...@@ -822,6 +826,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CPUContext>(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const float gamma, const float gamma,
const int neg_id,
const float eps, const float eps,
const float* scale, const float* scale,
const float* prob, const float* prob,
...@@ -855,7 +860,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CPUContext>(const int count, ...@@ -855,7 +860,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CPUContext>(const int count,
dXdata[i_] = grad * prob[i_]; dXdata[i_] = grad * prob[i_];
} }
} }
valid[0]++; if (label > neg_id) valid[0]++;
} }
} }
} }
......
...@@ -1417,19 +1417,21 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, CUDAContext>(const int coun ...@@ -1417,19 +1417,21 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, CUDAContext>(const int coun
/******************** loss.sparse_softmax_focal_loss ********************/ /******************** loss.sparse_softmax_focal_loss ********************/
template <typename T> template <typename T>
__global__ void _FocalScale(const int count, __global__ void _SparseSoftmaxFocalScale(const int count,
const float alpha, const float gamma,
const float gamma, const T* prob,
const T* prob, T* scale) {
T* scale) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_KERNEL_LOOP(idx, count) {
scale[idx] = alpha * std::pow((1.0f - prob[idx]), gamma); scale[idx] = std::pow((1.0f - prob[idx]), gamma);
} }
} }
template <typename T> template <typename T>
__global__ void _SparseSoftmaxFocalLoss(const int count, __global__ void _SparseSoftmaxFocalLoss(const int count,
const T* scale, const float pos_alpha,
const float neg_alpha,
const int neg_id,
T* scale,
const T* prob, const T* prob,
const T* labels, const T* labels,
T* loss, T* loss,
...@@ -1445,14 +1447,16 @@ __global__ void _SparseSoftmaxFocalLoss(const int count, ...@@ -1445,14 +1447,16 @@ __global__ void _SparseSoftmaxFocalLoss(const int count,
int k; int k;
for (k = 0; k < ignore_num; k++) { for (k = 0; k < ignore_num; k++) {
if (label == ignores[k]) { if (label == ignores[k]) {
loss[idx] = valid[idx] = 0; loss[idx] = valid[idx] = 0;
break; break;
} }
} }
if (k == ignore_num) { if (k == ignore_num) {
const int t_ = (o_idx * classes + label) * inner_dim + i_idx; const int t_ = (o_idx * classes + label) * inner_dim + i_idx;
scale[t_] = label > neg_id ? pos_alpha * scale[t_] :
neg_alpha * scale[t_];
loss[idx] = -scale[t_] * std::log(max(prob[t_], FLT_MIN)); loss[idx] = -scale[t_] * std::log(max(prob[t_], FLT_MIN));
valid[idx] = 1; valid[idx] = label > neg_id ? 1 : 0;
} }
} }
} }
...@@ -1461,8 +1465,10 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count, ...@@ -1461,8 +1465,10 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count,
const int classes, const int classes,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const float alpha, const float pos_alpha,
const float neg_alpha,
const float gamma, const float gamma,
const int neg_id,
const float* prob, const float* prob,
const float* labels, const float* labels,
float* scale, float* scale,
...@@ -1472,12 +1478,14 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count, ...@@ -1472,12 +1478,14 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count,
const int* ignores = ignore->count() > 0 ? const int* ignores = ignore->count() > 0 ?
ignore->data<int, CUDAContext>() : nullptr; ignore->data<int, CUDAContext>() : nullptr;
const int num_preds = outer_dim * inner_dim; const int num_preds = outer_dim * inner_dim;
_FocalScale<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _SparseSoftmaxFocalScale<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
alpha, gamma,
gamma, prob,
prob, scale);
scale);
_SparseSoftmaxFocalLoss<float> << <GET_BLOCKS(num_preds), CUDA_NUM_THREADS >> >(num_preds, _SparseSoftmaxFocalLoss<float> << <GET_BLOCKS(num_preds), CUDA_NUM_THREADS >> >(num_preds,
pos_alpha,
neg_alpha,
neg_id,
scale, scale,
prob, prob,
labels, labels,
...@@ -1493,6 +1501,7 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count, ...@@ -1493,6 +1501,7 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count,
template <typename T> template <typename T>
__global__ void _SparseSoftmaxFocalLossGrad(const int count, __global__ void _SparseSoftmaxFocalLossGrad(const int count,
const float gamma, const float gamma,
const int neg_id,
const float eps, const float eps,
const T* scale, const T* scale,
const T* prob, const T* prob,
...@@ -1517,7 +1526,7 @@ __global__ void _SparseSoftmaxFocalLossGrad(const int count, ...@@ -1517,7 +1526,7 @@ __global__ void _SparseSoftmaxFocalLossGrad(const int count,
} else { } else {
const int t_ = (o_idx * classes + label) * inner_dim + i_idx; const int t_ = (o_idx * classes + label) * inner_dim + i_idx;
T grad = -gamma * (scale[t_] / max((1.0f - prob[t_]), eps)) T grad = -gamma * (scale[t_] / max((1.0f - prob[t_]), eps))
* std::log(max(prob[t_], FLT_MIN)) * std::log(max(prob[t_], FLT_MIN))
* prob[t_] + scale[t_]; * prob[t_] + scale[t_];
for (int c = 0; c < classes; c++) { for (int c = 0; c < classes; c++) {
const int i_ = (o_idx * classes + c) * inner_dim + i_idx; const int i_ = (o_idx * classes + c) * inner_dim + i_idx;
...@@ -1527,7 +1536,7 @@ __global__ void _SparseSoftmaxFocalLossGrad(const int count, ...@@ -1527,7 +1536,7 @@ __global__ void _SparseSoftmaxFocalLossGrad(const int count,
dx[i_] = grad * prob[i_]; dx[i_] = grad * prob[i_];
} }
} }
valid[idx] = 1; valid[idx] = label > neg_id ? 1 : 0;
} }
} }
} }
...@@ -1537,6 +1546,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CUDAContext>(const int count, ...@@ -1537,6 +1546,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CUDAContext>(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const float gamma, const float gamma,
const int neg_id,
const float eps, const float eps,
const float* scale, const float* scale,
const float* prob, const float* prob,
...@@ -1549,6 +1559,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CUDAContext>(const int count, ...@@ -1549,6 +1559,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CUDAContext>(const int count,
const int num_preds = outer_dim * inner_dim; const int num_preds = outer_dim * inner_dim;
_SparseSoftmaxFocalLossGrad<float> << <GET_BLOCKS(num_preds), CUDA_NUM_THREADS >> >(num_preds, _SparseSoftmaxFocalLossGrad<float> << <GET_BLOCKS(num_preds), CUDA_NUM_THREADS >> >(num_preds,
gamma, gamma,
neg_id,
eps, eps,
scale, scale,
prob, prob,
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!