动手学强化学习 第 12 章 PPO 算法 训练代码
创始人
2024-11-14 22:04:57
0

基于 Hands-on-RL/第12章-PPO算法.ipynb at main · boyu-ai/Hands-on-RL · GitHub

理论 PPO 算法

修改了警告和报错

运行环境

Debian GNU/Linux 12 Python 3.9.19 torch 2.0.1 gym 0.26.2

运行代码

PPO.py

#!/usr/bin/env python   import gym import torch import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt import rl_utils   class PolicyNet(torch.nn.Module):     def __init__(self, state_dim, hidden_dim, action_dim):         super(PolicyNet, self).__init__()         self.fc1 = torch.nn.Linear(state_dim, hidden_dim)         self.fc2 = torch.nn.Linear(hidden_dim, action_dim)      def forward(self, x):         x = F.relu(self.fc1(x))         return F.softmax(self.fc2(x), dim=1)   class ValueNet(torch.nn.Module):     def __init__(self, state_dim, hidden_dim):         super(ValueNet, self).__init__()         self.fc1 = torch.nn.Linear(state_dim, hidden_dim)         self.fc2 = torch.nn.Linear(hidden_dim, 1)      def forward(self, x):         x = F.relu(self.fc1(x))         return self.fc2(x)   class PPO:     ''' PPO算法,采用截断方式 '''      def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,                  lmbda, epochs, eps, gamma, device):         self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)         self.critic = ValueNet(state_dim, hidden_dim).to(device)         self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),                                                 lr=actor_lr)         self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),                                                  lr=critic_lr)         self.gamma = gamma         self.lmbda = lmbda         self.epochs = epochs  # 一条序列的数据用来训练轮数         self.eps = eps  # PPO中截断范围的参数         self.device = device      def take_action(self, state):         state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)         probs = self.actor(state)         action_dist = torch.distributions.Categorical(probs)         action = action_dist.sample()         return action.item()      def update(self, transition_dict):         states = torch.tensor(np.array(transition_dict['states']),                               dtype=torch.float).to(self.device)         actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(             self.device)         rewards = torch.tensor(transition_dict['rewards'],                                dtype=torch.float).view(-1, 1).to(self.device)         next_states = torch.tensor(np.array(transition_dict['next_states']),                                    dtype=torch.float).to(self.device)         dones = torch.tensor(transition_dict['dones'],                              dtype=torch.float).view(-1, 1).to(self.device)         td_target = rewards + self.gamma * self.critic(next_states) * (1 -                                                                        dones)         td_delta = td_target - self.critic(states)         advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,                                                td_delta.cpu()).to(self.device)         old_log_probs = torch.log(self.actor(states).gather(1,                                                             actions)).detach()          for _ in range(self.epochs):             log_probs = torch.log(self.actor(states).gather(1, actions))             ratio = torch.exp(log_probs - old_log_probs)             surr1 = ratio * advantage             surr2 = torch.clamp(ratio, 1 - self.eps,                                 1 + self.eps) * advantage  # 截断             actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数             critic_loss = torch.mean(                 F.mse_loss(self.critic(states), td_target.detach()))             self.actor_optimizer.zero_grad()             self.critic_optimizer.zero_grad()             actor_loss.backward()             critic_loss.backward()             self.actor_optimizer.step()             self.critic_optimizer.step()   actor_lr = 1e-3 critic_lr = 1e-2 num_episodes = 500 hidden_dim = 128 gamma = 0.98 lmbda = 0.95 epochs = 10 eps = 0.2 device = torch.device("cuda") if torch.cuda.is_available() else torch.device(     "cpu")  env_name = 'CartPole-v1' env = gym.make(env_name) env.reset(seed=0) torch.manual_seed(0) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,             epochs, eps, gamma, device)  return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)  episodes_list = list(range(len(return_list))) plt.plot(episodes_list, return_list) plt.xlabel('Episodes') plt.ylabel('Returns') plt.title('PPO on {}'.format(env_name)) plt.show()  mv_return = rl_utils.moving_average(return_list, 9) plt.plot(episodes_list, mv_return) plt.xlabel('Episodes') plt.ylabel('Returns') plt.title('PPO on {}'.format(env_name)) plt.show() 

rl_utils.py 参考

动手学强化学习 第 11 章 TRPO 算法 训练代码-CSDN博客

相关内容

热门资讯

绝活儿辅助!广西老友玩老是输怎... 绝活儿辅助!广西老友玩老是输怎么办(辅助挂)都是真的有辅助app(讲解有挂)在进入广西老友玩老是输怎...
法门辅助!福建13水插件(辅助... 法门辅助!福建13水插件(辅助挂)一贯是有辅助技巧(有挂技术)1、许多玩家不知道福建13水插件辅助怎...
办法辅助!潮友会app下载官方... 办法辅助!潮友会app下载官方辅助器(辅助挂)真是真的是有辅助app(有挂教程)该软件可以轻松地帮助...
妙招辅助!邯郸胡乐挂辅助(辅助... 妙招辅助!邯郸胡乐挂辅助(辅助挂)好像存在有辅助插件(有挂方略)1、上手简单,内置详细流程视频教学,...
教程书辅助!乐酷辅助(辅助挂)... 教程书辅助!乐酷辅助(辅助挂)其实存在有辅助脚本(有挂细节)乐酷辅助能透视中分为三种模型:乐酷辅助模...
学习辅助!决战卡五星辅助(辅助... 学习辅助!决战卡五星辅助(辅助挂)本来真的是有辅助软件(有人有挂)学习辅助!决战卡五星辅助(辅助挂)...
绝活辅助!边锋嘉兴麻将辅助器(... 绝活辅助!边锋嘉兴麻将辅助器(辅助挂)真是真的有辅助神器(新版有挂)1、边锋嘉兴麻将辅助器公共底牌简...
举措辅助!枫叶辅助器(辅助挂)... 举措辅助!枫叶辅助器(辅助挂)本来存在有辅助技巧(竟然有挂)1、下载好枫叶辅助器正确养号方法之后点击...
讲义辅助!点我达辅助(辅助挂)... 讲义辅助!点我达辅助(辅助挂)一直存在有辅助技巧(有人有挂)1、点我达辅助辅助器安装包、点我达辅助辅...
模块辅助!威信茶馆有挂的吗(辅... 模块辅助!威信茶馆有挂的吗(辅助挂)一直真的是有辅助脚本(揭秘有挂)1、玩家可以在威信茶馆有挂的吗线...