iDP3复现代码模型训练全流程(二)——train_policy.py
在 train_policy.sh 接收命令行参数后,通过此脚本加载配置文件、初始化工作空间并运行训练流程
详细解释如下:
目录
1 模块导入
import os
# 用于操作系统交互,比如设置环境变量和路径管理
from diffusion_policy_3d.workspace.base_workspace import BaseWorkspace
# 引入 BaseWorkspace 类
import pathlib
# 提供跨平台的路径管理,用于处理文件和目录路径
from omegaconf import OmegaConf
# 引入配置管理库,用于加载和解析 YAML 配置文件,支持动态变量解析
import hydra
# 引入 Hydra 配置框架
from termcolor import cprint
# 用于打印带颜色的终端输出
import sys
# 处理标准输入输出流,这里用于调整日志输出的缓冲方式
os: 用于操作系统级别的环境变量和文件路径设置
BaseWorkspace: 模型基类
pathlib: 提供对路径操作的跨平台支持
OmegaConf: 配置库,用来加载、解析和操作 yaml 配置文件
hydra: 用于动态管理配置文件,支持运行时参数调整
termcolor: 用来在终端中打印带颜色的文本
sys: 用于调整标准输入输出的缓冲模式
2 设置标准输出的缓冲模式
# 为标准输出和标准错误设置行缓冲模式,这样每一行日志都会实时输出
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
将标准输出(stdout)和标准错误(stderr)设为行缓冲模式:每行日志或错误会立即输出到终端,便于实时调试和日志记录
3 设置环境变量
# 设置环境变量 'WANDB_SILENT',禁用 Weights & Biases 的日志打印
os.environ['WANDB_SILENT'] = "True"
禁用了 Weights & Biases 的自动日志输出
4 配置自定义求值器
# (可选)启用故障处理器,用于捕捉和调试段错误等低级别错误
# import faulthandler
# faulthandler.enable()
# cprint("[fault handler enabled]", "cyan")
# 注册一个新的配置解析器 "eval",允许在配置文件中直接运行 Python 表达式。
OmegaConf.register_new_resolver("eval", eval, replace=True)
注册新的解析器 eval,允许在配置文件中动态运行 Python 表达式
5 @hydra.main 主函数装饰器
# 使用 Hydra 的主函数装饰器,定义配置文件的路径
@hydra.main(
config_path=str(pathlib.Path(__file__).parent.joinpath(
'diffusion_policy_3d', 'config')) # 配置文件路径,位于 'diffusion_policy_3d/config' 目录中。
)
指定配置文件的路径,配置文件存放在 diffusion_policy_3d/config 文件夹下
配合 --config-name 参数,加载特定的配置文件
6 主函数 main(cfg)
def main(cfg: OmegaConf):
# **解析配置**: 解析并填充所有动态变量,确保配置文件中的占位符被替换为具体值
OmegaConf.resolve(cfg)
# **动态加载类**: 从配置文件中获取 `_target_` 字段,动态加载并初始化工作空间类
cls = hydra.utils.get_class(cfg._target_)
workspace: BaseWorkspace = cls(cfg) # 将工作空间实例化
# **运行工作流**: 调用工作空间的 `run` 方法,执行训练或其他核心逻辑
workspace.run()
(1)解析配置文件
# **解析配置**: 解析并填充所有动态变量,确保配置文件中的占位符被替换为具体值
OmegaConf.resolve(cfg)
解析配置文件 cfg,确保所有动态变量和表达式(如 ${now} 或 `${eval:"..."})都在运行时被替换为具体值
(2)获取类并实例化工作空间
# **动态加载类**: 从配置文件中获取 `_target_` 字段,动态加载并初始化工作空间类
cls = hydra.utils.get_class(cfg._target_)
workspace: BaseWorkspace = cls(cfg) # 将工作空间实例化
通过配置文件的 _target_ 字段,动态加载并实例化 BaseWorkspace 或其子类
此处的 _target_ 是类的路径,hydra.utils.get_class 动态导入类
(3)运行工作空间逻辑
# **运行工作流**: 调用工作空间的 `run` 方法,执行训练或其他核心逻辑
workspace.run()
调用 BaseWorkspace 的 run 方法,执行工作空间逻辑。这个方法通常包括以下步骤:
加载数据 -> 初始化模型和训练参数 -> 开始训练或评估模型 -> 保存训练结果
7 脚本入口
# 脚本入口:当脚本被独立运行时(而不是作为模块导入),执行 main 函数
if __name__ == "__main__":
main()
原文地址:https://blog.csdn.net/qq_28912651/article/details/144729245
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!