pytorch:多标签分类的损失函数和准确率计算

article/2025/9/1 19:54:19

1 损失函数

我们先用sklearn生成一个多标签分类数据集。

from sklearn.datasets import make_multilabel_classificationX, y = make_multilabel_classification(n_samples=1000,n_features=10,n_classes=3,n_labels=2,random_state=1)
print(X.shape, y.shape)

在这里插入图片描述
看一下标签长啥样。
在这里插入图片描述
每一行都是0、1标签,1可能有多个,这就是多标签了。
由于仍然是二分类(标签只有0和1),所以激活函数用Sigmoid(对输出的每一个维度用Sigmoid)。这个时候损失函数就是BCELoss。
如果是普通的二分类,Sigmoid的输出是一个值。用 N N N表示样本数量, p n p_n pn表示预测第 n n n个样本为正例的概率, y n y_n yn表示第 n n n个样本的标签,则BCELoss计算公式为:
l o s s = − 1 N ∑ n = 1 N y n × l o g ( p n ) + ( 1 − y n ) × l o g ( 1 − p n ) loss=-\frac{1}{N}\sum_{n=1}^{N}y_n×log(p_n)+(1-y_n)×log(1-p_n) loss=N1n=1Nyn×log(pn)+(1yn)×log(1pn)
那么对于多标签分类呢?BCELoss会计算每一个维度上的损失然后求平均。
举个例子,假如模型某个输出是[0.2,0.6,0.8],真实值是[0,0,1],那么该样本损失可以计算如下:
a = 0 × l n ( 0.2 ) + 1 × l n ( 1 − 0.2 ) b = 0 × l n ( 0.6 ) + 1 × l n ( 1 − 0.6 ) c = 1 × l n ( 0.8 ) + 0 × l n ( 1 − 08 ) l o s s = ( a + b + c ) / 3 a=0×ln(0.2)+1×ln(1-0.2)\\ b=0×ln(0.6)+1×ln(1-0.6)\\ c=1×ln(0.8)+0×ln(1-08)\\ loss=(a+b+c)/3 a=0×ln(0.2)+1×ln(10.2)b=0×ln(0.6)+1×ln(10.6)c=1×ln(0.8)+0×ln(108)loss=(a+b+c)/3
这只是单个样本的损失,最后还需要求所有样本损失的平均值。但是你就不用管了,只需要知道多标签分类用Sigmoid+BCELoss就可以完成损失计算。还有一个函数叫BCEWithLogitsLoss,是Sigmoid和BCELoss的结合。如果损失函数用这个,Sigmoid就可以不用。

2 准确率计算

依然是上面的例子,模型的输出是[0.2,0.6,0.8],真实值是[0,0,1]。准确率该怎么计算呢?

pred = torch.tensor([0.2, 0.6, 0.8])
y = torch.tensor([0, 0, 1])
accuracy = (pred.ge(0.5) == y).all().int().item()
accuracy
# output : 0

首先ge函数将pred中大于等于0.5的转化为True,小于0.5的转化成False,再比较pred和y(必须所有维度都相同才算分类准确),最后将逻辑值转化为整数输出即可。
训练时都是按照一个batch计算的,那就写一个循环吧。

pred = torch.tensor([[0.2, 0.5, 0.8], [0.4, 0.7, 0.1]])
y = torch.tensor([[0, 0, 1], [0, 1, 0]])
accuracy = sum(row.all().int().item() for row in (pred.ge(0.5) == y))
accuracy
# output : 1

3 完整代码

from sklearn.datasets import make_multilabel_classification
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_splitdef get_dataset():X, y = make_multilabel_classification(n_samples=1000,n_features=10,n_classes=3,n_labels=2,random_state=1)return X,yn_inputs, n_outputs = X.shape[1], y.shape[1]
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.33,random_state=42)
X_train = torch.from_numpy(X_train).float()
X_test = torch.from_numpy(X_test).float()
y_train = torch.from_numpy(y_train).float()
y_test = torch.from_numpy(y_test).float()train_data=[(X,y) for X,y in zip(X_train,y_train)]
train_loader = DataLoader(train_data, batch_size=64,shuffle=True)class MLP(nn.Module):def __init__(self, n_inputs, n_outputs, num_hiddens):super(MLP, self).__init__()self.linear_relu_stack = nn.Sequential(nn.Linear(n_inputs, num_hiddens),nn.ReLU(),nn.Linear(num_hiddens, n_outputs), nn.Sigmoid())def forward(self, x):outputs = self.linear_relu_stack(x)return outputsnum_hiddens = 30
model = MLP(n_inputs, n_outputs, num_hiddens)
print(model)loss = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)def evaluate_accuracy(X, y, model):pred = model(X)correct = sum(row.all().int().item() for row in (pred.ge(0.5) == y))n = y.shape[0]return correct / ndef train(train_loader, X_test, y_test, model, loss, num_epochs, batch_size,optimizer):batch_count = 0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_loader:pred = model(X)l = loss(pred, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.item()train_acc_sum += sum(row.all().int().item()for row in (pred.ge(0.5) == y))n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(X_test, y_test, model)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n,test_acc))num_epochs, batch_size = 20, 64
train(train_loader, X_test, y_test, model, loss, num_epochs, batch_size,optimizer)

在这里插入图片描述


http://chatgpt.dhexx.cn/article/3J7PqtUs.shtml

相关文章

caffe训练分类模型教程

caffe训练分类模型教程 1.已有图像存放在train和val下,book和not-book(两类)的图片数量相同 在caffe/data下新建一個myself文件夾,并新建两个文件夹分别命名为train和val 批量重命名图片 # -*- coding:utf8 -*- import os class …

金融风控实战——模型融合

过采样方法使用条件 (1)负样本可以代表样本空间 (2)数据是足够干净的(样本、特征没有噪声) 过拟合 (1)增多数据 (2)特征筛选 (3)调参…

【推荐算法】ctr预估模型总结(LR、FM、FFM、NFM、AFM、WDL、DCN、DeepFM、FwFM、FLEN)

文章目录 前言LRPOLY2FM(Factorization Machine)FFM(Field-aware Factorization Machine)AFM(Attention Factorization Machine)NFM(Neural Factorization Machine)WDL(w…

概率图模型 —— 串连 NB、LR、MEM、HMM、CRF

概率图模型(PGM),作为机器学习的重要分支,能串连起很多传统模型,比如 NB、LR、MEM、HMM、CRF、DBN 等。本篇文章,从串连多个模型的角度,来谈谈 PGM,顺便把这些模型回顾下。 1 Why PG…

基于GBDT+LR模型的深度学习推荐算法

GBDTLR算法最早是由Facebook在2014年提出的一个推荐算法,该算法分两部分构成,第一部分是GBDT,另一部分是LR.下面先介绍GBDT算法,然后介绍如何将GBDT和LR算法融合 1.1 GBDT算法 GBDT的全称是 Gradient Boosting Decision Tree&am…

Logistic逻辑回归模型(LR)基础

逻辑回归(Logistic Regression, LR)模型其实仅在线性回归的基础上,套用了一个逻辑函数,但也就由于这个逻辑函数,使得逻辑回归模型成为了机器学习领域一颗耀眼的明星,更是计算广告学的核心。本文主要详述逻辑回归模型的基础&#x…

模型压缩一-知识蒸馏

一、知识蒸馏简介 知识蒸馏是模型压缩方法中的一个大类,是一种基于“教师-学生网络(teacher-student-network)思想”的训练方法, 其主要思想是拟合教师模型(teacher-model)的泛化性等(如输出概率…

推荐系统之GBDT+LR

前言 前面讲过的FM与FFM模型虽然增强了模型的交叉能力,但是不管怎样都只能做二阶的交叉,如果想要继续加大特征交叉的维度,那就会出大计算爆炸的情况。所以Facebook提出了梯度提升树(GBDT)逻辑回归(LR&…

使用Keras进行单模型多标签分类

原文:https://www.pyimagesearch.com/2018/05/07/multi-label-classification-with-keras/ 作者:Adrian Rosebrock 时间:2018年5月7日 源码:https://pan.baidu.com/s/1x7waggprAHQDjalkA-ctvg (wa61) 译者&…

LR模型常见问题小议

 LR模型常见问题小议 标签: LR机器学习 2016-01-10 23:33 671人阅读 评论(0) 收藏 举报 本文章已收录于: 分类: 机器学习(10) 作者同类文章 X 版权声明:本文为博主原创文章&…

信用评分卡(A卡) 基于LR模型的数据处理及建模过程

数据来自:魔镜杯风控算法大赛(拍拍贷)。有关数据的具体描述可以看比赛页面。 0. 数据集的关键字段及描述: Master:每一行代表一个样本(一笔成功成交借款),每个样本包含200多个各类…

机器分类---LR分类+模型评估

文章目录 数据集ROC曲线与AUC理论知识曲线理解实例计算 代码 更详细的数据集介绍(有图形分析,应该比较好理解) https://blog.csdn.net/weixin_42567027/article/details/107416002 数据集 数据集有三个类别,每个类别有50个样本。…

python机器学习算法(赵志勇)学习笔记( Logistic Regression,LR模型)

Logistic Regression(逻辑回归) 分类算法是典型的监督学习,分类算法通过对训练样本的学习,得到从样本特征到样本的标签之间的映射关系,也被称为假设函数,之后可利用该假设函数对新数据进行分类。 通过训练数据中的正负样本,学习样本特征到样本标签之间的假设函数,Log…

推荐系统实战中LR模型训练(二)

背景: 上一篇推荐系统实战中LR模型训练(一) 中完成了LR模型训练的代码部分。本文中将详细讲解数据准备部分,即将文本数据数值化为稀疏矩阵的形式。 文本数据: 稀疏矩阵: 实现过程: 文本数据格…

机器学习 | LR逻辑回归模型

逻辑回归(Logistic Regression,简称LR)名为“回归”却是用来分类工作、在线性数据上表现优异的分类器。 视频教程:第07讲:逻辑回归是线性分类器的佼佼者 LR是数据挖掘领域常用的一种分类模型,常用于解决二分类问题,例如垃圾邮件判定、经济预测、疾病诊断(通过年龄、性…

推荐系统实战中LR模型训练(一)

背景: 在“批量导入数据到Redis” 中已经介绍了将得到的itema item1:score1,item2:score2…批量导入到Redis数据库中。本文的工作是运用机器学习LR技术,抽取相应的特征,进行点击率的估计。 点击率(Click-Through-Rate, CTR&#…

Prometheus TSDB存储原理

Python微信订餐小程序课程视频 https://blog.csdn.net/m0_56069948/article/details/122285951 Python实战量化交易理财系统 https://blog.csdn.net/m0_56069948/article/details/122285941 Prometheus 包含一个存储在本地磁盘的时间序列数据库,同时也支持与远程…

数据库必知必会:TiDB(8)TiDB 数据库 SQL 执行流程

数据库必知必会:TiDB(8)TiDB 数据库 SQL 执行流程 数据库 SQL 执行流程DML语句读流程概述SQL的Parse与CompileSQL的Execute DML语句写流程概述执行 DDL语句流程概要执行 知识点回顾 数据库 SQL 执行流程 在TiDB中三个重要组件: …

时不我待,TSDB崛起正当时

近期有小伙伴问Jesse,为什么你们要在现在这个时点做TSDB,这是个好时点吗?我认为这是个挺好的问题,因为再强的个人也比不上一个团队,再牛的团队也需要顺势而为。我们其实一直在深度思考“Why Now”的问题,因…

时间序列数据库TSDB排名

DB-Engines 中时序列数据库排名 我们先来看一下DB-Engines中关于时序列数据库的排名,这是当前(2016年2月的)排名情况: 下面,我们就按照这个排名的顺序,简单介绍一下这些时序列数据库中的一些。下面要介绍的…