欢迎访问宙启技术站
智能推送

使用gym.utils在Python中实现Actor-Critic算法

发布时间:2024-01-06 01:53:10

Actor-Critic算法是一种在强化学习中常用的算法,它结合了值函数和策略函数,通过一个Actor网络用于生成动作策略,一个Critic网络用于评估值函数的学习。

下面是一个使用gym.utils实现Actor-Critic算法的简单例子:

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from gym import wrappers

# 定义Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x, dim=1)

# 定义Critic网络
class Critic(nn.Module):
    def __init__(self, state_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# 定义Actor-Critic算法
class ActorCritic:
    def __init__(self, state_dim, action_dim):
        self.actor = Actor(state_dim, action_dim)
        self.critic = Critic(state_dim)
        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=0.01)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=0.01)
        self.gamma = 0.99

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
        probs = self.actor(state)
        dist = Categorical(probs)
        action = dist.sample()
        return action.item()

    def update(self, rewards, log_probs, values):
        R = 0
        returns = []
        critic_losses = []
        actor_losses = []
        for r in rewards[::-1]:
            R = r + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)

        for log_prob, value, R in zip(log_probs, values, returns):
            advantage = R - value.item()

            critic_loss = advantage.pow(2)
            critic_losses.append(critic_loss)

            actor_loss = -log_prob * advantage
            actor_losses.append(actor_loss)

        self.optimizer_critic.zero_grad()
        critic_loss = torch.stack(critic_losses).sum()
        critic_loss.backward()
        self.optimizer_critic.step()

        self.optimizer_actor.zero_grad()
        actor_loss = torch.stack(actor_losses).sum()
        actor_loss.backward()
        self.optimizer_actor.step()


# 创建CartPole环境
env = gym.make('CartPole-v1')
env = wrappers.Monitor(env, "./gym-results", force=True)

# 初始化Actor-Critic算法
actor_critic = ActorCritic(state_dim=4, action_dim=2)

# 训练
for i_episode in range(1000):
    state = env.reset()
    rewards = []
    log_probs = []
    values = []
    done = False
    while not done:
        action = actor_critic.select_action(state)
        next_state, reward, done, _ = env.step(action)
        rewards.append(reward)
        log_probs.append(F.log_softmax(torch.tensor([action]), dim=1))
        values.append(actor_critic.critic(torch.tensor(state, dtype=torch.float).unsqueeze(0)))

        state = next_state

    actor_critic.update(rewards, log_probs, values)

    if i_episode % 100 == 0:
        print('Episode {} finished'.format(i_episode))

env.close()

上述代码是用gym.utils实现Actor-Critic算法的一个简单例子。其中,我们首先定义了一个Actor网络和一个Critic网络,然后定义了ActorCritic类,该类包含了选择动作、更新网络参数的方法。接着我们创建了一个CartPole环境,并使用ActorCritic类对其进行训练。最后,我们将训练结果保存在"./gym-results"文件夹中。

这是一个稍微简化的例子,以便更好地理解Actor-Critic算法的实现。实际上,Actor-Critic算法还包含其他改进和优化的方法,例如使用多个Actor和Critic网络、使用优势函数等。实际应用中,可以根据需要进行适当的调整和改进。