CE Loss,BCE Loss以及Focal Loss的原理理解

article/2025/10/28 10:42:15

一、交叉熵损失函数(CE Loss,BCE Loss)

最开始理解交叉熵损失函数被自己搞的晕头转向的,最后发现是对随机变量的理解有偏差,不知道有没有读者和我有着一样的困惑,所以在本文开始之前,先介绍一下随机变量是啥。

什么是概率分布?
概率分布,是指用于表述随机变量取值的概率规律。随机变量的概率表示了一次试验中某一个结果发生的可能性大小 ,想象画在图上就是横坐标(自变量)是随机变量。根据随机变量所属类型的不同,概率分布取不同的表现形式。举个最简单的例子:抛一枚硬币,随机变量为抛硬币的结果,产生的结果的概率分布为:p(正面)=0.5,p(背面)=0.5

随机变量是什么?
随机变量是将随机试验的结果数量化,具有随机性的,注意是结果!!!在概率论中,概率质量函数(probability mass function,简写为pmf)是离散随机变量在各特定取值上的概率。一个概率质量函数的图像。函数的所有值必须非负,且总和为1。

如在抛50次硬币这个事件中,随机变量是指抛硬币获得正面的次数。不要把随机变量理解为试验的次数的取值!!!再拿二分类任务举个例子,二分类的随机变量就是看做0和1两个类别。二分类猫狗任务就相当于二项分布中的伯努利分布(试验次数为1时就叫伯努利分布,就相当于只丢一次硬币),因为去识别一张图片,最后试验的结果只能要么是猫要么是狗,这任务中的随机变量不是每一个训练样本(训练集中的每一张图片),而是分类的结果即猫or狗!在训练过程中,如果用交叉熵损失函数,假如p(x)是目标真实的分布,而q(x)是预测得来的分布。网络对每一个训练样本来讲,这张图片经过网络输出后得到的q(x)尽可能和这张图像的p(x)分布相等,x为类别的随机变量,x1为猫,x2为狗。如p(x1)=1,就是表示这张图片得到的x1这个类别的结果概率是1,所以由标签可知它的真实分布即p就是p(猫,狗)~(1,0),从训练来讲就是让这张训练样本图片经过网络输出后,得到的q(x)去无限接近上面p(猫,狗)-(1,0)这个分布。 拟合分布就是让预测分布的参数不断接近分布的参数!如p就是伯努利分布中的参数。所谓的交叉熵的交叉就是指这两个分布之间的交叉,让两个分布越接近则交叉熵损失越小。

要充分理解交叉熵损失函数,首先要理解相对熵,又称互熵。设p(x)和q(x)是两个概率分布,相对熵用来表示两个概率分布的差异,当两个随机分布相同时,它们的相对熵为零,当两个随机分布的差别增大时,它们的相对熵也会增大。

而相对熵=交叉熵-信息熵!!!
由于在机器学习和深度学习中,样本和标签已知(即p已知,样本就是xi),那么信息熵H(p)相当于常量,此时,只需拟合交叉熵,使交叉熵拟合为0即可。关键点:所以最小化交叉熵损失函数就相当于使得交叉熵公式里的p和q这两个概率分布(指交叉熵公式里的那两个乘法因子)的差异最小!式子中的n就是随机变量的取值集合,在这里就是类别数,p(xi)就是事件X=xi的概率。
在这里插入图片描述
信息熵(公式里的两个乘法因子都是指同一个分布的):
信息熵则是在结果出来之前对可能产生的信息量的期望信息量表示一条信息消除不确定性的程度,如中国目前的高铁技术世界第一,这个概率为1,这句话本身是确定的,没有消除任何不确定性。而中国的高铁技术将一直保持世界第一,这句话是个不确定事件,包含的信息量就比较大。信息量的大小和事件发生的概率成反比。信息熵越小就表示这个事件发生的概率越大,-logP就是信息量的公式(P表示事件发生的概率)。
在这里插入图片描述
交叉熵(公式是针对一个样本的,公式里的两个乘法因子分别指两个分布,n为类别数):
在这里插入图片描述

下面进入正题,也就是BCE Loss和CE Loss:

对于二分类交叉熵,下图的x1和x2是指两个类别,比如x1和x2分别代表猫和狗两类,p就是这个样本为猫的标签,这个标签可能是0也有可能是1;q就是这个样本被预测为猫的概率!
在这里插入图片描述

下图给出了多分类问题(实现为F.cross_entropy)和二分类问题(实现为F.binary_cross_entropy)的交叉熵损失公式,下图中多分类问题中的公式是针对单个样本的,公式里的i表示每一个类别。而对于二分类问题的公式即BCE loss,公式里的i表示每一个样本,所以要注意区分! 对于多分类问题即CE loss,假设真实标签的one-hot编码是:[0,0,…,1,…,0],预测的softmax概率为[0.1,0.3,…,0.4,…,0.1],那么Loss=-log(0.4)。对于二分类问题即BCE loss来说,每个样本就输出一个数字。
在这里插入图片描述

需要注意的是,BCE loss在pytorch中实现多分类损失时,也就是通过多个二分类来实现多分类时,target要转换成one-hot形式(只能有1个元素为1,其余都为0)。如下图所示,下图就是一个用BCE loss实现6分类的例子,BCE loss就把这个问题当成6个二分类实现,因为一个目标只能是属于一个类别,所以可以转换成one-hot形式。然后对于用BCE loss处理多分类问题的情况,最后其实返回的是每个类别的二分类损失求和的平均值,所以真正返回的是:4.7938/6 = 0.7990
在这里插入图片描述

二、Focal loss

Focal loss的本质

  1. 首先给出原始二分类交叉熵的公式:

在这里插入图片描述

  1. 在二分类交叉熵损失的基础上,控制了正负样本的权重来解决了正负样本的不平衡,下图就是基于二分类交叉熵损失通过α来控制正负样本比例的例子,当α=0.5时,正负样本的比重是一样的。
    在这里插入图片描述
  2. 在上面图中损失的基础上,增加控制“容易分类和难分类样本的权重”来解决难例挖掘的问题。
  3. 结合这两个方法,就是最终的二分类的Focal loss(如下图所示),最前面红框的第一项是最普通的交叉熵;第二项是控制正负样本平衡的α参数;第三项是控制难易分类样本的平衡,即对于正样本而言,预测分数越接近于1的表示这个样本越简单,那么这个样本应该对损失的影响越小:
    在这里插入图片描述
  4. 同理,多分类的Focal loss(softmax)的公式如下图所示:

这里是引用在这里插入图片描述

Focal loss的具体代码实现

# 参考了:
# 1. https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
# 2. https://github.com/c0nn3r/RetinaNet/blob/master/focal_loss.pyimport torch
import torch.nn.functional as Fdef focal_loss(logits, labels, gamma=2, reduction="mean"):r"""focal loss for multi classification(简洁版实现)`https://arxiv.org/pdf/1708.02002.pdf`FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)"""# 这段代码比较简洁,具体可以看作者是怎么定义的,或者看 focal_lossv1 版本的实现# 经测试,reduction 加不加结果都一样,但是为了保险,还是加上# logits是过激活函数前的值,reduction="none"就是不对loss进行求mean或者sum 保留每个样本的CE lossce_loss = F.cross_entropy(logits, labels, reduction="none")log_pt = -ce_losspt = torch.exp(log_pt)weights = (1 - pt) ** gammafl = weights * ce_lossif reduction == "sum":fl = fl.sum()elif reduction == "mean":fl = fl.mean()else:raise ValueError(f"reduction '{reduction}' is not valid")return fldef balanced_focal_loss(logits, labels, alpha=0.25, gamma=2, reduction="mean"):r"""带平衡因子的 focal loss,这里的 alpha 在多分类中应该是个向量,向量中的每个值代表类别的权重。但是为了简单起见,我们假设每个类一样,直接传 0.25。如果是长尾数据集,则应该自行构造 alpha 向量,同时改写 focal loss 函数。"""return alpha * focal_loss(logits, labels, gamma, reduction)def focal_lossv1(logits, labels, gamma=2):r"""focal loss for multi classification(第一版)FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)"""# pt = F.softmax(logits, dim=-1)  # 直接调用可能会溢出#什么是softmax的溢出:https://blog.csdn.net/qq_35054151/article/details/125891745# 一个不会溢出的 tricklog_pt = F.log_softmax(logits, dim=-1)  # 这里相当于 CE loss#pt:tensor([[0.1617, 0.2182, 0.2946, 0.3255],#    [0.2455, 0.2010, 0.3314, 0.2221]])pt = torch.exp(log_pt)  # 通过 softmax 函数后打的分labels = labels.view(-1, 1)  # 多加一个维度,为使用 gather 函数做准备#.gather第一个参数表示根据哪个维度,第二个参数表示按照索引列表index从input中选取指定元素pt = pt.gather(1, labels)  # 从pt中挑选出真实值对应的 softmax 打分,也可以使用独热编码实现#pt,因为只有两个样本所以只有两项损失: tensor([[0.2182],#                                      [0.2221]])ce_loss = -torch.log(pt)weights = (1 - pt) ** gamma#对应元素相乘fl = weights * ce_loss#大家都是默认取均值而不是取sumfl = fl.mean()return flif __name__ == "__main__":#2个样本,4分类问题logits = torch.tensor([[0.3, 0.6, 0.9, 1], [0.6, 0.4, 0.9, 0.5]])labels = torch.tensor([1, 3])print(focal_loss(logits, labels))print(focal_loss(logits, labels, reduction="sum"))print(focal_lossv1(logits, labels))print(balanced_focal_loss(logits, labels))

Refer
交叉熵损失原理详解
随机变量的理解
GAN交叉熵
从二分类(二项分布)到多分类(多项分布)
FocalLoss 对样本不平衡的权重调节和减低损失值

再记录几个好的文章非常实用:
一文搞懂F.cross_entropy的具体实现
一文搞懂F.binary_cross_entropy以及weight参数
softmax loss详解,softmax与交叉熵的关系
二分类问题,应该选择sigmoid还是softmax?


http://chatgpt.dhexx.cn/article/23E0ty2U.shtml

相关文章

损失函数loss

http://blog.csdn.net/pipisorry/article/details/23538535 监督学习及其目标函数 损失函数(loss function)是用来估量你模型的预测值f(x)与真实值Y的不一致程度,它是一个非负实值函数,通常使用L(Y, f(x))来表示。 损失函数是经…

机器学习模型中的损失函数loss function

1. 概述 在机器学习算法中,有一个重要的概念就是损失函数(Loss Function)。损失函数的作用就是度量模型的预测值 f ( x ) f\left ( \mathbf{x} \right ) f(x)与真实值 y \mathbf{y} y之间的差异程度的函数,且是一个非负实值函数。…

损失函数(Loss)

如果我们定义了一个机器学习模型,比如一个三层的神经网络,那么就需要使得这个模型能够尽可能拟合所提供的训练数据。但是我们如何评价模型对于数据的拟合是否足够呢?那就需要使用相应的指标来评价它的拟合程度,所使用到的函数就称…

focal loss详解

文章目录 focal loss的整体理解易分辨样本、难分辨样本的含义focal loss的出现过程focal loss 举例说明focal loss的 α \alpha α变体 focal loss的整体理解 focal loss 是一种处理样本分类不均衡的损失函数,它侧重的点是根据样本分辨的难易程度给样本对应的损失添…

深度学习——损失函数(Regression Loss、Classification Loss)

简介 Loss function 损失函数 用于定义单个训练样本与真实值之间的误差 Cost function 代价函数 用于定义单个批次/整个训练集样本与真实值之间的误差 Objective function 目标函数 泛指任意可以被优化的函数 损失函数用于衡量模型所做出的预测离真实值(GT)之间的偏离程度。 …

深度学习中常见的损失函数(L1Loss、L2loss)

损失函数定义 损失函数:衡量模型输出与真实标签的差异。 L1_loss 平均绝对误差(L1 Loss):平均绝对误差(Mean Absolute Error,MAE)是指模型预测值f(x)和真实值y之间距离的平均值,公式如下: 优…

损失函数(loss function)

文章目录 1、什么是损失函数2、为什么要使用损失函数3、损失函数分类1、分类一2、分类二3、分类三3.1基于距离度量的损失函数3.1.1 均方误差损失函数(MSE)3.1.2 L2损失函数3.1.3 L1损失函数3.1.4 Smooth L1损失函数3.1.5 huber损失函数 3.2 基于概率分布…

Focal Loss损失函数(超级详细的解读)

什么是损失函数? 1、什么是损失呢? 在机器学习模型中,对于每一个样本的预测值与真实值的差称为损失。 2、什么是损失函数呢? 显而易见,是一个用来计算损失的函数。它是一个非负实值函数,通常使用L(Y, f(x))来表示。 3、…

损失函数loss大总结

分类任务loss: 二分类交叉熵损失sigmoid_cross_entropy: TensorFlow 接口: tf.losses.sigmoid_cross_entropy(multi_class_labels,logits,weights1.0,label_smoothing0,scopeNone,loss_collectiontf.GraphKeys.LOSSES,reductionReduction.SUM_BY_NONZER…

深度学习基础(三)loss函数

loss函数,即损失函数,是决定网络学习质量的关键。若网络结构不变的前提下,损失函数选择不当会导致模型精度差等后果。若有错误,敬请指正,Thank you! 目录 一、loss函数定义 二、常见的loss算法种类 1.M…

loss函数 激活函数

一、LOSS函数 loss函数指机器学习模型中用于最小化的目标函数,其一般意义是指判别错误的程度,因为我们要提高准确率,也就是降低错误率,所以可以归结为一个最小化损失函数的问题。 具体的,我们假设有一个十分类问题&a…

Loss损失函数

损失函数是将随机事件或其有关随机变量的取值映射为非负实数以表示该随机事件的“风险”或“损失”的函数,用于衡量预测值与实际值的偏离程度。在机器学习中,损失函数是代价函数的一部分,而代价函数是目标函数的一种类型。 在《神经网络中常…

深度学习之——损失函数(loss)

深度学习中的所有学习算法都必须有一个 最小化或最大化一个函数,称之为损失函数(loss function),或“目标函数”、“代价函数”。损失函数是衡量模型的效果评估。比如:求解一个函数最小点最常用的方法是梯度下降法&…

1_一些文献中的英文解释和用法整理

目录 1、Theorem、Proposition、Lemma和Corollary等的解释与区别 2、论文里的 Preliminaries 究竟是什么意思? (1)Preliminaries是什么? (2)Preliminaries应该写什么内容? (3)…

区分定理(Theorem)、引理(Lemma)、推论(Corollary)等概念

ZZ: http://blog.sina.com.cn/s/blog_a0e53bf70101jwv1.html Theorem:就是定理,比較重要的,簡寫是 Thm。 Lemma:小小的定理,通常是為了證明後面的定理,如果證明的篇幅很長時,可能會把證明拆成幾…

CodeForces - 1364D Ehabs Last Corollary(dfs树找最小环)

题目链接:点击查看 题目大意:给出一个由 n 个结点和 m 条边构成的无向图,再给出一个 k ,需要在图中完成下面任意一种操作: 找到一个大小恰好为 的独立集找到一个大小不超过 k 的环 题目分析: 题目已经…

Codeforces Round 649 (Rated for Div. 2)D. Ehab s Last Corollary详细题解(图论+简单环)

树 边 : 树边: 树边:深度优先森林中的边。如果结点v是因对(u,v)的探索而首先被发现,则(u,v)是一条树边。 后 向 边 : 后向边: 后向边:后向边(u,v)是将结点u连接到其在深度优先树中一个祖先节点v的边. (本文我就称之为反向边了,问题不大) 前…

#649 (Div. 2)D. Ehab‘s Last Corollary

题目描述 Given a connected undirected graph with n vertices and an integer k, you have to either: either find an independent set that has exactly ⌈k2⌉ vertices. or find a simple cycle of length at most k. An independent set is a set of vertices such that…

Ehabs Last Corollary

Given a connected undirected graph with n n n vertices and an integer k k k, you have to either: either find an independent set that has exactly ⌈ k 2 ⌉ ⌈\frac{k}{2}⌉ ⌈2k​⌉ vertices.or find a simple cycle of length at most k k k. An independen…

Latent Variables的理解

加入我们有X,Y两个随机变量,他们的概率分布如下。要直接用一个函数还表示这个分布是比较困难的。 但我们发现这个分布可以分成三个聚类。如果我们给每个聚类编号为。 那么就是简单的高斯函数了。 这里z就是 加入latent variable的意义在于&#xff0c…