【深度域自适应】二、利用DANN实现MNIST和MNIST-M数据集迁移训练

article/2025/11/8 19:01:23

前言

在前一篇文章【深度域自适应】一、DANN与梯度反转层(GRL)详解中,我们主要讲解了DANN的网络架构与梯度反转层(GRL)的基本原理,接下来这篇文章中我们将主要复现DANN论文Unsupervised Domain Adaptation by Backpropagation中MNIST和MNIST-M数据集的迁移训练实验。


一、MNIST和MNIST-M介绍

为了利用DANN实现MNIST和MNIST-M数据集的迁移训练,我们首先需要获取到MNIST和MNIST-M数据集。其中MNIST数据集很容易获取,官网下载链接为:MNSIT。需要下载的文件如下图所示蓝色的4个文件。
在这里插入图片描述
由于tensorflow和keras深度融合,我们可以通过keras的相关API进行MNIST数据集,如下:

from tensorflow.keras.datasets import mnist# 导入MNIST数据集
(X_train,y_train),(X_test,y_test) = mnist.load_data()

MNIST-M数据集由MNIST数字与BSDS500数据集中的随机色块混合而成。那么要像生成MNIST-M数据集,请首先下载BSDS500数据集。BSDS500数据集的官方下载地址为:BSDS500。
以下是BSDS500数据集官方网址相关截图,点击下图中蓝框的连接即可下载数据。
在这里插入图片描述
下载好BSDS500数据集后,我们必须根据MNIST和BSDS500数据集来生成MNIST-M数据集,生成数据集的脚本create_mnistm.py如下:

# -*- coding: utf-8 -*-
# @Time    : 2021/7/24 下午1:50
# @Author  : Dai Pu wei
# @Email   : 771830171@qq.com
# @File    : create_mnistm.py
# @Software: PyCharmfrom __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import tarfile
import numpy as np
import pickle as pkl
import skimage.io
import skimage.transform
from tensorflow.keras.datasets import mnistrand = np.random.RandomState(42)def compose_image(mnist_data, background_data):"""这是将MNIST数据和BSDS500数据进行融合成MNIST-M数据的函数:param mnist_data: MNIST数据:param background_data: BDSD500数据,作为背景图像:return:"""# 随机融合MNIST数据和BSDS500数据w, h, _ = background_data.shapedw, dh, _ = mnist_data.shapex = np.random.randint(0, w - dw)y = np.random.randint(0, h - dh)bg = background_data[x:x + dw, y:y + dh]return np.abs(bg - mnist_data).astype(np.uint8)def mnist_to_img(x):"""这是实现MNIST数据格式转换的函数,0/1数据位转化为RGB数据集:param x: 0/1格式MNIST数据:return:"""x = (x > 0).astype(np.float32)d = x.reshape([28, 28, 1]) * 255return np.concatenate([d, d, d], 2)def create_mnistm(X,background_data):"""这是生成MNIST-M数据集的函数,MNIST-M数据集介绍可见:http://jmlr.org/papers/volume17/15-239/15-239.pdf:param X: MNIST数据集:param background_data: BSDS500数据集,作为背景:return:"""# 遍历所有MNIST数据集,生成MNIST-M数据集X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)for i in range(X.shape[0]):if i % 1000 == 0:print('Processing example', i)# 随机选择背景图像bg_img = rand.choice(background_data)# 0/1数据位格式MNIST数据转换为RGB格式mnist_image = mnist_to_img(X[i])# 将MNIST数据和BSDS500数据背景进行融合mnist_image = compose_image(mnist_image, bg_img)X_[i] = mnist_imagereturn X_def run_main():"""这是主函数"""# 初始化路径BST_PATH = os.path.abspath('./model_data/dataset/BSR_bsds500.tgz')mnist_dir = os.path.abspath("model_data/dataset/MNIST")mnistm_dir = os.path.abspath("model_data/dataset/MNIST_M")# 导入MNIST数据集(X_train,y_train),(X_test,y_test) = mnist.load_data()# 加载BSDS500数据集f = tarfile.open(BST_PATH)train_files = []for name in f.getnames():if name.startswith('BSR/BSDS500/data/images/train/'):train_files.append(name)print('Loading BSR training images')background_data = []for name in train_files:try:fp = f.extractfile(name)bg_img = skimage.io.imread(fp)background_data.append(bg_img)except:continue# 生成MNIST-M训练数据集和验证数据集print('Building train set...')train = create_mnistm(X_train,background_data)print(np.shape(train))print('Building validation set...')valid = create_mnistm(X_test,background_data)print(np.shape(valid))# 将MNIST数据集转化为RGB格式X_train = np.expand_dims(X_train,-1)X_test = np.expand_dims(X_test,-1)X_train = np.concatenate([X_train,X_train,X_train],axis=3)X_test = np.concatenate([X_test,X_test,X_test],axis=3)y_train = np.array(y_train).astype(np.int32)y_test = np.array(y_test).astype(np.int32)# 保存MNIST数据集为pkl文件if not os.path.exists(mnist_dir):os.mkdir(mnist_dir)with open(os.path.join(mnist_dir, 'mnist_data.pkl'), 'wb') as f:pkl.dump({'train': X_train,'train_label': y_train,'val': X_test,'val_label':y_test}, f, pkl.HIGHEST_PROTOCOL)# 保存MNIST-M数据集为pkl文件if not os.path.exists(mnistm_dir):os.mkdir(mnistm_dir)with open(os.path.join(mnistm_dir, 'mnist_m_data.pkl'), 'wb') as f:pkl.dump({'train': train,'train_label':y_train,'val': valid,'val_label':y_test}, f, pkl.HIGHEST_PROTOCOL)# 计算数据集平均值,用于数据标准化print(np.shape(X_train))print(np.shape(X_test))print(np.shape(train))print(np.shape(valid))print(np.shape(y_train))print(np.shape(y_test))pixel_mean = np.vstack([X_train,train,X_test,valid]).mean((0,1,2))print(np.shape(pixel_mean))print(pixel_mean)if __name__ == '__main__':run_main()

二、参数配置类config

由于整个DANN-MNIST网络的训练过程中涉及到很多超参数,因此为了整个项目的编程方便,我们利用面向对象的思想将所有的超参数放置到一个类中,即参数配置类config。这个参数配置类config的代码如下:

# -*- coding: utf-8 -*-
# @Time    : 2020/2/15 15:05
# @Author  : Dai PuWei
# @Email   : 771830171@qq.com
# @File    : config.py
# @Software: PyCharmimport osclass config(object):__defualt_dict__ = {"pre_model_path":None,"checkpoints_dir":os.path.abspath("./checkpoints"),"logs_dir":os.path.abspath("./logs"),"config_dir":os.path.abspath("./config"),"image_input_shape":(28,28,3),"image_size":28,"init_learning_rate": 1e-2,"momentum_rate":0.9,"batch_size":256,"epoch":500,"pixel_mean":[45.652287,45.652287,45.652287],}def __init__(self,**kwargs):"""这是参数配置类的初始化函数:param kwargs: 参数字典"""# 初始化相关配置参数self.__dict__.update(self. __defualt_dict__)# 根据相关传入参数进行参数更新self.__dict__.update(kwargs)if not os.path.exists(self.checkpoints_dir):os.makedirs(self.checkpoints_dir)if not os.path.exists(self.logs_dir):os.makedirs(self.logs_dir)if not os.path.exists(self.config_dir):os.makedirs(self.config_dir)def set(self,**kwargs):"""这是参数配置的设置函数:param kwargs: 参数字典:return:"""# 根据相关传入参数进行参数更新self.__dict__.update(kwargs)def save_config(self,time):"""这是保存参数配置类的函数:param time: 时间点字符串:return:"""# 更新相关目录self.checkpoints_dir = os.path.join(self.checkpoints_dir,time)self.logs_dir = os.path.join(self.logs_dir,time)self.config_dir = os.path.join(self.config_dir,time)if not os.path.exists(self.config_dir):os.makedirs(self.config_dir)if not os.path.exists(self.checkpoints_dir):os.makedirs(self.checkpoints_dir)if not os.path.exists(self.logs_dir):os.makedirs(self.logs_dir)config_txt_path = os.path.join(self.config_dir,"config.txt")with open(config_txt_path,'a') as f:for key,value in self.__dict__.items():if key in ["checkpoints_dir","logs_dir","config_dir"]:value = os.path.join(value,time)s = key+": "+value+"\n"f.write(s)

三、梯度反转层(GradientReversalLayer)

在DANN中比较重要的模块就是梯度反转层(Gradient Reversal Layer, GRL)的实现。GRL的tf2.x代码实现如下:

import tensorflow as tf
from tensorflow.keras.layers import Layer@tf.custom_gradient
def gradient_reversal(x,alpha=1.0):def grad(dy):return -dy * alpha, Nonereturn x, gradclass GradientReversalLayer(Layer):def __init__(self,**kwargs):"""这是梯度反转层的初始化函数:param kwargs: 参数字典"""super(GradientReversalLayer,self).__init__(kwargs)def call(self, x,alpha=1.0):"""这是梯度反转层的初始化函数:param x: 输入张量:param alpha: alpha系数,默认为1:return:"""return gradient_reversal(x,alpha)

在上述代码中@ops.RegisterGradient(grad_name)修饰 _flip_gradients(op, grad)函数,即自定义该层的梯度取反。同时gradient_override_map函数主要用于解决使用自己定义的函数方式来求梯度的问题,gradient_override_map函数的参数值为一个字典。即字典中value表示使用该值表示的函数代替key表示的函数进行梯度运算。


四、 DANN类代码

DANN论文Unsupervised Domain Adaptation by Backpropagation中给出MNIST和MNIST-M数据集的迁移训练实验的网络,网络架构图如下图所示。
在这里插入图片描述
接下来,我们将利用tensorflow2.4.0来搭建整个DANN-MNIST网络,DANN-MNIST网络结构代码如下:

# -*- coding: utf-8 -*-
# @Time    : 2020/2/14 20:27
# @Author  : Dai PuWei
# @Email   : 771830171@qq.com
# @File    : MNIST2MNIST_M.py
# @Software: PyCharmimport tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Activationdef build_feature_extractor():"""这是特征提取子网络的构建函数:param image_input: 图像输入张量:param name: 输出特征名称:return:"""model = tf.keras.Sequential([Conv2D(filters=32, kernel_size=5,strides=1),#tf.keras.layers.BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=48, kernel_size=5,strides=1),#tf.keras.layers.BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Flatten(),])return modeldef build_image_classify_extractor():"""这是搭建图像分类器模型的函数:param image_classify_feature: 图像分类特征张量:return:"""model = tf.keras.Sequential([Dense(100),#tf.keras.layers.BatchNormalization(),Activation('relu'),#tf.keras.layers.Dropout(0.5),Dense(100,activation='relu'),#tf.keras.layers.Dropout(0.5),Dense(10,activation='softmax',name="image_cls_pred"),])return modeldef build_domain_classify_extractor():"""这是搭建域分类器的函数:param domain_classify_feature: 域分类特征张量:return:"""# 搭建域分类器model = tf.keras.Sequential([Dense(100),#tf.keras.layers.BatchNormalization(),Activation('relu'),#tf.keras.layers.Dropout(0.5),Dense(2, activation='softmax', name="domain_cls_pred")])return model

六、实验结果

下面主要包括了MNIST和MNIST-M数据集在自适应训练过程中学习率、梯度反转层参数 λ \lambda λ、训练集和验证集的图像分类损失、域分类损失、图像分类精度、域分类精度和模型总损失的可视化。

首先是超参数学习率和梯度反转层参数 λ \lambda λ在训练过程中的数据可视化。
在这里插入图片描述

接着是训练数据集和验证数据集的图像分类精度和域分类精度在训练过程中的数据可视化,其中蓝色代表训练集,红色代表验证集。训练精度是在源域数据集即MNIST数据集上的统计结果,验证精度是在目标域数据集即MNIST-M数据集上的统计结果。 由于RTX30显卡的精度高,MNIST和MNIST-M数据集的自适应训练的训练结果稳定在86%左右,比原始论文的81.49%精度高出不少也就在情理之中。
在这里插入图片描述

最后是训练数据集和验证数据集的图像分类损失和域分类损失在训练过程中的数据可视化,其中蓝色代表训练集,红色代表验证集。
在这里插入图片描述

后记

最初实现DANN使用tf1.x框架,后期发现由于GRL的特殊性,tf1.和GRL与复杂网络结构,如YOLO v3之间的适配度较低,因此现已将代码全面升到tf2.x,未如有需要也会支持pytorch。原始tf1.x的项目代码地址为:DANN-MNIST的tf1分支,tf2.x的项目代码地址如下:

  • DANN-MNIST的tf2和master分支(tf2和master分支合并)
  • DANN-MNIST-tf2

欢迎大家在CSDN和Github上一键三连


http://chatgpt.dhexx.cn/article/ihcGpKZI.shtml

相关文章

【深度域适配】一、DANN与梯度反转层(GRL)详解

CSDN博客原文链接:https://blog.csdn.net/qq_30091945/article/details/104478550 知乎专栏原文链接:https://zhuanlan.zhihu.com/p/109051269 前言 在当前人工智能的如火如荼在各行各业得到广泛应用,尤其是人工智能也因此从各个方面影响当前…

【ICML 2015迁移学习论文阅读】Unsupervised Domain Adaptation by Backpropagation (DANN) 反向传播的无监督领域自适应

会议:ICML 2015 论文题目:Unsupervised Domain Adaptation by Backpropagation 论文地址: http://proceedings.mlr.press/v37/ganin15.pdf 论文代码: GitHub - fungtion/DANN: pytorch implementation of Domain-Adversarial Trai…

Domain Adaptation(领域自适应,MMD,DANN)

Domain Adaptation 现有深度学习模型都不具有普适性,即在某个数据集上训练的结果只能在某个领域中有效,而很难迁移到其他的场景中,因此出现了迁移学习这一领域。其目标就是将原数据域(源域,source domain)尽…

【迁移学习】深度域自适应网络DANN模型

DANN Domain-Adversarial Training of Neural Networks in Tensorflow域适配:目标域与源域的数据分布不同但任务相同下的迁移学习。 模型建立 DANN假设有两种数据分布:源域数据分布 S ( x , y ) \mathcal{S}(x,y) S(x,y)和目标域数据分布 T ( x , y ) …

【深度域自适应】一、DANN与梯度反转层(GRL)详解

前言 在当前人工智能的如火如荼在各行各业得到广泛应用,尤其是人工智能也因此从各个方面影响当前人们的衣食住行等日常生活。这背后的原因都是因为如CNN、RNN、LSTM和GAN等各种深度神经网络的强大性能,在各个应用场景中解决了各种难题。 在各个领域尤其…

Domain-Adversarial Training of Neural Networks

本篇是迁移学习专栏介绍的第十八篇论文,发表在JMLR2016上。 Abstrac 提出了一种新的领域适应表示学习方法,即训练和测试时的数据来自相似但不同的分布。我们的方法直接受到域适应理论的启发,该理论认为,要实现有效的域转移&#…

DANN:Domain-Adversarial Training of Neural Networks

DANN原理理解 DANN中源域和目标域经过相同的映射来实现对齐。 DANN的目标函数分为两部分: 1. 源域分类损失项 2. 源域和目标域域分类损失项 1.源域分类损失项 对于一个m维的数据点X,通过一个隐含层Gf,数据点变为D维: 然后经…

DaNN详解

1.摘要 本文提出了一个简单的神经网络模型来处理目标识别中的域适应问题。该模型将最大均值差异(MMD)度量作为监督学习中的正则化来减少源域和目标域之间的分布差异。从实验中,本文证明了MMD正则化是一种有效的工具,可以为特定图像数据集的SURF特征建立良好的域适应模型。本…

[Tensorflow2] 梯度反转层(GRL)与域对抗训练神经网络(DANN)的实现

文章目录 概述原理回顾 (可跳过)GRL 层实现使用 GRL 的域对抗(DANN)模型实现DANN 的使用案例 !!!后记 概述 域对抗训练(Domain-Adversarial Training of Neural Networks,DANN)属于广义迁移学习的一种, 可以矫正另一个域的数据集的分布, 也可…

DANN 领域迁移

DANN(Domain Adaptation Neural Network,域适应神经网络)是一种常用的迁移学习方法,在不同数据集之间进行知识迁移。本教程将介绍如何使用DANN算法实现在MNIST和MNIST-M数据集之间进行迁移学习。 首先,我们需要了解两个…

DANN-经典论文概念及源码梳理

没错,我就是那个为了勋章不择手段的屑(手动狗头)。快乐的假期结束了哭哭... DANN 对抗迁移学习 域适应Domain Adaption-迁移学习;把具有不同分布的源域(Source Domain)和目标域(Target Domain…

EHCache 单独使用

参考: http://macrochen.blogdriver.com/macrochen/869480.html 1. EHCache 的特点,是一个纯Java ,过程中(也可以理解成插入式)缓存实现,单独安装Ehcache ,需把ehcache-X.X.jar 和相关类库方到classpath中…

ehcache 的使用

http://my.oschina.net/chengjiansunboy/blog/70974 在开发高并发量,高性能的网站应用系统时,缓存Cache起到了非常重要的作用。本文主要介绍EHCache的使用,以及使用EHCache的实践经验。 笔者使用过多种基于Java的开源Cache组件,其…

Ehcache 的简单使用

文章目录 Ehcache 的简单使用背景使用版本配置配置项编程式配置XML 配置自定义监听器 验证示例代码 改进代码 备注完整示例代码官方文档 Ehcache 的简单使用 背景 当一个JavaEE-Java Enterprise Edition应用想要对热数据(经常被访问,很少被修改的数据)进行缓存时&…

SpringBoot 缓存(EhCache 使用)

SpringBoot 缓存(EhCache 使用) 源文链接:http://blog.csdn.net/u011244202/article/details/55667868 SpringBoot 缓存(EhCache 2.x 篇) SpringBoot 缓存 在 Spring Boot中,通过EnableCaching注解自动化配置合适的缓存管理器(CacheManager…

shiro框架04会话管理+缓存管理+Ehcache使用

目录 一、会话管理 1.基础组件 1.1 SessionManager 1.2 SessionListener 1.3 SessionDao 1.4 会话验证 1.5 案例 二、缓存管理 1、为什么要使用缓存 2、什么是ehcache 3、ehcache特点 4、ehcache入门 5、shiro与ehcache整合 1)导入相关依赖&#xff0…

使用Ehcache的两种方式(代码、注解)

Ehcache,一个开源的缓存机制,在一些小型的项目中可以有效的担任缓存的角色,分担数据库压力此外,ehcache在使用上也是极为简单, 下面是简单介绍一下ehcahce的本地使用的两种方式: 1,使用代码编写的方式使用…

EhCache常用配置详解和持久化硬盘配置

一、EhCache常用配置 EhCache 给我们提供了丰富的配置来配置缓存的设置; 这里列出一些常见的配置项: cache元素的属性: name:缓存名称 maxElementsInMemory:内存中最大缓存对象数 maxElementsOnDisk&#xff…

EhCache初体验

一、简介 EhCache 是一个纯Java的进程内缓存框架,具有快速、精干等特点。Ehcache是一种广泛使用的开源Java分布式缓存。主要面向通用缓存,Java EE和轻量级容器。它具有内存和磁盘存储,缓存加载器,缓存扩展,缓存异常处理程序,一个gzip缓存servlet过滤器,支…

setw()使用方法

使用setw(n)之前&#xff0c;要使用头文件iomanip 使用方法: #include<iomanip> 1、setw&#xff08;int n&#xff09;只是对直接跟在<<后的输出数据起作用&#xff0c;而在之后的<<需要在之前再一次使用setw&#xff1b; &#xff08;Sets the number of…