LSTM -长短期记忆网络(RNN循环神经网络)

article/2025/9/22 12:11:38

文章目录

      • 基本概念及其公式
        • 输入门、输出门、遗忘门
        • 候选记忆元
        • 记忆元
        • 隐状态
      • 从零开始实现 LSTM
        • 初始化模型参数
        • 定义模型
        • 训练和预测
      • 简洁实现
      • 小结

基本概念及其公式

LSTM,即(long short-term Memory)长短期记忆网络,也是RNN循环神经网络的一种改进方法,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的,在NLP领域具有很重要的作用。

LSTM 模型同 GRU 模型思想相像,也是依靠逻辑门思想来试图解决序列依赖的一种方法。不过 LSTM 的逻辑门实现方法与 GRU 模型有所不同。

LSTM 模型中共需要以下的逻辑门与记忆信息:

  • 输入门 I t I_{t} It
  • 输出门 O t O_{t} Ot
  • 遗忘门 F t F_{t} Ft
  • 候选记忆元 C t c a n C_{t}^{can} Ctcan
  • 记忆元 C t C_t Ct
  • 隐状态 H t H_{t} Ht

通过对以上参数进行组合即可实现LSTM模型,以下将详解这6个逻辑门与记忆信息的功能与计算方法。

输入门、输出门、遗忘门

如同在门控循环单元 GRU 中一样, 当前时间步的输入前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如下图所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在 ( 0 , 1 ) (0, 1) (0,1) 的范围内。

在这里插入图片描述

输入门、遗忘门和输出门的计算公式:

{ 输 入 门 I t = σ ( X i W x i + H t − 1 W h i + b i ) 输 出 门 O t = σ ( X i W x o + H t − 1 W h o + b o ) 遗 忘 门 F t = σ ( X i W x f + H t − 1 W h f + b f ) \begin {cases} 输入门 \quad I_{t} = \sigma(X_{i}W_{xi} + H_{t-1}W_{hi} + b_{i}) \\ 输出门 \quad O_{t} = \sigma(X_{i}W_{xo} + H_{t-1}W_{ho} + b_{o}) \\ 遗忘门 \quad F_{t} = \sigma(X_{i}W_{xf} + H_{t-1}W_{hf} + b_{f}) \\ \end {cases} It=σ(XiWxi+Ht1Whi+bi)Ot=σ(XiWxo+Ht1Who+bo)Ft=σ(XiWxf+Ht1Whf+bf)

候选记忆元

由于还没有指定各种门的操作,所以先介绍候选记忆元(candidate memory cell)$\tilde{\mathbf{C}}_t。它的计算与上面描述的三个门的计算类似,但是使用 tanh ⁡ \tanh tanh 函数作为激活函数,函数的值范围为 ( − 1 , 1 ) (-1, 1) (1,1) 。下面导出在时间步 t t t 处的方程:

候 选 记 忆 元 C ~ t = tanh ⁡ ( X i W x c + H i − 1 W h c + b c ) 候选记忆元 \quad \tilde{C}_{t} = \tanh(X_{i}W_{xc} + H_{i-1}W_{hc} + b_{c}) C~t=tanh(XiWxc+Hi1Whc+bc)

候选记忆元的图示如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hB076I5Z-1666138018315)(attachment:QQ%E6%88%AA%E5%9B%BE20221018220816.png)]

记忆元

在门控循环单元 GRU 中,有一种机制来控制输入和遗忘(或跳过)。类似地,在长短期记忆网络中,也有两个门用于这样的目的:输入门 I t \mathbf{I}_t It控制采用多少来自 C ~ t \tilde{\mathbf{C}}_t C~t的新数据,而遗忘门 F t \mathbf{F}_t Ft控制保留多少过去的记忆元 C t − 1 \mathbf{C}_{t-1} Ct1的内容。 使用按元素乘法,得出:

C t = F t ⊙ C t − 1 + I t ⊙ C ~ t C_{t} = F_{t} \odot C_{t-1} + I_{t} \odot \tilde{C}_{t} Ct=FtCt1+ItC~t

如果遗忘门始终为 1 1 1且输入门始终为 0 0 0,则过去的记忆元 C t − 1 \mathbf{C}_{t-1} Ct1将随时间被保存并传递到当前时间步。引入这种设计是为了缓解梯度消失问题,
并更好地捕获序列中的长距离依赖关系。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fcLlSgMQ-1666138018316)(attachment:QQ%E6%88%AA%E5%9B%BE20221018221445.png)]

隐状态

最后,我们定义如何计算隐状态 H t \mathbf{H}_t Ht,这就是输出门发挥作用的地方。在长短期记忆网络中,它仅仅是记忆元的 tanh ⁡ \tanh tanh的门控版本。这就确保了 H t \mathbf{H}_t Ht的值始终在区间 ( − 1 , 1 ) (-1, 1) (1,1)内:

H t = O t ⊙ t a n h ( C t ) H_{t} = O_{t} \odot tanh(C_{t}) Ht=Ottanh(Ct)

只要输出门接近 1 1 1,我们就能够有效地将所有记忆信息传递给预测部分,而对于输出门接近 0 0 0,我们只保留记忆元内的所有信息,而不需要更新隐状态。

其图示如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EJf98ogf-1666138018317)(attachment:QQ%E6%88%AA%E5%9B%BE20221018221913.png)]

从零开始实现 LSTM

现在,我们从零开始实现长短期记忆网络。 与之前 RNN 模型的实验相同, 我们首先加载时光机器的数据集(目的是通过训练能够自动补全句子)。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35             #批量大小32,序列步数35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

我们现在来定义和初始化模型参数。如之前一致,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差 0.01 的高斯分布初始化权重,并将偏置项设为 0。

def get_lstm_params(vocab_size, num_hiddens, device):#输入和输出一致num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))#逻辑门参数W_xi, W_hi, b_i = three()                #输入门参数W_xf, W_hf, b_f = three()                #遗忘门参数W_xo, W_ho, b_o = three()                #输出门参数W_xc, W_hc, b_c = three()                #候选记忆元参数#输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)#列表存储参数信息params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]#增加梯度for param in params:param.requires_grad_(True)return params

定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元 C t C_t Ct不直接参与输出计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:#逻辑门相关计算I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)#候选记忆元计算C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)#记忆元计算C = F * C + I * C_tilda#隐状态计算H = O * torch.tanh(C)#输出结果计算Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化之前引入的RNNModelScratch类详情查看这里(RNN从零开始实现)来训练一个长短期记忆网络, 就如我们在之前所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 8249.8 tokens/sec on cpu
time traveller for so it will be convenient to speak one wroch a
travelleryou can show black is white by argument said filby

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jSMF2xW8-1666138018318)(output_31_1.svg)]

简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_sizelstm_layer = nn.LSTM(num_inputs, num_hiddens)      #定义LSTM层
model = d2l.RNNModel(lstm_layer, len(vocab))       #定义RNN模型
model = model.to(device)
#训练模型
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 7963.8 tokens/sec on cpu
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dP4nTR0Q-1666138018319)(output_34_1.svg)]

小结

1、长短期记忆网络有三种类型的门:输入门、遗忘门和输出门

2、长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

3、长短期记忆网络可以缓解梯度消失梯度爆炸


http://chatgpt.dhexx.cn/article/pw1LSo32.shtml

相关文章

机器学习之LSTM的Python实现

什么是LSTM? LSTM(长短期记忆人工神经网络),是一种可以学习长期依赖特殊的RNN(循环神经网络)。 传统循环网络RNN虽然可以通过记忆体,实现短期记忆,进行连续数据的预测。但是当连续…

神经网络:LSTM基础学习

1、LSTM简介 在时间序列数据学习中,传统的循环神经网络(RNN)存在较多的学习瓶颈和技术缺陷,而长短时记忆(LSTM)神经网络克服了循环神经网络的缺陷,使其在长时间序列数据学习训练中能克服梯度爆炸…

Lstm(循环神经网络)

算法模型Lstm(循环神经网络): 简介 LSTM和RNN相似,它们都是在前向传播的过程中处理流经细胞的数据,不同之处在于 LSTM 中细胞的结构和运算有所变化。 LSTM结构: 遗忘门: 遗忘门的功能是决定应丢弃或保留哪些信息。…

基于MATLAB的LSTM神经网络时序预测

参考博客及文献:4 Strategies for Multi-Step Time Series Forecasting Multivariate Time Series Forecasting with LSTMs in Keras (machinelearningmastery.com) LSTM进阶:使用LSTM进行多维多步的时间序列预测_lstm多维多部预测_一只小EZ的博客-CSD…

LSTM神经网络图解

LSTM神经网络图详解 (1)遗忘门,用于计算信息的遗忘(保留)程度,通过sigmoid处理后为0到1的值,1表示全部保留,0表示全部忘记。 f t σ ( W f ⋅ [ h t − 1 , x t ] b f ) f_{t}\si…

【神经网络】LSTM

1.什么是LSTM 长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,相比普通的RNN,LSTM能够在更长的序列中有更好的表现。 LSTM区别于RNN地方…

[深入浅出] LSTM神经网络

由来 人类并不是每时每刻都从一片空白的大脑开始他们的思考。在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义。我们不会将所有的东西都全部丢弃,然后用空白的大脑进行思考。我们的思想拥有持久性。 传统的神经…

简单理解LSTM神经网络

递归神经网络 在传统神经网络中,模型不会关注上一时刻的处理会有什么信息可以用于下一时刻,每一次都只会关注当前时刻的处理。举个例子来说,我们想对一部影片中每一刻出现的事件进行分类,如果我们知道电影前面的事件信息&#xf…

LSTM神经网络

LSTM被广泛用于许多序列任务(包括天然气负荷预测,股票市场预测,语言建模,机器翻译),并且比其他序列模型(例如RNN)表现更好,尤其是在有大量数据的情况下。 LSTM经过精心设…

(神经网络深度学习)--循环神经网络LSTM

一、什么是LSTM: 如果你经过上面的文章看懂了RNN的内部原理,那么LSTM对你来说就很简单了,首先大概介绍一下LSTM,是四个单词的缩写,Long short-term memory,翻译过来就是长短期记忆,是RNN的一种…

机器学习——人工神经网络模型LSTM

LSTM的学习 学习目标: 1理解什么是人工神经网络。2深入理解LSTM(长短期记忆网络)3Code 浅析人工神经网络: 在谈人工神经网络模型之前我们先来了解一下生理上的神经网络。 下面是一张对比图: Neural Science Compute…

LSTM神经网络详解

LSTM 长短时记忆网络(Long Short Term Memory Network, LSTM),是一种改进之后的循环神经网络,可以解决RNN无法处理长距离的依赖的问题,目前比较流行。 长短时记忆网络的思路: 原始 RNN 的隐藏层只有一个状态,即h&am…

LSTM神经网络介绍

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 引言一、介绍1.1什么是LSTM?1.2基础知识1.2.1门控机制1.2.2 相关激活函数1.2.3网络参数介绍 二、LSTM网络架构2.1架构图 三、LSTM的门3.1遗忘门3.2输入门3.3输出门…

大白话理解LSTM神经网络(附实例讲解)

前言 本文章为个人学习笔记整理,所学习的内容来自b站up主老弓的学习日记,附有实例讲解。 归类 长短期记忆神经网络(LSTM)是一种特殊的循环神经网络(RNN)。原始的RNN在训练中,随着训练时间的加长以及网络层数的增多&a…

Idea如何导入一个SpringBoot项目

最近公司要求开发工具要用Idea,作为一个eclipse的老员工,记录一下Idea中遇到的坑 刚开始用Idea从Git上导入一个项目时,遇到了很多坑,网上有很多方法,我不多做介绍。只说明一下我使用的方法。 1.本地新建一个文件夹&a…

idea导入项目框架的方法

学习时,使用IDEA的时候,经常需要导入项目框架,下面操作介绍如何导入项目框架。 打开需要导入的项目 打开方式: 打开 idea ,选择 Import Project 也可以进入idea后,选择 Flie --> New --> Project …

IDEA导入Eclipse项目

背景:用习惯了idea再去用eclipse实在用的不习惯,于是将老的eclipse项目导入到eclipse,网上有很多教程,看了很多博客都不行,一直报错,各种报错,现在终于好了,我们一起来看看怎么将ecl…

关于新版idea如何导入项目

现今有很多同学都发现idea怎么找不到import project这个按钮了,我也遇到了这个问题,经过研究发现,之前使用import project最关键还是在于project form Existing Sources。 而就在打开项目后,File-->New-->Project form Exi…

idea导入项目后没有被识别为maven项目的解决办法

开发中遇到了idea导入项目后没有被识别为maven项目,使用下面方法即可 1、首先点击工具栏最左边的 Help 再点击 Find Action ;或者使用快捷键 CtrlShiftA 2、接着在输入框中输入 maven projects ,会弹出一个 Add Maven Projects 选项&#xf…

IDEA导入web项目并启动

导入项目 依次点击idea左上角的File->Project Structure->project 修改SDK、Language level,选择自己电脑对应的jdk版本,为web的运行提供jdk的环境 第二步,依次点击Facts->Web 点击Artifacts->Web Application:Exploded->…