YOLOv10-1.1部分代码阅读笔记-build.py
build.py
ultralytics\data\build.py
目录
2.class InfiniteDataLoader(dataloader.DataLoader):
5.def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
6.def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
8.def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
1.所需的库和模块
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import random
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torch.utils.data import dataloader, distributed
from ultralytics.data.loaders import (
LOADERS,
LoadImagesAndVideos,
LoadPilAndNumpy,
LoadScreenshots,
LoadStreams,
LoadTensor,
SourceTypes,
autocast_list,
)
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file
from .dataset import YOLODataset
from .utils import PIN_MEMORY
2.class InfiniteDataLoader(dataloader.DataLoader):
# 这段代码定义了一个名为 InfiniteDataLoader 的类,它继承自PyTorch的 DataLoader 类。这个类的目的是创建一个可以无限迭代的数据加载器,即使数据集中的数据有限。这在某些训练场景中非常有用,例如当需要进行长时间的训练并且不希望在每个epoch结束时停止训练时。
# 定义类 InfiniteDataLoader ,继承自 DataLoader 。
class InfiniteDataLoader(dataloader.DataLoader):
# 重用工作器的数据加载器。
# 使用与原始数据加载器相同的语法。
"""
Dataloader that reuses workers.
Uses same syntax as vanilla DataLoader.
"""
# 定义初始化方法,接收任意数量的位置参数 *args 和关键字参数 **kwargs ,这些参数将传递给父类 DataLoader 的初始化方法。
def __init__(self, *args, **kwargs):
# 无限循环使用 worker 的 Dataloader,继承自 DataLoader。
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
# 调用父类 DataLoader 的初始化方法,初始化数据加载器。
super().__init__(*args, **kwargs)
# object.__setattr__(name, value)
# 在Python中, object.__setattr__() 是一个特殊方法,用于设置对象的属性。它是 object 类的一个方法,而 object 是Python中所有类的基类。 __setattr__() 方法在设置对象属性时被自动调用,但也可以在子类中被重写以自定义属性赋值的行为。
# 参数 :
# name :要设置的属性的名称。
# value :属性的值。
# 行为 :
# 当对一个对象的属性进行赋值操作时,例如 obj.attr = value ,Python会自动调用该对象的 __setattr__() 方法。这个方法的默认实现会设置一个名为 name 的属性,其值为 value 。
# 为什么使用 object.__setattr__ :
# 在某些情况下,你可能需要直接调用 __setattr__() 方法,特别是当你需要绕过属性赋值的默认行为时。例如,你可能想要在设置属性之前执行一些额外的检查或操作。
# 注意事项 :
# 使用 object.__setattr__() 时,应该谨慎,因为它会绕过属性的正常赋值机制,包括可能的属性监视器或装饰器。
# 在大多数情况下,直接使用 obj.attr = value 就足够了,除非有特殊需求需要自定义属性赋值的行为。
# 使用 object.__setattr__ 方法设置 batch_sampler 属性,将其替换为 _RepeatSampler 实例。 _RepeatSampler 是一个自定义的采样器,用于重复采样,使得数据加载器可以无限迭代。
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
# 创建一个 迭代器 self.iterator ,用于 迭代数据加载器 。
self.iterator = super().__iter__()
# 定义 __len__ 方法,返回数据加载器的长度。
def __len__(self):
# 返回批量采样器的采样器的长度。
"""Returns the length of the batch sampler's sampler."""
# 返回 batch_sampler 中的 sampler 的长度,即 数据集的大小 。
return len(self.batch_sampler.sampler)
# 定义 __iter__ 方法,用于迭代数据加载器。
def __iter__(self):
# 创建一个无限重复的采样器。
"""Creates a sampler that repeats indefinitely."""
# 在 __iter__ 方法中,使用一个循环迭代 len(self) 次,每次调用 next(self.iterator) 获取下一个批次的数据,并使用 yield 关键字返回。这使得数据加载器可以无限迭代,即使数据集中的数据有限。
for _ in range(len(self)):
yield next(self.iterator)
# 定义 reset 方法,用于重置数据加载器。
def reset(self):
# 重置迭代器。
# 当我们想要在训练时修改数据集的设置时,这很有用。
"""
Reset iterator.
This is useful when we want to modify settings of dataset while training.
"""
# 调用 self._get_iterator() 方法重新创建一个迭代器,重置数据加载器的迭代状态。
self.iterator = self._get_iterator()
# InfiniteDataLoader 类通过继承 DataLoader 并修改其行为,实现了无限迭代的功能。这在需要长时间训练且不希望在每个epoch结束时停止训练的场景中非常有用。通过使用 _RepeatSampler 采样器,数据加载器可以在数据集结束时重新开始,从而实现无限迭代。 reset 方法允许用户在需要时重置数据加载器的迭代状态。
3.class _RepeatSampler:
# 这段代码定义了一个名为 _RepeatSampler 的类,其目的是创建一个可以无限重复的采样器,用于与 InfiniteDataLoader 类结合,实现数据加载器的无限迭代功能。
# 定义类 _RepeatSampler 。
class _RepeatSampler:
# 永远重复的采样器。
"""
Sampler that repeats forever.
Args:
sampler (Dataset.sampler): The sampler to repeat.
"""
# 定义初始化方法,接收一个参数。
# 1.sampler :这是一个采样器对象,通常用于 DataLoader 中以决定数据的加载顺序。
def __init__(self, sampler):
# 初始化一个无限重复给定采样器的对象。
"""Initializes an object that repeats a given sampler indefinitely."""
# 将传入的 sampler 对象存储为实例属性,以便在后续的迭代中使用。
self.sampler = sampler
# 定义 __iter__ 方法,这是 Python 迭代器协议的一部分,用于使类的实例能够被迭代。
def __iter__(self):
# 迭代‘采样器’并产生其内容。
"""Iterates over the 'sampler' and yields its contents."""
# 使用一个无限循环 while True ,确保迭代过程可以无限进行。在每次循环中,使用 yield from 语句来委托迭代过程给 self.sampler 。 iter(self.sampler) 会获取 sampler 的迭代器, yield from 则会从这个迭代器中逐个产生元素,直到迭代器耗尽。由于 while True 的存在,一旦 sampler 的迭代器耗尽,循环会重新开始,从而实现无限重复。
while True:
yield from iter(self.sampler)
# _RepeatSampler 类通过简单的迭代器委托和无限循环,实现了采样器的无限重复功能。当与 InfiniteDataLoader 结合使用时,它允许数据加载器在数据集结束时自动重新开始,从而实现无限迭代,这对于某些需要长时间持续训练的机器学习任务非常有用。
4.def seed_worker(worker_id):
# 这段代码定义了一个名为 seed_worker 的函数,其目的是为 PyTorch 数据加载器的每个工作进程设置随机种子。这有助于确保在多进程数据加载时,每个工作进程生成的随机数序列是可重现的,从而提高实验的可重复性。
# 定义函数 seed_worker ,接收一个参数。 # noqa 是一个注释,用于告诉代码检查工具忽略这一行的检查。
# 1.worker_id :这是工作进程的唯一标识符。
def seed_worker(worker_id): # noqa
# 设置数据加载器工作器种子 https://pytorch.org/docs/stable/notes/randomness.html#dataloader。
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
# torch.initial_seed() -> int
# torch.initial_seed() 函数返回用于生成随机数的初始种子,类型为 Python 的 long 整数。这个函数通常用于获取当前随机数生成器的初始种子值,以便在需要时可以重现相同的随机数序列。
# 返回值 :
# 返回一个 long 整数,表示当前随机数生成器的初始种子。
# 使用场景 :
# 多进程数据加载 :在使用 DataLoader 的多进程数据加载时,每个工作进程需要设置不同的随机种子,以确保每个进程生成的随机数序列是不同的。 torch.initial_seed() 可以用于获取当前工作进程的初始种子,并在此基础上设置其他随机数生成器的种子。
# 调试和重现 :在调试和重现实验结果时,获取和设置初始种子可以帮助确保每次运行的随机数序列是相同的。
# 注意事项 :
# torch.initial_seed() 返回的是当前随机数生成器的初始种子,这个值在每次调用 torch.manual_seed() 时会被设置。
# 在多进程环境中,每个工作进程的初始种子是不同的,通常是由主进程的初始种子加上工作进程的 ID 计算得到的。
# 在子进程中运行 torch.initial_seed() ,返回的就是 torch 当前的随机数种子,即 base_seed + worker_id 。
# 因为每个 epoch 开始时,主进程都会重新生成一个 base_seed,所以 base_seed 是随 epoch 变化而变化的随机数。
# 此外,torch.initial_seed() 返回的是 long int 类型,而Numpy只接受 uint 类型([0, 2**32 - 1]),所以需要对 2**32 取模。
# 使用 torch.initial_seed() 获取当前工作进程的初始随机种子。 torch.initial_seed() 返回一个长整型数,通常是一个非常大的数。为了确保种子值在 32 位整数范围内,使用模运算 % 2**32 将其限制在 0 到 2^32-1 之间。
worker_seed = torch.initial_seed() % 2**32
# numpy.random.seed(seed=None)
# np.random.seed() 是 NumPy 库中的一个函数,用于设置随机数生成器的种子。通过设置种子,可以确保每次运行代码时生成的随机数序列是相同的,这对于调试和重现实验结果非常有用。
# 参数 :
# seed : int, array_like, or None 种子值。可以是一个整数、一个数组或 None 。如果为 None ,则从操作系统提供的随机源中获取种子值。
# 使用场景 :
# 调试 :在调试过程中,设置固定的随机种子可以确保每次运行代码时生成的随机数序列相同,便于定位和修复问题。
# 重现实验结果 :在科学研究和机器学习实验中,设置固定的随机种子可以确保实验结果的可重现性。
# 多进程数据加载 :在多进程环境中,每个工作进程需要设置不同的随机种子,以确保每个进程生成的随机数序列是不同的。这可以通过结合工作进程的 ID 和主进程的种子值来实现。
# 注意事项 :
# np.random.seed() 只影响 NumPy 的随机数生成器,不影响 Python 标准库 random 模块的随机数生成器。
# 在多进程环境中,每个工作进程应独立设置随机种子,以避免生成相同的随机数序列。
# 使用 np.random.seed(worker_seed) 设置 NumPy 的随机种子为 worker_seed 。这确保了在该工作进程中,所有使用 NumPy 生成的随机数都是可重现的。
np.random.seed(worker_seed)
# random.seed(a=None, version=2)
# random.seed() 是 Python 标准库 random 模块中的一个函数,用于初始化随机数生成器的种子。这个函数确保了随机数生成器产生的随机数序列是可重复的,即在相同的种子下,每次运行程序时产生的随机数序列都是相同的。
# 参数 :
# a :种子值,可以是任何 hashable(可哈希)对象。如果为 None ,则使用当前时间作为种子。
# version :随机数生成器的版本,可以是 1 或 2。默认为 2。版本 2 在 Python 3.2.3 和 3.3.0 中引入,提供了更好的随机性。
# 作用 :
# 当你提供一个特定的种子值时, random.seed() 会重置随机数生成器的状态,使得随后的随机数生成可以预测。
# 如果不提供种子值(或为 None ),则随机数生成器将使用一个不可预测的值(通常是当前时间)作为种子,这使得每次程序运行时产生的随机数序列都是不同的。
# 使用 random.seed(worker_seed) 设置 Python 标准库 random 模块的随机种子为 worker_seed 。这确保了在该工作进程中,所有使用 random 模块生成的随机数也是可重现的。
random.seed(worker_seed)
# seed_worker 函数通过为每个工作进程设置相同的随机种子,确保了在多进程数据加载时,每个工作进程生成的随机数序列是可重现的。这对于需要使用随机性(如数据增强、随机采样等)的机器学习任务非常有用,特别是在多进程环境中,可以显著提高实验的可重复性。通常,这个函数会在创建 DataLoader 时通过 worker_init_fn 参数传递,以确保每个工作进程在启动时正确设置随机种子。
5.def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
# 这段代码定义了一个名为 build_yolo_dataset 的函数,其目的是根据给定的配置和参数构建一个用于 YOLO 模型训练或验证的 YOLODataset 实例。这个函数封装了 YOLODataset 的创建过程,使得数据集的构建更加灵活和方便。
# 定义函数 build_yolo_dataset ,接收以下参数 :
# 1.cfg :配置对象,包含多种配置参数。
# 2.img_path :图像路径,可以是单个文件路径或包含多个文件路径的列表。
# 3.batch :批量大小。
# 4.data :数据配置,通常包含数据集的路径和其他相关信息。
# 5.mode :模式,可以是 "train" 或 "val" ,默认为 "train" 。
# 6.rect :是否使用矩形批次, 默认为 False 。
# 7.stride :模型的步长,默认为 32。
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
# 构建 YOLO 数据集。
"""Build YOLO Dataset."""
# 返回一个 YOLODataset 实例,配置如下。
# class YOLODataset(BaseDataset):
# -> 用于处理YOLO模型的数据集。
# -> def __init__(self, *args, data=None, task="detect", **kwargs):
return YOLODataset(
# 图像路径,可以是单个文件路径或包含多个文件路径的列表。
img_path=img_path,
# 图像大小,从 cfg 中获取。
imgsz=cfg.imgsz,
# 批量大小,从 batch 参数中获取。
batch_size=batch,
# 是否进行数据增强,仅在训练模式下为 True 。
augment=mode == "train", # augmentation
# 超参数配置,直接使用 cfg 。
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
# 是否使用矩形批次,从 cfg.rect 或 rect 参数中获取。
rect=cfg.rect or rect, # rectangular batches
# 缓存设置,从 cfg.cache 中获取,如果未设置则为 None 。
cache=cfg.cache or None,
# 是否为单类别训练,从 cfg.single_cls 中获取,如果未设置则为 False 。
single_cls=cfg.single_cls or False,
# 模型的步长,从 stride 参数中获取并转换为整数。
stride=int(stride),
# 填充比例,训练模式下为 0.0 ,验证模式下为 0.5 。
pad=0.0 if mode == "train" else 0.5,
# 日志前缀,使用 colorstr 函数添加颜色,显示模式信息。
prefix=colorstr(f"{mode}: "),
# 任务类型,从 cfg.task 中获取。
task=cfg.task,
# 类别信息,从 cfg.classes 中获取。
classes=cfg.classes,
# 数据配置,包含数据集的路径和其他相关信息。
data=data,
# 数据集的使用比例,训练模式下从 cfg.fraction 中获取,验证模式下为 1.0 。
fraction=cfg.fraction if mode == "train" else 1.0,
)
# build_yolo_dataset 函数通过封装 YOLODataset 的创建过程,提供了一个灵活且方便的方式来构建 YOLO 模型的训练或验证数据集。通过传入不同的参数,可以轻松地调整数据集的配置,适用于不同的训练和验证场景。这个函数特别适用于需要动态构建数据集的场景,如超参数调优、不同数据集的实验等。
6.def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
# 这段代码定义了一个名为 build_dataloader 的函数,其目的是根据给定的参数构建一个用于训练或验证的数据加载器( DataLoader )。这个函数可以返回一个 InfiniteDataLoader 或 DataLoader ,具体取决于是否需要无限迭代数据集。这个函数特别适用于分布式训练和多进程数据加载的场景。
# 定义函数 build_dataloader ,接收以下参数 :
# 1.dataset :数据集对象,继承自 torch.utils.data.Dataset 。
# 2.batch :批量大小。
# 3.workers :工作进程数。
# 4.shuffle :是否打乱数据,默认为 True 。
# 5.rank :分布式训练中的进程排名,默认为 -1 ,表示非分布式训练。
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
# 返回用于训练或验证集的 InfiniteDataLoader 或 DataLoader。
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
# 将批量大小 batch 限制为 数据集长度的最小值 ,确保批量大小不会超过数据集的大小。
batch = min(batch, len(dataset))
# 获取可用的 CUDA 设备数量。
nd = torch.cuda.device_count() # number of CUDA devices
# 计算 实际使用的工作进程数 nw ,确保不超过系统可用的 CPU 核心数除以 CUDA 设备数量(至少为1),并且不超过用户指定的工作进程数 workers 。
nw = min([os.cpu_count() // max(nd, 1), workers]) # number of workers
# torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)
# torch.utils.data.distributed.DistributedSampler 类的构造函数用于创建一个新的分布式采样器实例,它主要用于分布式训练环境中,以确保每个进程只处理数据集的一部分,从而实现数据的均匀分配。
# 参数 :
# dataset ( Dataset ) :要采样的数据集对象。
# num_replicas ( int ,可选) :分布式环境中的总副本(进程)数量。默认值为 None ,在这种情况下,它会尝试从当前的分布式环境变量中获取 world_size 。
# rank ( int ,可选) :当前进程的排名或ID。默认值为 None ,在这种情况下,它会尝试从当前的分布式环境变量中获取 rank 。
# shuffle ( bool ) :是否在每个epoch开始时打乱数据集的采样顺序。默认值为 True 。
# seed ( int ) :用于打乱数据集的随机种子。确保在所有进程中使用相同的种子以获得一致的打乱结果。默认值为 0 。
# drop_last ( bool ) :如果为 True ,则在数据集不能被均匀分配时,丢弃最后一部分数据以确保每个进程处理相同数量的数据。如果为 False ,则可能有些进程会处理更多的数据。默认值为 False 。
# 返回值 :
# 返回一个新的 DistributedSampler 实例。
# DistributedSampler 类在 PyTorch 中用于分布式训练,以下是它的一些常用属性和方法 :
# 属性 :
# dataset : 返回与采样器关联的数据集。
# num_replicas : 返回分布式环境中的总副本(进程)数量。
# rank : 返回当前进程的排名或ID。
# epoch : 返回当前的epoch数。这个属性在每个epoch开始时通过调用 set_epoch() 方法更新。
# 方法 :
# set_epoch(epoch) : 设置当前的epoch数。这对于确保在每个epoch中数据被打乱是必要的,特别是在 shuffle=True 时。
# __iter__() : 返回一个迭代器,该迭代器产生当前epoch中被采样器选中的数据集索引。
# __len__() : 返回当前epoch中被采样器选中的数据集索引的数量。
# update() : 更新采样器的状态,这个方法在 PyTorch 的某些版本中存在,用于重新配置采样器的参数。
# DistributedSampler 的主要作用是确保在分布式训练中,每个进程都能够处理数据集的不同部分,从而提高数据加载的效率和训练的可扩展性。通过在每个epoch开始时调用 set_epoch() 方法,可以确保数据在每个epoch中都被重新打乱,这对于模型的训练是非常重要的。
# 如果 rank 为 -1 ,表示非分布式训练, sampler 为 None 。否则,使用 DistributedSampler 进行分布式采样,确保每个进程处理不同的数据子集。
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
# torch.Generator(device='cpu')
# 在PyTorch中, torch.Generator 是一个用于生成随机数的伪随机数生成器(PRNG)的类。它主要用于生成与特定种子或设备相关的随机数,以确保实验的可重复性。
# 参数 :
# device : 指定生成器所在的设备,可以是'cpu'或'cuda'设备。
# 主要方法 :
# manual_seed(seed) : 设置生成器的种子。 seed : 一个整数,用于初始化生成器。
# seed() : 自动设置生成器的种子。 这将生成一个随机种子,确保每次运行代码时生成的随机数不同。
# get_state() : 返回生成器的当前状态。
# set_state(state) : 设置生成器的状态。 state : 一个张量,表示生成器的状态。
# initial_seed() : 返回生成器的初始种子。
# 注意事项 :
# 当使用多个设备(如CPU和GPU)时,需要为每个设备创建一个独立的 Generator 实例,并设置不同的种子,以确保随机数的生成是独立的。
# 在分布式训练或多线程环境中,正确管理生成器的状态非常重要,以避免随机数生成的冲突。
# torch.Generator 提供了一种灵活的方式来控制随机数的生成,使得实验和模型训练更加可重复和可控。
# 创建一个 torch.Generator 对象,并设置一个固定的随机种子,确保每个工作进程生成的随机数序列是可重现的。 RANK 是一个全局变量,表示当前进程的排名。
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
# 返回一个 InfiniteDataLoader 实例,配置如下。
return InfiniteDataLoader(
# 数据集对象。
dataset=dataset,
# 批量大小。
batch_size=batch,
# 是否打乱数据,仅在 sampler 为 None 时生效。
shuffle=shuffle and sampler is None,
# 工作进程数。
num_workers=nw,
# 采样器对象,用于分布式训练。
sampler=sampler,
# 是否使用 pinned memory,提高数据传输效率。
pin_memory=PIN_MEMORY,
# 数据批处理函数,从数据集对象中获取。
collate_fn=getattr(dataset, "collate_fn", None),
# 工作进程初始化函数,用于设置每个工作进程的随机种子。
worker_init_fn=seed_worker,
# 随机数生成器,用于生成可重现的随机数序列。
generator=generator,
)
# build_dataloader 函数通过封装 InfiniteDataLoader 的创建过程,提供了一个灵活且方便的方式来构建数据加载器。这个函数特别适用于分布式训练和多进程数据加载的场景,确保每个工作进程生成的随机数序列是可重现的,并且可以灵活地配置数据加载器的各种参数。通过这个函数,可以轻松地构建适用于训练和验证的数据加载器,提高代码的复用性和灵活性。
7.def check_source(source):
# 这段代码定义了一个名为 check_source 的函数,其目的是检查和处理输入的图像或视频源,并确定其类型。这个函数返回处理后的源以及一些标志,这些标志指示源的类型,如是否为摄像头、屏幕截图、内存中的图像、文件路径或PyTorch张量。
# 定义函数 check_source ,接收一个参数。
# 1.source :它可以是多种类型,包括字符串、整数、路径、列表、元组、PIL图像对象、NumPy数组或PyTorch张量。
def check_source(source):
# 检查源类型并返回相应的标志值。
"""Check source type and return corresponding flag values."""
# 初始化几个布尔标志,用于表示源的类型。
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
# 如果 source 是字符串、整数或路径。
if isinstance(source, (str, int, Path)): # int for local usb camera
# 将 source 转换为字符串。
source = str(source)
# 检查 source 是否为文件路径,通过检查文件扩展名是否在支持的图像或视频格式中。
is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS)
# 检查 source 是否为URL,通过检查是否以特定协议开头。
is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
# 如果 source 是数字、以 .streams 结尾或是一个URL但不是文件,则认为它是摄像头源。
webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
# 如果 source 是 "screen" ,则认为它是屏幕截图源。
screenshot = source.lower() == "screen"
# 如果 source 是URL且是文件,则下载文件。
if is_url and is_file:
# def check_file(file, suffix="", download=True, hard=True):
# -> 用于检查文件的存在性,并根据需要下载或搜索文件。如果满足上述条件之一,则返回文件名。返回下载后的文件名。根据搜索结果返回文件名或抛出错误。
# -> return file / return file / return files[0] if len(files) else [] if hard else file # return file
source = check_file(source) # download
# 如果 source 是 LOADERS 类型之一,认为它是内存中的图像源。
# LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots)
elif isinstance(source, LOADERS):
in_memory = True
# 如果 source 是列表或元组,调用 autocast_list 函数将所有元素转换为PIL图像对象或NumPy数组,并设置 from_img 标志。
elif isinstance(source, (list, tuple)):
# def autocast_list(source): -> 将一个包含不同类型的图像源的列表转换为一个统一的图像对象列表,这些图像对象可以是PIL图像或NumPy数组,以便进行后续的图像处理或分析。返回 转换后的图像对象列表 files 。 -> return files
source = autocast_list(source) # convert all list elements to PIL or np arrays
from_img = True
# 如果 source 是PIL图像对象或NumPy数组,设置 from_img 标志。
elif isinstance(source, (Image.Image, np.ndarray)):
from_img = True
# 如果 source 是PyTorch张量,设置 tensor 标志。
elif isinstance(source, torch.Tensor):
tensor = True
# 如果 source 的类型不匹配上述任何一种,抛出 TypeError 异常,提示不支持的图像类型。
else:
raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict") # 不支持的图像类型。有关支持的类型,请参阅 https://docs.ultralytics.com/modes/predict。
# 返回处理后的 source 以及几个布尔标志(webcam摄像头, screenshot屏幕截图, from_img, in_memory, tensor),这些标志指示源的类型。
return source, webcam, screenshot, from_img, in_memory, tensor
# check_source 函数通过检查输入的 source ,确定其类型,并返回处理后的源和相应的类型标志。这使得函数能够灵活地处理多种类型的图像或视频源,包括文件路径、URL、摄像头、屏幕截图、内存中的图像和PyTorch张量。通过这个函数,可以轻松地在应用程序中集成多种数据源的处理逻辑,提高代码的复用性和灵活性。
8.def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
# 这段代码定义了一个名为 load_inference_source 的函数,其目的是根据给定的输入源 source 加载相应的数据集,并返回一个数据加载器。这个函数通过调用 check_source 函数来确定输入源的类型,并根据类型选择合适的加载器类来处理数据。
# 定义函数 load_inference_source ,接收以下参数 :
# 1.source :输入源,可以是多种类型,如字符串、路径、列表、元组、PIL图像对象、NumPy数组或PyTorch张量。
# 2.batch :批量大小,默认为1。
# 3.vid_stride :视频帧率步长,默认为1。
# 4.buffer :是否使用缓冲区,默认为 False 。
def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
# 加载用于对象检测的推理源并应用必要的转换。
"""
Loads an inference source for object detection and applies necessary transformations.
Args:
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
batch (int, optional): Batch size for dataloaders. Default is 1.
vid_stride (int, optional): The frame interval for video sources. Default is 1.
buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
Returns:
dataset (Dataset): A dataset object for the specified input source.
"""
# 调用 check_source 函数,检查输入源 source 并返回处理后的源以及多个布尔标志,这些标志指示源的类型。
source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
# 根据 in_memory 标志,确定 source_type 。如果 in_memory 为 True ,则从 source 中获取 source_type ;否则,使用 SourceTypes 枚举类创建 source_type 。
# class SourceTypes:
# -> 用于表示预测输入源的各种类型。 dataclass 装饰器可以自动生成特殊方法,如 __init__ 、 __repr__ 和 __eq__ ,使得类的定义更加简洁和易于维护。
source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)
# 根据输入源的类型,选择合适的加载器类来处理数据。
# Dataloader
# 如果 tensor 为 True ,使用 LoadTensor 类。
if tensor:
dataset = LoadTensor(source)
# 如果 in_memory 为 True ,直接使用 source 作为数据集。
elif in_memory:
dataset = source
# 如果 stream 为 True ,使用 LoadStreams 类。
elif stream:
dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
# 如果 screenshot 为 True ,使用 LoadScreenshots 类。
elif screenshot:
dataset = LoadScreenshots(source)
# 如果 from_img 为 True ,使用 LoadPilAndNumpy 类。
elif from_img:
dataset = LoadPilAndNumpy(source)
# 其他情况下,使用 LoadImagesAndVideos 类。
else:
dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride)
# Attach source types to the dataset
# 将 source_type 附加到数据集对象上,以便后续使用。
setattr(dataset, "source_type", source_type)
# 返回构建好的数据集对象 dataset 。
return dataset
# load_inference_source 函数通过调用 check_source 函数确定输入源的类型,并根据类型选择合适的加载器类来处理数据。这个函数提供了一个灵活且方便的方式来加载多种类型的输入源,适用于不同的推理场景。通过这个函数,可以轻松地构建适用于训练和验证的数据加载器,提高代码的复用性和灵活性。
原文地址:https://blog.csdn.net/m0_58169876/article/details/145193265
免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!