PyTorch实现PPO代码

article/2025/9/21 9:01:56

原理:Proximal Policy Optimization近端策略优化(PPO)
视频:Proximal Policy Optimization (PPO) is Easy With PyTorch | Full PPO Tutorial
代码来自github:
Youtube-Code-Repository
EasyRL
网站:Neuralnet.ai

Package

import gym
import os 
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

Memory

  • sample():memory也就是一个batch分成多个mini batch
  • push():存储env.step后的trace信息,包括state,action,prob,val,reward,done
  • clear():更新完后清空memory,存放新的trace
class PPOmemory:def __init__(self, mini_batch_size):self.states = []  # 状态self.actions = []  # 实际采取的动作self.probs = []  # 动作概率self.vals = []  # critic输出的状态值self.rewards = []  # 奖励self.dones = []  # 结束标志self.mini_batch_size = mini_batch_size  # minibatch的大小def sample(self):n_states = len(self.states)  # memory记录数量=20batch_start = np.arange(0, n_states, self.mini_batch_size)  # 每个batch开始的位置[0,5,10,15]indices = np.arange(n_states, dtype=np.int64)  # 记录编号[0,1,2....19]np.random.shuffle(indices)  # 打乱编号顺序[3,1,9,11....18]mini_batches = [indices[i:i + self.mini_batch_size] for i in batch_start]  # 生成4个minibatch,每个minibatch记录乱序且不重复return np.array(self.states), np.array(self.actions), np.array(self.probs), \np.array(self.vals), np.array(self.rewards), np.array(self.dones), mini_batches# 每一步都存储trace到memorydef push(self, state, action, prob, val, reward, done):self.states.append(state)self.actions.append(action)self.probs.append(prob)self.vals.append(val)self.rewards.append(reward)self.dones.append(done)# 固定步长更新完网络后清空memorydef clear(self):self.states = []self.actions = []self.probs = []self.vals = []self.rewards = []self.dones = []

Actor

  • input:state
  • output:动作分布Categorical

actor网络即策略网络,输入state,输出action概率,使用Categorical生成动作分布

# actor:policy network
class Actor(nn.Module):def __init__(self, n_states, n_actions, hidden_dim):super(Actor, self).__init__()self.actor = nn.Sequential(nn.Linear(n_states, hidden_dim),nn.Tanh(),nn.Linear(hidden_dim, hidden_dim),nn.Tanh(),nn.Linear(hidden_dim, n_actions),nn.Softmax(dim=-1))def forward(self, state):dist = self.actor(state)dist = Categorical(dist)entropy = dist.entropy()return dist, entropy

Critic

  • input:state
  • output:状态值函数

critic网络即值网络,输入state,输出state-value

# critic:value network
class Critic(nn.Module):def __init__(self, n_states, hidden_dim):super(Critic, self).__init__()self.critic = nn.Sequential(nn.Linear(n_states, hidden_dim),nn.Tanh(),nn.Linear(hidden_dim, hidden_dim),nn.Tanh(),nn.Linear(hidden_dim, 1))def forward(self, state):value = self.critic(state)return value

Agent

  • choose_action():输入state,输出随机action,记录state的value以及action的对数prob
  • learn():更新actor和critic的网络参数

(1)计算GAE优势函数
(2)获取每个mini batch更新后的新策略
(3)执行clip操作得到actor loss
(4)更新估计状态值函数得到critic loss
(5)反向传播更新参数

class Agent:def __init__(self, n_states, n_actions, cfg):self.gamma = cfg.gammaself.policy_clip = cfg.policy_clipself.n_epochs = cfg.n_epochsself.gae_lambda = cfg.gae_lambdaself.device = cfg.deviceself.actor = Actor(n_states, n_actions, cfg.hidden_dim).to(self.device)self.critic = Critic(n_states, cfg.hidden_dim).to(self.device)self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=cfg.actor_lr)self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=cfg.critic_lr)self.memory = PPOmemory(cfg.mini_batch_size)self.loss = 0def choose_action(self, state):state = torch.tensor(state, dtype=torch.float).to(self.device)dist, entropy = self.actor(state)value = self.critic(state)action = dist.sample()prob = torch.squeeze(dist.log_prob(action)).item()action = torch.squeeze(action).item()value = torch.squeeze(value).item()return action, prob, valuedef learn(self):for _ in range(self.n_epochs):state_arr, action_arr, old_prob_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.sample()values = vals_arr[:]# 计算GAEadvantage = np.zeros(len(reward_arr), dtype=np.float32)for t in range(len(reward_arr) - 1):discount = 1a_t = 0for k in range(t, len(reward_arr) - 1):a_t += discount * (reward_arr[k] + self.gamma * values[k + 1] * (1 - int(dones_arr[k])) - values[k])discount *= self.gamma * self.gae_lambdaadvantage[t] = a_tadvantage = torch.tensor(advantage).to(self.device)# mini batch 更新values = torch.tensor(values).to(self.device)for batch in batches:states = torch.tensor(state_arr[batch], dtype=torch.float).to(self.device)old_probs = torch.tensor(old_prob_arr[batch]).to(self.device)actions = torch.tensor(action_arr[batch]).to(self.device)# 计算新的策略分布dist, entropy = self.actor(states)critic_value = torch.squeeze(self.critic(states))new_probs = dist.log_prob(actions)prob_ratio = new_probs.exp() / old_probs.exp()# actor_lossweighted_probs = advantage[batch] * prob_ratioweighted_clipped_probs = torch.clamp(prob_ratio, 1 - self.policy_clip,1 + self.policy_clip) * advantage[batch]actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()# critic_lossreturns = advantage[batch] + values[batch]critic_loss = (returns - critic_value) ** 2critic_loss = critic_loss.mean()# 更新entropy_loss = entropy.mean()total_loss = actor_loss + 0.5 * critic_loss - entropy_loss * 0.01self.loss = total_lossself.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()total_loss.backward()torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)self.actor_optimizer.step()self.critic_optimizer.step()self.memory.clear()return self.lossdef save(self, path):actor_checkpoint = os.path.join(path, 'ppo_actor.pt')critic_checkpoint = os.path.join(path, 'ppo_critic.pt')torch.save(self.actor.state_dict(), actor_checkpoint)torch.save(self.critic.state_dict(), critic_checkpoint)def load(self, path):actor_checkpoint = os.path.join(path, 'ppo_actor.pt')critic_checkpoint = os.path.join(path, 'ppo_critic.pt')self.actor.load_state_dict(torch.load(actor_checkpoint))self.critic.load_state_dict(torch.load(critic_checkpoint))

参数

def get_args():parser = argparse.ArgumentParser(description="hyper parameters")parser.add_argument('--algo_name', default='PPO', type=str, help="name of algorithm")parser.add_argument('--env_name', default='CartPole-v1', type=str, help="name of environment")parser.add_argument('--train_eps', default=200, type=int, help="episodes of training")parser.add_argument('--test_eps', default=20, type=int, help="episodes of testing")parser.add_argument('--gamma', default=0.99, type=float, help="discounted factor")parser.add_argument('--mini_batch_size', default=5, type=int, help='mini batch size')parser.add_argument('--n_epochs', default=4, type=int, help='update number')parser.add_argument('--actor_lr', default=0.0003, type=float, help="learning rate of actor net")parser.add_argument('--critic_lr', default=0.0003, type=float, help="learning rate of critic net")parser.add_argument('--gae_lambda', default=0.95, type=float, help='GAE lambda')parser.add_argument('--policy_clip', default=0.2, type=float, help='policy clip')parser.add_argument('-batch_size', default=20, type=int, help='batch size')parser.add_argument('--hidden_dim', default=256, type=int, help='hidden dim')parser.add_argument('--device', default='cpu', type=str, help="cpu or cuda")args = parser.parse_args()return args

训练

def train(cfg, env, agent):print('开始训练!')print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')rewards = []steps = 0for i_ep in range(cfg.train_eps):state = env.reset()done = Falseep_reward = 0while not done:action, prob, val = agent.choose_action(state)state_, reward, done, _ = env.step(action)steps += 1ep_reward += rewardagent.memory.push(state, action, prob, val, reward, done)if steps % cfg.batch_size == 0:agent.learn()state = state_rewards.append(ep_reward)if (i_ep + 1) % 10 == 0:print(f"回合:{i_ep + 1}/{cfg.train_eps},奖励:{ep_reward:.2f}")print('完成训练!')

环境

def env_agent_config(cfg, seed=1):env = gym.make(cfg.env_name)n_states = env.observation_space.shape[0]n_actions = env.action_space.nagent = Agent(n_states, n_actions, cfg)if seed != 0:torch.manual_seed(seed)env.seed(seed)np.random.seed(seed)return env, agent

运行

cfg = get_args()
env, agent = env_agent_config(cfg, seed=1)
train(cfg, env, agent)

结果

开始训练!
环境:CartPole-v1, 算法:PPO, 设备:cpu
回合:10/200,奖励:12.00
回合:20/200,奖励:52.00
回合:30/200,奖励:101.00
回合:40/200,奖励:141.00
回合:50/200,奖励:143.00
回合:60/200,奖励:118.00
回合:70/200,奖励:84.00
回合:80/200,奖励:500.00
回合:90/200,奖励:112.00
回合:100/200,奖励:149.00
回合:110/200,奖励:252.00
回合:120/200,奖励:500.00
回合:130/200,奖励:500.00
回合:140/200,奖励:500.00
回合:150/200,奖励:500.00
回合:160/200,奖励:500.00
回合:170/200,奖励:500.00
回合:180/200,奖励:500.00
回合:190/200,奖励:500.00
回合:200/200,奖励:500.00
完成训练!

在这里插入图片描述


http://chatgpt.dhexx.cn/article/rUXblue8.shtml

相关文章

优化PPO

优化PPO 介绍core implementation details1.Vectorized architecture 量化结构Orthogonal Initialization of Weights and Constant Initialization of biases 算法权重的初始化以及恒定偏差的初始化The Adam Optimizer’s Epsilon Parameter Adam优化器的ε参数Adam Learning …

PPO Algorithm

‘‘目录 PPO ALGORITHM 进行看别人文章: 如何直观理解PPO算法?[理论篇] - 知乎 (zhihu.com) 【强化学习8】PPO - 知乎 (zhihu.com) PPO(OpenAI) Proximal Policy Optimization(PPO)算法原理及实现! - 简书 (jianshu.com) 1-Critic的作用与效果.m…

PPO算法实战

原理简介 PPO是一种on-policy算法,具有较好的性能,其前身是TRPO算法,也是policy gradient算法的一种,它是现在 OpenAI 默认的强化学习算法,具体原理可参考PPO算法讲解。PPO算法主要有两个变种,一个是结合K…

Proximal Policy Optimization(近端策略优化)(PPO)原理详解

本节开始笔者针对自己的研究领域进行RL方面的介绍和笔记总结,欢迎同行学者一起学习和讨论。本文笔者来介绍RL中比较出名的算法PPO算法,读者需要预先了解Reinforcement-Learning中几个基础定义才可以阅读,否则不容易理解其中的内容。不过笔者尽…

【强化学习PPO算法】

强化学习PPO算法 一、PPO算法二、伪代码三、相关的简单理论1.ratio2.裁断3.Advantage的计算4.loss的计算 四、算法实现五、效果六、感悟 最近再改一个代码,需要改成PPO方式的,由于之前没有接触过此类算法,因此进行了简单学习,论文…

【深度强化学习】(6) PPO 模型解析,附Pytorch完整代码

大家好,今天和各位分享一下深度强化学习中的近端策略优化算法(proximal policy optimization,PPO),并借助 OpenAI 的 gym 环境完成一个小案例,完整代码可以从我的 GitHub 中获得: https://gith…

autoit连接mysql数据库

原链接点我 一,准备工作 1, 下载mysql.au3(这个点击就下载了) 把mysql.au3放入到autoit的include目录下 2, 下载mysql驱动(根据自己系统选,下载完之后,双击运行会自动安装,一路next就行) 二,使用 #include "mysql.au3" #include <Array.au3> ;弹窗 Func aler…

AutoIt-v3的安装,和robotframework-autoitlibrary的导入

AutoIt 最新是v3版本&#xff0c;这是一个使用类似BASIC脚本语言的免费软件,它设计用于Windows GUI&#xff08;图形用户界面)中进行自动化操作。它利用模拟键盘按键&#xff0c;鼠标移动和窗口/控件的组合来实现自动化任务。而这是其它语言不可能做到或无可靠方法实现的。 Au…

selenium 上传下载调用windows窗口--AutoIT

AutoIT解决自动化上传下载文件调用Windows窗口 AutoIT下载安装使用AotuIt 操作windows上传窗口1. 打开AutoIt定位窗口组件2. 定位上传窗口属性 &#xff08;鼠标选中Finder Tool 拖拽至属性窗口&#xff09;3. 打开autoIt编辑器&#xff0c;编写代码4. 将脚本文件转成exe文件5.…

软件质量保证与测试 实验十一:AutoIt的使用

目录 实验概述实验内容1. 下载安装AutoIT。2. 测试win系统自带计算器程序&#xff0c; 246&#xff0c;是否正确&#xff1f; 写出Script。&#xff08;小提示&#xff1a;使用WinGetText获得输出&#xff09;3.测试win系统自带计算器程序&#xff0c; 写出3个以上的测试用例的…

selenium 用autoIT上传下载文件

一、下载安装AutoIT 下载并安装AutoIT&#xff0c;下载链接&#xff1a;https://www.autoitscript.com/site/autoit/AutoIT安装成功后&#xff0c;可以在开始菜单下看到AutoIT的所有工具&#xff0c;如下图所示&#xff1a; 其中分为几类&#xff0c;AutoIT Window Info用来识…

selenium autoit java_selenium+java利用AutoIT实现文件上传

转载自&#xff1a;https://www.cnblogs.com/yunman/p/7112882.html?utm_sourceitdadao&utm_mediumreferral 1、AutoIT介绍 AutoIT是一个类似脚本语言的软件&#xff0c;利用此软件我们可以方便的实现模拟键盘、鼠标、窗口等操作&#xff0c;实现自动化。 2、实现原理 利用…

autoIT 自动化上传/下载文件图文详解【python selenium】

情景&#xff1a; 在用selenium进行web页面自动化时&#xff0c;时不时会遇到上传附件的情况&#xff0c;常见的情况就是一个上传按钮&#xff0c;点击后弹出windows窗口&#xff0c;选择文件后上传&#xff0c;如下图1所示 图1 这种情况超出了selenium的能力范围&#xff0c;需…

AutoIt介绍

AutoIt的下载网址&#xff1a; https://www.autoitscript.com/site/autoit/downloads/ AutoIt在线文档&#xff1a;http://www.autoit3.cn/Doc/ AutoIt的优势&#xff1a; 简单易懂的类BASIC 表达式模拟键盘,鼠标动作事件操作窗口与进程直接与窗口的”标准控件”交互(设置/获…

AutoIt的应用

少数情况下需要操作系统级的弹窗&#xff0c;可以使用AutoIt。 AutoIt现在最新版是V3版本&#xff0c;这是一个类似BASIC脚本语言的免费软件&#xff0c;用于Windows GUI中进行自动化操作。利用模拟键盘按键&#xff0c;鼠标移动&#xff0c;窗口和控件的组合来实现自动化任务…

java 调用autoit_java和autoit连接

autoit可以实现本机文件的上传&#xff0c;修改&#xff0c;新建&#xff0c;也可以实现网页上文件下载到本地 连接步骤&#xff1a; (1)下载autoitx4java 包&#xff0c;地址在code.google.com/p/autoitx4java。解压后直接将jar包添加到工程里面。然后需要使用jacob包&#xf…

AutoIt在线使用手册地址

AutoIt 在线文档https://autoitx.com/Doc/

AutoIt3.0

autoIt主要用于窗口自动化&#xff0c;结合python&#xff0c;可解决web自动化&#xff0c;页面调出窗口的问题 autoIt脚本代码例子&#xff1a; 1.打开Windows 任务管理器 2.依次点击【应用程序、进程、服务、性能、联网、用户】按钮 3.再次点击应用程序按钮 4.点选第二个…

Python + Selenium + AutoIt 模拟键盘实现另存为、上传、下载操作详解

前言 在web页面中&#xff0c;可以使用selenium的定位方式来识别元素&#xff0c;从而来实现页面中的自动化&#xff0c;但对于页面中弹出的文件选择框&#xff0c;selenium就实现不了了&#xff0c;所以就需引用AutoIt工具来实现。 AutoIt介绍 AutoIt简单介绍下&#xff0c…

autoit 下载图片验证码

autoit 下载图片验证码 自动化测试中&#xff0c;我做了验证码识别的功能&#xff0c;那么接下来就是怎么获取验证码图片了&#xff0c;还好autoit 里面提供了一些方法。下面就介绍一下怎样利用autoit 下载验证码图片&#xff1a; 先说思路&#xff1a; 右键点击验证码 使用…