自学内容网 自学内容网

【强化学习】基础在线算法:Sarsa算法

        📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:

【强化学习】- 【单智能体强化学习】(3)---《基础在线算法:Sarsa算法》

基础在线算法:Sarsa算法

目录

1.Sarsa算法简介

2.核心思想

3.Sarsa算法的更新公式

4.算法步骤

5.Sarsa与Q-Learning的区别

6.算法的理论推导

[Python] Sarsa算法实现

参数设置

定义Q表

选择动作

环境反馈

环境更新

SARSA算法

主函数入口

[Results] 运行结果

[Notice]  代码功能概述:

7.Sarsa算法的应用场景

8.总结


1.Sarsa算法简介

        Sarsa算法是一种强化学习(Reinforcement Learning, RL)的经典算法,属于时序差分(Temporal Difference, TD)方法。它是一种基于策略的学习算法,用于解决马尔可夫决策过程(Markov Decision Process, MDP)中的问题。

        简单来说,Sarsa的目标是通过不断地交互,学习如何从当前状态选择最优动作,从而获得最大的累积奖励。


2.核心思想

        Sarsa的核心是估计状态-动作值函数(Q函数),然后根据这个函数选择动作。该值函数 Q(s, a)表示在状态 s 下采取动作 a 所能获得的期望回报。

        Sarsa算法的名字来源于它的更新过程涉及的五元组:

State (s), Action (a), Reward (r),next State (s'), next Action (a')


3.Sarsa算法的更新公式

        Sarsa使用以下公式来更新Q值:

Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma Q(s', a') - Q(s, a) \right]

  • s:当前状态
  • a:当前动作
  • r:当前奖励
  • s':下一状态
  • a':下一动作
  • \alpha:学习率,控制更新的步幅
  • \gamma:折扣因子,衡量未来奖励的重要性

4.算法步骤

  1. 初始化

    • 初始化 Q(s, a) 值为任意值(通常为0)。
    • 初始化学习率 \alpha、折扣因子\gamma
  2. 重复以下过程,直到收敛

    1. 在当前状态s根据策略(如 \epsilon-贪婪策略)选择动作 a
    2. 执行动作 a,观察到奖励r 和下一个状态 s'
    3. 在状态s'中,根据策略选择下一动作a'
    4. 使用更新公式更新 Q(s, a)Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma Q(s', a') - Q(s, a) \right]
    5. 更新状态和动作:s \leftarrow s', a \leftarrow a'
  3. 策略改进

    随着Q值的更新,逐渐改善选择动作的策略。

5.Sarsa与Q-Learning的区别

特点SarsaQ-Learning
策略类型基于当前策略(on-policy)基于最优策略(off-policy)
更新公式中的动作使用实际选择的动作a'使用最优动作 \max_a Q(s', a)
行为特点更安全、探索性强更快逼近最优,但可能冒险

直观理解

  • Sarsa更新考虑实际采取的动作,强调“过程导向”;
  • Q-Learning直接使用最优值更新,强调“结果导向”。

关于on-policy和off-policy的区别,下面这篇文章进行了较为详细的描述:

【MARL】深入理解多智能体近端策略优化(MAPPO)算法与调参


6.算法的理论推导

  1. 贝尔曼方程: 马尔可夫决策过程的目标是找到使得累积奖励期望最大化的策略。状态-动作值函数  Q 满足贝尔曼方程: Q^\pi(s, a) = \mathbb{E}\pi \left[ r_t + \gamma Q^\pi(s{t+1}, a_{t+1}) \mid s_t = s, a_t = a \right]

  2. 时序差分目标: 使用样本替代期望,构建近似目标:Q(s, a) \approx r + \gamma Q(s', a')

  3. 迭代更新: 使用梯度下降方法不断逼近目标,从而得到更新公式。


[Python] Sarsa算法实现

项目代码我已经放入GitCode里面,可以通过下面链接跳转:🔥

【强化学习】---Sarsa算法

后续相关单智能体强化学习算法也会不断在【强化学习】项目里更新,如果该项目对你有所帮助,请帮我点一个星星✨✨✨✨✨,鼓励分享,十分感谢!!!

若是下面代码复现困难或者有问题,也欢迎评论区留言

"""《Sarsa算法实现》
    时间:2024.12
    作者:不去幼儿园
"""
import numpy as np  # 导入numpy库,用于数组和矩阵运算
import pandas as pd  # 导入pandas库,用于数据处理和创建数据表
import matplotlib.pyplot as plt  # 导入matplotlib库,用于绘图
import time  # 导入time库,用于控制程序暂停时间

参数设置

# 定义强化学习的一些超参数
ALPHA = 0.1  # 学习率,控制更新Q值的幅度
GAMMA = 0.95  # 折扣因子,控制未来奖励的重要性
EPSILION = 0.9  # epsilon值,用于ε-贪婪策略,控制探索与利用的权衡
N_STATE = 6  # 状态的数量,表示状态空间的大小
ACTIONS = ['left', 'right']  # 可能的动作列表,表示智能体的可选行为
MAX_EPISODES = 200  # 最大的训练轮次,表示最大实验次数
FRESH_TIME = 0.1  # 控制环境更新的时间间隔,用于显示训练过程

定义Q表

# 定义Q表的构建函数
def build_q_table(n_state, actions):
    # 创建一个Q表,行代表状态,列代表动作,初始时所有Q值为0
    q_table = pd.DataFrame(
        np.zeros((n_state, len(actions))),  # 初始化一个全0的表格,大小为状态数x动作数
        np.arange(n_state),  # 状态的索引
        actions  # 动作的名称
    )
    return q_table  # 返回初始化后的Q表

选择动作

# 定义选择动作的函数
def choose_action(state, q_table):
    # epslion - greedy策略
    state_action = q_table.loc[state, :]  # 获取当前状态下所有可能动作的Q值
    if np.random.uniform() > EPSILION or (state_action == 0).all():  # 如果随机数大于epsilon或所有动作Q值为0
        action_name = np.random.choice(ACTIONS)  # 选择一个随机动作(探索)
    else:
        action_name = state_action.idxmax()  # 否则选择Q值最大的动作(利用)
    return action_name  # 返回选择的动作

 环境反馈

# 定义环境反馈的函数
def get_env_feedback(state, action):
    # 根据当前状态和动作来返回下一个状态和奖励
    if action == 'right':  # 如果选择向右移动
        if state == N_STATE - 2:  # 如果当前状态是倒数第二个状态
            next_state = 'terminal'  # 到达终止状态
            reward = 1  # 奖励为1
        else:
            next_state = state + 1  # 否则状态加1
            reward = -0.5  # 奖励为-0.5
    else:  # 如果选择向左移动
        if state == 0:  # 如果当前状态是最左边的状态
            next_state = 0  # 保持在原地
        else:
            next_state = state - 1  # 否则状态减1
        reward = -0.5  # 奖励为-0.5
    return next_state, reward  # 返回下一个状态和奖励

环境更新

# 定义环境更新的函数
def update_env(state, episode, step_counter):
    # 生成一个表示环境的字符串,'-'表示空地,'T'表示终止状态
    env = ['-'] * (N_STATE - 1) + ['T']
    if state == 'terminal':  # 如果到达终止状态
        print("Episode {}, the total step is {}".format(episode + 1, step_counter))  # 打印当前回合和步骤
        final_env = ['-'] * (N_STATE - 1) + ['T']  # 环境没有变化
        return True, step_counter  # 终止回合,返回True
    else:
        env[state] = '*'  # 将当前状态位置标记为'*'
        env = ''.join(env)  # 将环境列表转化为字符串
        print(env)  # 打印当前环境的状态
        time.sleep(FRESH_TIME)  # 暂停程序FRESH_TIME秒,模拟环境变化的延迟
        return False, step_counter  # 没有到达终止状态,返回False

SARSA算法

# 定义SARSA学习算法的函数
def sarsa_learning():
    q_table = build_q_table(N_STATE, ACTIONS)  # 创建一个Q表
    step_counter_times = []  # 用于记录每个回合的步骤数
    for episode in range(MAX_EPISODES):  # 进行最大回合数的学习
        state = 0  # 初始状态设为0
        is_terminal = False  # 初始状态不是终止状态
        step_counter = 0  # 初始步骤计数为0
        update_env(state, episode, step_counter)  # 更新环境并显示
        while not is_terminal:  # 当未到达终止状态时继续学习
            action = choose_action(state, q_table)  # 根据当前状态选择动作
            next_state, reward = get_env_feedback(state, action)  # 获取环境反馈(下一个状态和奖励)
            if next_state != 'terminal':  # 如果不是终止状态
                next_action = choose_action(next_state, q_table)  # 选择下一个状态的动作(SARSA更新方法)
            else:
                next_action = action  # 如果是终止状态,动作不再改变
            next_q = q_table.loc[state, action]  # 获取当前Q值

            if next_state == 'terminal':  # 如果到达终止状态
                is_terminal = True  # 设置为终止状态
                q_target = reward  # 目标Q值为奖励
            else:
                delta = reward + GAMMA * q_table.loc[next_state, next_action] - q_table.loc[state, action]  # SARSA更新公式
                q_table.loc[state, action] += ALPHA * delta  # 更新Q表中的值
            state = next_state  # 更新当前状态为下一个状态
            is_terminal, steps = update_env(state, episode, step_counter + 1)  # 更新环境并检查是否终止
            step_counter += 1  # 增加步骤计数
            if is_terminal:  # 如果到达终止状态,记录步骤数
                step_counter_times.append(steps)

主函数入口

# 主函数入口
if __name__ == '__main__':
    q_table, step_counter_times = sarsa_learning()  # 执行SARSA学习
    print("Q table\n{}\n".format(q_table))  # 打印最终的Q表
    print('end')  # 输出训练结束
    plt.plot(step_counter_times, 'g-')  # 绘制每回合的步骤数变化曲线
    plt.ylabel("steps")  # 设置y轴标签
    plt.title("Sarsa Algorithm")  # 设置图标题
    plt.show()  # 显示图形
    print("The step_counter_times is {}".format(step_counter_times))  # 打印每回合的步骤数

[Results] 运行结果


[Notice]  代码功能概述:

  1. Q-表构建:初始化一个包含所有状态和动作的Q表,每个元素初始化为0。
  2. ε-贪婪策略:用来在探索(随机选择动作)和利用(选择当前Q值最大的动作)之间做权衡。
  3. 环境反馈:根据智能体的动作和当前状态,反馈新的状态和奖励。
  4. SARSA学习算法:通过不断更新Q表来提高智能体的策略,每次从当前状态选择动作,执行动作并观察奖励,根据SARSA更新公式调整Q表。
  5. 绘图:最终绘制每回合所需的步数,帮助观察学习过程的效率。
# 环境配置
Python                  3.11.5
torch                   2.1.0
torchvision             0.16.0
gym                     0.26.2

7.Sarsa算法的应用场景

  1. 动态路径规划:如机器人导航。
  2. 游戏智能体:训练游戏中的角色实现特定任务。
  3. 推荐系统:根据用户行为调整推荐策略。

8.总结

Sarsa算法是强化学习领域的基石之一,其优点在于:

  • 简单易实现
  • 能适应动态环境
  • 对探索行为有天然支持

但在实际应用中,Sarsa的收敛速度较慢,需要良好的超参数调整。


        博客都是给自己看的笔记,如有误导深表抱歉。文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:Rainbook_2,联系作者。✨


原文地址:https://blog.csdn.net/qq_51399582/article/details/144221982

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!