最全的交叉熵损失函数(Pytorch)

article/2025/9/10 3:43:52

损失函数

  • 引言
  • BCELoss
  • BCEWithLogitsLoss
  • NLLLoss
  • CrossEntropyLoss
  • 总结
  • 参考

引言

这里主要讲述pytorch中的几种交叉熵损失类,熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度。交叉熵越小,表示数据越接近真实样本。公式为:
在这里插入图片描述

在pytorch中,损失可以通过函数或者类来计算,这里BCELoss、BCEWithLogitsLoss、NLLLoss、CrossEntropyLoss都是类,需要先进行类的定义,再调用函数方法,例如:

import torch
one = torch.nn.CrossEntropyLoss()(pre, label)

当然也可以直接使用函数方法,例如:

import torch.nn.functional as F
one = F.cross_entropy(pre,label)

BCELoss

BCELoss是一个二分类损失函数,全称:Binary Cross Entropy Loss,是交叉熵损失函数应用于二分类损失的特殊形式,一般配合sigmoid使用。公式为:
在这里插入图片描述
例如:

import torch
import torch.nn.functional as F
pre_a = torch.tensor([[0.8, 0.2]
],dtype=torch.float)
pre_b = torch.tensor([[0.6, 0.4]
],dtype=torch.float)
target = torch.tensor([[1, 0]
],dtype=torch.float)
a = F.binary_cross_entropy(pre_a, target)
b = F.binary_cross_entropy(pre_b, target)
print(a, b)
# 结果:tensor(0.2231) tensor(0.5108)

对于预测值a[0.8,0.2]、预测值b[0.6,0.4]和标签[1,0],其中预测值都是经过sigmoid得到的,他们的和为1,带入公式中计算得:
在这里插入图片描述
笔试计算和代码计算结果一致。可以看出预测值a与标签更加接近,同时损失函数也更小,符合预期结果。当然使用多个二分类损失也可以达到多分类的效果,例如是不是猫,是不是狗,是不是熊,我最先看到这样实现的好像是在YoloV4中吧!

BCEWithLogitsLoss

BCEWithLogitsLoss = sigmoid + BCELoss,使用BCEWithLogisLoss会自动帮助预测值进行sigmoid计算。

pre_a = torch.tensor([[1.8, 0.2]
],dtype=torch.float)
pre_a_sigmoid = F.sigmoid(pre_a)
a = F.binary_cross_entropy(pre_a_sigmoid,target)
b = F.binary_cross_entropy_with_logits(pre_a,target)
print(a,b) # tensor(0.4756) tensor(0.4756)

NLLLoss

NLLoss又称负对数似然损失函数,用于处理多分类问题,输入是对数化的概率值。公式为:
在这里插入图片描述

import torch
import torch.nna = torch.Tensor([[1,2,3]])
nll = nn.NLLLoss()
target1 = torch.Tensor([0]).long()
target2 = torch.Tensor([1]).long()
target3 = torch.Tensor([2]).long()#测试
n1 = nll(a,target1)
#输出:tensor(-1.)
n2 = nll(a,target2)
#输出:tensor(-2.)
n3 = nll(a,target3)
#输出:tensor(-3.)

NLLLoss负对数似然损失函数实现的是将标签target位置的预测值取出求反

CrossEntropyLoss

CrossEntropyLoss是一种用于多分类的损失函数,输入是未经过softmax的tensor型值。在这里插入图片描述
CrossEntropyLoss就是将softmax、log、NLLLoss合为一体。公式为:
在这里插入图片描述
下面我将通过三种方法实现多分类交叉熵损失函数:

  1. 一步实现 CrossEntropyLoss
  2. 二步实现 log_softmax+nll_loss
  3. 三步实现 softmax+log+nll_loss
pre = np.array([[0.5, 0.1],[0.3, 0.8],[0.6, 0.1]
])
target = np.array([[1, 0],[0, 1],[1, 0]
])
pre = torch.from_numpy(pre)
target = torch.from_numpy(target)label = torch.argmax(target, dim=1)
label = torch.LongTensor(label)# 一步实现
one = torch.nn.CrossEntropyLoss()(pre, label)
print(one)
one = F.cross_entropy(pre,label)
print(one)print('*' * 20)# 两步实现
pre_two = torch.nn.LogSoftmax(dim=1)(pre)
print(pre_two)
two = torch.nn.NLLLoss()(pre_two, label)
print(two)print('*' * 20)# 三步实现
pre_three = torch.nn.Softmax(dim=1)(pre)
print(pre_three)
pre_log = torch.log(pre_three)
print(pre_log)
three = torch.nn.NLLLoss()(pre_log, label)
print(three)

结果都是tensor(0.4871, dtype=torch.float64)

总结

交叉熵损失函数广泛应用于图像分类、分割中,可以发现,不管是二分类还是多分类,其实计算损失函数都经历三个步骤,其中步骤之间可以合并:

  • 激活函数,通过sigmoid或者softmax将预测值缩放到[0,1]之间
  • log操作,对计算缩放求取log操作,进一步缩放至[-无穷,0]之间
  • 累积求和,根据函数定义,将标签和缩放后的预测值进行相乘求和

参考

https://blog.csdn.net/sdu_hao/article/details/103499223
https://blog.csdn.net/watermelon1123/article/details/91044856
https://blog.csdn.net/qq_16949707/article/details/79929951
https://zhuanlan.zhihu.com/p/159477597


http://chatgpt.dhexx.cn/article/xMM9jVYd.shtml

相关文章

简单的交叉熵损失函数,你真的懂了吗?

个人网站:红色石头的机器学习之路 CSDN博客:红色石头的专栏 知乎:红色石头 微博:RedstoneWill的微博 GitHub:RedstoneWill的GitHub 微信公众号:AI有道(ID:redstonewill&#xf…

交叉熵损失概念

交叉熵是信息论中的一个概念,要想了解交叉熵的本质,需要先从最基本的概念讲起。 1. 信息量 首先是信息量。假设我们听到了两件事,分别如下: 事件A:巴西队进入了2018世界杯决赛圈。 事件B:中国队进入了2…

softmax交叉熵损失函数深入理解(二)

0、前言 前期博文提到经过两步smooth化之后,我们将一个难以收敛的函数逐步改造成了softmax交叉熵损失函数,解决了原始的目标函数难以优化的问题。Softmax 交叉熵损失函数是目前最常用的分类损失函数,本博文继续学习Softmax 交叉熵损失函数的改…

史上最全交叉熵损失函数详解

在我们自学神经网络神经网络的损失函数的时候会发现有一个思路就是交叉熵损失函数,交叉熵的概念源于信息论,一般用来求目标与预测值之间的差距。比如说我们在人脑中有一个模型,在神经网络中还有一个模型,我们需要找到神经网络模型…

交叉熵损失函数原理详解

交叉熵损失函数原理详解 之前在代码中经常看见交叉熵损失函数(CrossEntropy Loss),只知道它是分类问题中经常使用的一种损失函数,对于其内部的原理总是模模糊糊,而且一般使用交叉熵作为损失函数时,在模型的输出层总会接一个softm…

损失函数——交叉熵损失函数

一篇弄懂交叉熵损失函数 一、定义二、交叉熵损失函数:知识准备:1、信息熵:将熵引入到信息论中,命名为“信息熵”2、 KL散度(相对熵): 交叉熵:结论: Softmax公式Sigmoid常…

交叉熵损失函数详解

我们知道,在二分类问题模型:例如逻辑回归「Logistic Regression」、神经网络「Neural Network」等,真实样本的标签为 [0,1],分别表示负类和正类。模型的最后通常会经过一个 Sigmoid 函数,输出一个概率值&am…

交叉熵损失函数(CrossEntropy Loss)(原理详解)

监督学习主要分为两类: 分类:目标变量是离散的,如判断一个西瓜是好瓜还是坏瓜,那么目标变量只能是1(好瓜),0(坏瓜)回归:目标变量是连续的,如预测西瓜的含糖率…

nn.CrossEntropyLoss()交叉熵损失函数

1、nn.CrossEntropyLoss() 在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题。在使用nn.CrossEntropyLoss()其内部会自动加上Sofrmax层 nn.CrossEntropyLoss()的计算公式如下: 其中&#xff0c…

损失函数——交叉熵损失函数(CrossEntropy Loss)

损失函数——交叉熵损失函数(CrossEntropy Loss) 交叉熵函数为在处理分类问题中常用的一种损失函数,其具体公式为: 1.交叉熵损失函数由来 交叉熵是信息论中的一个重要概念,主要用于度量两个概率分布间的差异性。首先…

损失函数——交叉熵损失(Cross-entropy loss)

交叉熵损失(Cross-entropy loss)是深度学习中常用的一种损失函数,通常用于分类问题。它衡量了模型预测结果与实际结果之间的差距,是优化模型参数的关键指标之一。以下是交叉熵损失的详细介绍。 假设我们有一个分类问题&#xff0…

【Pytorch】交叉熵损失函数 CrossEntropyLoss() 详解

文章目录 一、损失函数 nn.CrossEntropyLoss()二、什么是交叉熵三、Pytorch 中的 CrossEntropyLoss() 函数参考链接 一、损失函数 nn.CrossEntropyLoss() 交叉熵损失函数 nn.CrossEntropyLoss() ,结合了 nn.LogSoftmax() 和 nn.NLLLoss() 两个函数。 它在做分类&a…

一文读懂交叉熵损失函数

进行二分类或多分类问题时,在众多损失函数中交叉熵损失函数较为常用。 下面的内容将以这三个问题来展开 什么是交叉熵损失以图片分类问题为例,理解交叉熵损失函数从0开始实现交叉熵损失函数 1,什么是交叉熵损失 交叉熵是信息论中的一个重…

交叉熵损失函数

目录 一、交叉熵损失函数含义 二、交叉熵损失函数定义为:​ 三、交叉熵损失函数计算案例 一、交叉熵损失函数含义 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的。给定两个 概率分布p和q,通过q来表示p的交叉熵为 交叉熵刻画…

交叉熵损失函数(Cross Entropy Loss)

基础不牢,地动山摇,读研到现在有一年多了,发现自己对很多经常打交道的知识并不了解,仅仅是会改一改别人的代码,这使我感到非常焦虑,自此开始我的打基础之路。如果博客中有错误的地方,欢迎大家评…

js遍历数组中的对象并拿到值

拿到一组数组,数组中是对象,想拿到这个对象里面的某个值,可以参考以下例子: 这样就拿到所有n1的值. 想拿到这个对象里面所有对应的值如下: 也可以这样取值: 往数组里面push多个值: js中!!用法 …

js遍历数组以及获取数组对象的key和key的值方法

数组: let arr [{ appendData: { "Expiration Date mm- dd - yyyy(2D)": "03-04-2025" }},{appendData: { "Manufacturer(21P)": "MURATA" }}]arr.forEach((value,i)>{ //数组循环for(var pl in value){ //数组对象遍…

javascript遍历数组的方法总结

一、for循环 var arr[javascript,jquery,html,css,学习,加油,1,2]; for(var i0;i<arr.length;i){console.log(输出值,arr[i]); } 二、for...in 遍历的是key 适合遍历对象 var arr[javascript,jquery,html,css,学习,加油,1,2]; for(var i in arr){ console.log(输出值---…

html函数参数数组遍历,JavaScript foreach遍历数组

JavaScript forEach遍历数组教程 JavaScript forEach详解 定义 forEach() 方法为每个数组元素调用一次函数(回调函数)。 语法 array.forEach(function(currentValue, index, arr), thisValue); 参数 参数 描述 function(currentValue, index, arr) 必须。数组每个元素需要执行的…

js中遍历数组加到新数组_js数组遍历:JavaScript如何遍历数组?

什么是数组的遍历? 操作数组中的每一个数组元素。 使用for循环来遍历数组 因为数组的下标是连续的&#xff0c;数组的下标是从0开始。 我们也可以得到数组的长度。 格式&#xff1a;for(var i0;i 数组变量名[i] } 注意&#xff1a;条件表达式的写法 i i<数组的长度-1 // 数…