torch.optim

article/2025/9/16 20:26:52

torch.optim是一个实现了各种优化算法的库。

1. 如何使用optimizer

我们需要构建一个optimizer对象。这个对象能够保持当前参数状态并基于计算得到的梯度进行参数更新。

1.1 构建

为了构建一个Optimizer,你需要给它一个包含了需要优化的参数(必须都是Variable对象)的iterable。然后,你可以设置optimizer的参 数选项,比如学习率,权重衰减,等等。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

1.2 为每个参数单独设置选项

Optimizer也支持为每个参数单独设置选项。此时,我们应该传入dict的iterable。每一个dict都分别定义了一组参数,并且包含一个param键,这个键对应参数的列表。

当你只想改动一个参数组的选项,但其他参数组的选项不变时,这是非常有用的。

例如,当我们想指定每一层的学习率时,

optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)

这意味着model.base的参数将会使用1e-2的学习率,model.classifier的参数将会使用1e-3的学习率,并且0.9的momentum将会被用于所有的参数。

1.3 进行单次优化

所有的optimizer都实现了step()方法,这个方法会更新所有的参数。它能按两种方式来使用:

1.optimizer.step()

这是大多数optimizer所支持的简化版本。一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数。

例子:

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

2.optimizer.step(closure)

这个函数的用法没看懂,但用的也很少。

2. 优化器

class torch.optim.Optimizer(params, defaults)

这是所有优化器的基类。

参数

  • params (iterable) —— Variable 或者 dict的iterable。指定了什么参数应当被优化。
  • defaults —— (dict):包含了优化选项默认值的字典(一个参数组没有指定的参数选项将会使用默认值)。

(1)load_state_dict(state_dict)

加载optimizer状态。

参数

state_dict (dict) —— optimizer的状态。应当是一个调用state_dict()所返回的对象。

(2)state_dict()

dict返回optimizer的状态。

它包含两项。

state - 一个保存了当前优化状态的dict。optimizer的类别不同,state的内容也会不同。
param_groups - 一个包含了全部参数组的dict。

(3)step(closure)

进行单次优化 (参数更新)。

closure参数是可选的。

(4)zero_grad()

清空所有被优化过的Variable的梯度。

下面给出几个最常用的优化器。

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)[source]

参数:

  • params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
  • lr (float, 可选) – 学习率(默认:1e-3)
  • betas (Tuple[float, float], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)
  • eps (float, 可选) – 为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)
  • weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)

class torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)[source]

参数:

  • params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
  • lr (float, 可选) – 学习率(默认:1e-2)
  • momentum (float, 可选) – 动量因子(默认:0)
  • alpha (float, 可选) – 平滑常数(默认:0.99)
  • eps (float, 可选) – 为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)
  • centered (bool, 可选) – 如果为True,计算中心化RMSProp,并且用它的方差预测值对梯度进行归一化
  • weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)

class torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)[source]

参数:

  • params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
  • lr (float) – 学习率
  • momentum (float, 可选) – 动量因子(默认:0)
  • weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认:0)
  • dampening (float, 可选) – 动量的抑制因子(默认:0)
  • nesterov (bool, 可选) – 使用Nesterov动量(默认:False)

例子:

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()

3. 调整学习率

torch.optim.lr_scheduler 提供了几种方法来根据 epoch 的数量调整学习率。

学习率调整应该在优化器更新后应用,比如:

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler.step()

torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=- 1, verbose=False)

将每个参数组的学习率设置为初始 lr 乘以给定函数。当 last_epoch=-1 时,设置initial lr 为 lr。

l r e p o c h = l r i n i t i a l ∗ l r _ l a m b d a ( e p o c h ) lr_{epoch} =lr_{initial} * lr\_lambda(epoch) lrepoch=lrinitiallr_lambda(epoch)

在这里插入图片描述

# Assuming optimizer has two groups.
lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):train(...)validate(...)scheduler.step()

torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda, last_epoch=- 1, verbose=False)

将每个参数组的学习率乘以指定函数中给出的因子。 当 last_epoch=-1 时,设置初始 lr 为 lr。

l r e p o c h = l r e p o c h − 1 ∗ l r _ l a m b d a ( e p o c h ) lr_{epoch} =lr_{epoch - 1} * lr\_lambda(epoch) lrepoch=lrepoch1lr_lambda(epoch)

在这里插入图片描述

lmbda = lambda epoch: 0.95
scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
for epoch in range(100):train(...)validate(...)scheduler.step()

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=- 1, verbose=False)

step_size epochs 衰减每个参数组的学习率。 请注意,这种衰减可能与此调度程序外部对学习率的其他更改同时发生。 当 last_epoch=-1 时,设置初始 lr 为 lr。

在这里插入图片描述

# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05     if epoch < 30
# lr = 0.005    if 30 <= epoch < 60
# lr = 0.0005   if 60 <= epoch < 90
# ...
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(100):train(...)validate(...)scheduler.step()

torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=- 1, verbose=False)

一旦 epoch 的数量达到milestones之一,通过 gamma 衰减每个参数组的学习率。 请注意,这种衰减可能与此调度程序外部对学习率的其他更改同时发生。 当 last_epoch=-1 时,设置初始 lr 为 lr。

在这里插入图片描述

# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05     if epoch < 30
# lr = 0.005    if 30 <= epoch < 80
# lr = 0.0005   if epoch >= 80
scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(100):train(...)validate(...)scheduler.step()

torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.3333333333333333, total_iters=5, last_epoch=- 1, verbose=False)

将每个参数组的学习率衰减一个小的常数因子,直到 epoch 的数量达到预定义的total_iters。 请注意,这种衰减可能与此调度程序外部对学习率的其他更改同时发生。 当 last_epoch=-1 时,设置初始 lr 为 lr。

在这里插入图片描述

# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.025   if epoch == 0
# lr = 0.025   if epoch == 1
# lr = 0.025   if epoch == 2
# lr = 0.025   if epoch == 3
# lr = 0.05    if epoch >= 4
scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
for epoch in range(100):train(...)validate(...)scheduler.step()

torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.3333333333333333, end_factor=1.0, total_iters=5, last_epoch=- 1, verbose=False)

通过线性改变小的乘法因子来衰减每个参数组的学习率,直到 epoch 的数量达到预定义的total_iters。 请注意,这种衰减可能与此调度程序外部对学习率的其他更改同时发生。 当 last_epoch=-1 时,设置初始 lr 为 lr。

在这里插入图片描述

# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.025    if epoch == 0
# lr = 0.03125  if epoch == 1
# lr = 0.0375   if epoch == 2
# lr = 0.04375  if epoch == 3
# lr = 0.05    if epoch >= 4
scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
for epoch in range(100):train(...)validate(...)scheduler.step()

http://chatgpt.dhexx.cn/article/DZBOeWeA.shtml

相关文章

torch.optim.lr_scheduler.CosineAnnealingWarmRestarts、OneCycleLR定义与使用

torch中有多种余弦退火学习率调整方法&#xff0c;包括&#xff1a;OneCycleLR、CosineAnnealingLR和CosineAnnealingWarmRestarts。 CosineAnnealingWarmRestarts&#xff08;带预热的余弦退火&#xff09;学习率方法定义 torch.optim.lr_scheduler.CosineAnnealingWarmRest…

torch.optim优化算法理解之optim.Adam()

torch.optim是一个实现了多种优化算法的包&#xff0c;大多数通用的方法都已支持&#xff0c;提供了丰富的接口调用&#xff0c;未来更多精炼的优化算法也将整合进来。 为了使用torch.optim&#xff0c;需先构造一个优化器对象Optimizer&#xff0c;用来保存当前的状态&#x…

Python-torch.optim优化算法理解之optim.Adam()

目录 简介 分析 使用 Adam算法 参数 论文理解 torch.optim.adam源码理解 Adam的特点 转载torch.optim优化算法理解之optim.Adam() 官方手册&#xff1a;torch.optim — PyTorch 1.11.0 documentation 其他参考 pytorch中优化器与学习率衰减方法总结 Adam和学习率衰减…

R 语言 optim 使用

stats中的optim函数是解决优化问题的一个简易的方法。 Univariate Optimization f function(x,a) (x-a)^2 xmin optimize(f,interval c(0,1),a1/3) xminGeneral Optimization optim函数包含了几种不同的算法。 算法的选择依赖于求解导数的难易程度&#xff0c;通常最好提…

PyTorch基础(六)-- optim模块

PyTorch的optim是用于参数优化的库&#xff08;可以说是花式梯度下降&#xff09;&#xff0c;optim文件夹主要包括1个核心的父类&#xff08;optimizer&#xff09;、1个辅助类&#xff08;lr_scheduler&#xff09;以及10个常用优化算法的实现类。optim中内置的常用算法包括a…

pytorch基础(四):使用optim优化函数

文章目录 前言一、问题描述二、官方文档代码三、optimizer的工作原理总结 前言 本系列主要是对pytorch基础知识学习的一个记录&#xff0c;尽量保持博客的更新进度和自己的学习进度。本人也处于学习阶段&#xff0c;博客中涉及到的知识可能存在某些问题&#xff0c;希望大家批评…

HTTP协议之GET与POST区别

GET和POST是HTTP请求的两种基本方式&#xff0c;对于这两种请求方式的区别&#xff0c;只要是接触过Web开发的就能说出一二&#xff1a;GET把参数包含在URL中&#xff0c;POST通过正文传参&#xff01; 而我想深入了解以下的时候&#xff0c;就去了w3cschool&#xff0c;这是w…

Get和Post区别是什么

附上原文地址&#xff1a;https://www.cnblogs.com/logsharing/p/8448446.html GET和POST是HTTP请求的两种基本方法&#xff0c;要说它们的区别&#xff0c;接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中&#xff0c;POST通过request body传递参数。…

get与post区别(很全~)

get与post区别 GET&#xff1a; &#xff08;1&#xff09;从指定的资源请求数据 &#xff08;2&#xff09;请求数据有长度限制&#xff08;不同每个浏览器限制长度可能不一样&#xff09; &#xff08;3&#xff09;请求url会在浏览器地址栏中显示 &#xff08;4&#x…

Web中get和post区别

99% 的人都理解错了 HTTP 中 GET 与 POST 的区别&#xff08;转&#xff09; 转自&#xff1a;WebTechGarden 微信公众号GET 和 POST 是 HTTP 请求的两种基本方法&#xff0c;要说它们的区别&#xff0c;接触过 WEB 开发的人都能说出一二。最直观的区别就是 GET 把参数包含在 U…

get,post区别

Http定义了与服务器交互的不同方法&#xff0c;最基本的方法有4种&#xff0c;分别是GET&#xff0c;POST&#xff0c;PUT&#xff0c;DELETE。URL全称是资源描述符&#xff0c;我们可以这样认为&#xff1a;一个URL地址&#xff0c;它用于描述一个网络上的资源&#xff0c;而H…

GET 与 POST 区别

GET和POST是HTTP请求的两种基本方法&#xff0c;要说它们的区别&#xff0c;接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中&#xff0c;POST通过request body传递参数。 你可能自己写过无数个GET和POST请求&#xff0c;或者已经看过很多权威网站总结出…

GET和POST区别及缓存问题

2.就是get和post区别的缓存问题。 首先要了解什么是缓存。 HTTP缓存的基本目的就是使应用执行的更快&#xff0c;更易扩展&#xff0c;但是HTTP缓存通常只适用于idempotent request&#xff08;可以理解为查询请求&#xff0c;也就是不更新服务端数据的请求&#xff09;&#x…

Get与Post区别与范例讲解

林炳文Evankaka原创作品。转载请注明出处http://blog.csdn.net/evankaka 一、 J2EE WEB应用文件目录结构 Java Web应用由一组静态HTML页、Servlet、JSP和其他相关的class组成&#xff0c;它们一起构成一个大的工程项目。每种组件在Web应用中都有固定的存放目录。Web应用的配…

GET和POST区别详解

GET和POST是HTTP请求的两种基本方法&#xff0c;要说它们的区别&#xff0c;接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中&#xff0c;POST通过request body传递参数。 你可能自己写过无数个GET和POST请求&#xff0c;或者已经看过很多权威网站总结出…

get和post区别

GET和POST的安全性 1、GET是通过URL方式请求&#xff0c;可以直接看到&#xff0c;明文传输 2、POST是通过请求header请求&#xff0c;可以开发者工具或者抓包可以看到&#xff0c;同样也是明文的 3、GET请求会保存在浏览器历史纪录中&#xff0c;还可能会保存在Web的日志中 G…

post和get区别

GET和POST是HTTP请求的两种基本方法&#xff0c;要说它们的区别&#xff0c;接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中&#xff0c;POST通过request body传递参数。 你可能自己写过无数个GET和POST请求&#xff0c;或者已经看过很多权威网站总结出…

GET 和 POST 有什么区别?

GET 和 POST 是 HTTP 请求中最常用的两种请求方法&#xff0c;在日常开发的 RESTful 接口中&#xff0c;都能看到它们的身影。而它们之间的区别&#xff0c;也是一道常见且经典的面试题&#xff0c;所以我们本文就来详细的聊聊。 HTTP 协议定义的方法类型总共有以下 10 种&…

查找(一)——静态查找表

目录 一、查找的基本概念 二、顺序查找 &#xff08;线性查找&#xff09; 1、基本思想 2、核心代码 3、顺序查找设置哨兵 4、顺序查找的优点&#xff1a; 5、顺序查找的缺点&#xff1a; 6、折半查找 7、折半查找判定树 8、线性表查找的特点 三、索引顺序表&#x…

查找表结构

查找表介绍 在日常生活中&#xff0c;几乎每天都要进行一些查找的工作&#xff0c;在电话簿中查阅某个人的电话号码&#xff1b;在电脑的文件夹中查找某个具体的文件等等。本节主要介绍用于查找操作的数据结构——查找表。 查找表是由同一类型的数据元素构成的集合。例如电话号…