torch.optim优化算法理解之optim.Adam()

article/2025/9/16 20:31:56

torch.optim是一个实现了多种优化算法的包,大多数通用的方法都已支持,提供了丰富的接口调用,未来更多精炼的优化算法也将整合进来。
为了使用torch.optim,需先构造一个优化器对象Optimizer,用来保存当前的状态,并能够根据计算得到的梯度来更新参数。
要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。 然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

Optimizer还支持指定每个参数选项。 只需传递一个可迭代的dict来替换先前可迭代的Variable。dict中的每一项都可以定义为一个单独的参数组,参数组用一个params键来包含属于它的参数列表。其他键应该与优化器接受的关键字参数相匹配,才能用作此组的优化选项。

optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)

如上,model.base.parameters()将使用1e-2的学习率,model.classifier.parameters()将使用1e-3的学习率。0.9的momentum作用于所有的parameters。
优化步骤:
所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新,它有两种调用方法:

optimizer.step()

这是大多数优化器都支持的简化版本,使用如下的backward()方法来计算梯度的时候会调用它。

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()
optimizer.step(closure)

一些优化算法,如共轭梯度和LBFGS需要重新评估目标函数多次,所以你必须传递一个closure以重新计算模型。 closure必须清除梯度,计算并返回损失。

for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

Adam算法:

adam算法来源:Adam: A Method for Stochastic Optimization

Adam(Adaptive Moment Estimation)本质上是带有动量项的RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。它的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。其公式如下:

这里写图片描述

其中,前两个公式分别是对梯度的一阶矩估计和二阶矩估计,可以看作是对期望E|gt|,E|gt^2|的估计;
公式3,4是对一阶二阶矩估计的校正,这样可以近似为对期望的无偏估计。可以看出,直接对梯度的矩估计对内存没有额外的要求,而且可以根据梯度进行动态调整。最后一项前面部分是对学习率n形成的一个动态约束,而且有明确的范围

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

参数:

params(iterable):可用于迭代优化的参数或者定义参数组的dictslr (float, optional) :学习率(默认: 1e-3)
betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))
eps (float, optional):为了提高数值稳定性而添加到分母的一个项(默认: 1e-8)
weight_decay (float, optional):权重衰减(如L2惩罚)(默认: 0)
step(closure=None)函数:执行单一的优化步骤
closure (callable, optional):用于重新评估模型并返回损失的一个闭包 

torch.optim.adam源码:

import math
from .optimizer import Optimizerclass Adam(Optimizer):def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0):defaults = dict(lr=lr, betas=betas, eps=eps,weight_decay=weight_decay)super(Adam, self).__init__(params, defaults)def step(self, closure=None):loss = Noneif closure is not None:loss = closure()for group in self.param_groups:for p in group['params']:if p.grad is None:continuegrad = p.grad.datastate = self.state[p]# State initializationif len(state) == 0:state['step'] = 0# Exponential moving average of gradient valuesstate['exp_avg'] = grad.new().resize_as_(grad).zero_()# Exponential moving average of squared gradient valuesstate['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']beta1, beta2 = group['betas']state['step'] += 1if group['weight_decay'] != 0:grad = grad.add(group['weight_decay'], p.data)# Decay the first and second moment running average coefficientexp_avg.mul_(beta1).add_(1 - beta1, grad)exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)denom = exp_avg_sq.sqrt().add_(group['eps'])bias_correction1 = 1 - beta1 ** state['step']bias_correction2 = 1 - beta2 ** state['step']step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1p.data.addcdiv_(-step_size, exp_avg, denom)return loss

Adam的特点有:
1、结合了Adagrad善于处理稀疏梯度和RMSprop善于处理非平稳目标的优点;
2、对内存需求较小;
3、为不同的参数计算不同的自适应学习率;
4、也适用于大多非凸优化-适用于大数据集和高维空间。


http://chatgpt.dhexx.cn/article/kwK9i93s.shtml

相关文章

Python-torch.optim优化算法理解之optim.Adam()

目录 简介 分析 使用 Adam算法 参数 论文理解 torch.optim.adam源码理解 Adam的特点 转载torch.optim优化算法理解之optim.Adam() 官方手册:torch.optim — PyTorch 1.11.0 documentation 其他参考 pytorch中优化器与学习率衰减方法总结 Adam和学习率衰减…

R 语言 optim 使用

stats中的optim函数是解决优化问题的一个简易的方法。 Univariate Optimization f function(x,a) (x-a)^2 xmin optimize(f,interval c(0,1),a1/3) xminGeneral Optimization optim函数包含了几种不同的算法。 算法的选择依赖于求解导数的难易程度,通常最好提…

PyTorch基础(六)-- optim模块

PyTorch的optim是用于参数优化的库(可以说是花式梯度下降),optim文件夹主要包括1个核心的父类(optimizer)、1个辅助类(lr_scheduler)以及10个常用优化算法的实现类。optim中内置的常用算法包括a…

pytorch基础(四):使用optim优化函数

文章目录 前言一、问题描述二、官方文档代码三、optimizer的工作原理总结 前言 本系列主要是对pytorch基础知识学习的一个记录,尽量保持博客的更新进度和自己的学习进度。本人也处于学习阶段,博客中涉及到的知识可能存在某些问题,希望大家批评…

HTTP协议之GET与POST区别

GET和POST是HTTP请求的两种基本方式,对于这两种请求方式的区别,只要是接触过Web开发的就能说出一二:GET把参数包含在URL中,POST通过正文传参! 而我想深入了解以下的时候,就去了w3cschool,这是w…

Get和Post区别是什么

附上原文地址:https://www.cnblogs.com/logsharing/p/8448446.html GET和POST是HTTP请求的两种基本方法,要说它们的区别,接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中,POST通过request body传递参数。…

get与post区别(很全~)

get与post区别 GET: (1)从指定的资源请求数据 (2)请求数据有长度限制(不同每个浏览器限制长度可能不一样) (3)请求url会在浏览器地址栏中显示 (4&#x…

Web中get和post区别

99% 的人都理解错了 HTTP 中 GET 与 POST 的区别(转) 转自:WebTechGarden 微信公众号GET 和 POST 是 HTTP 请求的两种基本方法,要说它们的区别,接触过 WEB 开发的人都能说出一二。最直观的区别就是 GET 把参数包含在 U…

get,post区别

Http定义了与服务器交互的不同方法,最基本的方法有4种,分别是GET,POST,PUT,DELETE。URL全称是资源描述符,我们可以这样认为:一个URL地址,它用于描述一个网络上的资源,而H…

GET 与 POST 区别

GET和POST是HTTP请求的两种基本方法,要说它们的区别,接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中,POST通过request body传递参数。 你可能自己写过无数个GET和POST请求,或者已经看过很多权威网站总结出…

GET和POST区别及缓存问题

2.就是get和post区别的缓存问题。 首先要了解什么是缓存。 HTTP缓存的基本目的就是使应用执行的更快,更易扩展,但是HTTP缓存通常只适用于idempotent request(可以理解为查询请求,也就是不更新服务端数据的请求)&#x…

Get与Post区别与范例讲解

林炳文Evankaka原创作品。转载请注明出处http://blog.csdn.net/evankaka 一、 J2EE WEB应用文件目录结构 Java Web应用由一组静态HTML页、Servlet、JSP和其他相关的class组成,它们一起构成一个大的工程项目。每种组件在Web应用中都有固定的存放目录。Web应用的配…

GET和POST区别详解

GET和POST是HTTP请求的两种基本方法,要说它们的区别,接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中,POST通过request body传递参数。 你可能自己写过无数个GET和POST请求,或者已经看过很多权威网站总结出…

get和post区别

GET和POST的安全性 1、GET是通过URL方式请求,可以直接看到,明文传输 2、POST是通过请求header请求,可以开发者工具或者抓包可以看到,同样也是明文的 3、GET请求会保存在浏览器历史纪录中,还可能会保存在Web的日志中 G…

post和get区别

GET和POST是HTTP请求的两种基本方法,要说它们的区别,接触过WEB开发的人都能说出一二。 最直观的区别就是GET把参数包含在URL中,POST通过request body传递参数。 你可能自己写过无数个GET和POST请求,或者已经看过很多权威网站总结出…

GET 和 POST 有什么区别?

GET 和 POST 是 HTTP 请求中最常用的两种请求方法,在日常开发的 RESTful 接口中,都能看到它们的身影。而它们之间的区别,也是一道常见且经典的面试题,所以我们本文就来详细的聊聊。 HTTP 协议定义的方法类型总共有以下 10 种&…

查找(一)——静态查找表

目录 一、查找的基本概念 二、顺序查找 (线性查找) 1、基本思想 2、核心代码 3、顺序查找设置哨兵 4、顺序查找的优点: 5、顺序查找的缺点: 6、折半查找 7、折半查找判定树 8、线性表查找的特点 三、索引顺序表&#x…

查找表结构

查找表介绍 在日常生活中,几乎每天都要进行一些查找的工作,在电话簿中查阅某个人的电话号码;在电脑的文件夹中查找某个具体的文件等等。本节主要介绍用于查找操作的数据结构——查找表。 查找表是由同一类型的数据元素构成的集合。例如电话号…

数据结构 第八章 查找(静态查找表)

集合 1、集合中的数据元素除了属于同一集合外,没有任何的逻辑关系 2、在集合中,每个数据元素都有一个区别于其他元素的唯一标识(键值或者关键字值) 3、集合的运算: 1 查找某一元素是否存在(内部查找、外部查找) 2 将集合中的元素按照它的唯一标识进行排序4、集合的…

9.1 查找表:静态查找表

9.1 查找表:静态查找表. 9.2 查找表:动态查找表. 9.3 查找表:哈希表. 9.1 查找表:静态查找表 1 基本概念2 抽象数据类型3 顺序查找表3.1 顺序存储结构模块中的实现3.2 分析顺序查找的时间性能. 4 有序查找表4.1 代码实现4.2 折半查…