label smoothing是一种在分类问题中,防止过拟合的方法。
label smoothing(标签平滑)
- 交叉熵损失函数在多分类任务中存在的问题
- label smoothing(标签平滑)
- 参考资料
交叉熵损失函数在多分类任务中存在的问题
多分类任务中,神经网络会输出一个当前数据对应于各个类别的置信度分数,将这些分数通过softmax进行归一化处理,最终会得到当前数据属于每个类别的概率。
q i = e x p ( z i ) ∑ j = 1 k e x p ( z j ) q_i={{exp(z_i)}\over{\sum_{j=1}^kexp(z_j)}} qi=∑j=1kexp(zj)exp(zi)
然后计算交叉熵损失函数:
L o s s = − ∑ i = 1 k p i l o g q i Loss=-\sum_{i=1}^k p_i \space log\space q_i Loss=−i=1∑kpi log qi
p i = { 1 , i f ( i = y ) 0 , i f ( i ≠ y ) p_i=\left\{\begin{matrix} 1,if(i=y)\\0,if(i\neq y) \end{matrix}\right. pi={1,if(i=y)0,if(i=y)
其 中 i 表 示 多 分 类 中 的 某 一 类 其中i表示多分类中的某一类 其中i表示多分类中的某一类
训练神经网络时,最小化预测概率和标签真实概率之间的交叉熵,从而得到最优的预测概率分布。最优的预测概率分布是:
Z i = { + ∞ , i f ( i = y ) 0 , i f ( i ≠ y ) Z_i=\left\{\begin{matrix} +\infty,if(i=y)\\0,if(i\neq y) \end{matrix}\right. Zi={+∞,if(i=y)0,if(i=y)
神经网络会促使自身往正确标签和错误标签差值最大的方向学习,在训练数据较少,不足以表征所有的样本特征的情况下,会导致网络过拟合。
label smoothing(标签平滑)
label smoothing可以解决上述问题,这是一种正则化策略,主要通过soft one-hot来加入噪声,减少真实样本标签的类别在计算损失函数时的权重,最终起到抑制过拟合的效果。
增加label smoothing后真实的概率分布有如下改变:
p i = { 1 , i f ( i = y ) 0 , i f ( i ≠ y ) p_i=\left\{\begin{matrix} 1,if(i=y)\\0,if(i\neq y) \end{matrix}\right. pi={1,if(i=y)0,if(i=y)
p i = { ( 1 − ϵ ) , i f ( i = y ) ϵ K − 1 , i f ( i ≠ y ) p_i=\left\{\begin{matrix} (1-\epsilon),if(i=y)\\{{\epsilon}\over{K-1}},if(i\neq y) \end{matrix}\right. pi={(1−ϵ),if(i=y)K−1ϵ,if(i=y)
K 表 示 多 分 类 的 类 别 总 数 K表示多分类的类别总数 K表示多分类的类别总数
ϵ 是 一 个 较 小 的 超 参 数 \epsilon是一个较小的超参数 ϵ是一个较小的超参数
交叉熵损失函数的改变如下:
L o s s = − ∑ i = 1 k p i l o g q i Loss=-\sum_{i=1}^k p_i \space log\space q_i Loss=−i=1∑kpi log qi
L o s s = { ( 1 − ϵ ) ∗ L o s s , i f ( i = y ) ϵ ∗ L o s s , i f ( i ≠ y ) Loss=\left\{\begin{matrix} (1-\epsilon)*Loss,if(i=y)\\ \epsilon*Loss,if(i\neq y) \end{matrix}\right. Loss={(1−ϵ)∗Loss,if(i=y)ϵ∗Loss,if(i=y)
最优预测概率分布如下:
Z i = { + ∞ , i f ( i = y ) 0 , i f ( i ≠ y ) Z_i=\left\{\begin{matrix} +\infty,if(i=y)\\0,if(i\neq y) \end{matrix}\right. Zi={+∞,if(i=y)0,if(i=y)
Z i = { l o g ( k − 1 ) ( 1 − ϵ ) ϵ + α , i f ( i = y ) α , i f ( i ≠ y ) Z_i=\left\{\begin{matrix} log{{(k-1)(1-\epsilon)}\over{\epsilon+\alpha}},if(i=y)\\\alpha,if(i\neq y) \end{matrix}\right. Zi={logϵ+α(k−1)(1−ϵ),if(i=y)α,if(i=y)
这里的α是任意实数,最终模型通过抑制正负样本输出差值,使得网络有更强的泛化能力。
参考资料
- https://zhuanlan.zhihu.com/p/116466239
- https://blog.csdn.net/qq_43211132/article/details/100510113