Pytorch教程[10]完整模型训练套路

article/2025/10/29 18:27:03

一般的模型构建都是按照下图这样的流程
在这里插入图片描述


下面分享一个自己手动搭建的网络

在这里插入图片描述

from model import *
import torchvision
import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torch import nn
from torch.utils.data import DataLoader#数据增强
data_transforms = transforms.Compose([transforms.RandomRotation(45),transforms.ToTensor(),])#准备数据集
#train_data = torchvision.datasets.CIFAR10(root="D:\pythonProject_pytorchstudy", train=True, transform=torchvision.transforms.ToTensor(), download=False)
#test_data = torchvision.datasets.CIFAR10(root="D:\pythonProject_pytorchstudy", train=False, transform=torchvision.transforms.ToTensor(), download=False)
train_data = torchvision.datasets.CIFAR10(root="D:\pythonProject_pytorchstudy", train=True, transform=data_transforms, download=False)
test_data = torchvision.datasets.CIFAR10(root="D:\pythonProject_pytorchstudy", train=False, transform=torchvision.transforms.ToTensor(), download=False)#数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练集的长度为:{}".format(train_data_size))
print("测试集的长度为:{}".format(test_data_size))#利用Dataloader加载数据集
train_dataloader =DataLoader(train_data,batch_size=64)
test_dataloader =DataLoader(test_data,batch_size=64)#搭建神经网络
#model.py#创建网络模型
Yolo = My_Model()################################
if torch.cuda.is_available():  #Yolo = My_Model().cuda()   #
#################################损失函数
loss_fn = nn.CrossEntropyLoss()################################
if torch.cuda.is_available():  #loss_fn = loss_fn.cuda()   #
#################################优化器
learning_rate = 0.01 #1e-2 = 1 x (10)^(-2) =1/100 =0.01
optimizer  = torch.optim.SGD(Yolo.parameters(), lr = learning_rate, )#设置训练网络的参数
total_train_step = 0
#记录测试次数
total_test_step = 0
#训练轮数
epoch = 10#添加tensorboard
writer = SummaryWriter("D:\pythonProject_pytorchstudy\cifar-10-batches-py\logs_train")for i in range(epoch):print("第{}轮训练开始".format(i+1))#训练步骤开始Yolo.train()for data in train_dataloader:imgs,targets = data################################if torch.cuda.is_available():  #imgs = imgs.cuda()         #targets = targets.cuda()   #################################outputs = Yolo(imgs)loss  = loss_fn(outputs,targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 30 ==0:print("Iteration:{},loss:{}".format(total_train_step,loss.item()))writer.add_scalar("train_loss", loss.item(),total_train_step)#测试步骤开始Yolo.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad(): #让网络中的梯度没有for data in test_dataloader:imgs, targets = data################################if torch.cuda.is_available():  #imgs = imgs.cuda()         #targets = targets.cuda()   #################################outputs = Yolo(imgs)loss = loss_fn(outputs,targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_train_step += 1torch.save(Yolo,"YOLO_{}".format(i+1))#torch.save(Yolo.state_dict(),"Yolo_{}.pth".format(i+1))print("模型已保存")writer.close()
import torch
from torch import nnclass My_Model(nn.Module):def __init__(self):super(My_Model, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x# Yolo = My_Model()# input = torch.ones(64,3,32,32)# output = Yolo(input)# print(output.shape)

http://chatgpt.dhexx.cn/article/NsUdAH3B.shtml

相关文章

PaddleOCR学习(二)PaddleOCR检测模型训练

这一部分主要介绍,如何使用自己的数据库去训练PaddleOCR的文本检测模型。 官方教程https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md 一、准备训练数据 首先你需要有自己的数据,如果没有自己的数据,推荐使…

迁移学习的模型训练

用深度学习解决目标检测有两个重要工作: 1、设计、实现、训练和验证模型 模型如果设计模型如何编程实现如何收集足够的数据来训练并验证模型是否符合预期 从头开始设计、实现、训练和验证模型是需要有众多深度学习算法人才做支撑,并且极其耗时耗力 2、…

TF2.0模型训练

TF2.0模型训练 概述数据集介绍1、通过fit方法训练模型准备数据创建模型编译模型训练模型 2、通过fit_generator方法训练模型构建生成器创建模型编译模型训练模型 3、自定义训练准备数据创建模型定义损失函数及优化器训练模型 下一篇TF2.0模型保存 概述 这是TF2.0入门笔记【TF2…

TensorFlow 2.0 —— 模型训练

目录 1、Keras版本模型训练1.1 构造模型(顺序模型、函数式模型、子类模型)1.2 模型训练:model.fit()1.3 模型验证:model.evaluate()1.4 模型预测:model.predict()1.5 使用样本加权和类别加权1.6 回调函数1.6.1 EarlySt…

如何在jupyter上运行Java代码(适用LINUX)

如何在jupyter上运行Java代码 1.下载必须软件 下载JDK且JDK版本必须 ≥ 9 ≥9 ≥9从github上下载ijava 附 : ijava下载链接.装有jupyter,我在LINUX上是直接装的anaconda 安装过程 将下载的ijava压缩包解压出来,并在此路径用该命令 : sudo…

Java单元测试介绍

文章目录 单元测试单元测试基本介绍单元测试快速入门单元测试常用注解 单元测试 单元测试基本介绍 单元测试: 单元测试就是针对最小的功能单元编写测试代码,Java程序最小的功能单元是方法,因此,单元测试就是针对Java方法的测试,…

Jupyter 配置 Java环境,写Java代码,测试成功

本次简单诉说下怎么通过jupyter安装iJava,写Java代码。 安装Java的不说了 我使用的是Java15 然后去:https://github.com/SpencerPark/IJava/releases 下载zip,不要下载其他的 得到就是一个py文件 下面就是一个 python install.py 我这里就…

java调用python执行脚本,附代码

最近有个功能需要java调用python脚本实现一些功能,前期需要做好的准备:配置好python环境,如下: 以下展示的为两种,一种为生成图片,另一种为生成字符串。 package com.msdw.tms.common.utils.py;import ja…

Selenium Java自动化测试环境搭建

IDE用的是Eclipse。 步骤1:因为是基于Java,所以首先要下载与安装JDK(Java Development Kit) 下载: 点击这里下载JDK 安装:按照默认安装一路点next就可以了。 验证:安装完成后,在命…

java单元测试(Junit)

相关代码下载链接: http://download.csdn.net/detail/stevenhu_223/4884357 在有些时候,我们需要对我们自己编写的代码进行单元测试(好处是,减少后期维护的精力和费用),这是一些最基本的模块测试。当然&…

Java单元测试工具:JUnit4(一)——概述及简单例子

(一)JUnit概述及一个简单例子 看了慕课网的JUnit视频教程: http://www.imooc.com/learn/356,总结笔记。 这篇笔记记录JUnit的概述,以及一个快速入门的例子。 1.概述 1.1 什么是JUnit ①JUnit是用于编写可复用测试集的…

Linux下执行Python脚本

1.Linux Python环境 Linux系统一般集成Python,如果没有安装,可以手动安装,联网状态下可直接安装。Fedora下使用yum install,Ubuntu下使用apt-get install,前提都是root权限。安装完毕,可将Python加入环境变…

python pytest脚本执行工具

pytest脚本执行工具 支持获取当前路径下所有.py脚本 添加多个脚本,一起执行 import tkinter as tk from tkinter import filedialog import subprocess import os from datetime import datetimedef select_script():script_path filedialog.askopenfilename(fil…

linux上运行python(简单版)

linux上运行python(简单版) 一、前提准备1.centOS72.挂载yum源[http://t.csdn.cn/Isf0i](http://t.csdn.cn/Isf0i) 二、安装python3三、运行程序 一、前提准备 1.centOS7 2.挂载yum源http://t.csdn.cn/Isf0i 在终端进行安装python3 二、安装python3 …

linux怎么运行python脚本?

linux运行python脚本的方法: 1、命令行执行: 建立一个test.py文档,在其中书写python代码。之后,在命令行执行:python test.py 说明:其中python可以写成python的绝对路径。使用which python进行查询。 注…

java实现远程执行Linux下的shell脚本

java实现远程执行Linux下的shell脚本 背景导入Jar包第一步:远程连接第二步:开启Session第三步:新建测试脚本文件结果报错 背景 最近有个项目,需要在Linux下的服务器内写了一部分Python脚本,业务处理却是在Java内&…

Java运行Python脚本

前段时间遇到了在JavaWeb项目中嵌入运行Python脚本的功能的需求。想到的方案有两种,一种是使用Java技术(Jython或Runtime.exec)运行Python脚本,另一种是搭建一个Python工程对外提供相应http或webservice接口。两种方案我都有实现&…

Java项目分层

MVC模式 在实际的开发中有一种项目的程序组织架构方案叫做MVC模式,按照程序 的功能将他们分成三个层,如下图:Modle层(模型层)、View层(显示层)、Controller层(控制层)。…

java项目收获总结_java开发项目收获心得

1 java开发项目收获心得 it行业现在的发展如日中天,很多人都纷纷走进这个行业,而java作为跨平台的编程语言更是受欢迎。java其实相对其他语言来说的确很有优势,但是也有点缺陷,但是以后发展到什么程度,谁都不知道。那么下面小编给大家说说java开发项目收获心得,希望能对你…

java查看jar包依赖_java项目开发中如何查找到项目依赖的jar包?

不管是java普通工程,还是java web项目,甚至是android项目,依赖包的管理有2种: 1.直接依赖jar包 这种方式简单直白,项目下载后在正确的ide或者稍微做转换就可以运行起来。比如java web工程的WEB-INF/lib下 只要按这个步骤Java Build Path=>Add Libraty=>Web App Libr…