PPO算法(附pytorch代码)

article/2025/9/21 8:35:05

这里写目录标题

  • 一、PPO算法
      • (1)简介
      • (2)On-policy?
      • (3)GAE (Generalized Advantage Estimation)
  • 三、代码
    • 代码解析:

一、PPO算法

(1)简介

  • PPO算法是一种强化学习中的策略梯度方法,它的全称是Proximal Policy Optimization,即近端策略优化1。PPO算法的目标是在与环境交互采样数据后,使用随机梯度上升优化一个“替代”目标函数,从而改进策略。PPO算法的特点是可以进行多次的小批量更新,而不是像标准的策略梯度方法那样每个数据样本只进行一次梯度更新12。
  • PPO算法有两种主要的变体:PPO-Penalty和PPO-Clip。PPO-Penalty类似于TRPO算法,它使用KL散度作为一个约束条件,但是将KL散度作为目标函数的一个惩罚项,而不是一个硬性约束,并且自动调整惩罚系数,使其适应数据的规模12。PPO-Clip则没有KL散度项,也没有约束条件,而是使用一种特殊的裁剪技术,在目标函数中消除了新策略远离旧策略的动机。
  • PPO同样使用了AC框架,不过相比DPG更加接近传统的PG算法,采用的是随机分布式的策略函数(Stochastic Policy),智能体(agent)每次决策时都要从策略函数输出的分布中采样,得到的样本作为最终执行的动作,因此天生具备探索环境的能力,不需要为了探索环境给决策加上扰动;PPO的重心会放到actor上,仅仅将critic当做一个预测状态好坏(在该状态获得的期望收益)的工具,策略的调整基准在于获取的收益,不是critic的导数。

(2)On-policy?

  • PPO 算法是一个 on-policy 的算法。123

  • PPO算法是一种基于策略的强化学习算法,它可以处理连续动作空间的问题。PPO算法是一种在线算法,也就是说它需要用当前的策略产生数据,并用这些数据更新策略。PPO算法不能直接使用经验回放,因为经验回放中的数据可能是由不同的策略产生的,这会导致策略梯度的偏差12。

    但是,PPO算法可以使用一种技术叫做重要性采样(importance sampling),来利用之前的数据进行多步更新23。重要性采样的思想是给每个数据加上一个权重,表示目标策略和行为策略的比例24。这样,PPO算法可以在一定程度上提高数据的利用效率,而不影响策略梯度的正确性23。

  • PPO 算法的原理是在每一步更新策略时,尽量减小代价函数,同时保证新策略和旧策略的差异不要太大。为了做到这一点,PPO 算法使用了一个特殊的目标函数,它包含了一个截断的比率因子,用来限制新策略和旧策略的比例。1
    在这里插入图片描述
    在这里插入图片描述
    参考链接点击

  • 重要性采样是一种调整数据权重的方法,它可以用于在线算法,也可以用于离线算法。PPO算法使用了重要性采样,但它并没有使用经验回放中的随机抽样和淘汰机制,而是按照时间顺序使用数据,并在一定次数后丢弃数据 。
    在这里插入图片描述

  • PPO 算法只使用当前策略产生的经验来更新网络,而不使用历史策略产生的经验。这意味着 PPO 算法需要在每次更新策略后丢弃之前收集的经验,因为它们不再适用于新的策略。这样做的好处是 PPO 算法可以保证策略和值函数之间的一致性,也就是说,值函数可以很好地估计当前策略的性能。2

  • 虽然 PPO 算法可以使用经验回放缓冲区来存储和重用历史经验,但这并不改变它是一个 on-policy 的算法。因为 PPO 算法在使用缓冲区中的数据时,仍然需要计算新策略和旧策略的比例,并且使用截断的比率因子来限制更新幅度。这样做相当于对 off-policy 的数据进行了校正,使得它们更接近 on-policy 的数据。4

  • 如果 PPO 算法使用经验回放,那么它需要对 off-policy 的数据进行一些校正,以减小偏差和方差。一种常用的校正方法是使用重要性采样权重 (ISW),它可以衡量数据的分布和当前策略的分布之间的差异,并对损失函数或者优势函数进行加权。12

  • 另一种校正方法是使用优先级经验回放 (PER),它可以根据数据的价值或者优势来给数据分配优先级,并按照优先级来采样数据。这样可以使得更有价值或者更有优势的数据被更频繁地采样,从而提高学习效率。2

(3)GAE (Generalized Advantage Estimation)

  • 它是一种用于估计优势函数的方法。12
  • 优势函数是指在某个状态下,采取某个动作比按照当前策略采取动作所能获得的期望回报的差值。优势函数可以用来减少策略梯度估计的方差,提高学习效率。
  • GAE 的思想是利用值函数来对优势函数进行多步估计,同时使用一个衰减因子来平滑不同步长的估计,从而得到一个既有较小偏差又有较小方差的优势函数估计。具体来说,GAE 使用以下公式来计算优势函数估计:
    在这里插入图片描述
    具体实现:
# 定义折扣因子和平滑因子
gamma = 0.99
lambda = 0.95# 初始化优势函数估计和回报估计为空列表
advantages = []
returns = []# 初始化时间差分误差为0
delta = 0# 从最后一个时间步开始反向遍历经验
for state, action, reward, next_state, done in reversed(experiences):# 如果是终止状态,那么下一个状态的值为0,否则用值函数网络预测if done:next_value = 0else:next_value = value_network.predict(next_state)# 计算当前状态的值value = value_network.predict(state)# 计算时间差分误差delta = reward + gamma * next_value - value# 计算优势函数估计,使用时间差分误差和上一个时间步的优势函数估计advantage = delta + gamma * lambda * advantage# 计算回报估计,使用奖励和下一个状态的值return = reward + gamma * next_value# 将优势函数估计和回报估计插入到列表的开头advantages.insert(0, advantage)returns.insert(0, return)# 将优势函数估计和回报估计转换为张量
advantages = torch.tensor(advantages)
returns = torch.tensor(returns)

注意一下几点:

  1. 为什么要从最后一个时间步开始反向遍历
    答:为了利用之前计算的优势函数估计和回报估计来加速计算。如果从第一个时间步开始正向遍历,那么每个时间步都需要计算一个无限级数的和,这样会很慢。而如果从最后一个时间步开始反向遍历,那么每个时间步只需要用一个时间差分误差和上一个时间步的优势函数估计来计算当前的优势函数估计,这样会很快。同理,回报估计也可以用奖励和下一个状态的值来递推计算,而不需要用一个无限级数的和。
  2. 代码中returns有什么作用?
    答:returns 是用来存储每个时间步的回报估计的,回报估计是指从当前状态开始,按照当前策略采取动作所能获得的未来折扣奖励之和的期望。回报估计可以用来更新值函数网络,使其更接近真实的状态值。回报估计也可以用来计算优势函数,如果没有值函数网络的话。

三、代码

'''
@Author  :Yan JP
@Created on Date:2023/4/19 17:31 
'''
# https://blog.csdn.net/dgvv4/article/details/129496576?spm=1001.2101.3001.6650.3&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-129496576-blog-117329002.235%5Ev29%5Epc_relevant_default_base3&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-129496576-blog-117329002.235%5Ev29%5Epc_relevant_default_base3&utm_relevant_index=6
# 代码用于离散环境的模型
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F# ----------------------------------- #
# 构建策略网络--actor
# ----------------------------------- #class PolicyNet(nn.Module):def __init__(self, n_states, n_hiddens, n_actions):super(PolicyNet, self).__init__()self.fc1 = nn.Linear(n_states, n_hiddens)self.fc2 = nn.Linear(n_hiddens, n_actions)def forward(self, x):x = self.fc1(x)  # [b,n_states]-->[b,n_hiddens]x = F.relu(x)x = self.fc2(x)  # [b, n_actions]x = F.softmax(x, dim=1)  # [b, n_actions]  计算每个动作的概率return x# ----------------------------------- #
# 构建价值网络--critic
# ----------------------------------- #class ValueNet(nn.Module):def __init__(self, n_states, n_hiddens):super(ValueNet, self).__init__()self.fc1 = nn.Linear(n_states, n_hiddens)self.fc2 = nn.Linear(n_hiddens, 1)def forward(self, x):x = self.fc1(x)  # [b,n_states]-->[b,n_hiddens]x = F.relu(x)x = self.fc2(x)  # [b,n_hiddens]-->[b,1]  评价当前的状态价值state_valuereturn x# ----------------------------------- #
# 构建模型
# ----------------------------------- #class PPO:def __init__(self, n_states, n_hiddens, n_actions,actor_lr, critic_lr, lmbda, epochs, eps, gamma, device):# 实例化策略网络self.actor = PolicyNet(n_states, n_hiddens, n_actions).to(device)# 实例化价值网络self.critic = ValueNet(n_states, n_hiddens).to(device)# 策略网络的优化器self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)# 价值网络的优化器self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gamma  # 折扣因子self.lmbda = lmbda  # GAE优势函数的缩放系数self.epochs = epochs  # 一条序列的数据用来训练轮数self.eps = eps  # PPO中截断范围的参数self.device = device# 动作选择def take_action(self, state):# 维度变换 [n_state]-->tensor[1,n_states]state = torch.tensor(state[np.newaxis, :]).to(self.device)# 当前状态下,每个动作的概率分布 [1,n_states]probs = self.actor(state)# 创建以probs为标准的概率分布action_list = torch.distributions.Categorical(probs)# 依据其概率随机挑选一个动作action = action_list.sample().item()return action# 训练def learn(self, transition_dict):# 提取数据集states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).to(self.device).view(-1, 1)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).to(self.device).view(-1, 1)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).to(self.device).view(-1, 1)# 目标,下一个状态的state_value  [b,1]next_q_target = self.critic(next_states)# 目标,当前状态的state_value  [b,1]td_target = rewards + self.gamma * next_q_target * (1 - dones)# 预测,当前状态的state_value  [b,1]td_value = self.critic(states)# 目标值和预测值state_value之差  [b,1]td_delta = td_target - td_value# 时序差分值 tensor-->numpy  [b,1]td_delta = td_delta.cpu().detach().numpy()advantage = 0  # 优势函数初始化advantage_list = []# 计算优势函数for delta in td_delta[::-1]:  # 逆序时序差分值 axis=1轴上倒着取 [], [], []# 优势函数GAE的公式 :计算优势函数估计,使用时间差分误差和上一个时间步的优势函数估计advantage = self.gamma * self.lmbda * advantage + deltaadvantage_list.append(advantage)# 正序advantage_list.reverse()# numpy --> tensor [b,1]advantage = torch.tensor(advantage_list, dtype=torch.float).to(self.device)# 策略网络给出每个动作的概率,根据action得到当前时刻下该动作的概率old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()# 一组数据训练 epochs 轮for _ in range(self.epochs):# 每一轮更新一次策略网络预测的状态log_probs = torch.log(self.actor(states).gather(1, actions))# 新旧策略之间的比例ratio = torch.exp(log_probs - old_log_probs)# 近端策略优化裁剪目标函数公式的左侧项surr1 = ratio * advantage# 公式的右侧项,ratio小于1-eps就输出1-eps,大于1+eps就输出1+epssurr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage# 策略网络的损失函数actor_loss = torch.mean(-torch.min(surr1, surr2))# 价值网络的损失函数,当前时刻的state_value - 下一时刻的state_valuecritic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))# 梯度清0self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()# 反向传播actor_loss.backward()critic_loss.backward()# 梯度更新self.actor_optimizer.step()self.critic_optimizer.step()import matplotlib.pyplot as plt
import gym
import torch
if __name__ == '__main__':device = torch.device('cuda') if torch.cuda.is_available() \else torch.device('cpu')# ----------------------------------------- ## 参数设置# ----------------------------------------- #num_episodes = 300  # 总迭代次数gamma = 0.9  # 折扣因子actor_lr = 1e-3  # 策略网络的学习率critic_lr = 1e-2  # 价值网络的学习率n_hiddens = 16  # 隐含层神经元个数env_name = 'CartPole-v0'return_list = []  # 保存每个回合的return# ----------------------------------------- ## 环境加载# ----------------------------------------- #env = gym.make(env_name)n_states = env.observation_space.shape[0]  # 状态数 4n_actions = env.action_space.n  # 动作数 2# ----------------------------------------- ## 模型构建# ----------------------------------------- #agent = PPO(n_states=n_states,  # 状态数n_hiddens=n_hiddens,  # 隐含层数n_actions=n_actions,  # 动作数actor_lr=actor_lr,  # 策略网络学习率critic_lr=critic_lr,  # 价值网络学习率lmbda=0.95,  # 优势函数的缩放因子epochs=10,  # 一组序列训练的轮次eps=0.2,  # PPO中截断范围的参数gamma=gamma,  # 折扣因子device=device)# ----------------------------------------- ## 训练--回合更新 on_policy# ----------------------------------------- #for i in range(num_episodes):state = env.reset()  # 环境重置done = False  # 任务完成的标记episode_return = 0  # 累计每回合的reward# 构造数据集,保存每个回合的状态数据transition_dict = {'states': [],'actions': [],'next_states': [],'rewards': [],'dones': [],}while not done:action = agent.take_action(state)  # 动作选择next_state, reward, done, _ = env.step(action)  # 环境更新# 保存每个时刻的状态\动作\...transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)# 更新状态state = next_state# 累计回合奖励episode_return += reward# 保存每个回合的returnreturn_list.append(episode_return)# 模型训练agent.learn(transition_dict)# 打印回合信息print(f'iter:{i}, return:{np.mean(return_list[-10:])}')# -------------------------------------- ## 绘图# -------------------------------------- #plt.plot(return_list)plt.title('return')plt.show()

代码解析:

  1. log_probs = torch.log(self.actor(states).gather(1, actions))
  • self.actor(states) 是一个策略网络,它接受一批状态作为输入,输出每个状态下每个动作的概率分布,假设有 N 个状态,M 个动作,那么输出的形状是 (N, M)。
  • .gather(1, actions) 是一个张量操作,它根据 actions 中的索引从第一个维度上选取元素,actions 是一个形状为 (N, 1) 的张量,表示每个状态下实际采取的动作的索引,那么输出的形状也是 (N, 1),表示每个状态下实际采取的动作的概率。
  • torch.log() 是一个张量操作,它对输入的每个元素取对数,输出的形状和输入相同,表示每个状态下实际采取的动作的对数概率。
    因此,这行代码最终得到一个形状为 (N, 1) 的张量,表示每个时间步的动作的对数概率。
  1. critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
    注意这里需要detach一下!!!!!!!!!
  2. PPO有价值网络critic,可以用目标网络吗?
  • PPO 有一个价值网络,用来估计状态值,并用于计算优势函数或者回报估计。理论上,PPO 也可以使用目标网络来生成目标状态值,从而稳定价值网络的训练。但是,PPO 是一个 on-policy 的算法,它只使用当前策略产生的经验来更新网络,而不使用历史策略产生的经验。这意味着 PPO 需要在每次更新策略后丢弃之前收集的经验,因为它们不再适用于新的策略。这样做的好处是 PPO 可以保证策略和值函数之间的一致性,也就是说,值函数可以很好地估计当前策略的性能。

  • 如果 PPO 使用目标网络,那么目标网络的参数会滞后于价值网络的参数,这可能导致目标状态值和价值网络输出之间的不一致性,也就是说,目标状态值可能不能很好地估计当前策略的性能。这样做的代价是 PPO 可能降低学习效率,因为它不能及时反映环境和策略的变化。

  • 因此,PPO 通常不使用目标网络,而是直接使用价值网络来生成目标状态值。这样做的好处是 PPO 可以提高学习效率,因为它可以及时反映环境和策略的变化。


http://chatgpt.dhexx.cn/article/8sMx1ca5.shtml

相关文章

论文笔记之PPO

15年OpenAI发表了TRPO算法,一直策略单调提升的算法;17年DeepMind基于TRPO发表了一篇Distributed-PPO,紧接着OpenAI发表了这篇PPO。可以说TRPO是PPO的前身,PPO在TRPO的基础上进行改进,使得算法可读性更高,实…

PPO实战学习总结

PPO used in go-bigger 前段时间一直在学习ppo算法,写了 一点总结,记录一下自己对ppo算法的一些理解与RL实战时候容易遇到的一些问题。代码地址如下,需要的可以自取: https://github.com/FLBa9762/PPO_used_in_Gobigger.git一般…

PPO算法

在线学习和离线学习 在线学习:和环境互动的Agent以及和要学习的Agent是同一个, 同一个Agent,一边和环境做互动,一边在学习。离线学习: 和环境互动及的Agent以和要学习的Agent不是同一个,学习的Agent通过看别人完来学习。 利用新的…

PPO2代码 pytorch框架

PPO2代码玩gym库的Pendulum环境 2022-8-02更新 我发现这篇文章浏览量惨淡啊。 咋滴,是不相信的我代码能用是吗? 所以,我给出reward的收敛曲线图: 开玩笑,出来混,我能卖你生瓜码子吗? ———…

PPO实战

哈哈初学,复现龙龙老师的实例! state:是平衡小车上的杆子,观测状态由 4 个连续的参数组成:推车位置 [-2.4,2.4],车速 [-∞,∞],杆子角度 [~-41.8&#xff0c…

PyTorch实现PPO代码

原理:Proximal Policy Optimization近端策略优化(PPO) 视频:Proximal Policy Optimization (PPO) is Easy With PyTorch | Full PPO Tutorial 代码来自github: Youtube-Code-Repository EasyRL 网站:Neural…

优化PPO

优化PPO 介绍core implementation details1.Vectorized architecture 量化结构Orthogonal Initialization of Weights and Constant Initialization of biases 算法权重的初始化以及恒定偏差的初始化The Adam Optimizer’s Epsilon Parameter Adam优化器的ε参数Adam Learning …

PPO Algorithm

‘‘目录 PPO ALGORITHM 进行看别人文章: 如何直观理解PPO算法?[理论篇] - 知乎 (zhihu.com) 【强化学习8】PPO - 知乎 (zhihu.com) PPO(OpenAI) Proximal Policy Optimization(PPO)算法原理及实现! - 简书 (jianshu.com) 1-Critic的作用与效果.m…

PPO算法实战

原理简介 PPO是一种on-policy算法,具有较好的性能,其前身是TRPO算法,也是policy gradient算法的一种,它是现在 OpenAI 默认的强化学习算法,具体原理可参考PPO算法讲解。PPO算法主要有两个变种,一个是结合K…

Proximal Policy Optimization(近端策略优化)(PPO)原理详解

本节开始笔者针对自己的研究领域进行RL方面的介绍和笔记总结,欢迎同行学者一起学习和讨论。本文笔者来介绍RL中比较出名的算法PPO算法,读者需要预先了解Reinforcement-Learning中几个基础定义才可以阅读,否则不容易理解其中的内容。不过笔者尽…

【强化学习PPO算法】

强化学习PPO算法 一、PPO算法二、伪代码三、相关的简单理论1.ratio2.裁断3.Advantage的计算4.loss的计算 四、算法实现五、效果六、感悟 最近再改一个代码,需要改成PPO方式的,由于之前没有接触过此类算法,因此进行了简单学习,论文…

【深度强化学习】(6) PPO 模型解析,附Pytorch完整代码

大家好,今天和各位分享一下深度强化学习中的近端策略优化算法(proximal policy optimization,PPO),并借助 OpenAI 的 gym 环境完成一个小案例,完整代码可以从我的 GitHub 中获得: https://gith…

autoit连接mysql数据库

原链接点我 一,准备工作 1, 下载mysql.au3(这个点击就下载了) 把mysql.au3放入到autoit的include目录下 2, 下载mysql驱动(根据自己系统选,下载完之后,双击运行会自动安装,一路next就行) 二,使用 #include "mysql.au3" #include <Array.au3> ;弹窗 Func aler…

AutoIt-v3的安装,和robotframework-autoitlibrary的导入

AutoIt 最新是v3版本&#xff0c;这是一个使用类似BASIC脚本语言的免费软件,它设计用于Windows GUI&#xff08;图形用户界面)中进行自动化操作。它利用模拟键盘按键&#xff0c;鼠标移动和窗口/控件的组合来实现自动化任务。而这是其它语言不可能做到或无可靠方法实现的。 Au…

selenium 上传下载调用windows窗口--AutoIT

AutoIT解决自动化上传下载文件调用Windows窗口 AutoIT下载安装使用AotuIt 操作windows上传窗口1. 打开AutoIt定位窗口组件2. 定位上传窗口属性 &#xff08;鼠标选中Finder Tool 拖拽至属性窗口&#xff09;3. 打开autoIt编辑器&#xff0c;编写代码4. 将脚本文件转成exe文件5.…

软件质量保证与测试 实验十一:AutoIt的使用

目录 实验概述实验内容1. 下载安装AutoIT。2. 测试win系统自带计算器程序&#xff0c; 246&#xff0c;是否正确&#xff1f; 写出Script。&#xff08;小提示&#xff1a;使用WinGetText获得输出&#xff09;3.测试win系统自带计算器程序&#xff0c; 写出3个以上的测试用例的…

selenium 用autoIT上传下载文件

一、下载安装AutoIT 下载并安装AutoIT&#xff0c;下载链接&#xff1a;https://www.autoitscript.com/site/autoit/AutoIT安装成功后&#xff0c;可以在开始菜单下看到AutoIT的所有工具&#xff0c;如下图所示&#xff1a; 其中分为几类&#xff0c;AutoIT Window Info用来识…

selenium autoit java_selenium+java利用AutoIT实现文件上传

转载自&#xff1a;https://www.cnblogs.com/yunman/p/7112882.html?utm_sourceitdadao&utm_mediumreferral 1、AutoIT介绍 AutoIT是一个类似脚本语言的软件&#xff0c;利用此软件我们可以方便的实现模拟键盘、鼠标、窗口等操作&#xff0c;实现自动化。 2、实现原理 利用…

autoIT 自动化上传/下载文件图文详解【python selenium】

情景&#xff1a; 在用selenium进行web页面自动化时&#xff0c;时不时会遇到上传附件的情况&#xff0c;常见的情况就是一个上传按钮&#xff0c;点击后弹出windows窗口&#xff0c;选择文件后上传&#xff0c;如下图1所示 图1 这种情况超出了selenium的能力范围&#xff0c;需…

AutoIt介绍

AutoIt的下载网址&#xff1a; https://www.autoitscript.com/site/autoit/downloads/ AutoIt在线文档&#xff1a;http://www.autoit3.cn/Doc/ AutoIt的优势&#xff1a; 简单易懂的类BASIC 表达式模拟键盘,鼠标动作事件操作窗口与进程直接与窗口的”标准控件”交互(设置/获…