Pytorch之Optim(优化器)

article/2025/9/16 13:16:50

使用优化器,接收损失函数的结果,并调整网络参数,完成反向传播

根据示例

optimizer = torch.optim.SGD(module.parameters(), lr=0.01, momentum=0.9)

然后根据提示,清空梯度>网络前传>计算损失>反向传播>优化网络参数

在运行区域引入库和之前的Module

if __name__ == '__main__':module = Module()loss = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(module.parameters(), lr=0.01, momentum=0.9)running_loss = 0.0for imgs, targets in dataloader:optimizer.zero_grad()outputs = module(imgs)result_loss = loss(outputs, targets)result_loss.backward()optimizer.step()running_loss = running_loss + result_lossprint(running_loss)

再因为优化器一般不只是优化一次,迭代完所有训练集只是完成了网络(对于该数据集)的一次优化,优化的次数就是俗称的epoch,一般都是在外面再写个循环完成迭代

if __name__ == '__main__':module = Module()loss = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(module.parameters(), lr=0.01, momentum=0.9)for epoch in range(12):running_loss = 0.0for imgs, targets in dataloader:optimizer.zero_grad()outputs = module(imgs)result_loss = loss(outputs, targets)result_loss.backward()optimizer.step()running_loss = running_loss + result_lossprint(running_loss)

运行获得以下结果

然后由于CPU实在是太慢,加入GPU

if __name__ == '__main__':module = Module()loss = nn.CrossEntropyLoss()if torch.cuda.is_available():module = module.cuda()loss = loss.cuda()optimizer = torch.optim.SGD(module.parameters(), lr=0.01, momentum=0.9)for epoch in range(12):running_loss = 0.0for imgs, targets in dataloader:if torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()optimizer.zero_grad()outputs = module(imgs)result_loss = loss(outputs, targets)result_loss.backward()optimizer.step()running_loss = running_loss + result_lossprint(running_loss)

最后放一下完整的代码

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset, batch_size=1)class Module(nn.Module):def __init__(self):super(Module, self).__init__()self.model = Sequential(Conv2d(3, 16, 5),MaxPool2d(2, 2),Conv2d(16, 32, 5),MaxPool2d(2, 2),Flatten(),  # 注意一下,线性层需要进行展平处理Linear(32*5*5, 120),Linear(120, 84),Linear(84, 10))def forward(self, x):x = self.model(x)return x

http://chatgpt.dhexx.cn/article/8bEv4edN.shtml

相关文章

torch.optim

torch.optim是一个实现了各种优化算法的库。 1. 如何使用optimizer 我们需要构建一个optimizer对象。这个对象能够保持当前参数状态并基于计算得到的梯度进行参数更新。 1.1 构建 为了构建一个Optimizer,你需要给它一个包含了需要优化的参数(必须都是…

torch.optim.lr_scheduler.CosineAnnealingWarmRestarts、OneCycleLR定义与使用

torch中有多种余弦退火学习率调整方法,包括:OneCycleLR、CosineAnnealingLR和CosineAnnealingWarmRestarts。 CosineAnnealingWarmRestarts(带预热的余弦退火)学习率方法定义 torch.optim.lr_scheduler.CosineAnnealingWarmRest…

torch.optim优化算法理解之optim.Adam()

torch.optim是一个实现了多种优化算法的包,大多数通用的方法都已支持,提供了丰富的接口调用,未来更多精炼的优化算法也将整合进来。 为了使用torch.optim,需先构造一个优化器对象Optimizer,用来保存当前的状态&#x…

Python-torch.optim优化算法理解之optim.Adam()

目录 简介 分析 使用 Adam算法 参数 论文理解 torch.optim.adam源码理解 Adam的特点 转载torch.optim优化算法理解之optim.Adam() 官方手册:torch.optim — PyTorch 1.11.0 documentation 其他参考 pytorch中优化器与学习率衰减方法总结 Adam和学习率衰减…

R 语言 optim 使用

stats中的optim函数是解决优化问题的一个简易的方法。 Univariate Optimization f function(x,a) (x-a)^2 xmin optimize(f,interval c(0,1),a1/3) xminGeneral Optimization optim函数包含了几种不同的算法。 算法的选择依赖于求解导数的难易程度,通常最好提…

PyTorch基础(六)-- optim模块

PyTorch的optim是用于参数优化的库(可以说是花式梯度下降),optim文件夹主要包括1个核心的父类(optimizer)、1个辅助类(lr_scheduler)以及10个常用优化算法的实现类。optim中内置的常用算法包括a…

pytorch基础(四):使用optim优化函数

文章目录 前言一、问题描述二、官方文档代码三、optimizer的工作原理总结 前言 本系列主要是对pytorch基础知识学习的一个记录,尽量保持博客的更新进度和自己的学习进度。本人也处于学习阶段,博客中涉及到的知识可能存在某些问题,希望大家批评…

HTTP协议之GET与POST区别

GET和POST是HTTP请求的两种基本方式,对于这两种请求方式的区别,只要是接触过Web开发的就能说出一二:GET把参数包含在URL中,POST通过正文传参! 而我想深入了解以下的时候,就去了w3cschool,这是w…

Get和Post区别是什么

附上原文地址:https://www.cnblogs.com/logsharing/p/8448446.html GET和POST是HTTP请求的两种基本方法,要说它们的区别,接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中,POST通过request body传递参数。…

get与post区别(很全~)

get与post区别 GET: (1)从指定的资源请求数据 (2)请求数据有长度限制(不同每个浏览器限制长度可能不一样) (3)请求url会在浏览器地址栏中显示 (4&#x…

Web中get和post区别

99% 的人都理解错了 HTTP 中 GET 与 POST 的区别(转) 转自:WebTechGarden 微信公众号GET 和 POST 是 HTTP 请求的两种基本方法,要说它们的区别,接触过 WEB 开发的人都能说出一二。最直观的区别就是 GET 把参数包含在 U…

get,post区别

Http定义了与服务器交互的不同方法,最基本的方法有4种,分别是GET,POST,PUT,DELETE。URL全称是资源描述符,我们可以这样认为:一个URL地址,它用于描述一个网络上的资源,而H…

GET 与 POST 区别

GET和POST是HTTP请求的两种基本方法,要说它们的区别,接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中,POST通过request body传递参数。 你可能自己写过无数个GET和POST请求,或者已经看过很多权威网站总结出…

GET和POST区别及缓存问题

2.就是get和post区别的缓存问题。 首先要了解什么是缓存。 HTTP缓存的基本目的就是使应用执行的更快,更易扩展,但是HTTP缓存通常只适用于idempotent request(可以理解为查询请求,也就是不更新服务端数据的请求)&#x…

Get与Post区别与范例讲解

林炳文Evankaka原创作品。转载请注明出处http://blog.csdn.net/evankaka 一、 J2EE WEB应用文件目录结构 Java Web应用由一组静态HTML页、Servlet、JSP和其他相关的class组成,它们一起构成一个大的工程项目。每种组件在Web应用中都有固定的存放目录。Web应用的配…

GET和POST区别详解

GET和POST是HTTP请求的两种基本方法,要说它们的区别,接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中,POST通过request body传递参数。 你可能自己写过无数个GET和POST请求,或者已经看过很多权威网站总结出…

get和post区别

GET和POST的安全性 1、GET是通过URL方式请求,可以直接看到,明文传输 2、POST是通过请求header请求,可以开发者工具或者抓包可以看到,同样也是明文的 3、GET请求会保存在浏览器历史纪录中,还可能会保存在Web的日志中 G…

post和get区别

GET和POST是HTTP请求的两种基本方法,要说它们的区别,接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中,POST通过request body传递参数。 你可能自己写过无数个GET和POST请求,或者已经看过很多权威网站总结出…

GET 和 POST 有什么区别?

GET 和 POST 是 HTTP 请求中最常用的两种请求方法,在日常开发的 RESTful 接口中,都能看到它们的身影。而它们之间的区别,也是一道常见且经典的面试题,所以我们本文就来详细的聊聊。 HTTP 协议定义的方法类型总共有以下 10 种&…

查找(一)——静态查找表

目录 一、查找的基本概念 二、顺序查找 (线性查找) 1、基本思想 2、核心代码 3、顺序查找设置哨兵 4、顺序查找的优点: 5、顺序查找的缺点: 6、折半查找 7、折半查找判定树 8、线性表查找的特点 三、索引顺序表&#x…