【用pytorch进行LSTM模型的学习】

article/2025/10/2 12:57:30

用pytorch进行LSTM模型的学习

  • LSTM模型
  • 用pytorch,采用LSTM对seaborn数据集做预测
    • 基本步骤
      • 数据的观察
      • 特殊数据处理
      • 数据归一化
      • 模型的构建与选择
      • 模型的保存
  • 飞机航班流量预测示例

LSTM模型

LSTM模型长下面这样,主要用在时间序列的预测,具有比RNN较好的性能。原因在于内部增加了很多门,用来控制前序信息的继续、遗忘、更新等,比RNN更好的表达了特征。
在这里插入图片描述

用pytorch,采用LSTM对seaborn数据集做预测

基本步骤

一般而言,进行深度学习的训练与应用包含大概如下步骤

=========工作流程=========- 数据读取与基本处理* 数据集读取* 数据的观察-画图* 特殊数据处理-空值、奇异值等- 数据集构建* 归一化* 训练集、验证集、测试集划分    - 模型建模* 基础模型架构* 损失函数* 优化器选择- 模型训练* 模型训练 与各种超参* 训练过程观察 * 训练中模型保存* 模型训练指标记录- 测试验证* 模型性能验证* 结果可视化* 测试性能指标记录

下面就流程中的几个重点进行说明

数据的观察

在拿到数据的时候,我们首先要对数据进行观察,观察的方法根据数据的类型略有不同,但是总体可以概括为

  • 肉眼观察:打开数据文件夹或者文件进行查看,比如文件个数有多少个,数据的大小是多少。
  • 数据展示观察:对于一些不好直接观察的,可以通过数据展示看一下,如打印dataframe结构的前几行,可以看到列名等信息,方便数据处理。
  • 画图观察:对于一些时序信息,可以通过作图的方式,看看数据的分布情况,是否有异常点等等。

为什么要对数据进行观察?主要有以下几个原因

  • 获取数据的基本信息,知道我们要处理的数据大概是怎样的。
  • 对原始数据有个感觉,数据的情况可能会影响我们模型的选择。以及模型训练的策略。比如小样本数据,样本数的多少会影响下一步的决策,如是否数据增强,是否迁移等等。
  • 观察到异常情况,如空值,奇异点,为下一步数据处理做准备。

特殊数据处理

机器学习处理的是数据的一般情况,即反映数据的一般规律和一般分布,对于奇异值或者特殊值,机器学习模型没有能力处理或者需要付出很大的代价才能处理。机器学习是帮助我们解决一般问题或者共性问题,对于一些特殊的问题,并不是这个学科的主要研究方向。当然,只有一个方向除外,即异常检测。
一般需要特殊处理的,有空值、错误值、奇异值。基本的处理方式有

  • 删除,即删除特殊值
  • 补全,补全空值
  • 修正,更改错误值

数据归一化

在一般情况下,尤其是时序数据,需要进行归一化,即把数据压缩到0-1之间。目的是使得数据有相同的尺度。例如,在一个数据集中,包含样本的年龄信息,收入信息等,这两个信息的度量尺度是不同的,如果不做归一化,那么由于年龄与收入在数值上相差很大,那么年龄的特征不能在模型中发挥很好的作用。

模型的构建与选择

针对不同的任务选择不同的模型,有pytorch内置了很多基础模型,因此模型结构的构建变得简单容易,需要注意的是模型的输入参数要求以及维度匹配,这就需要我们学习pytorch内置模型的接口函数,做一个合格的调包侠

模型的保存

在训练过程中,模型是不断更新的,每一次迭代后模型的参数就会不同。在这个过程中有必要有条件地保存下当前模型,主要有如下几个用途

  • 防止训练突然崩掉,重新训练浪费资源。在较长时间的训练过程中,由于种种原因,训练可能会崩溃,如突然掉电,机器故障灯,如果没有保存训练过程中的模型,则需要重新训练,那么浪费时间,浪费资源,尤其是接近训练完成的时候发生崩溃,人就更崩溃了。如果保存了模型,那么可以重新加载模型,断点续训练。
  • 根据过程中保存下来的模型,我们可以查看模型演变过程,进行过程的考察。
  • 测试验证用,保存模型,尤其是保存最后的或者最好的模型,在测试验证时,可以直接加载进行验证,不必再次训练

那么模型该如何保存呢? 模型保存的格式:pytorch中最常见的模型保存使用 .pt 或者是 .pth 作为模型文件扩展名。

pytorch模型保存的两种方式:

  • 一种是保存整个模型,
torch.save(model, "my_model.pth") # 保存整个模型` 
  • 另一种是只保存模型的参数,该方法速度快,占用空间少
torch.save(model.state_dict(), "my_model.pth") # 只保存模型的参数

相应的,加载也有两种方式

  • 加载整个模型
new_model = torch.load(PATH) 
  • 先构架模型架构,然后加载参数
new_model = Model()                          
new_model.load_state_dict(torch.load(PATH))   

飞机航班流量预测示例

完整代码如下

# -*- coding: utf-8 -*-
# @Time    : 2023/03/10 10:23
# @Author  : HelloWorld!
# @FileName: seq.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.8import torch
import torch.nn as nn
import argparse
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math# 数据读取与基本处理
class LoadData:def __init__(self,data_path ):self.ori_data = pd.read_csv(data_path)def data_observe(self):self.ori_data.head()self.draw_data(self.ori_data)def draw_data(self, data):print(data.head())fig_size = plt.rcParams["figure.figsize"]fig_size[0] = 15fig_size[1] = 5plt.rcParams["figure.figsize"] = fig_sizeplt.title('Month vs Passenger')plt.ylabel('Total Passengers')plt.xlabel('Months')plt.grid(True)plt.autoscale(axis='x', tight=True)plt.plot(data['passengers'])plt.show()#数据预处理,归一化def data_process(self):flight_data = self.ori_data.drop(['year'], axis=1)  # 删除不需要的列flight_data = flight_data.drop(['month'], axis=1)  # 删删除不需要的列flight_data = flight_data.dropna()  # 滤除缺失数据dataset = flight_data.values  # 获得csv的值dataset = dataset.astype('float32')dataset=self.data_normalization(dataset)return datasetdef data_normalization(self,x):'''数据归一化(0,1):param x::return:'''max_value = np.max(x)min_value = np.min(x)scale = max_value - min_valuey = (x - min_value) / scalereturn y#构建数据集,训练集、测试集
class CreateDataSet:def __init__(self, dataset,look_back=2):dataset = np.asarray(dataset)data_inputs, data_target = [], []for i in range(len(dataset) - look_back):a = dataset[i:(i + look_back)]data_inputs.append(a)data_target.append(dataset[i + look_back])self.data_inputs = np.array(data_inputs).reshape((-1, look_back))self.data_target = np.array(data_target).reshape((-1, 1))def split_train_test_data(self, rate=0.7):# 划分训练集和测试集,70% 作为训练集train_size = math.ceil(len(self.data_inputs) * rate)  #math.ceil()向上取整train_inputs = self.data_inputs[:train_size]train_target = self.data_target[:train_size]test_inputs = self.data_inputs[train_size:]test_target = self.data_target[train_size:]return train_inputs, train_target, test_inputs, test_target
# 构建模型
class LSTMModel(nn.Module):''' 定义LSTM模型,由于pytorch已经集成LSTM,直接用即可'''def __init__(self, input_size, hidden_size=4, num_layers=2, output_dim=1):''':param input_size:  输入数据的特征维数,通常就是embedding_dim(词向量的维度):param hidden_size: LSTM中隐层的维度:param num_layers: 循环神经网络的层数:param output_dim:'''super(LSTMModel,self).__init__()self.lstm_layer=nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)self.linear_layer=nn.Linear(hidden_size,output_dim)def forward(self,x):x,_=self.lstm_layer(x)s, b, h = x.shapex = x.view(s * b, h)  # 转换成线性层的输入格式x=self.linear_layer(x)x= x.view(s, b, -1)return x
#模型训练
class Trainer:def __init__(self,args):self.num_epoch =args.num_epochself. look_back=args.look_backself.batch_size=args.batch_sizeself.save_modelpath=args.save_modelpath #保存模型的位置load_data = LoadData(args.filepath)  # 加载数据self.dataset = load_data.data_process()  # 数据预处理dataset = CreateDataSet(self.dataset , look_back=args.look_back)  # 数据集开始构建self.train_inputs,  self.train_target,  self.test_inputs,  self.test_target = dataset.split_train_test_data()  # 拆分数据集为训练集、测试集self.data_inputs = dataset.data_inputs#改变下输入形状self.train_inputs = self.train_inputs.reshape(-1, self.batch_size, self.look_back)self.train_target = self.train_target.reshape(-1, self.batch_size, 1)self.test_inputs = self.test_inputs.reshape(-1, self.batch_size, self.look_back)self.data_inputs = self.data_inputs.reshape(-1, self.batch_size, self.look_back)self.model=self.build_model()self.loss =nn.MSELoss()self.optimizer=torch.optim.Adam(self.model.parameters(), lr=1e-2)def build_model(self):model=LSTMModel(input_size=self.look_back)return  model#训练过程def train(self):#把数据转成torch形式的inputs= torch.from_numpy(self.train_inputs)target=torch.from_numpy(self.train_target)self.model.train() #训练模式#开始训练for epoch in range(self.num_epoch):#前向传播out=self.model(inputs)#计算损失loss=self.loss(out,target)#反向传播self.optimizer.zero_grad()  #梯度清零loss.backward()  #反向传播self.optimizer.step() #更新权重参数if epoch % 100 == 0:  # 每 100 次输出结果print('Epoch: {}, Loss: {:.5f}'.format(epoch, loss.item()))torch.save(self.model,self.save_modelpath+'/model'+str(epoch)+'.pth')torch.save(self.model, self.save_modelpath + '/model_last' +  '.pth')self.test()def test(self,load_model=False):if not load_model:self.model.eval()  # 转换成测试模式inputs = torch.from_numpy(self.data_inputs)# inputs = torch.from_numpy(self.test_inputs)output = self.model(inputs)  # 测试集的预测结果else:model=torch.load(self.save_modelpath+ '/model_last' +  '.pth')inputs = torch.from_numpy(self.data_inputs)# inputs = torch.from_numpy(self.test_inputs)output =model(inputs)  # 测试集的预测结果# 改变输出的格式output = output.view(-1).data.numpy() #把tensor摊平# 画出实际结果和预测的结果plt.plot(output, 'r', label='prediction')plt.plot(self.dataset, 'g', label='real')# plt.plot(self.dataset[1:], 'b', label='real')plt.legend(loc='best')plt.show()if __name__ == '__main__':filepath ='seaborn-data-master/flights.csv'save_modelpath='model-path'parser = argparse.ArgumentParser(description=__doc__)parser.add_argument('--num_epoch',type=int, default=1000, help='训练的轮数' )parser.add_argument('--filepath',type=str, default=filepath, help='数据文件')parser.add_argument('--look_back', type=int, default=2, help='根据前几个数据预测')parser.add_argument('--batch_size', type=int, default=2, help='batch size')parser.add_argument('--save_modelpath',type=str, default=save_modelpath, help='训练中模型要保存的位置')args=parser.parse_args()train=Trainer(args)train.train()train.test(load_model=True)

结果如下
在这里插入图片描述


http://chatgpt.dhexx.cn/article/BEl507Ll.shtml

相关文章

LSTM模型预测新冠

LSTM是RNN的改进型,传统RNN模型会随着时间区间的增长,对早期的因素的权重越来越低,有可能会损失重要数据。而LSTM模型通过遗忘门、输入门、输出门三个逻辑,来筛选和保留数据。 原理详解可以参考如何从RNN起步,一步一步…

LSTM模型结构讲解

人类并不是每时每刻都从一片空白的大脑开始他们的思考。在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义。我们不会将所有的东西都全部丢弃,然后用空白的大脑进行思考。我们的思想拥有持久性。 传统的神经网络…

LSTM 模型实践一

简单介绍 原因:普通的RNN(Recurrent Neural Network)对于长期依赖问题效果比较差,当序列本身比较长时,神经网络模型的训练是采用backward进行,在梯度链式法则中容易出现梯度消失和梯度爆炸的问题。 解决&…

理解LSTM模型

写在前面:这是翻译自colah的一篇博客,原文关于LSTM神经网络模型的理解写的非常直观、简单易懂,所以翻译过来帮助大家学习理解LSTM模型。 当然我不是按照原文一字不落的翻译,而是摘出其中对模型理解最有帮助的部分,然后…

LSTM模型预测时间序列(快速上手)

写在前面 LSTM模型的一个常见用途是对长时间序列数据进行学习预测,例如得到了某商品前一年的日销量数据,我们可以用LSTM模型来预测未来一段时间内该商品的销量。但对于不熟悉神经网络或者对没有了解过RNN模型的人来说,想要看懂LSTM模型的原理…

基于优化LSTM 模型的股票预测

LSTM自诞生以来,便以其在处理时间序列方面的优越性能在预测回归,语音翻译等领域广受青睐。今天,主要研究的是通过对LSTM模型的优化来实现股票预测。其实,关于股票预测,LSTM模型已经表现的相当成熟,然而&…

【机器学习】LSTM模型原理

文章目录 1. 序言2. RNN的基本概念与网络结构2.1 概念2.2 结构2.3 要素 3. LSTM的基本概念与网络结构3.1 概念3.2 结构3.3 要素 4. LSTM网络结构的说明5. 补充 1. 序言 临渊羡鱼不如归而结网 学习的目的是为了应用 2. RNN的基本概念与网络结构 LSTM是在RNN的基础上演进而来…

LSTM模型

LSTM比RNN复杂很多,RNN只有一个参数矩阵A,LSTM有4个(遗忘门,输入门,更新值,输出门) LSTM有一个非常重要的传输带Ct,过去的信息通过这个传输带送给下一时刻,不会发生太大变…

时间序列预测——LSTM模型(附代码实现)

目录 模型原理 模型实现 导入所需要的库 设置随机数种子 导入数据集 打印前五行数据进行查看 数据处理 归一化处理 查看归一化处理后的数据 将时间序列转换为监督学习问题 打印数据前五行 划分训练集和测试集 查看划分后的数据维度 搭建LSTM模型 得到损失图 模型…

phpstorm汉化操作

问题描述:如何进行phpstorm的汉化(原始为英文) 解决办法: 1.下载汉化包—resources_cn; 2.将汉化包添加到phpstorm文件夹下的lib文件夹内(其中的resources_en可以不删除); 3.重启Ph…

php 教程 phpstorm

目录 php开发流程 php 编辑工具 phpstudy phpstorm如何配置php环境 php 语言 什么是URI URL和URI差别: 一、HTTP和HTTPS的基本概念 经典类型和自定义 实现本机域映射​ php开发流程 1、下载php语言包 php作为一门语言,本身可以是一个纯绿色版的…

【PHP】Phpstorm环境配置与应用

一. Phpstorm环境配置 (1)点击左上端File,选择下拉框中的setting,进入环境配置页面,如下图 (2)如下图点击1处,选中下拉框中的Deployment,Type(图示2处),下拉框中选择Local or mounted folder …

PHP开发工具PhpStorm v2022.3——完全支持PHP 8.2

PhpStorm是一个轻量级且便捷的PHP IDE,其旨在提高用户效率,可深刻理解用户的编码,提供智能代码补全,快速导航以及即时错误检查。可随时帮助用户对其编码进行调整,运行单元测试或者提供可视化debug功能。 PhpStorm v20…

phpstorm10.0.3汉化方法:

PhpStorm10.0.3汉化方法: 1、安装原版PhpStorm10.0.3,在打开最新的PhpStorm10汉化包下载地址: http://pan.baidu.com/s/1bouoyF9 2、双击用压缩软件打开resources_cn.jar(注意是打开而不是解压出来),将下载的汉化包…

PHPSTORM 中文版/汉化 即常用快捷键和配置

PHPStorm配置和快捷键大全(最新版)支持Win和Mac http://blog.csdn.net/fenglailea/article/details/53350080 推荐 1.汉化/中文版 使用的是开源的,翻译的还好,因为是开源,有些人不自觉,在翻译过程中加入广告 开源汉化地址&…

phpstorm10安装并汉化

一、下载phpstorm 下载地址:https://pan.baidu.com/s/1R64ZROVP1ljGbYfCwWjwxA 二、一直点击下一步安装即可 注意:第3步的时候选择一下支持的后缀 三、安装完毕,进行汉化 1、来到安装目录,PhpStorm 10.0.3\lib 目录下 2、…

PhpStorm 中文设置教程

本文仅供学习交流使用,如侵立删!demo下载见文末 Pycharm中文设置教程 1.首先打开PhpStorm ,点击file-settings.找到plugins,搜索Marketplace,然后搜索chinese。 2.找到之后直接点击安装. 3.安装完成之后点击Restart。…

PHPStorm运行PHP代码(新手教程)

PHPStorm是流行对PHP及前端开发IDE,在开发者初次使用写PHP代码时该怎么用呢~ 1、Create New Project 2、选择PHP Empty Project,并新建一个空目录(名字建议为英文,目录不要放在C盘!!!&#xff…

PHP教程二:开发工具 phpstorm 的下载、安装与激活

接着上一章节,我们继续开发工具的安装 phpstorm 的概括:PhpStorm 是 JetBrains 公司开发的一款商业的 PHP 集成开发工具,旨在提高用户效率,可深刻理解用户的编码,提供智能代码补全,快速导航以及即时错误检…

下载phpstorm2021汉化包

网址 网址:https://plugins.jetbrains.com/plugin/13710-chinese-simplified-language-pack---- 点安装到phpstorm 打开phpstorm2021版本 点file-------seeting-------k~(省略了)-----搜索chinese--------最后点安装就可以了。