DANN 领域迁移

article/2025/11/8 19:00:57

DANN(Domain Adaptation Neural Network,域适应神经网络)是一种常用的迁移学习方法,在不同数据集之间进行知识迁移。本教程将介绍如何使用DANN算法实现在MNIST和MNIST-M数据集之间进行迁移学习。

首先,我们需要了解两个数据集:MNIST和MNIST-M。MNIST是一个标准的手写数字图像数据集,包含60000个训练样本和10000个测试样本。MNIST-M是从MNIST数据集中生成的带噪声的手写数字数据集,用于模拟真实场景下的图像分布差异。

接下来,我们将分为以下步骤来完成这个任务:

1、加载MNIST和MNIST-M数据集

2、构建DANN模型

3、定义损失函数

4、定义优化器

5、训练模型

6、评估模型

加载MNIST和MNIST-M数据集

首先,我们需要下载并加载MNIST和MNIST-M数据集。你可以使用PyTorch内置的数据集类来完成这项任务。

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import RandomSampler, Dataset, DataLoader
from torch.autograd import Functionfrom torchvision import datasets, transformsfrom PIL import Image
from tqdm import tqdm
import numpy as np
import shutil
import os# 工具函数
def adjust_learning_rate(optimizer, epoch):lr = 0.001 * 0.1 ** (epoch // 10)for param_group in optimizer.param_groups:param_group['lr'] = lrreturn lrdef accuracy(output, target, topk=(1,)):maxk = max(topk)batch_size = target.size(0)_, pred = output.topk(maxk, 1, True, True)pred = pred.t()correct = pred.eq(target.view(1, -1).expand_as(pred))res = []for k in topk:correct_k = correct[:k].view(-1).float().sum(0)res.append(correct_k.mul_(100 / batch_size))return resclass mnist_m(Dataset):def __init__(self, root, label_file):super(mnist_m, self).__init__()self.transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])with open(label_file, "r") as f:self.imgs = []self.labels = []for line in f.readlines():line = line.strip("\n").split(" ")img_name, label = line[0], int(line[1])img = Image.open(root + os.sep + img_name)self.imgs.append(self.transform(img.convert("RGB")))self.labels.append(label)def __len__(self):return len(self.labels)def __getitem__(self, index):return self.imgs[index], self.labels[index]def __add__(self, other):passclass AverageMeter(object):def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count# Tensorboard
log_dir = "minist_experiment_1"
remove_log_dir = True
if remove_log_dir and os.path.exists(log_dir):shutil.rmtree(log_dir)# 读取数据
image_size = 28
batch_size = 128
transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])
train_ds = datasets.MNIST(root="mnist", train=True, transform=transform, download=True)
test_ds = datasets.MNIST(root="mnist", train=False, transform=transform, download=True)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
root_path = os.path.join("dataset", "mnist_m")
train_m_ds = mnist_m(os.path.join(root_path, "mnist_m_train"), os.path.join(root_path, "mnist_m_train_labels.txt"))
test_m_ds = mnist_m(os.path.join(root_path, "mnist_m_test"), os.path.join(root_path, "mnist_m_test_labels.txt"))
train_m_dl = DataLoader(train_m_ds, batch_size=batch_size, shuffle=True)
test_m_dl = DataLoader(test_m_ds, batch_size=batch_size, shuffle=False)# 在源域上独立训练CNN模型
class CNN(nn.Module):def __init__(self, num_classes=10):super(CNN, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 32, 5),nn.ReLU(inplace=True),nn.MaxPool2d(2),nn.Conv2d(32, 48, 5),nn.ReLU(inplace=True),nn.MaxPool2d(2),)self.avgpool = nn.AdaptiveAvgPool2d((5, 5))self.classifier = nn.Sequential(nn.Linear(48 * 5 * 5, 100),nn.ReLU(inplace=True),nn.Linear(100, 100),nn.ReLU(inplace=True),nn.Linear(100, num_classes))def forward(self, x):x = x.expand(x.data.shape[0], 3, image_size, image_size)x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x# 用一个5层的神经网络在mnist上使用Adam训练,准确率约为99.3%
cnn_model = CNN()
optimizer = Adam(cnn_model.parameters(), lr=0.001)
Loss = nn.CrossEntropyLoss()
epochs = 5
train_loss = AverageMeter()
test_loss = AverageMeter()
test_top1 = AverageMeter()
train_top1 = AverageMeter()
train_cnt = AverageMeter()
print_freq = 200
cnn_model.cuda()
for epoch in range(epochs):lr = adjust_learning_rate(optimizer, epoch)# writer.add_scalar("lr",lr,epoch)print("lr, epoch", lr, epoch)train_loss.reset()train_top1.reset()train_cnt.reset()test_top1.reset()test_loss.reset()for images, labels in tqdm(train_dl):images = images.cuda()labels = labels.cuda()optimizer.zero_grad()predict = cnn_model(images)losses = Loss(predict, labels)train_loss.update(losses.data, images.size(0))top1 = accuracy(predict.data, labels, topk=(1,))[0]train_top1.update(top1, images.size(0))train_cnt.update(images.size(0), 1)losses.backward()optimizer.step()if train_cnt.count % print_freq == 0:print("Epoch:{}[{}/{}],Loss:[{:.3f},{:.3f}],prec[{:.4f},{:.4f}]".format(epoch, train_cnt.count, len(train_dl),train_loss.val, train_loss.avg,train_top1.val, train_top1.avg))for images, labels in tqdm(test_dl):images = images.cuda()labels = labels.cuda()predict = cnn_model(images)losses = Loss(predict, labels)test_loss.update(losses.data, images.size(0))top1 = accuracy(predict.data, labels, topk=(1,))[0]test_top1.update(top1, images.size(0))print("Epoch:{},val,Loss:[{:.3f}],prec[{:.4f}]".format(epoch, test_loss.avg, test_top1.avg))# writer.add_scalar("train_loss", train_loss.avg, epoch)# writer.add_scalar("test_loss", test_loss.avg, epoch)# writer.add_scalar("train_top1", train_top1.avg, epoch)# writer.add_scalar("test_top1", test_top1.avg, epoch)# 直接用mnist数据集训练的网络识别mnist_m数据集,准确率约为58%.可以看作领域适应方法准确率的下界。test_m_top1 = AverageMeter()test_m_loss = AverageMeter()for images, labels in tqdm(test_m_dl):images = images.cuda()labels = labels.cuda()predict = cnn_model(images)losses = Loss(predict, labels)test_m_loss.update(losses.data, images.size(0))top1 = accuracy(predict.data, labels, topk=(1,))[0]test_m_top1.update(top1, images.size(0))print("Epoch:{},val,Loss:[{:.3f}],prec[{:.4f}]".format(epoch, test_m_loss.avg, test_m_top1.avg))# 直接使用mnist_m训练,准确率约为96%,可以看坐领域适应方法准确率的上界。
train_loss = AverageMeter()
test_loss = AverageMeter()
test_top1 = AverageMeter()
train_top1 = AverageMeter()
train_cnt = AverageMeter()
print_freq = 100
cnn_model.cuda()
epochs = 5
for epoch in range(epochs):lr = adjust_learning_rate(optimizer, epoch)# writer.add_scalar("lr",lr,epoch)train_loss.reset()train_top1.reset()train_cnt.reset()test_top1.reset()test_loss.reset()for images, labels in tqdm(train_m_dl):images = images.cuda()labels = labels.cuda()optimizer.zero_grad()predict = cnn_model(images)losses = Loss(predict, labels)train_loss.update(losses.data, images.size(0))top1 = accuracy(predict.data, labels, topk=(1,))[0]train_top1.update(top1, images.size(0))train_cnt.update(images.size(0), 1)losses.backward()optimizer.step()if train_cnt.count % print_freq == 0:print("Epoch:{}[{}/{}],Loss:[{:.3f},{:.3f}],prec[{:.4f},{:.4f}]".format(epoch, train_cnt.count, len(train_dl),train_loss.val, train_loss.avg,train_top1.val, train_top1.avg))for images, labels in tqdm(test_m_dl):images = images.cuda()labels = labels.cuda()predict = cnn_model(images)losses = Loss(predict, labels)test_loss.update(losses.data, images.size(0))top1 = accuracy(predict.data, labels, topk=(1,))[0]test_top1.update(top1, images.size(0))print("Epoch:{},val,Loss:[{:.3f}],prec[{:.4f}]".format(epoch, test_loss.avg, test_top1.avg))# writer.add_scalar("train_loss", train_loss.avg, epoch)# writer.add_scalar("test_loss", test_loss.avg, epoch)# writer.add_scalar("train_top1", train_top1.avg, epoch)# writer.add_scalar("test_top1", test_top1.avg, epoch)# GRL
# 梯度反转层,这一层正向表现为恒等变换,反向传播是改变梯度的符号,alpha用来平衡域损失的权重。
class GRL(Function):@staticmethoddef forward(ctx, x, alpha):ctx.alpha = alphareturn x.view_as(x)@staticmethoddef backward(ctx, grad_output):output = grad_output.neg() * ctx.alphareturn output, None# DANN
class DANN(nn.Module):def __init__(self, num_classes=10):super(DANN, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 32, 5),nn.ReLU(inplace=True),nn.MaxPool2d(2),nn.Conv2d(32, 48, 5),nn.ReLU(inplace=True),nn.MaxPool2d(2),)self.avgpool = nn.AdaptiveAvgPool2d((5, 5))self.task_classifier = nn.Sequential(nn.Linear(48 * 5 * 5, 100),nn.ReLU(inplace=True),nn.Linear(100, 100),nn.ReLU(inplace=True),nn.Linear(100, num_classes))self.domain_classifier = nn.Sequential(nn.Linear(48 * 5 * 5, 100),nn.ReLU(inplace=True),nn.Linear(100, 2))self.GRL = GRL()def forward(self, x, alpha):x = x.expand(x.data.shape[0], 3, image_size, image_size)x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)task_predict = self.task_classifier(x)x = GRL.apply(x, alpha)domain_predict = self.domain_classifier(x)return task_predict, domain_predict# 使用DANN进行领域迁移训练,使用mnist上的有标签数据和mnist_m上的无标签数据,准确率约为84%.
train_loss = AverageMeter()
train_domain_loss = AverageMeter()
train_task_loss = AverageMeter()
test_loss = AverageMeter()
test_top1 = AverageMeter()
test_domain_acc = AverageMeter()
train_top1 = AverageMeter()
train_cnt = AverageMeter()print_freq = 200
domain_model = DANN()
domain_model.cuda()
domain_loss = nn.CrossEntropyLoss()
task_loss = nn.CrossEntropyLoss()
lr = 0.001
optimizer = Adam(domain_model.parameters(), lr=lr)
epochs = 100for epoch in range(epochs):# lr=adjust_learning_rate(optimizer,epoch)# writer.add_scalar("lr", lr, epoch)train_loss.reset()train_domain_loss.reset()train_task_loss.reset()train_top1.reset()train_cnt.reset()test_top1.reset()test_loss.reset()for source, target in zip(train_dl, train_m_dl):train_cnt.update(images.size(0), 1)p = float(train_cnt.count + epoch * len(train_dl)) / (epochs * len(train_dl))alpha = torch.tensor(2. / (1. + np.exp(-10 * p)) - 1)src_imgs = source[0].cuda()src_labels = source[1].cuda()dst_imgs = target[0].cuda()optimizer.zero_grad()src_predict, src_domains = domain_model(src_imgs, alpha)src_label_loss = task_loss(src_predict, src_labels)src_domain_loss = domain_loss(src_domains, torch.ones(len(src_domains)).long().cuda())_, dst_domains = domain_model(dst_imgs, alpha)dst_domain_loss = domain_loss(dst_domains, torch.zeros(len(dst_domains)).long().cuda())losses = src_label_loss + src_domain_loss + dst_domain_losstrain_loss.update(losses.data, images.size(0))train_domain_loss.update(dst_domain_loss.data, images.size(0))train_task_loss.update(src_label_loss.data, images.size(0))top1 = accuracy(src_predict.data, src_labels, topk=(1,))[0]train_top1.update(top1, images.size(0))losses.backward()optimizer.step()if train_cnt.count % print_freq == 0:print("Epoch:{}[{}/{}],Loss:[{:.3f},{:.3f}],domain loss:[{:.3f},{:.3f}],label loss:[{:.3f},{:.3f}],prec[{:.4f},{:.4f}],alpha:{}".format(epoch, train_cnt.count, len(train_dl), train_loss.val, train_loss.avg,train_domain_loss.val, train_domain_loss.avg,train_task_loss.val, train_task_loss.avg, train_top1.val, train_top1.avg, alpha))for images, labels in tqdm(test_m_dl):images = images.cuda()labels = labels.cuda()predicts, domains = domain_model(images, 0)losses = task_loss(predicts, labels)test_loss.update(losses.data, images.size(0))top1 = accuracy(predicts.data, labels, topk=(1,))[0]domain_acc = accuracy(domains.data, torch.zeros(len(domains)).long().cuda(), topk=(1,))[0]test_top1.update(top1, images.size(0))test_domain_acc.update(domain_acc, images.size(0))print("Epoch:{},val,Loss:[{:.3f}],prec[{:.4f}],domain_acc[{:.4f}]".format(epoch, test_loss.avg, test_top1.avg,test_domain_acc.avg))# writer.add_scalar("train_loss", train_loss.avg, epoch)# writer.add_scalar("test_loss", test_loss.avg, epoch)# writer.add_scalar("train_top1", train_top1.avg, epoch)# writer.add_scalar("test_top1", test_top1.avg, epoch)# writer.add_scalar("test_domain", test_domain_acc.avg, epoch)

运行结果
在这里插入图片描述
需要数据私聊我


http://chatgpt.dhexx.cn/article/4bjWDf4e.shtml

相关文章

DANN-经典论文概念及源码梳理

没错,我就是那个为了勋章不择手段的屑(手动狗头)。快乐的假期结束了哭哭... DANN 对抗迁移学习 域适应Domain Adaption-迁移学习;把具有不同分布的源域(Source Domain)和目标域(Target Domain…

EHCache 单独使用

参考: http://macrochen.blogdriver.com/macrochen/869480.html 1. EHCache 的特点,是一个纯Java ,过程中(也可以理解成插入式)缓存实现,单独安装Ehcache ,需把ehcache-X.X.jar 和相关类库方到classpath中…

ehcache 的使用

http://my.oschina.net/chengjiansunboy/blog/70974 在开发高并发量,高性能的网站应用系统时,缓存Cache起到了非常重要的作用。本文主要介绍EHCache的使用,以及使用EHCache的实践经验。 笔者使用过多种基于Java的开源Cache组件,其…

Ehcache 的简单使用

文章目录 Ehcache 的简单使用背景使用版本配置配置项编程式配置XML 配置自定义监听器 验证示例代码 改进代码 备注完整示例代码官方文档 Ehcache 的简单使用 背景 当一个JavaEE-Java Enterprise Edition应用想要对热数据(经常被访问,很少被修改的数据)进行缓存时&…

SpringBoot 缓存(EhCache 使用)

SpringBoot 缓存(EhCache 使用) 源文链接:http://blog.csdn.net/u011244202/article/details/55667868 SpringBoot 缓存(EhCache 2.x 篇) SpringBoot 缓存 在 Spring Boot中,通过EnableCaching注解自动化配置合适的缓存管理器(CacheManager…

shiro框架04会话管理+缓存管理+Ehcache使用

目录 一、会话管理 1.基础组件 1.1 SessionManager 1.2 SessionListener 1.3 SessionDao 1.4 会话验证 1.5 案例 二、缓存管理 1、为什么要使用缓存 2、什么是ehcache 3、ehcache特点 4、ehcache入门 5、shiro与ehcache整合 1)导入相关依赖&#xff0…

使用Ehcache的两种方式(代码、注解)

Ehcache,一个开源的缓存机制,在一些小型的项目中可以有效的担任缓存的角色,分担数据库压力此外,ehcache在使用上也是极为简单, 下面是简单介绍一下ehcahce的本地使用的两种方式: 1,使用代码编写的方式使用…

EhCache常用配置详解和持久化硬盘配置

一、EhCache常用配置 EhCache 给我们提供了丰富的配置来配置缓存的设置; 这里列出一些常见的配置项: cache元素的属性: name:缓存名称 maxElementsInMemory:内存中最大缓存对象数 maxElementsOnDisk&#xff…

EhCache初体验

一、简介 EhCache 是一个纯Java的进程内缓存框架,具有快速、精干等特点。Ehcache是一种广泛使用的开源Java分布式缓存。主要面向通用缓存,Java EE和轻量级容器。它具有内存和磁盘存储,缓存加载器,缓存扩展,缓存异常处理程序,一个gzip缓存servlet过滤器,支…

setw()使用方法

使用setw(n)之前&#xff0c;要使用头文件iomanip 使用方法: #include<iomanip> 1、setw&#xff08;int n&#xff09;只是对直接跟在<<后的输出数据起作用&#xff0c;而在之后的<<需要在之前再一次使用setw&#xff1b; &#xff08;Sets the number of…

c语言iomanip头文件的作用,iomanip头文件的作用

在c程序里面经常见到下面的头文件 #include io代表输入输出&#xff0c;manip是manipulator(操纵器)的缩写(在c上只能通过输入缩写才有效。) 作用(推荐学习&#xff1a;C语言视频教程) 主要是对cin,cout之类的一些操纵运算子&#xff0c;比如setfill,setw,setbase,setprecisio…

QT学习C++(6)

立方体的类设计 设计立方体类&#xff0c;求出立方体的面积(2ad2ac2bc)和体积(a*b*c)&#xff0c;分别用全局函数和成员函数判断两个立方体是否相等&#xff1f; #include <iostream>using namespace std; class Cube{ private://数据&#xff0c;长宽高int c_l;int c_w…

C++中使用setw()使用方法

setw(int n)是c中在输出操作中使用的字段宽度设置&#xff0c;设置输出的域宽&#xff0c;n表示字段宽度。只对紧接着的输出有效&#xff0c;紧接着的输出结束后又变回默认的域宽。当后面紧跟着的输出字段长度小于n的时候&#xff0c;在该字段前面用空格补齐&#xff1b;当输出…

关系代数表达式的优化

查询的处理的代价通常取决于磁盘访问&#xff0c;磁盘访问比内存访问速度慢很多。 在这里由于计算机原理的知识的欠缺&#xff0c;理解起来有点费劲&#xff0c;例如不知道关系的连接在哪里进行&#xff0c;连接的中间结果放在哪里&#xff0c;计算后的结果怎么处理&#xff0c…

关系代数1

转自链接&#xff1a; https://blog.csdn.net/Flora_SM/article/details/84190119 1.查询选修了2号课程的学生的学号。 2.查询至少选修了一门其直接先行课为5号课程的学生姓名 因为是选修直接先行课&#xff0c;所以在Course表里&#xff0c;而学生姓名在Student表里&#xff…

关系代数和SQL语法

数据分析的语言接口 OLAP计算引擎是一架机器&#xff0c;而操作这架机器的是编程语言。使用者通过特定语言告诉计算引擎&#xff0c;需要读取哪些数据、以及需要进行什么样的计算。编程语言有很多种&#xff0c;任何人都可以设计出一门编程语言&#xff0c;然后设计对应的编译…

关系代数表达式练习(针对难题)

教师关系T&#xff08;T#,TNAME,TITLE&#xff09;课程关系C(C#,CNAME,TNO)学生关系S(S#,SNAME,AGE,SEX)选课关系SC(S#,C#,SCORE) 检索至少选修了C2,C4两门课程的学生学号&#xff1a; 这里的下标可以这样理解&#xff0c;课程表C取了别名SC1,SC2,SC1的第一个元素&#xff08;…

怎样用关系代数表达式表示查询要求?求过程

怎样用关系代数表达式表示查询要求&#xff1f; 用一个例子来讲述一下 题目&#xff1a;查询至少选修了全部课程的学生学号和姓名 题目所用到的表如下 题目&#xff1a;查询至少选修了全部课程的学生学号和姓名&#xff1f; ① 找出题目中暗含属性、以及它们所在的表 ② 根据…

关系代数与sql语句

关系代数定义&#xff1a; 关系代数是以关系为运算对象的一组高级运算的集合。关系代数的运算有集合运算&#xff08;集合<表>与集合<表>之间的运算&#xff09;和关系运算&#xff08;集合<表>内部的运算&#xff09; 集合运算&#xff1a; 并运算&#xf…

关系代数2

转载链接&#xff1a; https://blog.csdn.net/Bruce_why/article/details/46389603 题A 设有如下所示的关系S(S#,SNAME,AGE,SEX)、C(C#,CNAME,TEACHER)和SC(S#,C#,GRADE)&#xff0c;用关系代数表达式表示下列查询语句&#xff1a; (1) 检索“程军”老师所授课程的课程号(C#)…