Pytorch创建多任务学习模型

article/2025/11/8 13:26:03

在机器学习中,我们通常致力于针对单个任务,也就是优化单个指标。但是多任务学习(MTL)在机器学习的许多应用中都取得了成功,从自然语言处理和语音识别到计算机视觉和药物发现。

MTL最著名的例子可能是特斯拉的自动驾驶系统。在自动驾驶中需要同时处理大量任务,如物体检测、深度估计、3D重建、视频分析、跟踪等,你可能认为需要10个以上的深度学习模型,但事实并非如此。

HydraNet介绍

一般来说多任务学的模型架构非常简单:一个骨干网络作为特征的提取,然后针对不同的任务创建多个头。利用单一模型解决多个任务。

上图可以看到,特征提取模型提取图像特征。输出最后被分割成多个头,每个头负责一个特定的情况,由于它们彼此独立可以单独进行微调!

特斯拉的讲演中详细的说明这个模型(youtube:v=3SypMvnQT_s)

多任务学习项目

在本文中,我们将介绍如何在Pytorch中实现一个更简单的HydraNet。这里将使用UTK Face数据集,这是一个带有3个标签(性别、种族、年龄)的分类数据集。

我们的HydraNet将有三个独立的头,它们都是不同的,因为年龄的预测是一个回归任务,种族的预测是一个多类分类问题,性别的预测是一个二元分类任务。

每一个Pytorch 的深度学习的项目都应该从定义Dataset和DataLoader开始。

在这个数据集中,通过图像的名称定义了这些标签,例如UTKFace/30_0_3_20170117145159065.jpg.chip.jpg

  • 30岁是年龄
  • 0为性别(0:男性,1:女性)
  • 3是种族(0:白人,1:黑人,2:亚洲人,3:印度人,4:其他)

所以我们的自定义Dataset可以这样写:

 class UTKFace(Dataset):def __init__(self, image_paths):self.transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])self.image_paths = image_pathsself.images = []self.ages = []self.genders = []self.races = []for path in image_paths:filename = path[8:].split("_")if len(filename)==4:self.images.append(path)self.ages.append(int(filename[0]))self.genders.append(int(filename[1]))self.races.append(int(filename[2]))def __len__(self):return len(self.images)def __getitem__(self, index):img = Image.open(self.images[index]).convert('RGB')img = self.transform(img)age = self.ages[index]gender = self.genders[index]eth = self.races[index]sample = {'image':img, 'age': age, 'gender': gender, 'ethnicity':eth}return sample

简单的做个介绍:

__init__

方法初始化我们的自定义数据集,负责初始化各种转换和从图像路径中提取标签。

__get_item__

将:它将加载一张图像,应用必要的转换,获取标签,并返回数据集的一个元素,也就是说这个方法会返回数据集中的单条数据(单个样本)

然后我们定义dataloader

 train_dataloader = DataLoader(UTKFace(train_dataset), shuffle=True, batch_size=BATCH_SIZE)val_dataloader = DataLoader(UTKFace(valid_dataset), shuffle=False, batch_size=BATCH_SIZE)

下面我们定义模型,这里使用一个预训练的模型作为骨干,然后创建3个头。分别代表年龄,性别和种族。

 class HydraNet(nn.Module):def __init__(self):super().__init__()self.net = models.resnet18(pretrained=True)self.n_features = self.net.fc.in_featuresself.net.fc = nn.Identity()self.net.fc1 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))self.net.fc2 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))self.net.fc3 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 5))]))def forward(self, x):age_head = self.net.fc1(self.net(x))gender_head = self.net.fc2(self.net(x))ethnicity_head = self.net.fc3(self.net(x))return age_head, gender_head, ethnicity_head

forward方法返回每个头的结果。

损失作为优化的基础时十分重要的,因为它将会影响到模型的性能,我们能想到的最简单的事就是地把损失相加:

 L = L1 + L2 + L3

但是我们的模型中

L1:与年龄相关的损失,如平均绝对误差,因为它是回归损失。

L2:与种族相关的交叉熵,它是一个多类别的分类损失。

L3:性别有关的损失,例如二元交叉熵。

这里损失的计算最大问题是损失的量级是不一样的,并且损失的权重也是不相同的,这是一个一直在被深入研究的问题,我们这里暂不做讨论,我们只使用简单的相加,所以我们的一些超参数如下:

 model = HydraNet().to(device=device)ethnicity_loss = nn.CrossEntropyLoss()gender_loss = nn.BCELoss()age_loss = nn.L1Loss()sig = nn.Sigmoid()optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.09)

然后我们训练的循环如下:

 for epoch in range(n_epochs):model.train()total_training_loss = 0for i, data in enumerate(tqdm(train_dataloader)):inputs = data["image"].to(device=device)age_label = data["age"].to(device=device)gender_label = data["gender"].to(device=device)eth_label = data["ethnicity"].to(device=device)optimizer.zero_grad()age_output, gender_output, eth_output = model(inputs)loss_1 = ethnicity_loss(eth_output, eth_label)loss_2 = gender_loss(sig(gender_output), gender_label.unsqueeze(1).float())loss_3 = age_loss(age_output, age_label.unsqueeze(1).float())loss = loss_1 + loss_2 + loss_3loss.backward()optimizer.step()total_training_loss += loss

这样我们最简单的多任务学习的流程就完成了

关于损失的优化

多任务学习的损失函数,对每个任务的损失进行权重分配,在这个过程中,必须保证所有任务同等重要,而不能让简单任务主导整个训练过程。手动的设置权重是低效而且不是最优的,因此,自动的学习这些权重是十分必要的,

Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics cvpr_2018

这篇论文提出,将不同的loss拉到统一尺度下,这样就容易统一,具体的办法就是利用同方差的不确定性,将不确定性作为噪声,进行训练

End-to-End Multi-Task Learning with Attention cvpr_2019

这篇论文提出了一种可以自动调节权重的机制( Dynamic Weight Average),使得权重分配更加合理,大概的意思是每个任务首先计算前个epoch对应损失的比值,然后除以一个固定的值T,进行exp映射后,计算各个损失所占比

最后如果你对多任务学习感兴趣,可以先看看这篇论文:

A Survey on Multi-Task Learning arXiv 1707.08114

从算法建模、应用和理论分析的角度对MTL进行了调查,是入门的最好的资料。

https://avoid.overfit.cn/post/57d4e8712c634fe887247ce66e694f8f

作者:Alessandro Lamberti


http://chatgpt.dhexx.cn/article/IK3UHu4v.shtml

相关文章

多任务学习 Pytorch实现

多任务学习MTL的简单实现,主要是为了理解MTL 代码写得挺烂的,有时间回来整理一下 import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms import numpy as np import matplotlib.pyplot as plt import ma…

【综述】多任务学习

前言 本文对多任务学习(multi-task learning, MTL)领域近期的综述文章进行整理,从模型结构和训练过程两个层面回顾了其发展变化,旨在提供一份 MTL 入门指南,帮助大家快速了解多任务学习的进化史。 1. 什么是多任务学习? 多任务学习…

多任务学习原理与优化

文章目录 一、什么是多任务学习二、为什么我们需要多任务学习三、多任务学习模型演进Hard shared bottom 硬共享Soft shared bottom 软共享软共享: MOE & MMOE软共享: CGC & PLE加入FMMMOE/PLE 的调参ESMM 四、 loss权重1, 利用任务的…

【多任务学习-Multitask Learning概述】

多任务学习-Multitask Learning概述 1.单任务学习VS多任务学习多任务学习的提出多任务学习和单任务学习对比 2.多任务学习共享表示shared representation:多任务学习的优点那么如何衡量两个任务是否相关呢? 当任务之间相关性弱多任务MLP特点总结多任务学…

多任务学习(Multi-Task Learning, MTL)

目录 [显示] 1 背景2 什么是多任务学习?3 多任务学习如何发挥作用? 3.1 提高泛化能力的潜在原因3.2 多任务学习机制3.3 后向传播多任务学习如何发现任务是相关的4 多任务学习可被广泛应用? 4.1 使用未来预测现在4.2 多种表示和度量4.3 时间序…

Tensorflow 多任务学习

之前在caffe上实现了两个标签的多任务学习,如今换到了tensorflow,也想尝试一下,总的来说也不是很复杂。 建立多任务图 多任务的一个特点是单个tensor输入(X),多个输出(Y_1,Y_2...)。因此在定义占位符时要定义多个输出。同样也需要…

多任务学习:Multi-Task Learning as Multi-Objective Optimization

前言 最近在写一篇文章,是一篇深度学习与安全相结合的文章,模型的输出会交给两个损失函数(availability & security)进行损失计算,进而反向传播。起初的想法是直接将两项损失进行加权平均,共同进行反向…

深度学习中的多任务学习(一)

任务学习-Multitask Learning概述 Reference https://blog.csdn.net/u010417185/article/details/83065506 1、单任务学习VS多任务学习 单任务学习:一次只学习一个任务(task),大部分的机器学习任务都属于单任务学习。多任务学习…

C# 多线程八 任务Task的简单理解与运用二

目录 一.Task 1.AsyncState 2.CompletedTask 3.CreationOptions 4.CurrentId 5.Exception 6.Factory 7.Id 8.IsCanceled 9.IsCompleted 10.IsFaulted 11.Status 二.Task<TResult> 1.Result 上篇&#xff1a; C#…

多任务学习(一)

多任务学习 单任务学习 样本之间没有关联性。 缺点&#xff1a;训练出来的模型不具有泛化性&#xff1b;不共享信息使得学习能力下降。 多任务学习 多任务学习的构建原则 建模任务之间的相关性同时对多个任务的模型参数进行联合学习&#xff0c;挖掘其中的共享信息&#…

多任务学习-Multitask Learning概述

2020-02-22 09:59:48 1、单任务学习VS多任务学习 单任务学习&#xff1a;一次只学习一个任务&#xff08;task&#xff09;&#xff0c;大部分的机器学习任务都属于单任务学习。 多任务学习&#xff1a;把多个相关&#xff08;related&#xff09;的任务放在一起学习&#x…

深度学习之----多任务学习

介绍 在机器学习&#xff08;ML&#xff09;中&#xff0c;通常的关注点是对特定度量进行优化&#xff0c;度量有很多种&#xff0c;例如特定基准或商业 KPI 的分数。为了做到这一点&#xff0c;我们通常训练一个模型或模型组合来执行目标任务。然后&#xff0c;我们微调这些模…

深度学习中的多任务学习介绍

在2017年有一篇关于在深度神经网络中多任务学习概述的论文&#xff1a;《An Overview of Multi-Task Learning in Deep Neural Networks》&#xff0c;论文链接为&#xff1a;https://arxiv.org/pdf/1706.05098.pdf&#xff0c;它介绍了在深度学习中多任务学习(Multi-task Lear…

C# 多线程七 任务Task的简单理解与运用一

目录 一.Task 二.Task中的全局队列和本地队列 三.TaskCreationOptions 枚举 四.CancellationTokenSource/CancellationToken 1.延时取消线程 2.立即取消&#xff1a; 五.Task的三种调用方式 为了防止大家被标题误导 写在前面&#xff1a; Task并不是线程 Task的执行需要…

整理学习之多任务学习

如果有n个任务&#xff08;传统的深度学习方法旨在使用一种特定模型仅解决一项任务&#xff09;&#xff0c;而这n个任务或它们的一个子集彼此相关但不完全相同&#xff0c;则称为多任务学习&#xff08;MTL&#xff09; 通过使用所有n个任务中包含的知识&#xff0c;将有助于改…

多任务学习优化总结 Multi-task learning(附代码)

目录 一、多重梯度下降multiple gradient descent algorithm (MGDA) 二、Gradient Normalization (GradNorm) 三、Uncertainty 多任务学习的优势不用说了&#xff0c;主要是可以合并模型&#xff0c;减小模型体积&#xff0c;只用一次推理也可以加快速度。对于任务表现的提升…

经验 | 训练多任务学习(Multi-task Learning)方法总结

点击上方“小白学视觉”&#xff0c;选择加"星标"或“置顶” 重磅干货&#xff0c;第一时间送达 转载于&#xff1a;知乎Anticoder https://zhuanlan.zhihu.com/p/59413549 背景&#xff1a;只专注于单个模型可能会忽略一些相关任务中可能提升目标任务的潜在信息&…

多任务学习(Multi-Task Learning)

转自&#xff1a;https://www.cnblogs.com/zeze/p/8244357.html 1. 前言 多任务学习&#xff08;Multi-task learning&#xff09;是和单任务学习&#xff08;single-task learning&#xff09;相对的一种机器学习方法。在机器学习领域&#xff0c;标准的算法理论是一次学习一…

多任务学习 | YOLOP,一个网络同时完成三大任务

关注并星标 从此不迷路 计算机视觉研究院 公众号ID&#xff5c;ComputerVisionGzq 学习群&#xff5c;扫码在主页获取加入方式 paper: https://arxiv.org/abs/2108.11250 code: https://github.com/hustvl/YOLOP 计算机视觉研究院专栏 作者&#xff1a;Edison_G YOLOP: You Onl…

多任务学习概述

文章目录 前言1 文章信息2 背景、目的、结论2.1 背景2.1.1 多任务的类型分类2.1.1.1 相关任务的分类2.1.1.2 将输入变输出的逆多任务学习2.1.1.3 对抗性多任务学习2.1.1.4 辅助任务提供注意力特征的多任务学习2.1.1.5 附加预测性辅助任务的多任务学习 3 内容与讨论3.1 多任务学…