Commit 2f685b88 by Ting PAN

fix alpha and normalization for SparseSoftmaxFocalLoss

1 parent ddb76e7b
......@@ -18,18 +18,20 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
: SparseSoftmaxCrossEntropyOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)),
alpha(OperatorBase::GetSingleArg<float>("alpha", 0.5)),
gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)),
use_pseudo_metric(OperatorBase::GetSingleArg<bool>("use_pseudo_metric", true)) {
if (alpha == 1.0) use_pseudo_metric = false;
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {
pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float alpha, gamma;
bool use_pseudo_metric;
float alpha, gamma;
int neg_id;
float pos_alpha, neg_alpha;
TIndex axis, outer_dim, inner_dim;
Tensor* scale;
string normalization;
......@@ -43,13 +45,15 @@ class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyG
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))) {}
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float gamma, eps;
int neg_id;
TIndex axis, outer_dim, inner_dim;
Tensor* scale;
string normalization;
......
......@@ -340,8 +340,10 @@ void SparseSoftmaxFocalLoss(const int count,
const int classes,
const int outer_dim,
const int inner_dim,
const float alpha,
const float pos_alpha,
const float neg_alpha,
const float gamma,
const int neg_id,
const T* prob,
const T* labels,
T* scale,
......@@ -355,6 +357,7 @@ void SparseSoftmaxFocalLossGrad(const int count,
const int outer_dim,
const int inner_dim,
const float gamma,
const int neg_id,
const float eps,
const T* scale,
const T* prob,
......
......@@ -90,6 +90,7 @@ def FeedTensor(tensor, ndarray, force_cpu=False, dtype=None):
ndarray = np.array(ndarray, dtype=dtype)
FeedTensorCC(tensor, ndarray, StringfyProto(dev))
stages = {
'forward': {'include': '', 'exclude': 'Gradient'},
'backward': {'include': 'Gradient', 'exclude': 'Generate'},
......@@ -119,6 +120,7 @@ def RunGraph(graph_name, inputs=(), outputs=[], stage=None, return_outputs=True)
def PrintRawGraphDef(graph_def):
logger.info(graph_def)
def PrintOptimizedGraph(graph_def):
graph_name = graph_def.name
graph_tensor = 'GraphDef_' + graph_name
......@@ -156,6 +158,7 @@ def Snapshot(tensors, filename, prefix='', suffix='.bin', format=0):
names = [tensor.name for tensor in tensors]
SnapshotCC(filepath, names, format)
def Restore(filename, format=0):
if mpi.is_init():
if not mpi.allow_snapshot():
......
......@@ -123,7 +123,7 @@ def L2Loss(inputs, normalization='BATCH_SIZE', **kwargs):
def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=(),
alpha=0.25, gamma=2.0, eps=1e-10, use_pseudo_metric=True, **kwargs):
alpha=0.5, gamma=2.0, eps=1e-10, neg_id=-1, **kwargs):
"""
:param inputs: a list of Tensor contains [input, label]
:param axis a int of using which axis to compute softmax
......
......@@ -21,6 +21,9 @@ class Layer(object):
self._param = {}
self._common_param = {}
self._loss_weight = None if len(LayerParameter.loss_weight) == 0 \
else LayerParameter.loss_weight
for include in LayerParameter.include:
mpi_rank = [int(rank) for rank in include.mpi_rank]
if len(mpi_rank) > 0: self._common_param['mpi_rank'] = mpi_rank
......
......@@ -24,7 +24,9 @@ class SoftmaxWithLossLayer(Layer):
def Setup(self, bottom):
super(SoftmaxWithLossLayer, self).Setup(bottom)
return ops.SparseSoftmaxCrossEntropy(bottom, **self._param)
loss = ops.SparseSoftmaxCrossEntropy(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
return loss
class SigmoidCrossEntropyLossLayer(Layer):
......@@ -40,7 +42,8 @@ class SigmoidCrossEntropyLossLayer(Layer):
def Setup(self, bottom):
super(SigmoidCrossEntropyLossLayer, self).Setup(bottom)
return ops.SigmoidCrossEntropy(bottom, **self._param)
loss = ops.SigmoidCrossEntropy(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
class L2LossLayer(Layer):
......@@ -52,7 +55,9 @@ class L2LossLayer(Layer):
def Setup(self, bottom):
super(L2LossLayer, self).Setup(bottom)
return ops.L2Loss(bottom, **self._param)
loss = ops.L2Loss(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
return loss
class SmoothL1LossLayer(Layer):
......@@ -63,7 +68,9 @@ class SmoothL1LossLayer(Layer):
def Setup(self, bottom):
super(SmoothL1LossLayer, self).Setup(bottom)
return ops.SmoothL1Loss(bottom, **self._param)
loss = ops.SmoothL1Loss(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
return loss
class SoftmaxWithFocalLossLayer(Layer):
......@@ -83,8 +90,10 @@ class SoftmaxWithFocalLossLayer(Layer):
'alpha': float(focal_loss_param.alpha),
'gamma': float(focal_loss_param.gamma),
'eps': float(focal_loss_param.eps),
'use_pseudo_metric': focal_loss_param.use_pseudo_metric}
'neg_id': focal_loss_param.neg_id}
def Setup(self, bottom):
super(SoftmaxWithFocalLossLayer, self).Setup(bottom)
return ops.SparseSoftmaxFocalLoss(bottom, **self._param)
loss = ops.SparseSoftmaxFocalLoss(bottom, **self._param)
if self._loss_weight is not None: loss *= self._loss_weight
return loss
......@@ -1509,6 +1509,6 @@ message FocalLossParameter {
optional float alpha = 1 [default = 1.0];
optional float gamma = 2 [default = 0.25];
optional float eps = 3 [default = 1e-10];
optional bool use_pseudo_metric = 4 [default = true];
optional int32 neg_id = 4 [default = -1];
}
......@@ -26,7 +26,7 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() {
}
T normalizer;
if (normalization == "BATCH_SIZE") normalizer = outer_dim;
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), Ldata);
......@@ -76,7 +76,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
}
T normalizer;
if (normalization == "BATCH_SIZE") normalizer = outer_dim;
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1;
auto* dYdata = input(-1).template data<T, CPUContext>();
......
......@@ -33,7 +33,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() {
T normalizer;
if (normalization == "VALID")
normalizer = math::ASum<T, Context>(valid.count(), valid_data);
else if (normalization == "BATCH_SIZE") normalizer = outer_dim;
else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), loss_data);
......@@ -91,7 +91,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
T normalizer;
if (normalization == "VALID") normalizer = math::ASum<T, Context>(valid.count(), valid_data);
else if (normalization == "BATCH_SIZE") normalizer = outer_dim;
else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1;
auto* dYdata = input(-1).template data<T, CPUContext>();
......
......@@ -19,8 +19,10 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() {
input(0).dim(axis),
outer_dim,
inner_dim,
alpha,
pos_alpha,
neg_alpha,
gamma,
neg_id,
prob_data,
label_data,
scale_data,
......@@ -29,11 +31,6 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() {
&this->ignore);
if (normalization == "UNIT") {
if (use_pseudo_metric) {
math::MulScalar<T, Context>(this->losses.count(),
1.0 / alpha,
loss_data);
}
output(0)->ReshapeLike(this->losses);
output(0)->Share(this->losses);
return;
......@@ -42,11 +39,10 @@ void SparseSoftmaxFocalLossOp<Context>::RunWithType() {
T normalizer;
if (normalization == "VALID")
normalizer = math::ASum<T, Context>(this->valid.count(), valid_data);
else if (normalization == "BATCH_SIZE") normalizer = outer_dim;
else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(this->losses.count(), loss_data);
loss = use_pseudo_metric ? loss / alpha : loss;
output(0)->Reshape(vector<TIndex>(1, 1));
auto* Ydata = output(0)->template mutable_data<T, CPUContext>();
Ydata[0] = loss / normalizer;
......@@ -88,6 +84,7 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() {
outer_dim,
inner_dim,
gamma,
neg_id,
eps,
scale_data,
prob_data,
......@@ -110,7 +107,7 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() {
T normalizer;
if (normalization == "VALID") normalizer = math::ASum<T, Context>(this->valid.count(), valid_data);
else if (normalization == "BATCH_SIZE") normalizer = outer_dim;
else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = outer_dim * inner_dim;
else if (normalization == "NONE") normalizer = 1;
auto* dYdata = input(-1).template data<T, CPUContext>();
......
......@@ -19,6 +19,8 @@ message NetParameter {
}
message LayerParameter {
optional string name = 1; // the layer name
optional string name = 1;
repeated BlobProto blobs = 7;
}
\ No newline at end of file
}
......@@ -780,8 +780,10 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count,
const int classes,
const int outer_dim,
const int inner_dim,
const float alpha,
const float pos_alpha,
const float neg_alpha,
const float gamma,
const int neg_id,
const float* prob,
const float* labels,
float* scale,
......@@ -793,7 +795,7 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count,
const int dim = count / outer_dim;
for (int i = 0; i < count; ++i) {
scale[i] = alpha * std::pow((1.0f - prob[i]), gamma);
scale[i] = std::pow((1.0f - prob[i]), gamma);
}
for (int i = 0; i < outer_dim; ++i) {
......@@ -809,9 +811,11 @@ template <> void SparseSoftmaxFocalLoss<float, CPUContext>(const int count,
}
if (k == ignore->count()) {
const int t_ = i * dim + label * inner_dim + j;
float labeled_prob = prob[t_];
loss[idx] = -scale[t_] * std::log(std::max(labeled_prob, FLT_MIN));
valid[idx] = 1;
float labeled_prob = std::max(labeled_prob, FLT_MIN);
scale[t_] = label > neg_id ? pos_alpha * scale[t_] :
neg_alpha * scale[t_];
loss[idx] = -scale[t_] * std::log(labeled_prob);
valid[idx] = label > neg_id ? 1 : 0;
}
}
}
......@@ -822,6 +826,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const float gamma,
const int neg_id,
const float eps,
const float* scale,
const float* prob,
......@@ -855,7 +860,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CPUContext>(const int count,
dXdata[i_] = grad * prob[i_];
}
}
valid[0]++;
if (label > neg_id) valid[0]++;
}
}
}
......
......@@ -1417,19 +1417,21 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, CUDAContext>(const int coun
/******************** loss.sparse_softmax_focal_loss ********************/
template <typename T>
__global__ void _FocalScale(const int count,
const float alpha,
const float gamma,
const T* prob,
T* scale) {
__global__ void _SparseSoftmaxFocalScale(const int count,
const float gamma,
const T* prob,
T* scale) {
CUDA_KERNEL_LOOP(idx, count) {
scale[idx] = alpha * std::pow((1.0f - prob[idx]), gamma);
scale[idx] = std::pow((1.0f - prob[idx]), gamma);
}
}
template <typename T>
__global__ void _SparseSoftmaxFocalLoss(const int count,
const T* scale,
const float pos_alpha,
const float neg_alpha,
const int neg_id,
T* scale,
const T* prob,
const T* labels,
T* loss,
......@@ -1445,14 +1447,16 @@ __global__ void _SparseSoftmaxFocalLoss(const int count,
int k;
for (k = 0; k < ignore_num; k++) {
if (label == ignores[k]) {
loss[idx] = valid[idx] = 0;
loss[idx] = valid[idx] = 0;
break;
}
}
if (k == ignore_num) {
const int t_ = (o_idx * classes + label) * inner_dim + i_idx;
scale[t_] = label > neg_id ? pos_alpha * scale[t_] :
neg_alpha * scale[t_];
loss[idx] = -scale[t_] * std::log(max(prob[t_], FLT_MIN));
valid[idx] = 1;
valid[idx] = label > neg_id ? 1 : 0;
}
}
}
......@@ -1461,8 +1465,10 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count,
const int classes,
const int outer_dim,
const int inner_dim,
const float alpha,
const float pos_alpha,
const float neg_alpha,
const float gamma,
const int neg_id,
const float* prob,
const float* labels,
float* scale,
......@@ -1472,12 +1478,14 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count,
const int* ignores = ignore->count() > 0 ?
ignore->data<int, CUDAContext>() : nullptr;
const int num_preds = outer_dim * inner_dim;
_FocalScale<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
alpha,
gamma,
prob,
scale);
_SparseSoftmaxFocalScale<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
gamma,
prob,
scale);
_SparseSoftmaxFocalLoss<float> << <GET_BLOCKS(num_preds), CUDA_NUM_THREADS >> >(num_preds,
pos_alpha,
neg_alpha,
neg_id,
scale,
prob,
labels,
......@@ -1493,6 +1501,7 @@ template <> void SparseSoftmaxFocalLoss<float, CUDAContext>(const int count,
template <typename T>
__global__ void _SparseSoftmaxFocalLossGrad(const int count,
const float gamma,
const int neg_id,
const float eps,
const T* scale,
const T* prob,
......@@ -1517,7 +1526,7 @@ __global__ void _SparseSoftmaxFocalLossGrad(const int count,
} else {
const int t_ = (o_idx * classes + label) * inner_dim + i_idx;
T grad = -gamma * (scale[t_] / max((1.0f - prob[t_]), eps))
* std::log(max(prob[t_], FLT_MIN))
* std::log(max(prob[t_], FLT_MIN))
* prob[t_] + scale[t_];
for (int c = 0; c < classes; c++) {
const int i_ = (o_idx * classes + c) * inner_dim + i_idx;
......@@ -1527,7 +1536,7 @@ __global__ void _SparseSoftmaxFocalLossGrad(const int count,
dx[i_] = grad * prob[i_];
}
}
valid[idx] = 1;
valid[idx] = label > neg_id ? 1 : 0;
}
}
}
......@@ -1537,6 +1546,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CUDAContext>(const int count,
const int outer_dim,
const int inner_dim,
const float gamma,
const int neg_id,
const float eps,
const float* scale,
const float* prob,
......@@ -1549,6 +1559,7 @@ template<> void SparseSoftmaxFocalLossGrad<float, CUDAContext>(const int count,
const int num_preds = outer_dim * inner_dim;
_SparseSoftmaxFocalLossGrad<float> << <GET_BLOCKS(num_preds), CUDA_NUM_THREADS >> >(num_preds,
gamma,
neg_id,
eps,
scale,
prob,
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!