自学内容网 自学内容网

iDP3复现代码模型训练全流程(二)——train_policy.py

在 train_policy.sh 接收命令行参数后,通过此脚本加载配置文件、初始化工作空间并运行训练流程

详细解释如下:

目录

1 模块导入

2 设置标准输出的缓冲模式

3 设置环境变量

4 配置自定义求值器

5 @hydra.main 主函数装饰器

6 主函数 main(cfg)

7 脚本入口


1 模块导入

import os
# 用于操作系统交互,比如设置环境变量和路径管理

from diffusion_policy_3d.workspace.base_workspace import BaseWorkspace
# 引入 BaseWorkspace 类

import pathlib
# 提供跨平台的路径管理,用于处理文件和目录路径

from omegaconf import OmegaConf
# 引入配置管理库,用于加载和解析 YAML 配置文件,支持动态变量解析

import hydra
# 引入 Hydra 配置框架

from termcolor import cprint
# 用于打印带颜色的终端输出

import sys
# 处理标准输入输出流,这里用于调整日志输出的缓冲方式

os: 用于操作系统级别的环境变量和文件路径设置

BaseWorkspace: 模型基类

pathlib: 提供对路径操作的跨平台支持

OmegaConf: 配置库,用来加载、解析和操作 yaml 配置文件

hydra: 用于动态管理配置文件,支持运行时参数调整

termcolor: 用来在终端中打印带颜色的文本

sys: 用于调整标准输入输出的缓冲模式

2 设置标准输出的缓冲模式

# 为标准输出和标准错误设置行缓冲模式,这样每一行日志都会实时输出
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)

将标准输出(stdout)和标准错误(stderr)设为行缓冲模式:每行日志或错误会立即输出到终端,便于实时调试和日志记录

3 设置环境变量

# 设置环境变量 'WANDB_SILENT',禁用 Weights & Biases 的日志打印
os.environ['WANDB_SILENT'] = "True"

禁用了 Weights & Biases 的自动日志输出

4 配置自定义求值器

# (可选)启用故障处理器,用于捕捉和调试段错误等低级别错误
# import faulthandler
# faulthandler.enable()
# cprint("[fault handler enabled]", "cyan")

# 注册一个新的配置解析器 "eval",允许在配置文件中直接运行 Python 表达式。
OmegaConf.register_new_resolver("eval", eval, replace=True)

注册新的解析器 eval,允许在配置文件中动态运行 Python 表达式

5 @hydra.main 主函数装饰器

# 使用 Hydra 的主函数装饰器,定义配置文件的路径
@hydra.main(
    config_path=str(pathlib.Path(__file__).parent.joinpath(
        'diffusion_policy_3d', 'config'))  # 配置文件路径,位于 'diffusion_policy_3d/config' 目录中。
)

指定配置文件的路径,配置文件存放在 diffusion_policy_3d/config 文件夹下

配合 --config-name 参数,加载特定的配置文件

6 主函数 main(cfg)

def main(cfg: OmegaConf):
    # **解析配置**: 解析并填充所有动态变量,确保配置文件中的占位符被替换为具体值
    OmegaConf.resolve(cfg)

    # **动态加载类**: 从配置文件中获取 `_target_` 字段,动态加载并初始化工作空间类
    cls = hydra.utils.get_class(cfg._target_)
    workspace: BaseWorkspace = cls(cfg)  # 将工作空间实例化

    # **运行工作流**: 调用工作空间的 `run` 方法,执行训练或其他核心逻辑
    workspace.run()

(1)解析配置文件

    # **解析配置**: 解析并填充所有动态变量,确保配置文件中的占位符被替换为具体值
    OmegaConf.resolve(cfg)

解析配置文件 cfg,确保所有动态变量和表达式(如 ${now} 或 `${eval:"..."})都在运行时被替换为具体值

(2)获取类并实例化工作空间

    # **动态加载类**: 从配置文件中获取 `_target_` 字段,动态加载并初始化工作空间类
    cls = hydra.utils.get_class(cfg._target_)
    workspace: BaseWorkspace = cls(cfg)  # 将工作空间实例化

通过配置文件的 _target_ 字段,动态加载并实例化 BaseWorkspace 或其子类

此处的 _target_ 是类的路径,hydra.utils.get_class 动态导入类

(3)运行工作空间逻辑

    # **运行工作流**: 调用工作空间的 `run` 方法,执行训练或其他核心逻辑
    workspace.run()

调用 BaseWorkspace 的 run 方法,执行工作空间逻辑。这个方法通常包括以下步骤:

加载数据 -> 初始化模型和训练参数 -> 开始训练或评估模型 -> 保存训练结果

7 脚本入口

# 脚本入口:当脚本被独立运行时(而不是作为模块导入),执行 main 函数
if __name__ == "__main__":
    main()

原文地址:https://blog.csdn.net/qq_28912651/article/details/144729245

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!