十四、OPTIM

article/2025/11/9 13:15:36

一、torch.optim

torch.optim.Optimizer(params, defaults)优化器官网说明
在这里插入图片描述
由官网给的使用说明打开看出来优化器实验步骤:

①构造选择优化器

例如采用随机梯度下降优化器SGD
torch.optim.SGD(beyond.parameters(),lr=0.01),放入beyond模型的参数parameters;学习率learning rate;
每个优化器都有其特定独有的参数

②把网络中所有的可用梯度全部设置为0

optim.zero_grad()
梯度为tensor中的一个属性,这就是为啥神经网络传入的数据必须是tensor数据类型的原因,grad这个属性其实就是求导,常用在反向传播中,也就是通过先通过正向传播依次求出结果,再通过反向传播求导来依次倒退,其目的主要是对参数进行调整优化,详细的学习了解可自行百度。

③通过反向传播获取损失函数的梯度

result_loss.backward()
这里使用的损失函数为loss,其对象为result_loss,当然也可以使用其他的损失函数
从而得到每个可以调节参数的梯度

④调用step方法,对每个梯度参数进行调优更新

optim.step()
使用优化器的step方法,会利用之前得到的梯度grad,来对模型中的参数进行更新

二、优化器的使用

使用CIFAR-10数据集的测试集,使用之前实现的网络模型,二、复现网络模型训练CIFAR-10数据集

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset_testset = torchvision.datasets.CIFAR10("CIFAR_10",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset_testset,batch_size=2)class Beyond(nn.Module):def __init__(self):super(Beyond,self).__init__()self.model = torch.nn.Sequential(torch.nn.Conv2d(3,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,64,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Flatten(),torch.nn.Linear(1024,64),torch.nn.Linear(64,10))def forward(self,x):x = self.model(x)return x
loss = nn.CrossEntropyLoss()#构建选择损失函数为交叉熵
beyond = Beyond()
#print(beyond)
optim = torch.optim.SGD(beyond.parameters(),lr=0.01)for epoch in range(30):#进行30轮训练sum_loss = 0.0for data in dataloader:imgs, targets = dataoutput = beyond(imgs)# print(output)# print(targets)result_loss = loss(output, targets)# print(result_loss)optim.zero_grad()#把网络模型中所有的梯度都设置为0result_loss.backward()#反向传播获得每个参数的梯度从而可以通过优化器进行调优optim.step()#print(result_loss)sum_loss = sum_loss + result_lossprint(sum_loss)"""
tensor(9431.9678, grad_fn=<AddBackward0>)
tensor(7715.2842, grad_fn=<AddBackward0>)
tensor(6860.3115, grad_fn=<AddBackward0>)
......"""

在optim.zero_grad()及其下面三行处,左击打个断点,进入Debug模式(Shift+F9)下,
网络模型名称---Protected Attributes---__modules---0-8随便选一个,例如'0'---weight---grad就是参数的梯度

在这里插入图片描述

三、自动调整学习速率设置

torch.optim.lr_scheduler.ExponentialLR(optimizer=optim,gamma=0.1)
optimizer为优化器的名称,gamma表示每次都会将原来的lr乘以gamma
使用optim优化器,每次就会在原来的学习速率的基础上乘以0.1

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset_testset = torchvision.datasets.CIFAR10("CIFAR_10",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset_testset,batch_size=2)class Beyond(nn.Module):def __init__(self):super(Beyond,self).__init__()self.model = torch.nn.Sequential(torch.nn.Conv2d(3,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,64,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Flatten(),torch.nn.Linear(1024,64),torch.nn.Linear(64,10))def forward(self,x):x = self.model(x)return x
loss = nn.CrossEntropyLoss()#构建选择损失函数为交叉熵
beyond = Beyond()
#print(beyond)
optim = torch.optim.SGD(beyond.parameters(),lr=0.01)
scheduler = ExponentialLR(optimizer=optim,gamma=0.1)#在原来的lr上乘以gammafor epoch in range(30):#进行30轮训练sum_loss = 0.0for data in dataloader:imgs, targets = dataoutput = beyond(imgs)# print(output)# print(targets)result_loss = loss(output, targets)# print(result_loss)optim.zero_grad()#把网络模型中所有的梯度都设置为0result_loss.backward()#反向传播获得每个参数的梯度从而可以通过优化器进行调优optim.step()#print(result_loss)sum_loss = sum_loss + result_lossscheduler.step()#这里就需要不能用优化器,而是使用自动学习速率的优化器print(sum_loss)"""
tensor(9469.4385, grad_fn=<AddBackward0>)
tensor(7144.1514, grad_fn=<AddBackward0>)
tensor(6734.8311, grad_fn=<AddBackward0>)
......"""

http://chatgpt.dhexx.cn/article/36GB0lip.shtml

相关文章

GPU开发环境搭建(CUDA和 OptiX)

Optix是英伟达一直推出的闭源光线跟踪&#xff08;rayTracing&#xff09;引擎 CUDA&#xff08;Compute Unified Device Architecture&#xff09;&#xff0c;是显卡厂商NVIDIA推出的运算平台。 CUDA™是一种由NVIDIA推出的通用并行计算架构&#xff0c;该架构使GPU能够解决复…

Intel OpenImageDenoise VS Nvidia Optix 降噪结果对比

说明&#xff1a;原始图像&#xff08;Raytracing的直接输出结果&#xff0c;每一幅的左图&#xff09;为PPM格式&#xff0c; 一、OIDN 按照官方文档提示&#xff0c;先用ImageMagick转换成pfm格式&#xff0c;再将其作为oidn的输入&#xff0c;输出亦为pfm。 magick conve…

OTN技术及华为OTN设备简介

OTN技术及华为OTN设备简介 城域波分环四环五即将进行建设&#xff0c;本次工程采用华为华为下一代智能光传送平台OTN设备OptiX OSN 8800和OptiX OSN 6800。本文主要对OTN技术涉及的网络结构、复用方式、帧结构、ROADM技术和OptiX OSN 8800和OptiX OSN 6800设备特点及本次工程配…

【OptiX】第0个示例 OptixHello 学习Optix的工程配置以及基本框架

首先需要查看本博客的这篇文章&#xff1a;【Optix】Optix介绍与示例编译 把该安装的工程都安装好。可以按照本文所说的顺序创建和理解代码&#xff0c;也可以在本文末尾下载到已经配置好的代码。建议首先在本文末尾处下载代码&#xff0c;编译通过&#xff0c;这样配合文件看心…

OptiX-7入门教程

OptiX是英伟达专为光线追踪打造的SDK&#xff0c;但是他的官方案例都比较复杂&#xff0c;包含了大量初始化相关的代码&#xff0c;初学容易一头雾水。 本人跟着Github上的optiX7course一步步学习才算入门。这个课程是Siggraph 2019/2020上的OptiX课程&#xff0c;有源码&…

optix入门

射线追踪是embarrassingly parallel/perfectly parallel/pleasingly parallel的问题&#xff0c;就是说基本不用费劲就可以并行化。 射线追踪是指从某点发射射线&#xff0c;判断其与几何结构的交点&#xff0c;根据交点对图像进行渲染&#xff0c;或者计算。 nvidia optix是基…

jwt *

目录 一、jwt出现的原因及工作原理 1. JWT是什么 2. 为什么使用JWT 3. JWT的工作原理 4、jwt解决不需要登录就能直接访问的问题&#xff1a; 解决登录后树形菜单未出现的问题 &#xff1a; 二、jwt工具类介绍&#xff0c;三种场景 1、jwt工具类 2、三种场景 三、jwt…

JWT JWT

JWT&#xff08;JSON WEB TOKEN&#xff09; JWT的组成 header&#xff08;头部&#xff09;&#xff1a;中主要存储了两个字段 alg&#xff0c;typ。 alg表示加密的算法默认&#xff08;HMAC SHA256&#xff09;&#xff0c;typ表示这个令牌的类型默认为JWT。 payload&#…

JWT__

文章目录 JWT什么是JWT&#xff1f;JWT能做什么&#xff1f;认证流程JWT的结构是什么&#xff1f;使用代码要做一个JWT的例子引入pom依赖生成一个Token令牌验证令牌并从令牌中取出信息 JWT 什么是JWT&#xff1f; 官网地址:https://jwt.io/introduction/ 官方文档 JSON Web T…

JWT 和 JJWT 还傻傻的分不清吗

JWTs是JSON对象的编码表示。JSON对象由零或多个名称/值对组成&#xff0c;其中名称为字符串&#xff0c;值为任意JSON值。 JWT有助于在clear(例如在URL中)发送这样的信息&#xff0c;可以被信任为不可读(即加密的)、不可修改的(即签名)和URL - safe(即Base64编码的)。 JSON W…

【编码实战】2022年还在用jjwt操作jwt?,推荐你使用nimbus-jose-jwt,爽到飞起~

什么是nimbus-jose-jwt&#xff1f; nimbus-jose-jwt是基于Apache2.0开源协议的JWT开源库&#xff0c;支持所有的签名(JWS)和加密(JWE)算法。 对于JWT、JWS、JWE介绍 JWT是一种规范&#xff0c;它强调了两个组织之间传递安全的信息JWS是JWT的一种实现&#xff0c;包含三部分hea…

什么是JWT??

一、什么是JWT JWT(JSON WEB TOKEN)&#xff0c;通过数字签名的方式&#xff0c;以json对象为载体&#xff0c;在不同的服务终端之间安全的传输信息&#xff0c;用来解决传统session的弊端。 JWT在前后端分离系统&#xff0c;或跨平台系统中&#xff0c;通过JSON形式作为WEB应用…

JJWT实现令牌Token

登录实现方式 Session 详情&#xff1a; https://www.cnblogs.com/andy-zhou/p/5360107.html 会话的概念 会话就好比打电话&#xff0c;一次通话可以理解为一次会话。我们登录一个网站&#xff0c;在一个网站上不同的页面浏览&#xff0c;最后退出这个网站&#xff0c;也是…

3.JJWT

目录 1.JWT简介 2.JWT的结构 3.基于服务器的传统身份认证 4.基于token的身份认证 5. JWT的优势 6.Java中使用JJWT实现JWT 1.JWT简介 Json web token (JWT)&#xff0c; 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准(RFC 7519)。该token被设计为紧凑且安全的…

JWT技术

JWT 一、 JWT 实现无状态 Web 服务 1、什么是有状态 有状态服务&#xff0c;即服务端需要记录每次会话的客户端信息&#xff0c;从而识别客户端身份&#xff0c;根据用户身份进行请求的处理&#xff0c;典型的设计如tomcat中的session。 例如登录&#xff1a;用户登录后&am…

token学习笔记(JWT、jjwt的使用及案例实现)

文章目录 1. 首先、了解什么是会话2. 会话跟踪的主要技术3. Token 令牌学习3.1 流程图3.2 token3.3 JWT(JSON web Tokens)Json web 令牌(规范)3.4 JWT结构3.5 JWT需要的依赖3.6 JWT的获取与验证流程3.7JWT的使用方式3.8 jjwt的使用&#xff08;创建JWT方式&#xff09;1. jjwt需…

JWT 进阶 -- JJWT

###jwt是什么? JWTs是JSON对象的编码表示。JSON对象由零或多个名称/值对组成&#xff0c;其中名称为字符串&#xff0c;值为任意JSON值。JWT有助于在clear(例如在URL中)发送这样的信息&#xff0c;可以被信任为不可读(即加密的)、不可修改的(即签名)和URL - safe(即Base64编码…

JWT详解和使用(jjwt)

JWT详解和使用 JWT是啥 JWT&#xff08;JSON Web Token&#xff09;是一个开放标准(RFC 7519)&#xff0c;它定义了一种紧凑的、自包含的方式&#xff0c;用于作为JSON对象在各方之间安全地传输信息。该信息可以被验证和信任&#xff0c;因为它是数字签名的。 下列场景中使用…

JJWT 实现JWT

1什么是JJWT JJWT 是一个提供端到端的 JWT 创建和验证的 Java 库。永远免费和开源 (Apache License&#xff0c;版本2.0)&#xff0c;JJWT 很容易使用和理解。它被设计成一个以建筑为中心的流畅界面&#xff0c;隐藏了它的大部分复杂性。 2JJWT快速入门 2.1token的创建 2.1…

什么是JWT?

在HTTP接口调用的时候&#xff0c;服务端经常需要对调用方做认证&#xff0c;以保证安全性。一种常见的认证方式是使用JWT(Json Web Token)&#xff0c;采用这种方式时&#xff0c;经常在header传入一个authorization字段&#xff0c;值为对应的jwt_token&#xff0c;或者也有图…