融合transformer和对抗学习的多变量时间序列异常检测算法TranAD论文和代码解读...

article/2025/8/15 3:35:39

一、前言

今天的文章来自VLDB

TranAD: Deep Transformer Networks for Anomaly Detection in Multivariate Time Series Data

e96a54bcf5f8b7e651cc9e0266bff553.png
  • 论文链接:https://arxiv.org/pdf/2201.07284v6.pdf

  • 代码地址:https://github.com/imperial-qore/TranAD

二、问题

在文章中提出了对于多变量异常检测的几个有挑战性的问题

  1. 缺乏异常的label

  2. 大数据量

  3. 在现实应用中需要尽可能少的推理时间(实时速度要求高)

在本文中,提出了基于transformer的模型TranAD,该模型使用基于注意力机制的序列编码器,利用数据中更广泛的时间趋势快速推断。TranAD使用基于score的自适应来实现鲁棒的多模态特征提取以及通过adversarial training以获得稳定性。此外,模型引入元学习(MAML)允许我们使用有限的数据来训练模型。

三、方法

3.1 问题定义

一个时间序列d28716097400c3fad1dad3139bbcdbeb.png因为是多变量时间序列,每一个X是一个大小为m的向量,即该序列有m个特征。

该工作定义了两种任务

  1. Anomaly Detection(检测):给予一个序列来预测目前时刻的异常情况(0或者1),1代表该数据点是异常的。

  2. Anomaly Diagnosis(诊断): 文中这块用denote which of the modes of the datapoint at the 𝑡-th timestamp are anomalous.来描述,其实就是判断是哪几个维度的特征(mode)导致的实体的异常,诊断到维度模式的程度。

3.2 数据预处理

对数据做normalize,数据的保存形式,是一个实体一个npy文件,维度是(n, featureNum)

2f719176b829e3f3c0524496a91bc6d1.png

对数据进行滑窗,这里对于windowSize之前的数据并不舍去,而是用前面的数据直接复制,代码如下:

windows = []; w_size = model.n_windowfor i, g in enumerate(data): if i >= w_size: w = data[i-w_size:i]else: w = torch.cat([data[0].repeat(w_size-i, 1), data[0:i]])

3.3 模型

先看一下transformer的模型图。

432eb676e9e9acc1121b905f40c7a091.png

模型本身和TranAD除了有两个decoder其他基本上完全一样,这里结构不赘述了,具体看

  • transformer的论文: https://arxiv.org/pdf/1706.03762.pdf

  • 代码:https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html (推荐看官方的trasnformer代码,直接看源码实现)

TranAD也省去了decoder中feed forward后的add&Norm,softmax也更改为sigmoid,文中提到sigmoid是把输出的数据拟合到输入数据的归一化状态中,即【0, 1】范围内。89f2292bd662e00ac2b8d3489356d255.png

这个图其实十分清晰,这个方法最大的创新不在模型本身,甚至模型没什么改动,主要引入了对抗训练的思想 解释下其中的变量

  • W为输入的窗口数据(前面数据预处理页提到了窗口数据的生成)

  • Focus Score是和W一样维度的变量,在第一阶段为0矩阵,第二阶段是通过W和O1的计算得出

  • C W的最后一个窗口数据

这个C变量,文中这样说。

f4ef95e9ca715e59531b05542498e5bc.png

其实看给的代码最好理解

elif 'TranAD' in model.name:l = nn.MSELoss(reduction = 'none')data_x = torch.DoubleTensor(data); dataset = TensorDataset(data_x, data_x)bs = model.batch if training else len(data)dataloader = DataLoader(dataset, batch_size = bs)n = epoch + 1; w_size = model.n_windowl1s, l2s = [], []if training:for d, _ in dataloader:local_bs = d.shape[0]window = d.permute(1, 0, 2) // 这个就是Welem = window[-1, :, :].view(1, local_bs, feats)  // 这个就是Cz = model(window, elem)l1 = l(z, elem) if not isinstance(z, tuple) else (1 / n) * l(z[0], elem) + (1 - 1/n) * l(z[1], elem)if isinstance(z, tuple): z = z[1]l1s.append(torch.mean(l1).item())loss = torch.mean(l1)optimizer.zero_grad()loss.backward(retain_graph=True)optimizer.step()scheduler.step()tqdm.write(f'Epoch {epoch},\tL1 = {np.mean(l1s)}')return np.mean(l1s), optimizer.param_groups[0]['lr']

这里比较大的创新在于第一阶段和第二阶段的训练。

  • 第一阶段:为了更好的重构序列数据,和大部分encoder-decoder模型的作用没有什么不同

  • 第二阶段:引入对抗性训练的思想。

解读这个训练阶段之前,先把模型代码过一下。

class TranAD(nn.Module):def __init__(self, feats):super(TranAD, self).__init__()self.name = 'TranAD'self.lr = lrself.batch = 128self.n_feats = featsself.n_window = 10self.n = self.n_feats * self.n_windowself.pos_encoder = PositionalEncoding(2 * feats, 0.1, self.n_window)encoder_layers = TransformerEncoderLayer(d_model=2 * feats, nhead=feats, dim_feedforward=16, dropout=0.1)self.transformer_encoder = TransformerEncoder(encoder_layers, 1)decoder_layers1 = TransformerDecoderLayer(d_model=2 * feats, nhead=feats, dim_feedforward=16, dropout=0.1)self.transformer_decoder1 = TransformerDecoder(decoder_layers1, 1)decoder_layers2 = TransformerDecoderLayer(d_model=2 * feats, nhead=feats, dim_feedforward=16, dropout=0.1)self.transformer_decoder2 = TransformerDecoder(decoder_layers2, 1)self.fcn = nn.Sequential(nn.Linear(2 * feats, feats), nn.Sigmoid())def encode(self, src, c, tgt):src = torch.cat((src, c), dim=2)src = src * math.sqrt(self.n_feats)src = self.pos_encoder(src)memory = self.transformer_encoder(src)tgt = tgt.repeat(1, 1, 2)return tgt, memorydef forward(self, src, tgt):# Phase 1 - Without anomaly scoresc = torch.zeros_like(src)x1 = self.fcn(self.transformer_decoder1(*self.encode(src, c, tgt)))# Phase 2 - With anomaly scoresc = (x1 - src) ** 2x2 = self.fcn(self.transformer_decoder2(*self.encode(src, c, tgt)))return x1, x2

从代码可以看出来,虽然图中画的Window Encoder对于Decoder1和2来说是shared,但其实是分开的。

分别用transformer_decoder1和transformer_decoder2实现的。

第一阶段   Input Reconstruction

很多传统的encoder-decoder模型经常不能够获取short-term的趋势,会错过一些偏差很小的异常,为了解决这个问题,采用两个阶段的方式来进行重构。

在第一阶段,模型的目标是生成和输入窗口数据近似的reconstruction。对于这个推断的偏差,称之为focus score,有助于Transformer Encoder内部的注意力网络提取时间趋势,关注偏差高的子序列。(理解其实就是用与真实值的残差去拟合偏差高的序列,在第二阶段)

第一阶段下的focus score是0矩阵,与输入windows数据一致,输出O,与W计算L2 loss作为第一阶段的loss函数。

2ee10d698e3b66ff20609c77c17f6105.png

注意代码中的encode的部分,是将C和focus score直接concat再乘以sqrt(featureNum),之后再经过位置编码,其实文中只顺带说了一下而已,没有说为什么这么做,我个人倾向于为了将C和focus score信息放在一起,已concat常见的尝试揉在了一起。

第二阶段

将第一阶段O1与W的L2 loss作为focus score,在进行之前的步骤。这可以在第二阶段调整注意力权重,并为特定输入子序列提供更高的神经网络激活,以提取短期的时间趋势(这句话是文中说的,和我前面的那里理解基本上差不多)。

之后算是比较重点的地方了,引入对抗训练的思想,来设计第二阶段的loss这个地方十分的绕,建议大家读读原文的3.4 Offline Two-Phase Adversarial Training的Evolving Training Objective部分

  • 第二个decoder尽力去区分输入和第一阶段decoder1的重建,所以是max O2这个L2 loss(这里就会有疑问,decoder1的重建不是O1吗?为啥要max ||O2-W||2? 这里我从作者的角度去理解,其实他是一种条件近似转移,因为O1和O2再第一阶段都是近似W,那区别于第一阶段的decoder1的重建O1,其实就是区别于O2)

  • 第一个decoder尽力去通过创建一个接近W的O1来迷惑decoder2(其实就是想让这个focus score接近于0矩阵),其实就是min(W和O1),其实转移也就是min(W和O2),所以得到了下面这个公式

2ae2b47e9710fd39b61c12bad4f9eab5.png

可以分解为:

d5b73e777f0dfc8e5b4676c4adc13844.png

再加上第一阶段的loss,总loss就是:

844c8d9998b1b21cd5548463df708bae.png

这里又对decoder1和decoder2做了明确的解释:

3b8261aa5591fa582c6c4ebdf0586cf1.png

看代码理解

之后我们看下代码里的区别,代码里其实根本就没有算O2,只算了O1

bb4f2ab624f85fb1a65845ccf11e0106.png

在做反向传播,优化参数的时候也只算了L1的总loss,没有算L2

19c91253d3d0b32e2ae04ee27ab28d07.png这里有个z的type判别,因为做了很多消融实验,不是最终的模型,最终就是后面那个L1loss,其中前面有个参数,是epoch+1,参数会随着训练轮数的增加,倒数慢慢变小,即前面的第一阶段的loss慢慢权重减小,而第二阶段的loss慢慢权重增大。

其实这也给出了一个我个人感觉非常合理的解释,因为第二阶段要附属于第一阶段的训练,应该先让O1和O2接近于W,之后才能去用对抗训练,这样才会让第二阶段训练有效,否则就混乱了。

2cde1b13a88a60a7b470371c292f9093.png

现在再理解下这个,就完全明白了:

c5756c23ab8242d3e973bcfe713d8a6f.png测试阶段,引入了阈值自动选择(POT,但代码中没有看到这的设置),以及score的计算

143a7438a0a21e57fb536cc8eb65b1e4.png

四、结果

f232bbd41c5344a537f05a2ed9883c49.png

对于异常的定义,是score大于阈值就是异常。

8a0893e192de2bef97a8cdbb6b015a61.png任意一个维度有异常就算作是异常,感觉这样描述本质上还是单序列的异常检测,没有从根本解决多变量的问题。

4.1 数据集

2a755abb0cc962145332fff14ae29827.png大部分常用数据集

4.2 结果

每一个维度都有reconstruction,并且每个维度都有对应的score,不得不说这个图示还是很清晰的,构建很清晰易懂。278c1f4c022bd4780d165f1a0e5b6f1d.png

后面还有各种实验,参数灵敏度、数据集等实验,这篇paper实验部分还是很满的,整体来说,工作量还是拉满的。

本文也对两种任务都做了实验,异常检测部分不用说了,常规操作,诊断部分采用 HitRate and NDCG两种指标进行root cause的检验。

五、总结和思考

  1. 对于代码,给出了每个对比模型和数据集,可以为后续实验做参考,并给出了整体消融实验的代码,代码还是很全面的,虽然有一些杂乱,但对于一个要做这个方向的同学来说,还是相当于巨人的肩膀的。

  2. 把transformer和对抗性训练放在一起,确实是很新颖的想法

  3. 在代码处O2部分为何省略存疑,以及第二阶段的loss,其实有点套的生硬,为何不都引到O1上,假设引到O1上,那loss1就只剩下L2 loss了,可能在公式上就并非这种对称了。

  4. 存疑的点就是 代码中在训练过程中并未对L2进行训练,这样的话O2是否像理论说的那样输出工作?

  5. 诊断的定义和诊断任务的探索,其实是有一些生硬的,并且也没完全说清楚,当然这篇文章标题是anomaly detection,其实并未将diagnosis算重点,所以这个也可以接受。

  6. 文章的标题是多变量的异常检测,其实虽然可以应用在多变量上,但实际还是单变量的异常来判别是否是多变量的整体实体的异常,本质还是用单变量问题解决多变量(这里可以探究一下,因为最终的score是由loss决定,而loss本身的维度是和输入的window数据一样的维度,意思就是每一个特征维度有一个score,所以其实得到的score还是单变量的score而并非实体的score,所以这里作者也没有探究多变量pattern的情况,可能存在多个变量异常,但只是一个跳变的pattern,不足以让整体异常的情况。)

  7. 代码里也没给出POT的相关代码。

推荐阅读:我的2022届互联网校招分享我的2021总结浅谈算法岗和开发岗的区别互联网校招研发薪资汇总
2022届互联网求职现状,金9银10快变成铜9铁10!!公众号:AI蜗牛车保持谦逊、保持自律、保持进步发送【蜗牛】获取一份《手把手AI项目》(AI蜗牛车著)
发送【1222】获取一份不错的leetcode刷题笔记发送【AI四大名著】获取四本经典AI电子书

http://chatgpt.dhexx.cn/article/OnWZC6I8.shtml

相关文章

【深度学习】深入浅出对抗机器学习(AI攻防)

【深度学习】深入浅出对抗机器学习(AI攻防) 文章目录 1 Attack ML Model概述 2 基本概念 3 攻击分类 4 经典的对抗性样本生成算法 5 经典的对抗防御方法 6 人工智能安全现状概析1 Attack ML Model概述 随着AI时代机器学习模型在实际业务系统中愈发无处不在,模型的安全性也变…

基于噪声伪标签和对抗学习的医学图像分割标注高效学习

目录 背景: 面临问题: 解决方案: 一 没有图像标注对的学习 二 为训练图像生成伪标签 2.1 为训练图像生成伪标签 2.2 VAE-Based鉴别器 2.3 鉴别器引导的发生器信道校准 这里有不太理解 (未写完) 三 从嘈杂的伪标…

【论文推荐】了解《对抗学习》必看的6篇论文(附打包下载地址)

论文推荐 “SFFAI139期来自美国莱斯大学的傅泳淦推荐的文章主要关注于基础研究的对抗学习领域,你可以认真阅读讲者推荐的论文,来与讲者及同行线上交流哦。” 关注文章公众号 回复"SFFAI139"获取本主题精选论文 01 The Lottery Ticket Hypothes…

对抗机器学习——Universal adversarial perturbations

代码地址: https://github.com/LTS4/universal 核心思想: 本文提出一种 universal对抗扰动,universal是指同一个扰动加入到不同的图片中,能够使图片被分类模型误分类,而不管图片到底是什么。示意图: 形…

【强化学习】模仿学习:生成式对抗模仿学习

★★★ 本文源自AI Studio社区精品项目,【点击此处】查看更多精品内容 >>> 模仿学习– 生成式对抗模仿学习 1. 模仿学习 模仿学习(imitation learning)不是强化学习,而是强化学习的一种替代品。模仿学习与强化学习有相同…

最新综述:图像分类中的对抗机器学习

目录 1.引言 2.论文贡献 3.卷积神经网络简介 4.对抗样本和对抗攻击 4.1.1 对抗扰动范围 4.1.2 对抗扰动的可见性 4.1.3 对抗扰动的测量 4.2 对抗攻击的分类 4.2.1 攻击者的影响力 4.2.2 攻击者的知识 4.2.3 安全入侵 4.2.4 攻击的特异性 4.2.5 攻击方法 5.2 防御…

对抗学习DCGAN网络

文章目录 DCGAN教程1. 简介2. 生成对抗网络(Generative Adversarial Networks)2.1 什么是 GAN2.2 什么是 DCGAN 3. DCGAN实现过程3.1 输入3.2 数据3.3 实现3.3.1 权重初始化3.3.2 生成器3.3.3 判别器3.3.4 损失函数和优化器3.3.4 训练3.3.5 结果 DCGAN教…

NLP中的对抗学习VS对比学习-1

文章目录 1 对抗学习的目的是什么?2 embedding是什么?3 对抗训练4 常见的对抗训练方式4.1 FGM4.2 PGD4.3 FreeAT4.4 FreeLB5 对抗训练和constractive learning6 对比学习的history and achievement思维导图链接:https://www.processon.com/mindmap/64159f9ff502f062b5d616be…

深度学习对抗样本的防御方法

作者: 19届 lz 论文:《深度学习对抗样本的防御方法综述》 问题 2013年 ,Szegedy 等 人 [1]首先通过添加轻微扰动来干扰输入样本,使基于深度神经网络(Deep neural network, DNN)的图片识别系统输出攻击者想…

对抗机器学习模型

重磅推荐专栏: 《Transformers自然语言处理系列教程》 手把手带你深入实践Transformers,轻松构建属于自己的NLP智能应用! 1. Attack ML Model 随着AI时代机器学习模型在实际业务系统中愈发无处不在,模型的安全性也变得日渐重要。…

PyTorch 生成对抗网络 01.生成对抗网络

1. 简介 本教程通过一个例子来对 DCGANs 进行介绍。我们将会训练一个生成对抗网络(GAN)用于在展示了许多真正的名人的图片后产生新的名人。 这里的大部分代码来自pytorch/examples中的 dcgan 实现,本文档将对实现进行进行全面 的介绍&#x…

机器学习中火爆的对抗学习是什么,有哪些应用?

1、什么是对抗学习? 机器学习这一技术自出现之始就以优异的性能应用于各个领域。近年来,随着机器学习的快速发展与广泛应用,这一领域更是得到前所未有的蓬勃发展。 目前, 机器学习在计算机视觉、语音识别、自然语言处理等复杂任务中取得了公…

对抗学习概念、基本思想、方法综述

代码实现篇 对抗学习常见方法代码实现篇 对抗学习的基本概念 要认识对抗训练,首先要了解 “对抗样本”,在论文 Intriguing properties of neural networks 之中有关于对抗样本的阐述。简单来说,它是指对于人类来说 “看起来” 几乎一样&am…

【Python】MySQL数据库(安装MySQL、创建数据库、在Python中使用MySQL数据库)

MySQL是一个小巧的多用户、多线程SQL数据库服务器。MySQ是以客户机/服务器结构来实现的,它由一个服务器守护进程和客户程序组成。在Python中,可以使用pymysql模块连接到数据库,对MySQL数据库进行操作。 本文内容: 一、安装MySQL…

Python结合MySQL数据库编写简单信息管理系统

1,项目整体逻辑及使用工具 1.1 项目整体逻辑 本项目主要是使用Python进行编写,利用Python中的pymysql库进行连接数据库,将信息存入MySQL数据库中,然后实现对信息进行增删改查等一系列操作。 1.2 使用工具 (1&#…

十四、python学习之MySQL数据库(一):安装MySQL数据库

一、数据库概述: 1.数据库概述: 数据库是在数据管理和程序开发过程中,一种非常重要的数据管理软件,通过数据库,可以非常方便的对数据进行管理操作。 2.什么是数据: 数据用来描述事物的特征,…

Python操作MySQL库结(MySQL详细下载、安装、操控及第三方库中的使用)

​ 活动地址:CSDN21天学习挑战赛 学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。 学习日记(4) 目录 学习日记(4) 一、下载和安装MySQL 1、下载MySQL 2…

阿里云钉钉应用python后端开发之安装MySQL数据库

阿里云钉钉应用python后端开发之安装mysqlclient 在本系列文章中,项目需要选择MySQL作为默认数据库。 本篇为在Windows上安装mysqlclient。 在python后端开发中,可以选择的数据库有PostgreSQL, MariaDB, MySQL, or Oracle等,一般情况下&…

使用python对mysql数据库进行添加数据的操作

使用python连接mysql进行添加数据的操作 使用的是python3.6pymysql 1、导入pymysql,并创建数据库连接 import pymysql# 使用python连接mysql数据库,并对数据库添加数据的数据的操作 # 创建连接,数据库主机地址 数据库用户名称 密码 数据库…

基于PYTHON语言的工资管理系统制作(一)--MYSQL数据库的下载和安装

去官网下载MySQL Community Server社区免费版,网址如下:MySQL :: Download MySQL Community Serverhttps://dev.mysql.com/downloads/mysql/ 因为我的开发环境是WINDOWS64位操作系统,所以我选了Windows版。 下载完毕后直接傻瓜化无脑全部安装…