YOLOv10-1.1部分代码阅读笔记-dist.py
dist.py
ultralytics\utils\dist.py
目录
2.def find_free_network_port() -> int:
3.def generate_ddp_file(trainer):
4.def generate_ddp_command(world_size, trainer):
5.def ddp_cleanup(trainer, file):
1.所需的库和模块
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import shutil
import socket
import sys
import tempfile
from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9
2.def find_free_network_port() -> int:
# 这段代码定义了一个名为 find_free_network_port 的函数,用于查找一个可用的网络端口。它通过绑定一个临时的 TCP 套接字来实现这一功能。
# 定义了一个函数 find_free_network_port ,返回值为一个整数(表示可用的端口号)。
def find_free_network_port() -> int:
# 在本地主机上查找一个空闲端口。
# 当我们不想连接到真正的主节点但必须设置 `MASTER_PORT` 环境变量时,它在单节点训练中很有用。
"""
Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
"""
# 使用 socket.socket 创建一个 TCP 套接字( AF_INET 表示 IPv4 地址族, SOCK_STREAM 表示 TCP 协议)。 with 语句确保套接字在使用后自动关闭。
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# 将套接字绑定到本地地址 127.0.0.1 (即本地回环地址)和端口号 0 。 端口号 0 是一个特殊值,表示让操作系统自动分配一个可用的端口。
s.bind(("127.0.0.1", 0))
# 调用 getsockname() 方法获取套接字的地址信息,返回一个元组 (host, port) 。 [1] 表示提取元组中的第二个元素,即端口号。 返回的端口号是一个整数,表示一个可用的端口。
return s.getsockname()[1] # port
# find_free_network_port 函数的作用是。动态分配端口:通过绑定一个临时套接字,让操作系统分配一个可用的端口。返回端口号:返回分配的端口号,供后续使用。这种方法简单高效,适用于需要动态分配端口的场景。
3.def generate_ddp_file(trainer):
# 这段代码定义了一个名为 generate_ddp_file 的函数,用于生成一个临时 Python 脚本文件,该文件用于启动分布式数据并行(DDP)训练。这个脚本文件包含了训练器的配置和训练逻辑,并且会在训练结束后被删除。
# 定义了 generate_ddp_file 函数,接受一个参数。
# 1.trainer :这是一个训练器对象,包含训练相关的配置和逻辑。
def generate_ddp_file(trainer):
# 生成 DDP 文件并返回其文件名。
"""Generates a DDP file and returns its file name."""
# 提取 trainer 类的 模块路径 和 类名 。
# trainer.__class__.__module__ :获取 trainer 类的模块路径(如 ultralytics.trainers.yolov8 )。
# trainer.__class__.__name__ :获取 trainer 类的名称(如 YOLOv8Trainer )。
# 使用 rsplit(".", 1) 将模块路径和类名分开,分别存储到 module 和 name 中。
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
# 定义了临时脚本文件的内容。
content = f"""
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
overrides = {vars(trainer.args)}
if __name__ == "__main__":
from {module} import {name}
from ultralytics.utils import DEFAULT_CFG_DICT
cfg = DEFAULT_CFG_DICT.copy()
cfg.update(save_dir='') # handle the extra key 'save_dir'
trainer = {name}(cfg=cfg, overrides=overrides)
results = trainer.train()
"""
# vars(trainer.args) 将 trainer.args 的属性转换为字典,并嵌入到脚本中,用于覆盖默认配置。
# overrides = {vars(trainer.args)}
# 脚本中的主逻辑入口,确保脚本被直接运行时执行以下代码。
# if __name__ == "__main__":
# 动态导入 trainer 类所在的模块和类。
# from {module} import {name}
# 导入默认配置字典 DEFAULT_CFG_DICT ,用于初始化训练器。
# from ultralytics.utils import DEFAULT_CFG_DICT
# 复制默认配置字典,并更新 save_dir 为一个空字符串,以避免保存目录冲突。
# cfg = DEFAULT_CFG_DICT.copy()
# cfg.update(save_dir='') # handle the extra key 'save_dir'
# 使用动态导入的类和配置创建训练器实例。
# trainer = {name}(cfg=cfg, overrides=overrides)
# 调用训练器的 train 方法,开始训练过程。
# results = trainer.train()
# 创建一个目录用于存放临时脚本文件。 USER_CONFIG_DIR 是一个全局变量,表示用户的配置目录。
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
# tempfile.NamedTemporaryFile(mode='w+b', buffering=-1, encoding=None, newline=None, suffix=None, prefix=None, dir=None, delete=True)
# tempfile.NamedTemporaryFile 是 Python 标准库 tempfile 模块中的一个类,用于创建一个临时文件,并返回一个文件对象。这个文件在关闭时可以自动删除,也可以保留下来供其他程序使用。
# 参数 :
# mode :文件模式,默认为 'w+b' (二进制读写模式)。也可以是 'w' (文本写模式)、 'r+' (读写模式)等。
# buffering :缓冲区大小。默认为 -1 ,表示使用系统默认缓冲区大小。
# encoding :文件编码,仅在文本模式下有效。
# newline :换行符处理,仅在文本模式下有效。
# suffix :临时文件的后缀名(不包括点)。例如, suffix='.txt' 。
# prefix :临时文件的前缀名。默认为系统生成的随机前缀。
# dir :临时文件的存储目录。默认为系统默认的临时目录(如 /tmp )。
# delete :文件关闭时是否自动删除。默认为 True ,表示文件在关闭后自动删除。
# 返回值 :
# 返回一个文件对象,支持文件操作(如读写)。文件路径可以通过文件对象的 .name 属性访问。
# 主要用途 :
# 创建临时文件 :用于存储临时数据,文件在关闭时可以自动删除。
# 跨进程共享 :临时文件路径可以通过 .name 属性获取,供其他进程或程序访问。
# 灵活的文件操作 :支持多种文件模式(如文本模式、二进制模式)。
# 注意事项 :
# 文件路径 :文件路径可以通过 .name 属性访问。 文件路径是临时的,通常存储在系统的临时目录中(如 /tmp )。
# 文件删除 :如果 delete=True ,文件在关闭后自动删除。 如果 delete=False ,文件在关闭后不会自动删除,需要手动删除。
# 跨平台兼容性 : tempfile.NamedTemporaryFile 在 Windows 和 Unix 系统上都能正常工作。文件路径可能因操作系统而异。
# 安全性 :临时文件路径是随机生成的,避免了文件名冲突。 如果需要更高的安全性,可以使用 tempfile.mkstemp 或 tempfile.TemporaryDirectory 。
# tempfile.NamedTemporaryFile 是一个非常实用的工具,用于创建临时文件。它支持多种文件模式和灵活的配置选项,可以用于存储临时数据或跨进程共享文件。通过 delete 参数,可以控制文件是否在关闭后自动删除。
# 使用 tempfile.NamedTemporaryFile 创建一个临时文件。
with tempfile.NamedTemporaryFile(
# 文件名前缀。
prefix="_temp_",
# 文件名后缀,包含 trainer 的唯一标识符。
suffix=f"{id(trainer)}.py",
# 以读写模式打开文件。
mode="w+",
# 文件编码。
encoding="utf-8",
# 文件存储目录。
dir=USER_CONFIG_DIR / "DDP",
# 文件在关闭时不自动删除。
delete=False,
) as file:
# 将脚本内容写入临时文件。
file.write(content)
# 返回临时文件的路径。
return file.name
# generate_ddp_file 函数的作用是。动态生成脚本:根据 trainer 的配置和类信息,生成一个临时 Python 脚本文件。包含训练逻辑:脚本中包含了训练器的初始化和训练逻辑。支持分布式训练:生成的脚本可以被 PyTorch 分布式训练工具(如 torch.distributed.run )直接调用。返回文件路径:返回生成的临时文件路径,用于后续启动分布式训练。这种设计使得函数能够灵活地支持多 GPU 分布式训练,同时避免了硬编码脚本内容。
4.def generate_ddp_command(world_size, trainer):
# 这段代码定义了一个名为 generate_ddp_command 的函数,用于生成分布式数据并行(DDP)训练所需的命令。它主要用于设置 PyTorch 的分布式训练环境,并生成相应的启动命令。
# 定义了 generate_ddp_command 函数,接受两个参数。
# 1.world_size :表示每个节点上的进程数(通常等于 GPU 数量)。
# 2.trainer :一个训练器对象,包含与训练相关的配置和状态。
def generate_ddp_command(world_size, trainer):
# 生成并返回分布式训练的命令。
"""Generates and returns command for distributed training."""
# 导入 Python 的 __main__ 模块。这里提到的 URL 是一个已知的 PyTorch Lightning 问题的链接,可能与分布式训练相关。 # noqa 是一个注释,用于避免代码格式化工具的警告。
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
# 检查 trainer 是否处于恢复训练模式( resume )。如果不是,则执行以下操作。
if not trainer.resume:
# 删除 trainer.save_dir 目录。这通常是为了清理之前的训练保存目录,避免冲突。
shutil.rmtree(trainer.save_dir) # remove the save_dir
# 调用 generate_ddp_file 函数,生成一个与分布式训练相关的临时文件(可能是 Python 脚本或配置文件)。 generate_ddp_file 的作用是为 DDP 训练生成必要的文件。
file = generate_ddp_file(trainer)
# 根据 PyTorch 的版本选择分布式训练的启动命令。如果 PyTorch 版本 >= 1.9,使用 torch.distributed.run 。 否则,使用 torch.distributed.launch 。 TORCH_1_9 是一个未在代码中定义的变量,用来检查 PyTorch 版本是否大于或等于 1.9。
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
# 调用 find_free_network_port 函数,查找一个可用的网络端口。这个端口将用于分布式训练中的通信。
port = find_free_network_port()
# 构造分布式训练的启动命令。
# sys.executable :Python 解释器的路径。
# -m :表示运行一个模块。
# dist_cmd :分布式训练的启动命令( torch.distributed.run 或 torch.distributed.launch )。
# --nproc_per_node :每个节点上的进程数( world_size )。
# --master_port :主节点的通信端口( port )。
# file :生成的临时文件路径。
cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
# 返回 构造的命令列表 cmd 和 生成的文件路径 file 。
return cmd, file
# generate_ddp_command 函数的作用是。清理保存目录:如果训练器不是恢复模式,删除之前的保存目录。生成临时文件:调用 generate_ddp_file 生成分布式训练所需的文件。构造启动命令:根据 PyTorch 版本和参数构造分布式训练的启动命令。返回命令和文件路径:返回构造的命令列表和生成的文件路径。这种设计使得函数能够灵活地启动分布式训练,同时支持不同版本的 PyTorch。
5.def ddp_cleanup(trainer, file):
# 这段代码定义了一个名为 ddp_cleanup 的函数,用于清理与分布式数据并行(DDP)训练相关的临时文件。
# 定义了一个函数 ddp_cleanup ,接受两个参数。
# 1.trainer :通常是训练器对象,可能包含与分布式训练相关的上下文信息。
# 2.file :一个字符串,表示文件路径。
def ddp_cleanup(trainer, file):
# 如果创建了临时文件则删除。
"""Delete temp file if created."""
# id(object)
# id() 是 Python 的内置函数,用于获取一个对象的唯一标识符(内存地址)。
# 参数 :
# object :任何 Python 对象,包括变量、列表、字典、类实例等。
# 返回值 :
# 返回一个整数,表示对象的唯一标识符(通常是对象在内存中的地址)。
# 作用 :
# id() 函数的主要作用是获取一个对象的唯一标识符。在 CPython 实现中,这个标识符通常是对象在内存中的地址。由于每个对象在内存中都有唯一的地址, id() 可以用来判断两个变量是否指向同一个对象。
# 注意事项 :
# 唯一性 :在程序运行期间, id() 返回的值是唯一的。但程序结束后,该内存地址可能会被释放并重新分配给其他对象。
# 不可变性 :对象的 id 在其生命周期内不会改变。即使对象的内容发生变化(如列表或字典的内容), id 仍然保持不变。
# 用途 :
# id() 通常用于调试,帮助开发者理解变量的引用关系。
# 它也可以用于判断两个变量是否指向同一个对象,但通常更推荐使用 is 关键字 : print(x is y) # 等价于 id(x) == id(y) 。
# id() 函数是一个非常实用的内置函数,用于获取对象的唯一标识符(内存地址)。它在调试和理解变量引用关系时非常有用,但通常不建议在生产代码中直接使用 id() 来比较对象,而是使用 is 关键字。
# 检查文件名是否包含特定的后缀,这个后缀是通过 id(trainer) 生成的唯一标识符,并附加 .py 扩展名。
# id(trainer) :获取 trainer 对象的内存地址(唯一标识符)。
# f"{id(trainer)}.py" :将标识符转换为字符串,并附加 .py 后缀。
# if ... in file :检查生成的后缀是否包含在文件名中。
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
# 如果文件名包含指定的后缀,则调用 os.remove(file) 删除该文件。 os.remove 是 Python 标准库中的函数,用于删除指定路径的文件。
os.remove(file)
# dp_cleanup 函数的作用是。检查文件名是否包含与 trainer 对象相关的唯一标识符后缀。如果匹配,则删除该文件。这种设计通常用于清理分布式训练过程中生成的临时文件,例如在 DDP(Distributed Data Parallel)训练中,每个进程可能生成临时文件,这些文件在训练完成后需要被清理。
原文地址:https://blog.csdn.net/m0_58169876/article/details/145267819
免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!