自学内容网 自学内容网

【强化学习】异步优势Actor-Critic, A3C算法(对比AC、A2C)

        📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:

       【强化学习】- 【单智能体强化学习】(8)---《异步优势Actor-Critic, A3C算法》

异步优势Actor-Critic, A3C算法

目录

A3C算法介绍

简单类比:团队合作

A3C 的核心要点

A3C的公式推导

A3C 的核心公式

[Python] A3C的代码实现

1. 导入必要的库 

 2.  创建共享的全局网络

3. 创建本地网络和训练逻辑

4. 多线程运行

[Notice]  代码解析

优势

总结

AC、A2C和A3C三种算法对比


A3C算法介绍

      Asynchronous Advantage Actor-Critic,  A3C(异步优势Actor-Critic)算法可以用通俗的方式解释为一种“团队协作”的强化学习方法,它的核心思想是通过多个线程(“团队成员”)同时工作,快速学习一个任务的最佳策略。

简单类比:团队合作

想象一下:

  • 你有一个团队,每个人(线程)都在同一个任务的不同部分上工作,比如不同的房间探索宝藏。
  • 每个成员会记录自己的经验(例如,哪些房间有宝藏,哪些是陷阱),并定期与队长(全局神经网络)分享。
  • 队长会根据所有成员提供的信息更新任务的全局计划。
  • 每个人(线程)根据这个全局计划调整自己的探索方式,继续在自己的区域工作。

        A3C 的“异步”特性意味着这些线程可以同时运行,但不需要等待彼此完成。这种方式避免了资源竞争,效率非常高。


A3C 的核心要点

  1. 两种网络:策略网络(Actor)和价值网络(Critic)

    • 策略网络(Actor)决定“做什么动作”。
    • 价值网络(Critic)评估当前的动作“好不好”。
  2. 异步更新:团队成员的智慧汇总

    • 每个线程在独立的环境中运行,收集自己的经验。
    • 线程会定期把它的经验提交给“队长”(共享的神经网络),用来更新全局策略。
  3. 优势函数:优化学习

    • A3C 引入了“优势函数”,用来衡量当前动作的好坏,帮助更高效地调整策略。

A3C的公式推导

        A3C 是一种结合 策略梯度方法价值函数估计 的强化学习算法,其核心通过多线程并行和异步更新机制高效学习。以下是 A3C 的公式:

A3C 的核心公式

策略梯度更新: 策略更新基于策略梯度法:

\nabla_{\theta} J(\theta) = \mathbb{E}{s, a \sim \pi\theta} \left[ \nabla_{\theta} \log \pi_\theta(a|s) A(s, a) \right]

\pi_\theta(a|s): 策略函数,表示在状态 s 下选择动作 a 的概率。

A(s, a): 优势函数,用来衡量动作 (a) 相对于当前策略的好坏。

\nabla_{\theta} \log \pi_\theta(a|s): 策略的梯度,用于指导策略的改进。

优势函数: 优势函数 (A(s, a)) 通常用 时间差分(TD)误差 近似:

A(s, a) = R_t + \gamma V(s_{t+1}) - V(s_t)

R_t: 当前奖励。

V(s): 状态值函数。

\gamma: 折扣因子,衡量未来奖励的影响。

价值函数更新: 使用平方误差来更新价值函数:

L_V = \left( R_t + \gamma V(s_{t+1}) - V(s_t) \right)^2

总损失函数: A3C 的总损失函数是策略损失和价值损失的加权和,同时加上熵正则化项(鼓励策略的探索):

L = L_\text{policy} + \beta L_V - \alpha H(\pi)

H(\pi): 策略的熵,增加探索性。


[Python] A3C的代码实现

        A3C 的优势在于通过异步多线程高效利用资源,适合实时性和复杂环境的强化学习任务。以下是 Python 实现 A3C 的主要部分:

 项目代码我已经放入GitCode里面,可以通过下面链接跳转:🔥

【强化学习】--- A3C算法代码

后续相关单智能体强化学习算法也会不断在【强化学习】项目里更新,如果该项目对你有所帮助,请帮我点一个星星✨✨✨✨✨,鼓励分享,十分感谢!!!

若是下面代码复现困难或者有问题,也欢迎评论区留言

1. 导入必要的库 

"""《A3C算法项目》
    时间:2024.12
    作者:不去幼儿园
"""
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import gym
import threading
import time

 2.  创建共享的全局网络

class GlobalNetwork(models.Model):
    def __init__(self, state_shape, action_size):
        super(GlobalNetwork, self).__init__()
        self.shared_dense = layers.Dense(128, activation='relu')
        self.policy_logits = layers.Dense(action_size)
        self.value = layers.Dense(1)

    def call(self, inputs):
        x = self.shared_dense(inputs)
        return self.policy_logits(x), self.value(x)

3. 创建本地网络和训练逻辑

class Worker:
    def __init__(self, env_name, global_network, optimizer, worker_id, gamma=0.99):
        self.env = gym.make(env_name)
        self.global_network = global_network
        self.optimizer = optimizer
        self.worker_id = worker_id
        self.gamma = gamma

        # 本地网络
        self.local_network = GlobalNetwork(self.env.observation_space.shape, self.env.action_space.n)
        self.local_network.set_weights(self.global_network.get_weights())

    def compute_loss(self, states, actions, rewards, next_states, dones):
        logits, values = self.local_network(np.array(states))
        _, next_values = self.local_network(np.array(next_states))

        # 计算 TD 目标
        target = np.array(rewards) + self.gamma * np.array(next_values) * (1 - np.array(dones))
        advantage = target - np.array(values)

        # 策略损失
        actions_one_hot = tf.one_hot(actions, self.env.action_space.n)
        policy_loss = -tf.reduce_mean(tf.math.log(tf.reduce_sum(actions_one_hot * tf.nn.softmax(logits), axis=1)) * advantage)

        # 值损失
        value_loss = tf.reduce_mean(tf.square(advantage))

        # 熵正则化
        entropy_loss = tf.reduce_mean(-tf.reduce_sum(tf.nn.softmax(logits) * tf.math.log(tf.nn.softmax(logits)), axis=1))

        return policy_loss + 0.5 * value_loss - 0.01 * entropy_loss

    def train(self):
        while True:
            # 初始化环境
            states, actions, rewards, next_states, dones = [], [], [], [], []
            state = self.env.reset()
            done = False

            # 生成一批样本
            while not done:
                logits, _ = self.local_network(np.expand_dims(state, axis=0))
                action = np.random.choice(self.env.action_space.n, p=tf.nn.softmax(logits)[0].numpy())
                next_state, reward, done, _ = self.env.step(action)

                # 记录经验
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                next_states.append(next_state)
                dones.append(done)

                state = next_state

            # 计算损失并更新全局网络
            with tf.GradientTape() as tape:
                total_loss = self.compute_loss(states, actions, rewards, next_states, dones)

            grads = tape.gradient(total_loss, self.local_network.trainable_variables)
            self.optimizer.apply_gradients(zip(grads, self.global_network.trainable_variables))
            self.local_network.set_weights(self.global_network.get_weights())

4. 多线程运行

def run_workers(env_name, num_workers=4):
    global_network = GlobalNetwork((4,), 2)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    workers = [Worker(env_name, global_network, optimizer, i) for i in range(num_workers)]
    threads = []

    for worker in workers:
        thread = threading.Thread(target=worker.train)
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join()

# 启动 A3C
run_workers('CartPole-v1')


[Notice]  代码解析

全局网络:

        存储策略和值函数的权重,所有线程共享。

本地网络:

        每个线程有独立的副本,用于与环境交互并更新。

多线程训练:

        每个线程独立运行,并异步更新全局网络。

损失函数:

        包括策略损失、值函数损失和熵正则化。

​# 环境配置
Python                  3.11.5
torch                   2.1.0
torchvision             0.16.0
gym                     0.26.2

        由于博文主要为了介绍相关算法的原理和应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳或者无法运行,一是算法不适配上述环境,二是算法未调参和优化,三是没有呈现完整的代码,四是等等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。


优势

计算效率高:

        利用多线程并行处理,大大加速了学习过程。

探索多样性:

        每个线程在不同的环境中工作,探索不同的可能性,提高学习的全面性。

收敛速度快:

        异步更新减少了多个线程之间的竞争,使得训练更加稳定。


总结

        A3C 是一种高效、灵活的强化学习算法,它通过“异步团队合作”的方式加速了学习,同时也确保了探索的多样性。在许多任务中,尤其是需要实时决策的应用中,A3C 是一种非常强大的工具。


AC、A2C和A3C三种算法对比

许多朋友可能会好奇,

        以下是AC(Actor-Critic)、A2C(Advantage Actor-Critic)和A3C(Asynchronous Advantage Actor-Critic)三种强化学习算法的对比表格:

属性/特性AC (Actor-Critic)A2C (Advantage Actor-Critic)A3C (Asynchronous Advantage Actor-Critic)
算法结构基于Actor和Critic两部分的组合在AC基础上引入优势函数 (Advantage Function)基于A2C,增加多线程异步更新
同步性单线程、同步更新单线程、同步更新多线程、异步更新
计算效率较低,因同步性导致资源利用率不高提高了效率,但仍然单线程高效,多线程同时采样、训练
实现复杂度简单较为简单复杂,涉及线程管理
样本利用样本利用效率较低利用优势函数改善样本效率样本利用效率高,因多线程采样
收敛性收敛较慢,容易受高方差影响相较AC更稳定收敛更快,因多线程降低了方差
优势函数的引入有,利用优势函数减少方差有,同样利用优势函数
硬件需求低,适合资源有限的场景中等,仍然适合单机训练高,需要多线程与高计算资源
性能稳定性较低,容易受策略更新噪声影响较高,优势函数降低噪声高,异步更新进一步减少了噪声影响
典型应用场景小型环境或资源受限场景中等规模问题大规模问题、需要高效并行计算的任务
训练效率中等快,因并行线程加速采样与训练

关键差异解释:

  1. 同步与异步:

    • AC和A2C均为同步算法,意味着训练和策略更新必须等待所有数据收集完毕。
    • A3C通过多线程异步执行,无需等待所有线程完成即可进行局部更新,从而提升效率。
  2. 优势函数 (Advantage Function):

    • A2C和A3C通过计算优势函数(奖励减去状态值)减少方差,使策略梯度估计更加准确。
    • AC不直接使用优势函数,因此训练波动较大。
  3. 线程机制:

    • A3C通过多个线程采样环境,更新模型,能更快地探索和收敛。
    • AC和A2C因单线程限制,计算资源无法充分利用。

参考文献:

A2C(Advantage Actor-Critic):

        A2C本质上是AC的同步改进版本,正式作为一种实现没有单独的论文,但它基于以下基础:

        Paper: Schulman, J., Levine, S., Moritz, P., Jordan, M., Abbeel, P. "High-Dimensional Continuous Control Using Generalized Advantage Estimation."

        论文中引入了“优势函数估计”的概念,为A2C的实现奠定了基础。

        链接: GAE Paper (arXiv)

A3C(Asynchronous Advantage Actor-Critic):

        Paper: Mnih, V., Badia, A. P., Mirza, M., Graves, A., Lillicrap, T., Harley, T., Silver, D., Kavukcuoglu, K. "Asynchronous Methods for Deep Reinforcement Learning."

        这是A3C的核心论文,提出了异步多线程的训练方法。

        链接: A3C Paper (arXiv)

 更多强化学习文章,请前往:【强化学习(RL)】专栏


        博客都是给自己看的笔记,如有误导深表抱歉。文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:Rainbook_2,联系作者。✨


原文地址:https://blog.csdn.net/qq_51399582/article/details/144649161

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!