Attention可视化

article/2025/10/15 0:10:17

Attention matrix:

https://github.com/rockingdingo/deepnlp/blob/r0.1.6/deepnlp/textsum/eval.py
plot_attention(data, X_label=None, Y_label=None)函数

#!/usr/bin/python
# -*- coding:utf-8 -*-"""
Evaluation Method for summarization tasks, including BLUE and ROUGE score
Visualization of Attention Mask Matrix: plot_attention() method
"""from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import sysimport matplotlib
matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt # drawing heat map of attention weights
plt.rcParams['font.sans-serif']=['SimSun'] # set font familyimport timedef evaluate(X, Y, method = "rouge_n", n = 2):score = 0.0if (method == "rouge_n") :score = eval_rouge_n(X, Y, n)elif (method == "rouge_l"):score = eval_rouge_l(X, Y)elif (method == "bleu"):score = eval_bleu(X, Y, n)else:print ("method not found")score = 0.0return scoredef eval_bleu(y_candidate, y_reference, n = 2):'''Args: y_candidate: list of words, machine generated predictiony_reference: list of list, [[], [],], human generated referenced lineReturn:rouge_n score:double, maximum of pairwise rouge-n score'''if (type(y_reference[0]) != list):print ('y_reference should be list of list')returnm = len(y_reference)bleu_score = 0.0ngram_cand = generate_ngrams(y_candidate, n)total_cand_count = len(ngram_cand)ngram_ref_list = [] # list of ngrams for each reference sentencefor i in range(m): ngram_ref_list.append(generate_ngrams(y_reference[i], n))total_clip_count = 0for tuple in set(ngram_cand):# for each unique n-gram tuple in ngram_cand, calculate the clipped countcand_count = count_element(ngram_cand, tuple)max_ref_count = 0 # max_ref_count for this tuple in the references sentencesfor i in range(m): # tuple count in reference sentence inum = count_element(ngram_ref_list[i], tuple)max_ref_count = num if max_ref_count < num else max_ref_count # compare max_ref_count and numtotal_clip_count += min(cand_count, max_ref_count)  bleu_score = total_clip_count/total_cand_countreturn bleu_scoredef count_element(list, element):if element in list:return list.count(element)else:return 0def eval_rouge_n(y_candidate, y_reference, n = 2):'''Args: y_candidate: list of words, machine generated predictiony_reference: list of list, [[], [],], human generated referenced lineReturn:rouge_n score:double, maximum of pairwise rouge-n score'''if (type(y_reference[0]) != list):print ('y_reference should be list of list')returnm = len(y_reference)rouge_score = []ngram_cand = generate_ngrams(y_candidate, n)for i in range(m):ngram_ref = generate_ngrams(y_reference[i], n)num_match = count_match(ngram_cand, ngram_ref)rouge_score.append(num_match/len(ngram_ref))return max(rouge_score)def generate_ngrams(input_list, n):'''zip(x, x[1:,],x[2,],...x[n,]), end with shorted list'''return zip(*[input_list[i:] for i in range(n)])def count_match(listA, listB):match_list = [tuple for tuple in listA if tuple in listB]return len(match_list)def eval_rouge_l(y_candidate, y_reference):'''Args: y_candidate: list of words, machine generated predictiony_reference: list of list, [[], [],], human generated referenced lineReturn:rouge_l score:double, F1 score of longest common sequence'''if (type(y_reference[0]) != list):print ('y_reference should be list of list')returnK = len(y_reference)lcs_count = 0.0total_cand = len(y_candidate) # total of candidate wordstotal_ref = 0.0  # total of reference wordsfor k in range(K):cur_lcs = LCS(y_candidate, y_reference[k])lcs_count += len(cur_lcs)total_ref += len(y_reference[k])recall = lcs_count/total_refprecision = lcs_count/total_candbeta = 8.0 # coefficientf1 = (1 + beta * beta) * precision * recall/(recall + beta * beta * precision)return f1def LCS(X, Y):'''Get the element of longest common sequence'''length, flag = calc_LCS(X, Y)common_seq_rev = [] # reverse sequence# starting from end of X and Ystart_token = "START"X_new = [start_token] + list(X)Y_new = [start_token] + list(Y)i = len(X_new) - 1j = len(Y_new) - 1while(i >= 0 and j >= 0):if (flag[i][j] == 1):common_seq_rev.append(X_new[i])i -= 1j -= 1elif (flag[i][j] == 2):i -= 1   # i -> i-1else:j -= 1   # flag[i][j] == 3, j -> j-1common_seq =[common_seq_rev[len(common_seq_rev) - 1 - i] for i in range(len(common_seq_rev))]return common_seqdef calc_LCS(X, Y):'''Calculate Longest Common SequenceGet the length[][] matrix and flag[][] matrix of X and Y;length[i][j]: longest common sequence length up to X[i] and Y[j];flag[i][j]: path of LCS, (1,2,3) 1: jump diagonal, 2: jump down i-1 ->i, 3: jump right j-1 -> j '''start_token = "START"X_new = [start_token] + list(X) # adding start token to X sequenceY_new = [start_token] + list(Y)m = len(X_new)n = len(Y_new)# starting length and flag matrix size : (m + 1) * (n + 1)length = [[0 for j in range(n)] for i in range(m)]flag = [[0 for j in range(n)] for i in range(m)]for i in range(1, m):for j in range(1, n):if (X_new[i] == Y_new[j]): # compare stringlength[i][j] = length[i-1][j-1] + 1flag[i][j] = 1 # diagonalelse:if (length[i-1][j] > length[i][j-1]):length[i][j] = length[i-1][j]flag[i][j] = 2 # (i-1) -> ielse:length[i][j] = length[i][j-1]flag[i][j] = 3 # (j-1) -> jreturn length, flagdef plot_attention(data, X_label=None, Y_label=None):'''Plot the attention model heatmapArgs:data: attn_matrix with shape [ty, tx], cutted before 'PAD'X_label: list of size tx, encoder tagsY_label: list of size ty, decoder tags'''fig, ax = plt.subplots(figsize=(20, 8)) # set figure sizeheatmap = ax.pcolor(data, cmap=plt.cm.Blues, alpha=0.9)# Set axis labelsif X_label != None and Y_label != None:X_label = [x_label.decode('utf-8') for x_label in X_label]Y_label = [y_label.decode('utf-8') for y_label in Y_label]xticks = range(0,len(X_label))ax.set_xticks(xticks, minor=False) # major ticksax.set_xticklabels(X_label, minor = False, rotation=45)   # labels should be 'unicode'yticks = range(0,len(Y_label))ax.set_yticks(yticks, minor=False)ax.set_yticklabels(Y_label, minor = False)   # labels should be 'unicode'ax.grid(True)# Save Figureplt.title(u'Attention Heatmap')timestamp = int(time.time())file_name = 'img/attention_heatmap_' + str(timestamp) + ".jpg"print ("Saving figures %s" % file_name)fig.savefig(file_name)   # save the figure to fileplt.close(fig)    # close the figuredef test():#strA = "ABCBDAB"#strB = "BDCABA" #m = LCS(strA, strB)#listA = ['但是','我', '爱' ,'吃', '肉夹馍']#listB = ['我', '不是', '很', '爱', '肉夹馍']#m = LCS(listA, listB)y_candidate = ['我', '爱', '吃', '北京', '烤鸭']y_reference = [['我', '爱', '吃', '北京', '小吃', '烤鸭'], ['他', '爱', '吃', '北京', '烤鹅'],['但是', '我', '很','爱', '吃', '西湖', '醋鱼']]p1 = eval_rouge_l(y_candidate, y_reference)print ("ROUGE-L score %f" % p1)p2 = eval_rouge_n(y_candidate, y_reference, 2)print ("ROUGE-N score %f" % p2)p3 = eval_bleu(y_candidate, y_reference, 2)print ("BLEU score %f" % p3)if __name__ == "__main__":test()

self_attention:

https://github.com/kaushalshetty/Structured-Self-Attention/tree/master/visualization

#Credits to Lin Zhouhan(@hantek) for the complete visualization code
import random, os, numpy, scipy
from codecs import open
def createHTML(texts, weights, fileName):"""Creates a html file with text heat.weights: attention weights for visualizingtexts: text on which attention weights are to be visualized"""fileName = "visualization/"+fileNamefOut = open(fileName, "w", encoding="utf-8")part1 = """<html lang="en"><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><style>body {font-family: Sans-Serif;}</style></head><body><h3>Heatmaps</h3></body><script>"""part2 = """var color = "255,0,0";var ngram_length = 3;var half_ngram = 1;for (var k=0; k < any_text.length; k++) {var tokens = any_text[k].split(" ");var intensity = new Array(tokens.length);var max_intensity = Number.MIN_SAFE_INTEGER;var min_intensity = Number.MAX_SAFE_INTEGER;for (var i = 0; i < intensity.length; i++) {intensity[i] = 0.0;for (var j = -half_ngram; j < ngram_length-half_ngram; j++) {if (i+j < intensity.length && i+j > -1) {intensity[i] += trigram_weights[k][i + j];}}if (i == 0 || i == intensity.length-1) {intensity[i] /= 2.0;} else {intensity[i] /= 3.0;}if (intensity[i] > max_intensity) {max_intensity = intensity[i];}if (intensity[i] < min_intensity) {min_intensity = intensity[i];}}var denominator = max_intensity - min_intensity;for (var i = 0; i < intensity.length; i++) {intensity[i] = (intensity[i] - min_intensity) / denominator;}if (k%2 == 0) {var heat_text = "<p><br><b>Example:</b><br>";} else {var heat_text = "<b>Example:</b><br>";}var space = "";for (var i = 0; i < tokens.length; i++) {heat_text += "<span style='background-color:rgba(" + color + "," + intensity[i] + ")'>" + space + tokens[i] + "</span>";if (space == "") {space = " ";}}//heat_text += "<p>";document.body.innerHTML += heat_text;}</script></html>"""putQuote = lambda x: "\"%s\""%xtextsString = "var any_text = [%s];\n"%(",".join(map(putQuote, texts)))weightsString = "var trigram_weights = [%s];\n"%(",".join(map(str,weights)))fOut.write(part1)fOut.write(textsString)fOut.write(weightsString)fOut.write(part2)fOut.close()return

在这里插入图片描述


http://chatgpt.dhexx.cn/article/PD20k9x4.shtml

相关文章

Attention机制

文章目录 一、Attention机制是什么&#xff1f;二、推荐论文与链接三、self-attention 一、Attention机制是什么&#xff1f; Attention机制最早在视觉领域提出&#xff0c;九几年就被提出来的思想&#xff0c;真正火起来应该算是2014年Google Mind发表了《Recurrent Models o…

Attention详解

1.背景知识 Seq2Seq模型&#xff1a;使用两个RNN&#xff0c;一个作为编码器&#xff0c;一个作为解码器。 编码器&#xff1a;将输入数据编码成一个特征向量。 解码器&#xff1a;将特征向量解码成预测结果。 缺点&#xff1a;只将编码器的最后一个节点的结果进行了输出&am…

浅析NLP中的Attention技术

Attention&#xff08;注意力机制&#xff09;在NLP、图像领域被广泛采用&#xff0c;其显而易见的优点包括&#xff1a; &#xff08;1&#xff09;从context中捕捉关键信息&#xff1b; &#xff08;2&#xff09;良好的可视性和可解释性。 我们常用QKV模型来理解Attention&…

Attention 机制

文章目录 Attention 的本质是什么Attention 的3大优点Attention 的原理Attention 的 N 种类型 转载来源&#xff1a;https://easyai.tech/ai-definition/attention/ Attention 正在被越来越广泛的得到应用。尤其是 BERT 火爆了之后。 Attention 到底有什么特别之处&#xff1f…

详解Transformer中Self-Attention以及Multi-Head Attention

原文名称&#xff1a;Attention Is All You Need 原文链接&#xff1a;https://arxiv.org/abs/1706.03762 如果不想看文章的可以看下我在b站上录的视频&#xff1a;https://b23.tv/gucpvt 最近Transformer在CV领域很火&#xff0c;Transformer是2017年Google在Computation an…

Attention 一综述

近年来&#xff0c;注意力&#xff08;Attention&#xff09;机制被广泛应用到基于深度学习的自然语言处理(NLP)各个任务中。随着注意力机制的深入研究&#xff0c;各式各样的attention被研究者们提出&#xff0c;如单个、多个、交互式等等。去年6月&#xff0c;google机器翻译…

从Attention到Bert——1 Attention解读

下一篇从Attention到Bert——2 transformer解读 文章目录 1 Attention的发展历史2015-2017年 2 Attention的原理3 Multi-Head Attention4 Self-Attention为什么需要self-attention什么是self-attention 5 Position Embedding 最早&#xff0c;attention诞生于CV领域&#xff0…

Attention UNet

Attention UNet论文解析 - 知乎Attention UNet论文地址&#xff1a; https://arxiv.org/pdf/1804.03999.pdf 代码地址&#xff1a; https://github.com/ozan-oktay/Attention-Gated-NetworksAttention UNet在UNet中引入注意力机制&#xff0c;在对编码器每个分辨率上的特征与解…

attention

文章目录 Attention基本的Attention原理参考 Hierarchical Attention原理实践参考 Self Attentionother Attention Attention Attention是一种机制&#xff0c;可以应用到许多不同的模型中&#xff0c;像CNN、RNN、seq2seq等。Attention通过权重给模型赋予了区分辨别的能力&am…

史上最小白之Attention详解

1.前言 在自然语言处理领域&#xff0c;近几年最火的是什么&#xff1f;是BERT&#xff01;谷歌团队2018提出的用于生成词向量的BERT算法在NLP的11项任务中取得了非常出色的效果&#xff0c;堪称2018年深度学习领域最振奋人心的消息。而BERT算法又是基于Transformer&#xff0…

一文看懂 Attention(本质原理+3大优点+5大类型)

Attention 正在被越来越广泛的得到应用。尤其是 BERT 火爆了之后。 Attention 到底有什么特别之处&#xff1f;他的原理和本质是什么&#xff1f;Attention都有哪些类型&#xff1f;本文将详细讲解Attention的方方面面。 Attention 的本质是什么 Attention&#xff08;注意力&a…

史上最直白之Attention详解(原理+代码)

目录 为什么要了解Attention机制Attention 的直观理解图解深度学习中的Attention机制总结 为什么要了解Attention机制 在自然语言处理领域&#xff0c;近几年最火的是什么&#xff1f;是BERT&#xff01;谷歌团队2018提出的用于生成词向量的BERT算法在NLP的11项任务中取得了非常…

关于Attention的超详细讲解

文章目录 一、动物的视觉注意力二、快速理解Attention思想三、从Encoder-Decoder框架中理解为什么要有Attention机制四、Attention思想步骤五、Self-Attention5.1 Self-Attention的计算步骤5.2 根据代码进一步理解Q、K、V5.3 再来一个例子理解 六、缩放点积中为什么要除以根号d…

浅聊古代————汉朝

文章目录 西汉西汉建立&#xff1a;汉高祖刘邦西汉灭亡&#xff1a;王莽篡汉新朝建立&#xff1a;王莽篡汉改新新朝灭亡&#xff1a;绿林赤眉起义 东汉东汉建立&#xff1a;起义军刘秀东汉灭亡&#xff1a;曹丕篡汉 西汉 西汉建立&#xff1a;汉高祖刘邦 刘邦被项羽封为汉王。…

XNU简介

XNU内核比较庞大复杂&#xff0c;因此这里只是进行主要内容的介绍 XNU据说是一个无限递归的缩写&#xff1a;XNU’s Not UNIX XNU内核是Mac和iOS的核心&#xff0c;有三个主要部分组成的一个分层体系结构&#xff1b;内核XNU是Darwin的核心&#xff0c;也是整个OS X的核心。 …

函数

1.编写函数&#xff0c;完成下面的程序&#xff0c;将一个字符串中的字母排序后输出。输出要求&#xff1a;字母从小到大排序&#xff08;包括大小写&#xff09;后的字符串。例如: 输入&#xff1a;Hello World! 输出&#xff1a;!HWdellloor 程序中的必要代码为: main() …