[Tensorflow2] 梯度反转层(GRL)与域对抗训练神经网络(DANN)的实现

article/2025/11/8 19:02:41

文章目录

  • 概述
  • 原理回顾 (可跳过)
  • GRL 层实现
  • 使用 GRL 的域对抗(DANN)模型实现
  • DANN 的使用案例 !!!
  • 后记

概述

域对抗训练(Domain-Adversarial Training of Neural Networks,DANN)属于广义迁移学习的一种, 可以矫正另一个域的数据集的分布, 也可以看成是一种特殊的对抗式生成网络(GAN), 他也确实和 GAN 在差不多的时间发表. 此文主要记录如何在 Tensorflow2::keras API 下实现 DANN, 可以直接看使用案例.

原理回顾 (可跳过)

域对抗训练技术(Domain-Adversarial Training of Neural Networks,DANN)是一种简单高效的无监督域自适应方法.
DANN 架构
如上图, DANN 的做法是:首先通过一个深度网络来提取出某个域的输入数据的高层抽象特征(绿色部分),再通过一个域分类器(红色部分)对域进行分类;绿色部分的任务是想要学习到能骗过红色部分的特征,让域分类器不能区分输入数据来自于哪一个域,也就做到了把不同域的数据映射到同一个空间。绿色部分和红色部分可以看成分别构成了对抗神经网络的 Generator 与 Discriminator,不过在此 DANN中是插入了一个梯度反转层(Gradient Reversal Layer,GRL)对传播的梯度乘以一个负常数,来对绿色网络进行梯度上升并对红色部分实行梯度下降。然而网络很可能会学习到只输出一个常数向量的能力,此时不管什么域的输入数据都映射为了同一个向量,域分类器虽然分不出域,但对后续的分类任务非常不利;为了防止此类情况的发生,需要再引入标签分类器(蓝色部分)来约束特征提取网络,在训练阶段它需要对有标签的源域数据进行目标分类。

根据原文, 下面对 DANN 的损失函数与训练过程进行定量阐述。对于某一批输入 x i , i = 1 ⋯ N \boldsymbol{x}_i, i=1\cdots N xi,i=1N,假设前 n n n 个为有标签的源域数据,后 N − n N-n Nn 个为无标签的目标域数据。源域数据样本 x i \boldsymbol{x}_i xi 一次前向传播后,由上图中网络的两个输出部分可以得到两个损失:
L y i ( θ f , θ y ) = L y ( G y ( G f ( x i ; θ f ) ; θ y ) , y i ) L d i ( θ f , θ d ) = L d ( G d ( G f ( x i ; θ f ) ; θ d ) , d i ) \begin{array}{c} \mathcal{L}_{y}^{i}\left(\theta_{f}, \theta_{y}\right) = \mathcal{L}_{y}\left(G_{y}\left(G_{f}\left(\mathbf{x}_{i} ; \theta_{f}\right) ; \theta_{y}\right), y_{i}\right) \\ \mathcal{L}_{d}^{i}\left(\theta_{f}, \theta_{d}\right) = \mathcal{L}_{d}\left(G_{d}\left(G_{f}\left(\mathbf{x}_{i} ; \theta_{f}\right) ; \theta_{d}\right), d_{i}\right) \end{array} Lyi(θf,θy)=Ly(Gy(Gf(xi;θf);θy),yi)Ldi(θf,θd)=Ld(Gd(Gf(xi;θf);θd),di)
其中 G f , y , d ( ⋅ ) G_{f,y,d}(\cdot) Gf,y,d() 表示了三个网络,而 θ f , y , d \theta_{f,y,d} θf,y,d 为它们的参数; L y , d \mathcal{L}_{y,d} Ly,d 是标签预测网络和域分类网络的损失函数,一般选为交叉熵损失函数,将其带上上标 i i i 后表示为该样本一次传播的损失。对于目标域样本,则不含有 L y i \mathcal{L}_{y}^i Lyi 这一项。基于上述的对抗思想,总体的损失函数为:
E ( θ f , θ y , θ d ) = 1 n ∑ i = 1 n L y i ( θ f , θ y ) − λ ( 1 n ∑ i = 1 n L d i ( θ f , θ d ) + 1 n ′ ∑ i = n + 1 N L d i ( θ f , θ d ) ) E\left(\theta_{f}, \theta_{y}, \theta_{d}\right)=\frac{1}{n} \sum_{i=1}^{n} \mathcal{L}_{y}^{i}\left(\theta_{f}, \theta_{y}\right)-\lambda\left(\frac{1}{n} \sum_{i=1}^{n} \mathcal{L}_{d}^{i}\left(\theta_{f}, \theta_{d}\right)+\frac{1}{n^{\prime}} \sum_{i=n+1}^{N} \mathcal{L}_{d}^{i}\left(\theta_{f}, \theta_{d}\right)\right) E(θf,θy,θd)=n1i=1nLyi(θf,θy)λ(n1i=1nLdi(θf,θd)+n1i=n+1NLdi(θf,θd))
其中正数 λ \lambda λ 为梯度反转层的超参数,其值不能太大。如果 λ \lambda λ 过大相当于给红色部分给予很高的权重,驱使其将域分类完全分错,过犹不及,并不符合分不出域(即希望输出为一个均匀分布的向量)的初衷。综上,训练时 DANN 实际的优化目标为:
( θ ^ f , θ ^ y ) = argmin ⁡ θ f , θ y E ( θ f , θ y , θ ^ d ) θ ^ d = argmax ⁡ θ d E ( θ ^ f , θ ^ y , θ d ) \begin{aligned} \left(\hat{\theta}_{f}, \hat{\theta}_{y}\right) &=\underset{\theta_{f}, \theta_{y}}{\operatorname{argmin}} E\left(\theta_{f}, \theta_{y}, \hat{\theta}_{d}\right) \\ \hat{\theta}_{d} &=\underset{\theta_{d}}{\operatorname{argmax}} E\left(\hat{\theta}_{f}, \hat{\theta}_{y}, \theta_{d}\right) \end{aligned} (θ^f,θ^y)θ^d=θf,θyargminE(θf,θy,θ^d)=θdargmaxE(θ^f,θ^y,θd)
其中的最小化与最大化构成了一组对抗任务。
下图是作者原文中效果不错的一次实验结果:
t-SNE 特征降维可视化

GRL 层实现

可以直接用于 keras 的 Functional API 编程, 因为大家都是直接继承的 layers.Layer, 初始化时需要给定 lambda_ 即梯度要乘以的负数.

class GradientReversalLayer(tf.keras.layers.Layer):"""The gradient reversal layer is a layer that multiplies the gradient by a negative constant duringbackpropagation.Args:lambda_: Float32, the constant by which the gradient is multiplied. It should be a negative number."""def __init__(self, lambda_: float = -1):super().__init__(trainable=False, name="gradient_reversal_layer")self.lambda_ = tf.constant(lambda_, dtype=tf.float32)  # Normally, a negative valuedef call(self, x, **kwargs):return self.grad_reversed(x)@tf.custom_gradientdef grad_reversed(self, x):"""It returns input and a custom gradient function.Args:x: The input tensor.Returns:the input x and the custom gradient function."""def custom_gradient(dy):return self.lambda_ * dyreturn x, custom_gradientdef get_config(self):config = super().get_config().copy()config.update({"lambda": float(self.lambda_.numpy())})return config

主要参考:
https://stackoverflow.com/questions/56841166/how-to-implement-gradient-reversal-layer-in-tf-2-0
https://www.tensorflow.org/guide/eager#%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A2%AF%E5%BA%A6

使用 GRL 的域对抗(DANN)模型实现

修改两个分类器可以自己实现, 这里只用了一层全连接直接 softmax. 参数中需要给定一个 feature_extractor 即特征提取网络, 比如可以来自于 tf.keras.applications.ResNet50 并指定 include_top=false, 这样会去掉最上层的 softmax 而输出高维特征. 如果是自己的其他的提取器, 想修改的话, 可以见后面.

class DomainAdversarialModel:"""Domain-Adversarial Training of Neural Networks (DANN) in Tensorflow2.Args:feature_extractor: A model of tf.keras.Model, it would have attributes like .input, .outputnum_labels: Int, the number of labels.num_domains: Int, the number of domains.lambda_: Float32, the constant by which the gradient is multiplied. It should be a negative number.Attributes:output layer name of label classifier: "label_predict"output layer name of domain classifier: "domain_predict""""def __init__(self, feature_extractor: tf.keras.Model, num_labels, num_domains, lambda_: float = -1,name_label_classifier="label_predict", name_domain_classifier="domain_predict"):self.feature_extractor = feature_extractor  # has to be a tf.keras.Modelself.num_labels = num_labelsself.num_domains = num_domainsself.lambda_ = lambda_self.name_label_classifier = name_label_classifierself.name_domain_classifier = name_domain_classifierdef get_model(self):feature = self.feature_extractor.outputif len(self.feature_extractor.output_shape) != 2:# make sure feature has a shape of (None, feature_dim). Flatten is important for pytorch, maybe# not necessary for Tensorflow2.keras.feature = tf.keras.layers.Flatten()(feature)# output1 --> label_classifierlabel_predict = self.label_classifier(feature)# output2 --> domain_classifierdomain_predict = GradientReversalLayer(self.lambda_)(feature)domain_predict = self.domain_classifier(domain_predict)return tf.keras.Model(inputs=self.feature_extractor.input,outputs=[label_predict, domain_predict])def label_classifier(self, x):# x = tf.keras.layers.Dense(128, activation='relu')(x)return tf.keras.layers.Dense(self.num_labels, activation='softmax', name=self.name_label_classifier)(x)def domain_classifier(self, x):# x = tf.keras.layers.Dense(128, activation='relu')(x)return tf.keras.layers.Dense(self.num_domains, activation='softmax', name=self.name_domain_classifier)(x)

DANN 的使用案例 !!!

将上面两个模型复制到代码后, 可以如此调用. 首先需要一个特征提取网络, 并且属于 tf.keras.Model 对象. 如果你的模型并不是直接输出 feature 或想指定哪一层作为 feature 的输出, 则可以用第五行那样直接修改. keras 会自动追踪层间的计算关系. 剩下只需要给定标签个数和域的个数, 和梯度反转强度就行了.
如果想用自定义的域分类器和标签分类器, 则可以直接在上一节的两个 xxx_classifier 中自定义.

# Firstly, you should have a feature extraction model of tf.keras.Model, e.g., a tf.keras.applications.ResNet50 with `include_top=false`.
from .SE_ResNeXt_1DCNN import SEResNeXt 
Model = SEResNeXt(...).SEResNeXt50()  # A custom model to be specified by yourself# If the model doesn't output a feature, several top layers have to be removed like:
modified_model = tf.keras.Model(inputs=Model.input, outputs=Model.get_layer(index=-2).output)
DANN = DomainAdversarialModel(feature_extractor=modified_model, num_labels=4, num_domains=3, lambda_=-0.8).get_model()
DANN.summary()
DANN.compile()   # To be specified by yourselfdata = tf.random.normal((1, ...))  # To be specified by yourself
print(DANN(data))

后记

如果无法复制则可以到 github中复制或下载. 另外, 根据李宏毅的视频里提到, 像 GAN 那样交替训练可能会性能更佳, 即先训练一下域分类器, 再训练特征提取网络和标签分类器.


http://chatgpt.dhexx.cn/article/RmV3SqqJ.shtml

相关文章

DANN 领域迁移

DANN(Domain Adaptation Neural Network,域适应神经网络)是一种常用的迁移学习方法,在不同数据集之间进行知识迁移。本教程将介绍如何使用DANN算法实现在MNIST和MNIST-M数据集之间进行迁移学习。 首先,我们需要了解两个…

DANN-经典论文概念及源码梳理

没错,我就是那个为了勋章不择手段的屑(手动狗头)。快乐的假期结束了哭哭... DANN 对抗迁移学习 域适应Domain Adaption-迁移学习;把具有不同分布的源域(Source Domain)和目标域(Target Domain…

EHCache 单独使用

参考: http://macrochen.blogdriver.com/macrochen/869480.html 1. EHCache 的特点,是一个纯Java ,过程中(也可以理解成插入式)缓存实现,单独安装Ehcache ,需把ehcache-X.X.jar 和相关类库方到classpath中…

ehcache 的使用

http://my.oschina.net/chengjiansunboy/blog/70974 在开发高并发量,高性能的网站应用系统时,缓存Cache起到了非常重要的作用。本文主要介绍EHCache的使用,以及使用EHCache的实践经验。 笔者使用过多种基于Java的开源Cache组件,其…

Ehcache 的简单使用

文章目录 Ehcache 的简单使用背景使用版本配置配置项编程式配置XML 配置自定义监听器 验证示例代码 改进代码 备注完整示例代码官方文档 Ehcache 的简单使用 背景 当一个JavaEE-Java Enterprise Edition应用想要对热数据(经常被访问,很少被修改的数据)进行缓存时&…

SpringBoot 缓存(EhCache 使用)

SpringBoot 缓存(EhCache 使用) 源文链接:http://blog.csdn.net/u011244202/article/details/55667868 SpringBoot 缓存(EhCache 2.x 篇) SpringBoot 缓存 在 Spring Boot中,通过EnableCaching注解自动化配置合适的缓存管理器(CacheManager…

shiro框架04会话管理+缓存管理+Ehcache使用

目录 一、会话管理 1.基础组件 1.1 SessionManager 1.2 SessionListener 1.3 SessionDao 1.4 会话验证 1.5 案例 二、缓存管理 1、为什么要使用缓存 2、什么是ehcache 3、ehcache特点 4、ehcache入门 5、shiro与ehcache整合 1)导入相关依赖&#xff0…

使用Ehcache的两种方式(代码、注解)

Ehcache,一个开源的缓存机制,在一些小型的项目中可以有效的担任缓存的角色,分担数据库压力此外,ehcache在使用上也是极为简单, 下面是简单介绍一下ehcahce的本地使用的两种方式: 1,使用代码编写的方式使用…

EhCache常用配置详解和持久化硬盘配置

一、EhCache常用配置 EhCache 给我们提供了丰富的配置来配置缓存的设置; 这里列出一些常见的配置项: cache元素的属性: name:缓存名称 maxElementsInMemory:内存中最大缓存对象数 maxElementsOnDisk&#xff…

EhCache初体验

一、简介 EhCache 是一个纯Java的进程内缓存框架,具有快速、精干等特点。Ehcache是一种广泛使用的开源Java分布式缓存。主要面向通用缓存,Java EE和轻量级容器。它具有内存和磁盘存储,缓存加载器,缓存扩展,缓存异常处理程序,一个gzip缓存servlet过滤器,支…

setw()使用方法

使用setw(n)之前&#xff0c;要使用头文件iomanip 使用方法: #include<iomanip> 1、setw&#xff08;int n&#xff09;只是对直接跟在<<后的输出数据起作用&#xff0c;而在之后的<<需要在之前再一次使用setw&#xff1b; &#xff08;Sets the number of…

c语言iomanip头文件的作用,iomanip头文件的作用

在c程序里面经常见到下面的头文件 #include io代表输入输出&#xff0c;manip是manipulator(操纵器)的缩写(在c上只能通过输入缩写才有效。) 作用(推荐学习&#xff1a;C语言视频教程) 主要是对cin,cout之类的一些操纵运算子&#xff0c;比如setfill,setw,setbase,setprecisio…

QT学习C++(6)

立方体的类设计 设计立方体类&#xff0c;求出立方体的面积(2ad2ac2bc)和体积(a*b*c)&#xff0c;分别用全局函数和成员函数判断两个立方体是否相等&#xff1f; #include <iostream>using namespace std; class Cube{ private://数据&#xff0c;长宽高int c_l;int c_w…

C++中使用setw()使用方法

setw(int n)是c中在输出操作中使用的字段宽度设置&#xff0c;设置输出的域宽&#xff0c;n表示字段宽度。只对紧接着的输出有效&#xff0c;紧接着的输出结束后又变回默认的域宽。当后面紧跟着的输出字段长度小于n的时候&#xff0c;在该字段前面用空格补齐&#xff1b;当输出…

关系代数表达式的优化

查询的处理的代价通常取决于磁盘访问&#xff0c;磁盘访问比内存访问速度慢很多。 在这里由于计算机原理的知识的欠缺&#xff0c;理解起来有点费劲&#xff0c;例如不知道关系的连接在哪里进行&#xff0c;连接的中间结果放在哪里&#xff0c;计算后的结果怎么处理&#xff0c…

关系代数1

转自链接&#xff1a; https://blog.csdn.net/Flora_SM/article/details/84190119 1.查询选修了2号课程的学生的学号。 2.查询至少选修了一门其直接先行课为5号课程的学生姓名 因为是选修直接先行课&#xff0c;所以在Course表里&#xff0c;而学生姓名在Student表里&#xff…

关系代数和SQL语法

数据分析的语言接口 OLAP计算引擎是一架机器&#xff0c;而操作这架机器的是编程语言。使用者通过特定语言告诉计算引擎&#xff0c;需要读取哪些数据、以及需要进行什么样的计算。编程语言有很多种&#xff0c;任何人都可以设计出一门编程语言&#xff0c;然后设计对应的编译…

关系代数表达式练习(针对难题)

教师关系T&#xff08;T#,TNAME,TITLE&#xff09;课程关系C(C#,CNAME,TNO)学生关系S(S#,SNAME,AGE,SEX)选课关系SC(S#,C#,SCORE) 检索至少选修了C2,C4两门课程的学生学号&#xff1a; 这里的下标可以这样理解&#xff0c;课程表C取了别名SC1,SC2,SC1的第一个元素&#xff08;…

怎样用关系代数表达式表示查询要求?求过程

怎样用关系代数表达式表示查询要求&#xff1f; 用一个例子来讲述一下 题目&#xff1a;查询至少选修了全部课程的学生学号和姓名 题目所用到的表如下 题目&#xff1a;查询至少选修了全部课程的学生学号和姓名&#xff1f; ① 找出题目中暗含属性、以及它们所在的表 ② 根据…

关系代数与sql语句

关系代数定义&#xff1a; 关系代数是以关系为运算对象的一组高级运算的集合。关系代数的运算有集合运算&#xff08;集合<表>与集合<表>之间的运算&#xff09;和关系运算&#xff08;集合<表>内部的运算&#xff09; 集合运算&#xff1a; 并运算&#xf…