pytorch nn.Embedding的用法和理解

article/2025/10/20 15:45:15

(2021.05.26补充)nn.Embedding.from_pretrained()的使用:
在这里插入图片描述

>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000,  5.1000,  6.3000]])

首先来看official docs对nn.Embedding的定义:
是一个lookup table,存储了固定大小的dictionary(的word embeddings)。输入是indices,来获取指定indices的word embedding向量。
在这里插入图片描述
!](https://img-blog.csdnimg.cn/20210325180944841.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM5NTQwNDU0,size_16,color_FFFFFF,t_70)
习惯性地,(1)把从单词到索引的映射存储在word_to_idx的字典中。(2)索引embedding表时,必须使用torch.LongTensor(因为索引是整数)

官方文档的示例:

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],[-0.6431,  0.0748,  0.6969],[ 1.4970,  1.3448, -0.9685],[-0.3677, -2.7265, -0.1685]],[[ 1.4970,  1.3448, -0.9685],[ 0.4362, -0.4004,  0.9400],[-0.6431,  0.0748,  0.6969],[ 0.9124, -2.3616,  1.1151]]])

我不太懂的是定义完nn.Embedding(num_embeddings-词典长度,embedding_dim-向量维度)之后,为什么就可以直接使用embedding(input)进行输入。
我们来仔细看看:

>>> embedding = nn.Embedding(10, 3)      

构造一个(假装)vocab size=10,每个vocab用3-d向量表示的table

>>> embedding.weight
Parameter containing:                   
tensor([[ 1.2402, -1.0914, -0.5382],[-1.1031, -1.2430, -0.2571],[ 1.6682, -0.8926,  1.4263],[ 0.8971,  1.4592,  0.6712],[-1.1625, -0.1598,  0.4034],[-0.2902, -0.0323, -2.2259],[ 0.8332, -0.2452, -1.1508],[ 0.3786,  1.7752, -0.0591],[-1.8527, -2.5141, -0.4990],[-0.6188,  0.5902, -0.0860]], requires_grad=True)

可以看做每行是一个词汇的向量表示!

>>> embedding.weight.size
torch.Size([10, 3])           

和nn.Embedding处的定义一致

>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> input                     
tensor([[1, 2, 4, 5],[4, 3, 2, 9]])

牢记:input是indices

>>> input.shape              
torch.Size([2, 4])

Input size表示这批有2个句子,每个句子由4个单词构成

>>> a = embedding(input)      
>>> a
tensor([[[-1.1031, -1.2430, -0.2571],     [ 1.6682, -0.8926,  1.4263],    [-1.1625, -0.1598,  0.4034],[-0.2902, -0.0323, -2.2259]],[[-1.1625, -0.1598,  0.4034],[ 0.8971,  1.4592,  0.6712],[ 1.6682, -0.8926,  1.4263],[-0.6188,  0.5902, -0.0860]]], grad_fn=<EmbeddingBackward>)

a=embedding(input)是去embedding.weight中取对应index的词向量!
看a的第一行,input处index=1,对应取出weight中index=1的那一行。其实就是按index取词向量!

>>> a.size()
torch.Size([2, 4, 3])

取出来之后变成了2*4*3的张量。

终于弄懂了,爽了


http://chatgpt.dhexx.cn/article/IVO30pbI.shtml

相关文章

EMBEDDING层作用

embedding层作用&#xff1a;①降维②对低维的数据进行升维时&#xff0c;可能把一些其他特征给放大了&#xff0c;或者把笼统的特征给分开了。 Embedding其实就是一个映射&#xff0c;从原先所属的空间映射到新的多维空间中&#xff0c;也就是把原先所在空间嵌入到一个新的空…

彻底理解embedding

本文转载自https://blog.csdn.net/weixin_42078618/article/details/84553940&#xff0c;版权问题请联系博主删除 首先&#xff0c;我们有一个one-hot编码的概念。 假设&#xff0c;我们中文&#xff0c;一共只有10个字。。。只是假设啊&#xff0c;那么我们用0-9就可以表示…

深度学习中的embedding

整理翻译自google developer的机器学习入门课程&#xff0c;介绍了embedding的应用方式和如何计算embedding&#xff0c;后面还配有通过tensorflow DNN训练embedding练习加深理解。 分类输入数据(Categorical Input Data) 分类数据是指表示来自有限选择集的一个或多个离散项的…

【文本分类】深入理解embedding层的模型、结构与文本表示

[1] 名词理解 embedding层&#xff1a;嵌入层&#xff0c;神经网络结构中的一层&#xff0c;由embedding_size个神经元组成&#xff0c;[可调整的模型参数]。是input输入层的输出。 词嵌入&#xff1a;也就是word embedding…根据维基百科&#xff0c;被定义为自然语言处理NLP中…

用万字长文聊一聊 Embedding 技术

作者&#xff1a;qfan&#xff0c;腾讯 WXG 应用研究员 随着深度学习在工业届不断火热&#xff0c;Embedding 技术便作为“基本操作”广泛应用于推荐、广告、搜索等互联网核心领域中。Embedding 作为深度学习的热门研究方向&#xff0c;经历了从序列样本、图样本、再到异构的多…

Embedding技术

1、Embedding 是什么 Embedding是用一个低维稠密的向量来“表示”一个对象&#xff08;这里的对象泛指一切可推荐的事物&#xff0c;比如商品、电影、音乐、新闻等&#xff09;&#xff0c;同时表示一词意味着Embedding能够表达相应对象的某些特征&#xff0c;同时向量之间的距…

什么是embedding?

本文转自&#xff1a;https://www.jianshu.com/p/6c977a9a53de    简单来说&#xff0c;embedding就是用一个低维的向量表示一个物体&#xff0c;可以是一个词&#xff0c;或是一个商品&#xff0c;或是一个电影等等。这个embedding向量的性质是能使距离相近的向量对应的物体…

Pairwise-ranking loss代码实现对比

Multi-label classification中Pairwise-ranking loss代码 定义 在多标签分类任务中&#xff0c;Pairwise-ranking loss中我们希望正标记的得分都比负标记的得分高&#xff0c;所以采用以下的形式作为损失函数。其中 c c_ c​是正标记&#xff0c; c − c_{-} c−​是负标记。…

【论文笔记】API-Net:Learning Attentive Pairwise Interaction for Fine-Grained Classification

API-Net 简介创新点mutual vector learning&#xff08;互向量学习&#xff09;gate vector generation&#xff08;门向量生成器&#xff09;pairwise interaction&#xff08;成对交互&#xff09; 队构造&#xff08;Pair Construction&#xff09;实验结果总结 简介 2020年…

白话点云dgcnn中的pairwise_distance

点云DGCNN中对于代码中pairwise_distance的分析与理解 2021年5月7日&#xff1a;已经勘误&#xff0c;请各位大佬不惜赐教。 一点一点读&#xff0c;相信我&#xff0c;我能讲清楚。 这个是本篇文章所要讨论的代码段 总体上把握&#xff0c;这个代码计算出了输入点云每对点之…

推荐系统[四]:精排-详解排序算法LTR (Learning to Rank): poitwise, pairwise, listwise相关评价指标,超详细知识指南。

搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排)、系统架构、常见问题、算法项目实战总结、技术细节以及项目实战(含码源) 专栏详细介绍:搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排)、系统架构、常见问题、算法项目实战总结、技术…

【torch】torch.pairwise_distance分析

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 记录torch.pairwise_distance 1. 一维 1.1 元素个数相同 1.1.1 元素个数为1 生成代码: t torch.randn(1) f torch.randn(1)计算代码&#xff0c;下…

pairwise损失_triplet损失_提升精排模型的trick

01标签 import torch import torch.nn as nn# 输入x是一个二维张量&#xff0c;每一行表示一个样本的分数&#xff0c;每一列表示一个特征或维度 x torch.tensor([[0.5, 0.7], [0.9, 0.8], [0.6, 0.4], [0.3, 0.6], [0.8, 0.7], [0.4, 0.5]])# 标签y是一个一维张量&#xff0c…

LTR (Learning to Rank): 排序算法 poitwise, pairwise, listwise常见方案总结

目录 1 Learing to Rank介绍2 The Pointwise Approach3 The Pairwise Approach3.1 RankNet 4 The Listwise Approach4.1 直接优化评测指标4.1.1 LambdaRank4.1.2 LambdaMART 4.2 定义Listwise损失函数4.2.1 ListNet4.2.2 ListMLE 5 排序评估指标5.1 Mean Reciprocal Rank (MRR)…

【论文精读】Pairwise learning for medical image segmentation

Published in: Medical Image Analysis 2020 论文&#xff1a;https://www.sciencedirect.com/science/article/abs/pii/S1361841520302401 代码&#xff1a;https://github.com/renzhenwang/pairwise_segmentation 目录 Published in: Medical Image Analysis 2020 摘要 一…

pairwise相似度计算

做了一个比赛&#xff0c;其中为了更好的构建负样本&#xff0c;需要计算不同句子之间的相似性&#xff0c;句子大概有100w&#xff0c;句子向量是300维&#xff0c;中间踩了很多坑&#xff0c;记录一下。 暴力计算 最简单的idea是预分配一个100w x 100w的矩阵&#xff0c;一…

如何计算 Pairwise correlations

Pairwise Correlation的定义是啥&#xff1f;配对相关性&#xff1f;和pearson correlations有什么区别&#xff1f; Pairwise Correlation顾名思义&#xff0c;用来计算两个变量间的相关性&#xff0c;而pearson correlations只是计算相关性的一种方法罢了。 1、pearson相关系…

再谈排序算法的pairwise,pointwise,listwise

NewBeeNLP干货 作者&#xff1a;DOTA 大家好&#xff0c;这里是 NewBeeNLP。 最近因为工作上的一些调整&#xff0c;好久更新文章和个人的一些经验总结了&#xff0c;下午恰好有时间&#xff0c;看了看各渠道的一些问题和讨论&#xff0c;看到一个熟悉的问题&#xff0c;在这…

【推荐】pairwise、pointwise 、 listwise算法是什么?怎么理解?主要区别是什么?

写在前面&#xff1a;写博客当成了学习笔记&#xff0c;容易找到去完善&#xff0c;不用于商业用途。通过各种途径网罗到知识汇总与此&#xff0c;如有侵权&#xff0c;请联系我&#xff0c;我下掉该内容~~ 排序学习的模型通常分为单点法&#xff08;Pointwise Approach&#…

软件测试用例设计之Pairwise算法

Pairwise算法简介 Pairwise是L. L. Thurstone(29 May1887 – 30 September 1955)在1927年首先提出来的。他是美国的一位心理统计学家。Pairwise也正是基于数学统计和对传统的正交分析法进行优化后得到的产物。 测试过程中&#xff0c;对于多参数参数多值的情况进行测试用例组…