基于DANN的图像分类任务迁移学习

article/2025/11/8 16:07:47

注:本博客的数据和任务来自NTU-ML2020作业,Kaggle网址为Kaggle.

数据预处理

我们要进行迁移学习的对象是10000张32x32x3的有标签正常照片,共有10类,和另外100000张人类画的手绘图,28x28x1黑白照片,类别也是10类但无标签。我们希望做到,让模型从有标签的原始分布数据中学到的知识能应用于无标签的,相似但与原始分布不相同的目标分布中,并提高黑白手绘图的正确率。
为此,训练前还要对数据做预处理。首先让原始分布的图像和目标分布的图像尽可能相似,我们要做有色图转灰度图,然后做边缘检测。为了模型的输入维度相同,要把28x28转为32x32.此外还可以增加一些平移旋转来让学习更鲁棒。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt# 在transform中使用转灰度-canny边缘提取-水平移动-小幅度旋转-转张量操作source_transform = transforms.Compose([transforms.Grayscale(),transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),transforms.ToPILImage(),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15, fill=(0,)),transforms.ToTensor(),
])
target_transform = transforms.Compose([transforms.Grayscale(),transforms.Resize((32, 32)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15, fill=(0,)),transforms.ToTensor(),
])# 读取数据集,分为source和target两部分source_dataset = ImageFolder('E:/real_or_drawing/train_data', transform=source_transform)
target_dataset = ImageFolder('E:/real_or_drawing/test_data', transform=target_transform)source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

DANN

Domain-Adversarial Training of NNs,值域对抗学习。这种算法是我们这里将要用的迁移学习方法,它被提出的起因是让CNN能够同时用于不同分布的数据,如果模型直接接收原值域的数据分布进行训练,即使原分布和目标分布有类似的地方,在接收目标值域的数据时,也会出现相当异常的特征提取和分类结果。我们可以理解为是模型在源数据分布上出现了过拟合(并不是对数据的过拟合),在接收一些没有见到过的数据时自然会表现不佳。
在这里插入图片描述
解决这个问题最好的办法就是让模型在训练时也接收目标数据分布的数据。但是目标数据分布是无标签的,我们要用什么标准来训练模型呢?回忆CNN的架构,CNN使用卷积-池化的特征提取层来提取图片特征,后接全连接层进行预测。我们只需要让特征提取层既能提取原数据分布的特征,又能提取目标数据分布的特征,这样全连接层就能对两种值域但具有相同特征的数据进行同样的分类,从而目标数据分布的输入也很有可能被正确分类。
在这里插入图片描述
那么问题就变成了如何训练输入两个不同分布的数据,输出却是同种分布的特征提取层。回忆GAN的架构,我们让分布朝着源数据分布发展的方法是建立判别器,让判别器能分辨两种数据,而让生成器改变参数骗过判别器。这里也可以用同样的思想,我们建立能分辨原始分布和目标分布的二分类判别器,把特征提取层和二分类判别层接在一起。首先训练判别器,让判别器能分辨两类数据分布。然后训练特征提取层,逆梯度更新让特征提取层生成能骗过判别器的数据(目标输出0.5).如此训练多次直到特征提取层能把两种值域的输入变成同种分布的输出。
在这里插入图片描述
但是只是用GAN方法train特征提取层并不明智,因为我们的目标输出只有0-1的二分类,训练很有可能只是让特征提取层提取到一些没有用的特征。因此我们要一边训练正常的标签预测任务,一边训练判别器的判别任务和混淆两类输入的任务。这可能需要自己定义特殊的loss function

最后,我们就获得了能同时提取两个值域的特征的特征提取层,它后面的多分类层就可以对目标分布的数据做出还算称心如意的预测。

模型、训练、测试代码

这里使用类VGG(用多个3x3的卷积核代替大型卷积核以节约参数)的搭建方式,写一个高度卷积的特征提取层

class FeatureExtractor(nn.Module):def __init__(self):super(FeatureExtractor, self).__init__()self.conv = nn.Sequential(nn.Conv2d(1, 64, 3, 1, 1),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(64, 128, 3, 1, 1),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(128, 256, 3, 1, 1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(256, 256, 3, 1, 1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(256, 512, 3, 1, 1),nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2))def forward(self, x):x = self.conv(x).squeeze()return x#值域分类器,即GAN中的discriminator
class DomainClassifier(nn.Module):def __init__(self):super(DomainClassifier, self).__init__()self.layer = nn.Sequential(nn.Linear(512, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Linear(512, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Linear(512, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Linear(512, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Linear(512, 1),)def forward(self, h):y = self.layer(h)return y#标签预测器,对特征作进一步分类
class LabelPredictor(nn.Module):def __init__(self):super(LabelPredictor, self).__init__()self.layer = nn.Sequential(nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, h):c = self.layer(h)return c
feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()
domain_classifier = DomainClassifier().cuda()# 多分类使用交叉熵损失进行训练
class_criterion = nn.CrossEntropyLoss()
# domain_classifier的输出是1维,要先sigmoid转概率再计算交叉熵,使用BCEWithlogits
domain_criterion = nn.BCEWithLogitsLoss()# 使用adam训练
optimizer_F = optim.Adam(feature_extractor.parameters())
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(domain_classifier.parameters())

我们训练200个epoch,让数据尽量收敛

def train_epoch(source_dataloader, target_dataloader, lamb):'''Args:source_dataloader: source data的dataloadertarget_dataloader: target data的dataloaderlamb: 对抗的lamb系数'''# D loss: Domain Classifier的loss# F loss: Feature Extrator & Label Predictor的loss# total_hit: 計算目前對了幾筆 total_num: 目前經過了幾筆running_D_loss, running_F_loss = 0.0, 0.0total_hit, total_num = 0.0, 0.0for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):source_data = source_data.cuda()source_label = source_label.cuda()target_data = target_data.cuda()# 把source data和target data混在一起,否则batch_norm会出错mixed_data = torch.cat([source_data, target_data], dim=0)# 设置判别器的目标标签domain_label = torch.zeros([source_data.shape[0] + target_data.shape[0], 1]).cuda()domain_label[:source_data.shape[0]] = 1# Step 1 : 训练Domain Classifierfeature = feature_extractor(mixed_data)# 这里detach feature,因为不需要更新extractor的参数domain_logits = domain_classifier(feature.detach())loss = domain_criterion(domain_logits, domain_label)running_D_loss+= loss.item()loss.backward()optimizer_D.step()# Step 2 : 训练Feature Extractor和Domain Classifierclass_logits = label_predictor(feature[:source_data.shape[0]])domain_logits = domain_classifier(feature)# 这里使用的loss是原值域数据的任务分类交叉熵损失减去,原值域数据和目标值域数据的判别损失# 因为我们想让extractor骗过判别器,判别损失加负号,而且为了调控训练使用lambda作为系数loss = class_criterion(class_logits, source_label) - lamb * domain_criterion(domain_logits, domain_label)running_F_loss+= loss.item()loss.backward()optimizer_F.step()optimizer_C.step()optimizer_D.zero_grad()optimizer_F.zero_grad()optimizer_C.zero_grad()total_hit += torch.sum(torch.argmax(class_logits, dim=1) == source_label).item()total_num += source_data.shape[0]print(i, end='\r')return running_D_loss / (i+1), running_F_loss / (i+1), total_hit / total_num# 训练50 epochs
for epoch in range(50):train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, lamb=0.1)torch.save(feature_extractor.state_dict(), f'extractor_model.bin')torch.save(domain_classifier.state_dict(), f'domain_model.bin')torch.save(label_predictor.state_dict(), f'predictor_model.bin')print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))

训练好之后可以看见原值域上的训练集正确率有98以上,想看手绘图片的正确率可以在Kaggle上提交一下。我们这里随便打印一些手绘图片和模型预测的标签。

feature_extractor.load_state_dict(torch.load('extractor_model.bin'))
domain_classifier.load_state_dict(torch.load('domain_model.bin'))
label_predictor.load_state_dict(torch.load('predictor_model.bin'))for i, (data, _) in enumerate(test_dataloader):breakclass_logits = label_predictor(feature_extractor(data.cuda()))#我们看50张手绘图的预测def no_axis_show(img, title='', cmap=None):fig = plt.imshow(img, interpolation='nearest', cmap=cmap)fig.axes.get_xaxis().set_visible(False)fig.axes.get_yaxis().set_visible(False)plt.title(title)titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
data = data.cuda()
for i in range(50):plt.subplot(5, 10, i+1)label = torch.argmax(class_logits[i]).cpu().detach().numpy()img = data[i].cpu().detach().numpy().reshape(32,32)fig = no_axis_show(img, title=titles[label])

在这里插入图片描述
正确率不能说有多高,但是模型似乎学会了分辨一些特征比较明显的图片。

值域

把特征提取层得到的特征用PCA降维可以在2D平面上看到值域的分布。
在这里插入图片描述
在不使用DANN时,原值域和目标值域是分开的,这样的特征投入全连接层必然不work。但是当我们强制让模型把两种数据的特征混在一起,就变成右图,这时目标值域的特征有机会被正确分类。


http://chatgpt.dhexx.cn/article/fgRaaanu.shtml

相关文章

【ICML 2015迁移学习论文阅读】Unsupervised Domain Adaptation by Backpropagation (DANN) 无监督领域自适应

会议:ICML 2015 论文题目:Unsupervised Domain Adaptation by Backpropagation 论文地址:http://proceedings.mlr.press/v37/ganin15.pdf 论文代码:https://github.com/fungtion/DANN 问题描述:深度学习的模型在source…

【深度域自适应】二、利用DANN实现MNIST和MNIST-M数据集迁移训练

前言 在前一篇文章【深度域自适应】一、DANN与梯度反转层(GRL)详解中,我们主要讲解了DANN的网络架构与梯度反转层(GRL)的基本原理,接下来这篇文章中我们将主要复现DANN论文Unsupervised Domain Adaptation…

【深度域适配】一、DANN与梯度反转层(GRL)详解

CSDN博客原文链接:https://blog.csdn.net/qq_30091945/article/details/104478550 知乎专栏原文链接:https://zhuanlan.zhihu.com/p/109051269 前言 在当前人工智能的如火如荼在各行各业得到广泛应用,尤其是人工智能也因此从各个方面影响当前…

【ICML 2015迁移学习论文阅读】Unsupervised Domain Adaptation by Backpropagation (DANN) 反向传播的无监督领域自适应

会议:ICML 2015 论文题目:Unsupervised Domain Adaptation by Backpropagation 论文地址: http://proceedings.mlr.press/v37/ganin15.pdf 论文代码: GitHub - fungtion/DANN: pytorch implementation of Domain-Adversarial Trai…

Domain Adaptation(领域自适应,MMD,DANN)

Domain Adaptation 现有深度学习模型都不具有普适性,即在某个数据集上训练的结果只能在某个领域中有效,而很难迁移到其他的场景中,因此出现了迁移学习这一领域。其目标就是将原数据域(源域,source domain)尽…

【迁移学习】深度域自适应网络DANN模型

DANN Domain-Adversarial Training of Neural Networks in Tensorflow域适配:目标域与源域的数据分布不同但任务相同下的迁移学习。 模型建立 DANN假设有两种数据分布:源域数据分布 S ( x , y ) \mathcal{S}(x,y) S(x,y)和目标域数据分布 T ( x , y ) …

【深度域自适应】一、DANN与梯度反转层(GRL)详解

前言 在当前人工智能的如火如荼在各行各业得到广泛应用,尤其是人工智能也因此从各个方面影响当前人们的衣食住行等日常生活。这背后的原因都是因为如CNN、RNN、LSTM和GAN等各种深度神经网络的强大性能,在各个应用场景中解决了各种难题。 在各个领域尤其…

Domain-Adversarial Training of Neural Networks

本篇是迁移学习专栏介绍的第十八篇论文,发表在JMLR2016上。 Abstrac 提出了一种新的领域适应表示学习方法,即训练和测试时的数据来自相似但不同的分布。我们的方法直接受到域适应理论的启发,该理论认为,要实现有效的域转移&#…

DANN:Domain-Adversarial Training of Neural Networks

DANN原理理解 DANN中源域和目标域经过相同的映射来实现对齐。 DANN的目标函数分为两部分: 1. 源域分类损失项 2. 源域和目标域域分类损失项 1.源域分类损失项 对于一个m维的数据点X,通过一个隐含层Gf,数据点变为D维: 然后经…

DaNN详解

1.摘要 本文提出了一个简单的神经网络模型来处理目标识别中的域适应问题。该模型将最大均值差异(MMD)度量作为监督学习中的正则化来减少源域和目标域之间的分布差异。从实验中,本文证明了MMD正则化是一种有效的工具,可以为特定图像数据集的SURF特征建立良好的域适应模型。本…

[Tensorflow2] 梯度反转层(GRL)与域对抗训练神经网络(DANN)的实现

文章目录 概述原理回顾 (可跳过)GRL 层实现使用 GRL 的域对抗(DANN)模型实现DANN 的使用案例 !!!后记 概述 域对抗训练(Domain-Adversarial Training of Neural Networks,DANN)属于广义迁移学习的一种, 可以矫正另一个域的数据集的分布, 也可…

DANN 领域迁移

DANN(Domain Adaptation Neural Network,域适应神经网络)是一种常用的迁移学习方法,在不同数据集之间进行知识迁移。本教程将介绍如何使用DANN算法实现在MNIST和MNIST-M数据集之间进行迁移学习。 首先,我们需要了解两个…

DANN-经典论文概念及源码梳理

没错,我就是那个为了勋章不择手段的屑(手动狗头)。快乐的假期结束了哭哭... DANN 对抗迁移学习 域适应Domain Adaption-迁移学习;把具有不同分布的源域(Source Domain)和目标域(Target Domain…

EHCache 单独使用

参考: http://macrochen.blogdriver.com/macrochen/869480.html 1. EHCache 的特点,是一个纯Java ,过程中(也可以理解成插入式)缓存实现,单独安装Ehcache ,需把ehcache-X.X.jar 和相关类库方到classpath中…

ehcache 的使用

http://my.oschina.net/chengjiansunboy/blog/70974 在开发高并发量,高性能的网站应用系统时,缓存Cache起到了非常重要的作用。本文主要介绍EHCache的使用,以及使用EHCache的实践经验。 笔者使用过多种基于Java的开源Cache组件,其…

Ehcache 的简单使用

文章目录 Ehcache 的简单使用背景使用版本配置配置项编程式配置XML 配置自定义监听器 验证示例代码 改进代码 备注完整示例代码官方文档 Ehcache 的简单使用 背景 当一个JavaEE-Java Enterprise Edition应用想要对热数据(经常被访问,很少被修改的数据)进行缓存时&…

SpringBoot 缓存(EhCache 使用)

SpringBoot 缓存(EhCache 使用) 源文链接:http://blog.csdn.net/u011244202/article/details/55667868 SpringBoot 缓存(EhCache 2.x 篇) SpringBoot 缓存 在 Spring Boot中,通过EnableCaching注解自动化配置合适的缓存管理器(CacheManager…

shiro框架04会话管理+缓存管理+Ehcache使用

目录 一、会话管理 1.基础组件 1.1 SessionManager 1.2 SessionListener 1.3 SessionDao 1.4 会话验证 1.5 案例 二、缓存管理 1、为什么要使用缓存 2、什么是ehcache 3、ehcache特点 4、ehcache入门 5、shiro与ehcache整合 1)导入相关依赖&#xff0…

使用Ehcache的两种方式(代码、注解)

Ehcache,一个开源的缓存机制,在一些小型的项目中可以有效的担任缓存的角色,分担数据库压力此外,ehcache在使用上也是极为简单, 下面是简单介绍一下ehcahce的本地使用的两种方式: 1,使用代码编写的方式使用…

EhCache常用配置详解和持久化硬盘配置

一、EhCache常用配置 EhCache 给我们提供了丰富的配置来配置缓存的设置; 这里列出一些常见的配置项: cache元素的属性: name:缓存名称 maxElementsInMemory:内存中最大缓存对象数 maxElementsOnDisk&#xff…