交叉熵损失函数详解

article/2025/9/10 3:49:39

我们知道,在二分类问题模型:例如逻辑回归「Logistic Regression」、神经网络「Neural Network」等,真实样本的标签为 [0,1],分别表示负类和正类。模型的最后通常会经过一个 Sigmoid 函数,输出一个概率值,这个概率值反映了预测为正类的可能性:概率越大,可能性越大。

Sigmoid 函数的表达式和图形如下所示:

g(s)=\frac{1}{1+e^{-s}}

 

 

其中 s 是模型上一层的输出,Sigmoid 函数有这样的特点:s = 0 时,g(s) = 0.5;s >> 0 时, g ≈ 1,s << 0 时,g ≈ 0。显然,g(s) 将前一级的线性输出映射到 [0,1] 之间的数值概率上。这里的 g(s) 就是交叉熵公式中的模型预测输出 。

我们说了,预测输出即 Sigmoid 函数的输出表征了当前样本标签为 1 的概率:

 

\hat y=P(y=1|x)

 

很明显,当前样本标签为 0 的概率就可以表达成:

 

1-\hat y=P(y=0|x)

 

重点来了,如果我们从极大似然性的角度出发,把上面两种情况整合到一起:

 

P(y|x)=\hat y^y\cdot (1-\hat y)^{1-y}

 

不懂极大似然估计也没关系。我们可以这么来看:

当真实样本标签 y = 0 时,上面式子第一项就为 1,概率等式转化为:

 

P(y=0|x)=1-\hat y

 

当真实样本标签 y = 1 时,上面式子第二项就为 1,概率等式转化为:

 

P(y=1|x)=\hat y

 

两种情况下概率表达式跟之前的完全一致,只不过我们把两种情况整合在一起了。

重点看一下整合之后的概率表达式,我们希望的是概率 P(y|x) 越大越好。首先,我们对 P(y|x) 引入 log 函数,因为 log 运算并不会影响函数本身的单调性。则有:

 

log\ P(y|x)=log(\hat y^y\cdot (1-\hat y)^{1-y})=ylog\ \hat y+(1-y)log(1-\hat y)

 

我们希望 log P(y|x) 越大越好,反过来,只要 log P(y|x) 的负值 -log P(y|x) 越小就行了。那我们就可以引入损失函数,且令 Loss = -log P(y|x)即可。则得到损失函数为:

 

L=-[ylog\ \hat y+(1-y)log\ (1-\hat y)]

 

非常简单,我们已经推导出了单个样本的损失函数,是如果是计算 N 个样本的总的损失函数,只要将 N 个 Loss 叠加起来就可以了:

 

L=-\sum_{i=1}^Ny^{(i)}log\ \hat y^{(i)}+(1-y^{(i)})log\ (1-\hat y^{(i)})

 

这样,我们已经完整地实现了交叉熵损失函数的推导过程。

2. 交叉熵损失函数的直观理解

可能会有读者说,我已经知道了交叉熵损失函数的推导过程。但是能不能从更直观的角度去理解这个表达式呢?而不是仅仅记住这个公式。好问题!接下来,我们从图形的角度,分析交叉熵函数,加深大家的理解。

首先,还是写出单个样本的交叉熵损失函数:

 

L=-[ylog\ \hat y+(1-y)log\ (1-\hat y)]

 

我们知道,当 y = 1 时:

 

L=-log\ \hat y

 

这时候,L 与预测输出的关系如下图所示:

 

看了 L 的图形,简单明了!横坐标是预测输出,纵坐标是交叉熵损失函数 L。显然,预测输出越接近真实样本标签 1,损失函数 L 越小;预测输出越接近 0,L 越大。因此,函数的变化趋势完全符合实际需要的情况。

当 y = 0 时:

 

L=-log\ (1-\hat y)

 

这时候,L 与预测输出的关系如下图所示:

 

 

同样,预测输出越接近真实样本标签 0,损失函数 L 越小;预测函数越接近 1,L 越大。函数的变化趋势也完全符合实际需要的情况。

从上面两种图,可以帮助我们对交叉熵损失函数有更直观的理解。无论真实样本标签 y 是 0 还是 1,L 都表征了预测输出与 y 的差距。

另外,重点提一点的是,从图形中我们可以发现:预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 “ 惩罚 ” 越大,而且是非线性增大,是一种类似指数增长的级别。这是由 log 函数本身的特性所决定的。这样的好处是模型会倾向于让预测输出更接近真实样本标签 y。

3. 交叉熵损失函数的其它形式

什么?交叉熵损失函数还有其它形式?没错!我刚才介绍的是一个典型的形式。接下来我将从另一个角度推导新的交叉熵损失函数。

这种形式下假设真实样本的标签为 +1 和 -1,分别表示正类和负类。有个已知的知识点是Sigmoid 函数具有如下性质:

 

1-g(s)=g(-s)

 

这个性质我们先放在这,待会有用。

好了,我们之前说了 y = +1 时,下列等式成立:

 

P(y=+1|x)=g(s)

 

如果 y = -1 时,并引入 Sigmoid 函数的性质,下列等式成立:

 

P(y=-1|x)=1-g(s)=g(-s)

 

重点来了,因为 y 取值为 +1 或 -1,可以把 y 值带入,将上面两个式子整合到一起:

 

P(y|x)=g(ys)

 

这个比较好理解,分别令 y = +1 和 y = -1 就能得到上面两个式子。

接下来,同样引入 log 函数,得到:

 

log\ P(y|x)=log\ g(ys)

 

要让概率最大,反过来,只要其负数最小即可。那么就可以定义相应的损失函数为:

 

L=-log g(ys)

 

还记得 Sigmoid 函数的表达式吧?将 g(ys) 带入:

 

L=-log\frac{1}{1+e^{-ys}}=log\ (1+e^{-ys})

 

好咯,L 就是我要推导的交叉熵损失函数。如果是 N 个样本,其交叉熵损失函数为:

 

L=\sum_{i=1}^Nlog\ (1+e^{-ys})

 

接下来,我们从图形化直观角度来看。当 y = +1 时:

 

L=log\ (1+e^{-s})

 

这时候,L 与上一层得分函数 s 的关系如下图所示:

 

横坐标是 s,纵坐标是 L。显然,s 越接近真实样本标签 1,损失函数 L 越小;s 越接近 -1,L 越大。

另一方面,当 y = -1 时:

 

L=log(1+e^s)

 

这时候,L 与上一层得分函数 s 的关系如下图所示:

 

同样,s 越接近真实样本标签 -1,损失函数 L 越小;s 越接近 +1,L 越大。

4. 总结

本文主要介绍了交叉熵损失函数的数学原理和推导过程,也从不同角度介绍了交叉熵损失函数的两种形式。第一种形式在实际应用中更加常见,例如神经网络等复杂模型;第二种多用于简单的逻辑回归模型。

5.多分类交叉熵

交叉熵可在神经网络(机器学习)中作为损失函数,p表示真实标记的分布,q则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量p与q的相似性。

 

交叉熵作为损失函数还有一个好处是使用sigmoid函数在梯度下降时能避免均方误差损失函数学习速率降低的问题,因为学习速率可以被输出的误差所控制。

 

 

单次观测下的多项式分布

其中,C代表类别数。p代表向量形式的模型参数,即各个类别的发生概率,如p=[0.1, 0.1, 0.7, 0.1],则p1=0.1, p3=0.7等。即,多项式分布的模型参数就是各个类别的发生概率!x代表one-hot形式的观测值,如x=类别3,则x=[0, 0, 1, 0]。xi代表x的第i个元素,比如x=类别3时,x1=0,x2=0,x3=1,x4=0。

 

机器学习model对某个样本的输出,就代表各个类别发生的概率。但是,对于当前这一个样本而言,它肯定只能有一个类别,所以这一个样本就可以看成是一次实验(观察),而这次实验(观察)的结果要服从上述各个类别发生的概率,那不就是服从多项式分布嘛!而且是单次观察!各个类别发生的概率predict当然就是这个多项式分布的参数阿。

对于多类分类问题,似然函数就是衡量当前这个以predict为参数的单次观测下的多项式分布模型与样本值label之间的似然度。

 

所以,根据似然函数的定义,单个样本的似然函数即:

 

所以,整个样本集(或者一个batch)的似然函数即:

而由于式子里有累乘运算,所以习惯性的加个log函数来将累乘化成累加以提高运算速度(虽然对于每个样本来说只有一个类别,但是哪怕是算0.2^0也是算了一遍指数函数啊,计算机可不会直接口算出1)。所以在累乘号前面加上log函数后,就成了所谓的对数似然函数:

 

而最大化对数似然函数就等效于最小化负对数似然函数,所以前面加个负号后不就是我们平常照着敲的公式嘛。。。

而这个形式跟交叉熵的形式是一模一样的:

这里X的分布模型即样本集label的真实分布模型,这里模型q(x)即想要模拟真实分布模型的机器学习模型。可以说交叉熵是直接衡量两个分布,或者说两个model之间的差异。而似然函数则是解释以model的输出为参数的某分布模型对样本集的解释程度。因此,可以说这两者是“同貌不同源”,但是“殊途同归”啦。

6.交叉熵概率分布差异之间的理解https://juejin.im/post/5b40a5156fb9a04faf478a45

原文:https://blog.csdn.net/ccj_ok/article/details/78066619

原文https://zhuanlan.zhihu.com/p/38241764

 

 


http://chatgpt.dhexx.cn/article/62fDvnGI.shtml

相关文章

交叉熵损失函数(CrossEntropy Loss)(原理详解)

监督学习主要分为两类&#xff1a; 分类&#xff1a;目标变量是离散的&#xff0c;如判断一个西瓜是好瓜还是坏瓜&#xff0c;那么目标变量只能是1&#xff08;好瓜&#xff09;,0&#xff08;坏瓜&#xff09;回归&#xff1a;目标变量是连续的&#xff0c;如预测西瓜的含糖率…

nn.CrossEntropyLoss()交叉熵损失函数

1、nn.CrossEntropyLoss() 在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数&#xff0c;用于解决多分类问题&#xff0c;也可用于解决二分类问题。在使用nn.CrossEntropyLoss()其内部会自动加上Sofrmax层 nn.CrossEntropyLoss()的计算公式如下&#xff1a; 其中&#xff0c…

损失函数——交叉熵损失函数(CrossEntropy Loss)

损失函数——交叉熵损失函数&#xff08;CrossEntropy Loss&#xff09; 交叉熵函数为在处理分类问题中常用的一种损失函数&#xff0c;其具体公式为&#xff1a; 1.交叉熵损失函数由来 交叉熵是信息论中的一个重要概念&#xff0c;主要用于度量两个概率分布间的差异性。首先…

损失函数——交叉熵损失(Cross-entropy loss)

交叉熵损失&#xff08;Cross-entropy loss&#xff09;是深度学习中常用的一种损失函数&#xff0c;通常用于分类问题。它衡量了模型预测结果与实际结果之间的差距&#xff0c;是优化模型参数的关键指标之一。以下是交叉熵损失的详细介绍。 假设我们有一个分类问题&#xff0…

【Pytorch】交叉熵损失函数 CrossEntropyLoss() 详解

文章目录 一、损失函数 nn.CrossEntropyLoss()二、什么是交叉熵三、Pytorch 中的 CrossEntropyLoss() 函数参考链接 一、损失函数 nn.CrossEntropyLoss() 交叉熵损失函数 nn.CrossEntropyLoss() &#xff0c;结合了 nn.LogSoftmax() 和 nn.NLLLoss() 两个函数。 它在做分类&a…

一文读懂交叉熵损失函数

进行二分类或多分类问题时&#xff0c;在众多损失函数中交叉熵损失函数较为常用。 下面的内容将以这三个问题来展开 什么是交叉熵损失以图片分类问题为例&#xff0c;理解交叉熵损失函数从0开始实现交叉熵损失函数 1&#xff0c;什么是交叉熵损失 交叉熵是信息论中的一个重…

交叉熵损失函数

目录 一、交叉熵损失函数含义 二、交叉熵损失函数定义为&#xff1a;​ 三、交叉熵损失函数计算案例 一、交叉熵损失函数含义 交叉熵是一个信息论中的概念&#xff0c;它原来是用来估算平均编码长度的。给定两个 概率分布p和q&#xff0c;通过q来表示p的交叉熵为 交叉熵刻画…

交叉熵损失函数(Cross Entropy Loss)

基础不牢&#xff0c;地动山摇&#xff0c;读研到现在有一年多了&#xff0c;发现自己对很多经常打交道的知识并不了解&#xff0c;仅仅是会改一改别人的代码&#xff0c;这使我感到非常焦虑&#xff0c;自此开始我的打基础之路。如果博客中有错误的地方&#xff0c;欢迎大家评…

js遍历数组中的对象并拿到值

拿到一组数组&#xff0c;数组中是对象&#xff0c;想拿到这个对象里面的某个值&#xff0c;可以参考以下例子&#xff1a; 这样就拿到所有n1的值. 想拿到这个对象里面所有对应的值如下&#xff1a; 也可以这样取值&#xff1a; 往数组里面push多个值&#xff1a; js中!!用法 …

js遍历数组以及获取数组对象的key和key的值方法

数组&#xff1a; let arr [{ appendData: { "Expiration Date mm- dd - yyyy(2D)": "03-04-2025" }},{appendData: { "Manufacturer(21P)": "MURATA" }}]arr.forEach((value,i)>{ //数组循环for(var pl in value){ //数组对象遍…

javascript遍历数组的方法总结

一、for循环 var arr[javascript,jquery,html,css,学习,加油,1,2]; for(var i0;i<arr.length;i){console.log(输出值,arr[i]); } 二、for...in 遍历的是key 适合遍历对象 var arr[javascript,jquery,html,css,学习,加油,1,2]; for(var i in arr){ console.log(输出值---…

html函数参数数组遍历,JavaScript foreach遍历数组

JavaScript forEach遍历数组教程 JavaScript forEach详解 定义 forEach() 方法为每个数组元素调用一次函数(回调函数)。 语法 array.forEach(function(currentValue, index, arr), thisValue); 参数 参数 描述 function(currentValue, index, arr) 必须。数组每个元素需要执行的…

js中遍历数组加到新数组_js数组遍历:JavaScript如何遍历数组?

什么是数组的遍历? 操作数组中的每一个数组元素。 使用for循环来遍历数组 因为数组的下标是连续的&#xff0c;数组的下标是从0开始。 我们也可以得到数组的长度。 格式&#xff1a;for(var i0;i 数组变量名[i] } 注意&#xff1a;条件表达式的写法 i i<数组的长度-1 // 数…

html页面遍历数组,javascript如何遍历数组?

作为一个程序员对于数组遍历大家都不是很陌生&#xff0c;在开发中我们也经常要处理数组。这里我们讨论下JavaScript中常用的数组遍历方法。 数组中常用的遍历方法有四种&#xff0c;分别是&#xff1a;for for-in forEach for-of (ES6) 1、第一种for循环var arr [1, 2, 3, 4]…

JavaScript遍历数组,附5个案例

先给大家分享一些JavaScript的相关资料&#xff1a; 认识JavaScript到初体验JavaScript 注释以及输入输出语句JavaScript变量的使用、语法扩展、命名规范JavaScript数据类型简介以及简单的数据类型JavaScript获取变量数据类型JavaScript 运算符&算数运算符JavaScript递增和…

1.9 JavaScript 遍历数组

遍历数组 数组的长度 使用 “数组名.length” 可以访问数组元素的数量&#xff08;数组长度&#xff09; a.length 动态监测数组元素的个数 案例 请将 [“关羽”, “张飞”, “赵云”,“小脆筒”], 将数组里的元素依次打印到控制台 代码实现 <!DOCTYPE html> <html&…

html怎么遍历数组,JavaScript如何遍历数组?遍历数组方法介绍

在往期文章中为大家介绍了 JavaScript 如何定义数组。那么这篇文章中 w3cschool 小编来为大家介绍下 JavaScript 如何遍历数组。 方法一&#xff1a;for 循环遍历数组 var arr[Tom,Jenny,Jan,Marry]; for(var i0;i console.log(arr[i]); } 实现效果&#xff1a; 方法二&#xf…

小程序 js 遍历数组

js 方式一&#xff1a; for (var index in res.data) { title : res.data[index].title } res.data&#xff1a;数组 index&#xff1a;下标 title&#xff1a;数组中的一个字段 方式二&#xff1a; for (var i 0; i < datas.length; i) { console.log(i); if( i > 1) b…

JS遍历数组的方法【详解】

法一&#xff1a;for循环 法二&#xff1a;forEach遍历&#xff08;可以同时取出数组中的值和值对应的下标&#xff09; 必须搭配函数使用&#xff0c;而且可以直接取出数组中的每个对象和对象对应的下标 let arr [{er: qwe},{er: asd}];arr.forEach((item,index)>{cons…

js遍历数组的方法

JS遍历数组的8种方法如下&#xff1a; 1.for循环 (改变原数组&#xff0c;无返回值) 2.forEach()&#xff08;改变原数组&#xff0c;无返回值&#xff09; 3.map() 4.filter() 5.reduce() 6.some() 7.every() 8.find() 1.for 循环&#xff1a;可以改变原数组。 2.f…