YOLOv8-ultralytics-8.2.103部分代码阅读笔记-base.py
base.py
ultralytics\data\base.py
目录
1.所需的库和模块
# Ultralytics YOLO 🚀, AGPL-3.0 license
import glob
import math
import os
import random
from copy import deepcopy
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import psutil
from torch.utils.data import Dataset
from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
2.class BaseDataset(Dataset):
# 这段代码是一个用于深度学习模型训练的图像数据集类 BaseDataset ,它继承自 PyTorch 的 Dataset 类。
class BaseDataset(Dataset):
# 用于加载和处理图像数据的基本数据集类。
"""
Base dataset class for loading and processing image data.
Args:
img_path (str): Path to the folder containing images.
imgsz (int, optional): Image size. Defaults to 640.
cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
augment (bool, optional): If True, data augmentation is applied. Defaults to True.
hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
prefix (str, optional): Prefix to print in log messages. Defaults to ''.
rect (bool, optional): If True, rectangular training is used. Defaults to False.
batch_size (int, optional): Size of batches. Defaults to None.
stride (int, optional): Stride. Defaults to 32.
pad (float, optional): Padding. Defaults to 0.0.
single_cls (bool, optional): If True, single class training is used. Defaults to False.
classes (list): List of included classes. Default is None.
fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
Attributes:
im_files (list): List of image file paths.
labels (list): List of label data dictionaries.
ni (int): Number of images in the dataset.
ims (list): List of loaded images.
npy_files (list): List of numpy file paths.
transforms (callable): Image transformation function.
"""
# 这段代码是 BaseDataset 类的构造函数 __init__ 的定义,它用于初始化类的实例。
# 参数解释.
# 1.self : 类实例的引用。
# 2.img_path : 存储图像文件的路径。
# 3.imgsz : 目标图像尺寸,默认为640。
# 4.cache : 是否缓存图像,可以是布尔值、"ram"、"disk"或None。
# 5.augment : 是否进行数据增强,默认为True。
# 6.hyp : 超参数配置,默认为 DEFAULT_CFG 。
# 7.prefix : 日志或命名的前缀。
# 8.rect : 是否使用矩形训练,默认为False。
# 9.batch_size : 批处理大小,默认为16。
# 10.stride : 步长,默认为32。
# 11.pad : 填充比例,默认为0.5。
# 12.single_cls : 是否为单类别训练,默认为False。
# 13.classes : 包含的类别列表。
# 14.fraction : 使用数据集的一部分,默认为1.0(即全部数据)。
def __init__(
self,
img_path,
imgsz=640,
cache=False,
augment=True,
# DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"
# 使用 yaml_load 函数来加载一个配置文件(路径由 DEFAULT_CFG_PATH 变量指定),并将结果(一个字典)赋值给 DEFAULT_CFG_DICT 。
# DEFAULT_CFG 这个对象包含了 DEFAULT_CFG_DICT 字典中的所有键值对。
hyp=DEFAULT_CFG,
prefix="",
rect=False,
batch_size=16,
stride=32,
pad=0.5,
single_cls=False,
classes=None,
fraction=1.0,
):
# 使用给定的配置和选项初始化 BaseDataset 。
"""Initialize BaseDataset with given configuration and options."""
# 初始化父类。这行代码调用了基类 Dataset 的构造函数。
super().__init__()
# 设置成员变量。
# self.img_path 到 self.transforms 都是类的成员变量,用于存储配置和状态。
self.img_path = img_path
self.imgsz = imgsz
self.augment = augment
self.single_cls = single_cls
self.prefix = prefix
self.fraction = fraction
# 获取图像文件和标签。
# 调用 get_img_files 方法获取图像文件列表。
self.im_files = self.get_img_files(self.img_path)
# 调用 get_labels 方法获取标签。
self.labels = self.get_labels()
# 更新标签。根据 classes 参数更新标签,用于处理 单类别训练 或 包含特定类别 。
self.update_labels(include_class=classes) # single_cls and include_class
# 计算图像数量。
self.ni = len(self.labels) # number of images
self.rect = rect
self.batch_size = batch_size
self.stride = stride
self.pad = pad
# 设置矩形训练。如果 rect 为True,检查 batch_size 是否为None,并调用 self.set_rectangle() 方法设置矩形训练参数。
if self.rect:
assert self.batch_size is not None
self.set_rectangle()
# Buffer thread for mosaic images
# 缓冲区设置。
# 用于缓存马赛克图像的缓冲区,大小等于批处理大小。
self.buffer = [] # buffer size = batch size
# 根据 图像数量 、 批处理大小 和 固定值 1000 计算缓冲区的最大长度。
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
# Cache images (options are cache = True, False, None, "ram", "disk")
# 缓存图像。
# self.ims , self.im_hw0 , self.im_hw :用于存储图像数据和尺寸的列表。
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
# 将图像文件路径转换为.npy格式,用于缓存。
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
# 根据 cache 参数设置缓存策略。
self.cache = cache.lower() if isinstance(cache, str) else "ram" if cache is True else None
# 如果使用RAM缓存或磁盘缓存,根据 hyp.deterministic 参数判断是否发出警告,并调用 self.cache_images() 方法缓存图像。
if (self.cache == "ram" and self.check_cache_ram()) or self.cache == "disk":
if self.cache == "ram" and hyp.deterministic:
LOGGER.warning(
"WARNING ⚠️ cache='ram' may produce non-deterministic training results. " # 警告⚠️cache='ram' 可能会产生不确定的训练结果。
"Consider cache='disk' as a deterministic alternative if your disk space allows." # 如果您的磁盘空间允许,请考虑将 cache='disk' 作为确定性的替代方案。
)
self.cache_images()
# Transforms
# 构建变换。方法构建图像变换。
self.transforms = self.build_transforms(hyp=hyp)
# 这个构造函数为 BaseDataset 类提供了一个灵活的初始化过程,允许用户根据需要配置数据集的各种参数。通过这种方式,用户可以轻松地调整数据集的行为以适应不同的训练需求。
# 这段代码定义了一个名为 get_img_files 的方法,它是 BaseDataset 类的一部分。这个方法的目的是读取图像文件路径,并返回一个包含所有图像文件路径的列表。
# 参数.
# 1.self : 类实例的引用。
# 2.img_path : 可以是一个包含图像文件路径的列表,或者是一个单一的图像文件路径。
def get_img_files(self, img_path):
# 读取图像文件。
"""Read image files."""
try:
# 初始化图像文件列表。初始化一个空列表,用于存储图像文件路径。
f = [] # image files
# 处理图像路径。如果 img_path 是列表,则遍历列表中的每个路径;如果不是列表,则将其视为包含单个路径的列表。
for p in img_path if isinstance(img_path, list) else [img_path]:
# 检查路径类型。
# 将路径转换为 Path 对象,使其跨操作系统兼容。
p = Path(p) # os-agnostic
# 如果路径是一个目录,则使用 glob.glob 递归搜索所有文件。 recursive :一个布尔值,默认为 False ,表示是否递归搜索子目录。
if p.is_dir(): # dir
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
# F = list(p.rglob('*.*')) # pathlib
# 如果路径是一个文件,则读取该文件,并将其包含的路径转换为全局路径。
elif p.is_file(): # file
# 打开文件。使用 with 语句打开文件 p ,这样可以确保文件在读取后会被正确关闭。变量 t 是文件对象。
with open(p) as t:
# 读取文件内容。读取文件内容, strip() 方法用于去除字符串首尾的空白字符(包括换行符), splitlines() 方法将字符串分割成行,返回一个列表。
t = t.read().strip().splitlines()
# os.sep
# os.sep 是 Python 标准库 os 模块中的一个属性,它代表操作系统特定的路径分隔符。这个属性在不同的操作系统中有不同的值 :
# 在 Windows 系统中, os.sep 是反斜杠 \ 。
# 在 Unix 和 Unix-like 系统(包括 Linux 和 macOS)中, os.sep 是正斜杠 / 。
# os.sep 用于确保代码在不同操作系统中处理文件路径时具有兼容性。例如,当你需要在路径字符串中插入分隔符时,使用 os.sep 而不是硬编码分隔符可以使得代码更加可移植。
# 获取父目录路径。获取文件 p 的父目录路径,并将其转换为字符串。 os.sep 是操作系统特定的路径分隔符(例如,在Windows上是 \ ,在Unix/Linux上是 / )。
parent = str(p.parent) + os.sep
# 替换路径为全局路径。
# 遍历文件内容的每一行 x ,如果行以 ./ 开头,这意味着路径是相对于当前文件的目录的局部路径。 replace("./", parent) 将这些局部路径转换为全局路径。如果不以 ./ 开头,则假定路径已经是全局路径,不做替换。
# 添加到文件列表。将处理后的路径添加到 f 列表中。
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
# 注释掉的路径处理(使用Pathlib)。
# 这一行被注释掉了,但它提供了另一种使用 pathlib 模块将局部路径转换为全局路径的方法。 p.parent / x.lstrip(os.sep) 将父目录路径与去除开头的路径分隔符的局部路径组合起来。
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
# 如果路径既不是目录也不是文件,则抛出 FileNotFoundError 。
else:
raise FileNotFoundError(f"{self.prefix}{p} does not exist") # {self.prefix}{p}不存在
# 筛选图像文件。筛选出文件扩展名在 IMG_FORMATS 列表中的文件,并排序。
# 列表推导式 : x.replace("/", os.sep) for x in f 。
# 这部分是一个列表推导式,它遍历 f 列表中的每个元素 x 。对于每个 x ,它使用 replace 方法将路径中的正斜杠 / 替换为操作系统特定的路径分隔符 os.sep 。这样做是为了确保路径在不同操作系统中的兼容性。
# 条件筛选 : if x.split(".")[-1].lower() in IMG_FORMATS 。
# 这部分是列表推导式的条件语句。对于每个路径 x ,它执行以下操作 :
# x.split(".") :根据点 . 分割路径字符串,返回一个包含路径中各部分的列表。
# [-1] :选择列表的最后一个元素,即文件扩展名。
# lower() :将文件扩展名转换为小写,以确保比较时不区分大小写。
# in IMG_FORMATS :检查转换后的文件扩展名是否在 IMG_FORMATS 列表中。 IMG_FORMATS 是一个包含所有支持的图像文件扩展名的列表。
# 只有当文件扩展名在 IMG_FORMATS 列表中时,该路径才会被包含在最终的结果列表中。
# 排序 : sorted(...) 。
# sorted 函数对列表推导式生成的列表进行排序。排序后的列表将包含所有有效图像文件的路径,且路径中的斜杠已被替换为操作系统特定的分隔符。
# 综上所述,这行代码的作用是 :
# 遍历 f 列表中的所有路径。
# 将路径中的正斜杠替换为操作系统特定的路径分隔符。
# 筛选出文件扩展名在 IMG_FORMATS 列表中的路径。
# 对筛选后的路径列表进行排序。
# 最终, im_files 变量将包含一个排序后的、包含所有有效图像文件路径的列表,这些路径已适配当前操作系统的路径分隔符。
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
# 断言检查。确保至少找到了一个图像文件,否则抛出错误。
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}" # {self.prefix}在{img_path}. {FORMATS_HELP_MSG}中未找到图像。
# 异常处理。捕获任何异常,并抛出 FileNotFoundError ,指示从 img_path 加载数据时出错。
except Exception as e:
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e # {self.prefix}从 {img_path}\n{HELP_URL} 加载数据时出错。
# 处理分数数据集。如果 fraction 小于1,则只保留数据集的一部分。
if self.fraction < 1:
im_files = im_files[: round(len(im_files) * self.fraction)] # retain a fraction of the dataset
# 返回值。返回一个包含所有图像文件路径的列表。
return im_files
# 这个方法提供了一个灵活的方式来处理图像文件路径,无论是单个文件、文件列表还是目录,都能正确地读取和返回图像文件路径列表。
# 这段代码定义了一个名为 update_labels 的方法,它是 BaseDataset 类的一部分。这个方法的目的是更新数据集中的标签,以便只包含特定的类别(如果提供了 include_class 参数)。
# 参数。
# 1.self : 类实例的引用。
# 2.include_class : 一个可选的列表,包含要包含的类别名称或索引。
def update_labels(self, include_class: Optional[list]):
# 更新标签以仅包含这些类(可选)。
"""Update labels to include only these classes (optional)."""
# 将 include_class 转换为 NumPy 数组。如果 include_class 不是 None ,则将其转换为 NumPy 数组,并将其重塑为一个行向量。
include_class_array = np.array(include_class).reshape(1, -1)
# 遍历所有标签。遍历 self.labels 列表中的每个标签,其中 self.labels 存储了数据集中所有图像的标签信息。
for i in range(len(self.labels)):
# 检查 include_class 是否提供。如果提供了 include_class 参数,则执行以下操作。
if include_class is not None:
# 提取当前标签信息。
# 从当前标签中提取 类别 、 边界框 、 线段 和 关键点 信息。
cls = self.labels[i]["cls"]
bboxes = self.labels[i]["bboxes"]
segments = self.labels[i]["segments"]
keypoints = self.labels[i]["keypoints"]
# 筛选类别。使用 NumPy 的广播和比较功能,找出当前标签中的类别是否在 include_class_array 中。
# 用于筛选出与 include_class 匹配的类别。
# 比较操作。
# cls == include_class_array :这是一个元素级别的比较操作,用于检查 cls 数组中的每个元素是否与 include_class_array 中的任何元素相等。
# cls 是一个一维数组,包含当前标签的类别索引,而 include_class_array 是一个二维数组(经过 reshape(1, -1) 处理),包含要包含的类别索引。
# 布尔数组。
# 比较操作的结果是一个布尔数组,其中每个元素表示 cls 中的相应元素是否与 include_class_array 中的任何元素相等。
# 任意值(any)。
# .any(1) :这是一个 NumPy 方法,用于沿着指定轴(这里是轴1,即列)检查布尔数组中是否有任何 True 值。如果有任何 True 值,则返回 True ;否则返回 False 。
# 作用 :
# j 是一个布尔数组,其长度与 cls 相同,表示 cls 中的每个类别是否在 include_class_array 中。
# 这种方法允许灵活地根据用户提供的类别列表筛选标签,使得数据集只包含特定的类别,这对于特定任务或数据子集的训练非常有用。
j = (cls == include_class_array).any(1)
# 更新标签信息。
# 根据筛选结果 j 更新 类别 和 边界框信息 。
self.labels[i]["cls"] = cls[j]
self.labels[i]["bboxes"] = bboxes[j]
# 更新线段和关键点信息。
# 如果存在线段或关键点信息,也根据筛选结果进行更新。
if segments:
# 这行代码是 update_labels 方法中的一部分,用于根据筛选结果 j 更新当前标签中的线段(segments)信息。
# 列表推导式。
# [segments[si] for si, idx in enumerate(j) if idx] :这是一个列表推导式,用于创建一个新的列表,其中只包含满足条件 if idx 的 segments 中的元素。
# 遍历和索引。
# enumerate(j) :这个函数返回 j 中的每个元素及其索引, si 是索引, idx 是对应的布尔值。
# if idx :这是一个条件语句,只有当 idx 为 True 时,才会将 segments[si] 添加到新列表中。
# 作用 :
# 这行代码的作用是筛选出 j 中为 True 的索引对应的 segments 元素,并将这些元素组成一个新的列表,赋值给 self.labels[i]["segments"] 。
# 这行代码确保只有与 include_class 匹配的类别对应的线段信息被保留在标签中,这对于处理特定类别的数据或清理数据集非常有用。通过这种方式,可以确保数据集中的每个标签只包含与特定任务相关的信息。
self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
if keypoints is not None:
self.labels[i]["keypoints"] = keypoints[j]
# 处理单类别情况。
# 如果 self.single_cls 为 True ,则将所有类别的索引设置为0,这通常用于单类别训练。
if self.single_cls:
self.labels[i]["cls"][:, 0] = 0
# update_labels 方法允许用户指定一个类别列表,只保留这些类别的标签信息,这对于特定任务或数据子集的训练非常有用。此外,它还支持单类别训练,通过将所有类别索引设置为0来实现。
# 这段代码定义了一个名为 load_image 的方法,它是 BaseDataset 类的一部分。这个方法的作用是从数据集中加载指定索引 i 的图像,并根据需要进行缩放,返回加载的图像以及它的原始尺寸和调整后的尺寸。
# 参数。
# 1.self : 类实例的引用。
# 2.i : 数据集中图像的索引。
# 3.rect_mode : 是否使用矩形模式进行缩放,默认为 True 。
def load_image(self, i, rect_mode=True):
# 从数据集索引“i”加载 1 个图像,返回(im,调整大小的 hw)。
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
# 检查缓存。从缓存中获取图像数据 im ,文件路径 f 和 .npy 缓存文件路径 fn 。
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
# 加载图像。
# 如果 im 是 None (即图像未缓存在 RAM 中)。
if im is None: # not cached in RAM
# 如果 .npy 文件存在,尝试从 .npy 文件加载图像。
if fn.exists(): # load npy
try:
im = np.load(fn)
# 如果加载失败,记录警告并删除损坏的 .npy 文件,然后从原始文件路径 f 使用 cv2.imread 加载图像。
except Exception as e:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}") # {self.prefix}警告 ⚠️ 删除损坏的 *.npy 图像文件 {fn},原因是:{e}。
# Path.unlink(missing_ok=False)
# unlink() 函数是 pathlib 模块中 Path 类的一个方法,用于删除文件系统中的一个文件。
# Path : 这是 pathlib 模块中的 Path 类,用于表示文件系统路径。
# unlink() : 这是 Path 类的方法,用于删除路径所指向的文件。
# 参数 :
# missing_ok : 这是一个可选参数,默认值为 False 。如果设置为 True ,则在文件不存在时不会抛出异常,而是静默地忽略这个错误。
# 功能 :
# unlink() 方法用于删除文件系统中的一个文件。如果文件不存在,并且 missing_ok 参数为 False (默认值),则会引发一个 FileNotFoundError 。
# 注意事项 :
# 使用 unlink() 方法时要小心,因为一旦文件被删除,就无法恢复。
# 确保在删除文件之前有适当的错误处理和文件存在性检查,除非你确定文件存在,或者你不在乎文件是否实际存在。
# 在多线程或多进程环境中,文件可能会在不同的执行线程或进程中被访问或修改,因此在使用 unlink() 时要特别注意同步和竞态条件。
Path(fn).unlink(missing_ok=True)
im = cv2.imread(f) # BGR
# 如果 .npy 文件不存在,直接从原始文件路径 f 加载图像。
else: # read image
im = cv2.imread(f) # BGR
# 如果从 f 加载图像失败,则抛出 FileNotFoundError 。
if im is None:
raise FileNotFoundError(f"Image Not Found {f}") # 未找到图片 {f}。
# 缩放图像。
# 这段代码处理图像的缩放操作,确保图像的尺寸符合特定的要求。
# 获取原始尺寸。这行代码从图像矩阵 im 中获取原始的高度 h0 和宽度 w0 。
h0, w0 = im.shape[:2] # orig hw
# 矩形模式缩放。
# 如果 rect_mode 为 True ,则执行以下操作。
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
# 计算缩放比例 r 。 self.imgsz 是目标图像尺寸, max(h0, w0) 是原始图像的较大边, r 是为了保持宽高比而需要缩放的比例。
r = self.imgsz / max(h0, w0) # ratio
# 检查比例是否为1(即是否需要缩放)。
# 如果 r 不等于1,说明需要缩放图像。
if r != 1: # if sizes are not equal
# 计算新的宽度 w 和高度 h ,并确保不超过目标尺寸 self.imgsz 。
# 这里使用 math.ceil 向上取整,因为宽度和高度必须是整数。 min 函数确保新的尺寸不会超过目标尺寸 self.imgsz 。
w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
# 执行缩放操作。使用 cv2.resize 函数按照计算出的新尺寸 w 和 h 缩放图像, interpolation=cv2.INTER_LINEAR 指定了线性插值方法。
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
# 非矩形模式缩放。
# 如果 rect_mode 为 False 且原始尺寸不等于目标尺寸 self.imgsz ,则执行以下操作。
elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz
# 将图像拉伸至目标尺寸 self.imgsz 。里直接将图像缩放到目标尺寸,不保持宽高比。
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
# 这段代码提供了两种缩放模式。矩形模式 :保持图像的宽高比,只缩放较长的一边至目标尺寸 self.imgsz 。非矩形模式 :不保持宽高比,直接将图像拉伸至目标尺寸 self.imgsz 的正方形。
# 这样的设计使得图像处理更加灵活,可以根据不同的应用场景选择合适的缩放策略。
# Add to buffer if training with augmentations
# 缓存和缓冲区管理。
# 这段代码是 load_image 方法的一部分,它处理图像的加载、缓存和数据增强。
# 数据增强检查。
# 如果 self.augment 为 True ,表示启用了数据增强功能。
if self.augment:
# 缓存图像信息。
# self.ims[i] 缓存加载的图像 im 。 self.im_hw0[i] 缓存图像的原始尺寸 (h0, w0) 。 self.im_hw[i] 缓存图像调整后的尺寸 im.shape[:2] 。
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
# 管理缓冲区。将当前图像的索引 i 添加到 self.buffer 缓冲区中。
self.buffer.append(i)
# 缓冲区长度检查。
# 检查缓冲区的长度是否超过了 self.max_buffer_length 。 1 < len(self.buffer) 确保缓冲区不为空。
if 1 < len(self.buffer) >= self.max_buffer_length: # prevent empty buffer
# 释放最早缓存。
# 如果缓冲区长度达到最大值,从缓冲区中移除最早的索引 j 。
j = self.buffer.pop(0)
# 如果缓存策略不是 "ram",则释放与索引 j 相关的缓存,即将 self.ims[j] 、 self.im_hw0[j] 和 self.im_hw[j] 设置为 None 。
if self.cache != "ram":
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
# 返回图像信息。返回加载的图像 im ,原始尺寸 (h0, w0) 和调整后的尺寸 im.shape[:2] 。
return im, (h0, w0), im.shape[:2]
# 这段代码处理了图像的加载、缓存和数据增强。它确保 :
# 在数据增强时,图像被缓存并管理在缓冲区中。
# 缓冲区长度不超过最大限制,超出时释放最早的缓存。
# 返回图像及其尺寸信息,无论是从缓存中还是新加载的图像。
# 这种设计使得图像加载和处理更加高效,特别是在数据增强时,可以减少重复的磁盘I/O操作,提高训练速度。同时,通过管理缓冲区,它还有助于控制内存使用。
# 返回缓存的图像信息。如果图像已经缓存在 RAM 中,则直接返回缓存的图像 self.ims[i] ,原始尺寸 self.im_hw0[i] 和调整后的尺寸 self.im_hw[i] 。
return self.ims[i], self.im_hw0[i], self.im_hw[i]
# 这个方法确保了图像数据的高效加载和处理,同时支持缓存和数据增强,适用于深度学习模型训练中的数据预处理。
# 这段代码定义了一个名为 cache_images 的方法,它是 BaseDataset 类的一部分。这个方法的作用是将数据集中的图像缓存到内存(RAM)或磁盘上,以加快后续的数据加载速度。
# 参数。 self : 类实例的引用。
def cache_images(self):
# 将图像缓存到内存或磁盘。
"""Cache images to memory or disk."""
# 初始化变量。 b 用于记录缓存的图像总字节数。 gb 是字节到吉字节的转换因子(1GB = 2^30 bytes)。
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
# 选择缓存函数和存储类型。
# 根据 self.cache 的值选择缓存函数和存储类型。如果 self.cache 是 "disk",则使用 self.cache_images_to_disk 函数将图像缓存到磁盘;否则,使用 self.load_image 函数将图像缓存到 RAM。
fcn, storage = (self.cache_images_to_disk, "Disk") if self.cache == "disk" else (self.load_image, "RAM")
# 使用线程池进行并行处理。
# 创建一个 ThreadPool 对象,用于并行处理图像的加载和缓存。 NUM_THREADS 是线程池的大小。
# NUM_THREADS :它用于设置 OpenMP(一个用于多平台共享内存并行编程的API)的线程数。
with ThreadPool(NUM_THREADS) as pool:
# 并行处理图像。使用 pool.imap 方法并行调用 fcn 函数,对数据集中的每个图像进行处理。 range(self.ni) 生成从 0 到 self.ni - 1 的索引序列。
results = pool.imap(fcn, range(self.ni))
# 进度条显示。
# 创建一个 TQDM 进度条对象,用于显示缓存进度。 enumerate(results) 将结果序列与索引配对, total=self.ni 设置进度条的总步数。
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
# 处理结果并更新缓存。
# 遍历处理结果,根据缓存类型更新缓存信息和进度条描述。
for i, x in pbar:
# 如果缓存到磁盘,更新已缓存的字节数 b 。
if self.cache == "disk":
# 这行代码是在 cache_images 方法中用于累加缓存到磁盘的图像文件大小的。
# self.npy_files[i] : self.npy_files 是一个列表,其中包含了数据集中每个图像对应的 .npy 文件的路径。 i 是当前处理的图像索引。 self.npy_files[i] 表示第 i 个图像对应的 .npy 文件的路径。
# .stat() : stat() 方法是 pathlib.Path 对象的一个方法,用于获取文件的状态信息。 它返回一个对象,其中包含了文件的各种属性,如大小、修改时间等。
# st_size : st_size 是 stat() 方法返回的对象的一个属性,表示文件的大小,单位是字节。
# 累加文件大小 : b += self.npy_files[i].stat().st_size 将第 i 个 .npy 文件的大小加到变量 b 上。 b 是一个累计变量,用于记录所有缓存图像的总字节数。
# 这行代码的作用是更新变量 b ,使其包含到目前为止处理的所有图像文件的总大小。 这个总大小用于计算和显示缓存进度,以及评估缓存操作对磁盘空间的影响。
b += self.npy_files[i].stat().st_size
# 如果缓存到 RAM,将加载的图像和尺寸信息存储到相应的成员变量中,并更新已缓存的字节数 b 。
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
# 这行代码是在 cache_images 方法中用于累加缓存到内存(RAM)的图像数据大小的。
# self.ims[i] : self.ims 是一个列表,其中包含了数据集中每个图像的像素数据。 i 是当前处理的图像索引。 self.ims[i] 表示第 i 个图像的像素数据。
# .nbytes : .nbytes 是 NumPy 数组的一个属性,表示数组占用的字节数。 对于图像数据,这通常包括图像的宽度、高度和颜色通道数。
# 累加字节数 : b += self.ims[i].nbytes 将第 i 个图像数据的大小(以字节为单位)加到变量 b 上。 b 是一个累计变量,用于记录所有缓存图像的总字节数。
# 这行代码的作用是更新变量 b ,使其包含到目前为止处理的所有图像数据的总大小。 这个总大小用于计算和显示缓存进度,以及评估缓存操作对内存使用的影响。
b += self.ims[i].nbytes
# 这行代码是用于更新进度条描述的,它在 cache_images 方法中显示当前缓存操作的进度和状态。
# 格式化字符串、
# f"{self.prefix}Caching images ({b / gb:.1f}GB {storage})" 是一个格式化字符串,用于构建进度条的描述文本。
# self.prefix 是类实例的一个成员变量,用于在日志或进度条前添加前缀。
# Caching images 是一个静态字符串,表示当前正在进行的操作是缓存图像。
# (b / gb:.1f}GB 计算当前缓存的总字节数 b 除以吉字节的字节数 gb ( 1GB = 2^30 bytes ),并格式化为保留一位小数的浮点数,表示已缓存数据的大小(以GB为单位)。
# {storage} 是一个变量,表示当前的存储位置,可以是 "RAM" 或 "Disk"。
pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {storage})"
# 关闭进度条。
pbar.close()
# cache_images 方法通过并行处理和缓存机制,提高了图像数据加载的效率。它允许用户选择将图像缓存到 RAM 或磁盘,以适应不同的内存和存储条件。进度条的使用使得缓存过程可视化,方便用户监控缓存进度。
# 这段代码定义了一个名为 cache_images_to_disk 的方法,它是 BaseDataset 类的一部分。这个方法的作用是将图像数据保存为 .npy 文件,以便在将来的数据加载过程中能够更快地加载图像。
# 参数。
# 1.self : 类实例的引用。
# 2.i : 数据集中图像的索引。
def cache_images_to_disk(self, i):
# 将图像保存为 *.npy 文件以便更快地加载。
"""Saves an image as an *.npy file for faster loading."""
# 获取 .npy 文件路径。
# self.npy_files 是一个列表,包含了数据集中每个图像对应的 .npy 文件的路径。 i 是当前处理的图像索引。 f 表示第 i 个图像对应的 .npy 文件的路径。
f = self.npy_files[i]
# 检查文件是否存在。使用 Path 对象的 exists() 方法检查对应的 .npy 文件是否已经存在于文件系统中。
if not f.exists():
# path.as_posix()
# 在 Python 的 pathlib 模块中, Path 类的 .as_posix() 方法用于将 Path 对象表示的路径转换为 POSIX 风格的字符串。POSIX 是一个操作系统标准,它规定了文件路径应该使用正斜杠( / )作为目录分隔符。
# path : Path 类的实例。
# 返回值 :
# 返回一个字符串,表示 Path 对象的路径,其中所有的路径分隔符都被替换为正斜杠( / )。
# 方法功能 :
# .as_posix() 方法将 Path 对象中的路径转换为一个字符串,这个字符串使用正斜杠( / )作为所有目录的分隔符,无论在原始路径中使用的是哪种操作系统的路径分隔符(例如,在 Windows 中可能是反斜杠 \ )。
# 此外,该方法还会处理路径中的一些特殊情况,例如,将相对路径(如 ./ 或 ../ )转换为简化形式,但不改变它们的相对性。 去除路径中多余的分隔符。
# 注意事项 :
# .as_posix() 方法不检查路径的实际存在性,它仅仅进行字符串层面的转换。
# 如果你需要在不同的操作系统之间移植代码,或者与期望 POSIX 路径风格的外部工具或库交互,使用 .as_posix() 方法可以帮助确保路径的兼容性。
# 读取图像并保存为 .npy 文件。
# 如果 .npy 文件不存在,则执行以下操作 :
# 使用 cv2.imread 从原始文件路径 self.im_files[i] 读取图像数据。
# 使用 np.save 将读取的图像数据保存为 .npy 文件。 f.as_posix() 将 Path 对象转换为字符串路径。
# allow_pickle=False 参数用于提高安全性,防止不受信任的数据被pickle序列化。
np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False)
# cache_images_to_disk 方法通过将图像数据保存为 .npy 文件,可以显著提高数据集的加载速度,特别是在处理大型数据集时。这种方法减少了每次从原始文件(如JPEG或PNG)加载图像的开销,使得数据加载更加高效。
# 注意事项 :
# 保存为 .npy 文件时,会占用一定的磁盘空间。因此,在使用这种方法时,需要确保有足够的磁盘空间。
# .npy 文件是二进制文件,只能使用NumPy库加载。这意味着,如果你需要在其他不支持NumPy的环境中访问这些图像数据,可能需要额外的处理步骤。
# 使用 allow_pickle=False 是一个安全措施,因为pickle序列化可能被用来执行任意代码。在处理不受信任的数据时,应该始终禁用pickle。
# 将图像保存为 .npy 文件(NumPy数组文件)有几个好处,特别是在数据处理和机器学习工作流程中 :
# 加载速度 :
# .npy 文件是二进制格式,存储了NumPy数组的原始字节,这意味着它们可以非常快速地被加载到内存中,因为NumPy可以直接将这些字节映射到数组对象,无需重新解释或解码数据。
# 存储效率 :
# .npy 文件存储的是原始像素值,不包含任何额外的元数据或文件格式特定的开销,这使得它们比图像文件格式(如JPEG或PNG)更加紧凑,可以节省存储空间。
# 一致性 :
# 在机器学习和深度学习中,数据通常以NumPy数组的形式处理。将图像保存为 .npy 文件可以保持数据的一致性,使得数据在不同阶段(如预处理、训练和推理)的处理更加方便。
# 并行处理 :
# 当处理大规模数据集时,可以并行地将图像转换为 .npy 文件,然后并行地加载它们,这有助于提高数据处理的效率。
# 减少I/O操作 :
# 读取和写入磁盘是一个相对慢的操作。通过预先将图像转换为 .npy 文件,可以减少在训练过程中对磁盘I/O操作的需求,因为 .npy 文件可以被更快地加载。
# 避免格式转换 :
# 直接从原始图像文件(如JPEG或PNG)加载图像通常需要格式转换(解码),这会增加处理时间。使用 .npy 文件可以避免这一步骤。
# 数据安全 :
# 如前所述, .npy 文件不支持pickle序列化,这可以防止潜在的安全风险,因为pickle可以执行任意代码。
# 兼容性 :
# .npy 文件是纯数据文件,可以在不同的操作系统和硬件平台上使用,无需担心兼容性问题。
# 尽管有这些优点,但 .npy 文件也有一些局限性,比如它们不是图像文件,不能直接被图像查看器打开,且只能使用支持NumPy的程序来读取。
# 因此,是否使用 .npy 文件需要根据具体的应用场景和需求来决定。
# 这段代码定义了一个名为 check_cache_ram 的方法,它是 BaseDataset 类的一部分。这个方法的作用是检查是否有足够的可用内存来缓存整个数据集的图像到RAM中。
# 参数。
# 1.self : 类实例的引用。
# 2.safety_margin : 一个浮点数,表示在计算所需内存时要考虑的安全边际,默认为0.5(即50%)。
def check_cache_ram(self, safety_margin=0.5):
# 检查图像缓存要求与可用内存。
"""Check image caching requirements vs available memory."""
# 初始化变量。 b 用于记录估计的缓存图像总字节数。 gb 是字节到吉字节的转换因子(1GB = 2^30 bytes)。
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
# 选择样本数量。 n 是样本数量,用于估计整个数据集所需的内存,最多为30或数据集中的图像数量。
n = min(self.ni, 30) # extrapolate from 30 random images
# 估计单张图像的内存需求。
# 循环 n 次,每次随机选择一个图像文件。
for _ in range(n):
im = cv2.imread(random.choice(self.im_files)) # sample image
# 计算图像缩放比例。
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
# 计算缩放后的图像大小。
b += im.nbytes * ratio**2
# 计算总内存需求。根据样本图像计算整个数据集所需的内存,并考虑安全边际。
# 这行代码计算了缓存整个数据集到RAM所需的总内存量。
# 变量解释。
# b :累加的样本 图像数据大小 乘以 缩放比例的平方 ,用于估计单张图像缩放后所需的内存。
# self.ni :数据集中图像的总数。
# n :样本数量,用于估计整个数据集所需的内存,这里取数据集中图像总数和30的最小值。
# safety_margin :安全边际比例,默认为0.5(即50%),用于确保计算中包含一定的额外空间以应对不确定性。
# 计算过程。
# b * self.ni :首先,将单张图像的估计内存乘以数据集中图像的总数,得到未考虑样本比例和安全边际时的总内存需求。
# ... / n :然后,将上述结果除以样本数量 n ,以根据样本估计整个数据集的内存需求。这是因为 b 是基于样本图像计算的,需要根据样本比例进行放大。
# ... * (1 + safety_margin) :最后,将上述结果乘以 (1 + safety_margin) ,以包含安全边际。这意味着实际计算的内存需求比估计值多出安全边际指定的比例,以确保有足够的缓冲空间。
# 作用 :
# 这行代码的作用是计算缓存整个数据集到RAM所需的总内存量,并考虑了安全边际。 mem_required 表示最终计算出的所需内存量,以字节为单位。
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
# psutil.virtual_memory()
# psutil.virtual_memory() 是一个函数,属于 psutil 库,用于获取系统虚拟内存(RAM)的使用情况。
# 参数 :无参数。
# 返回值 :
# 该函数返回一个命名元组( psutil._common.smem ),其中包含了以下属性 :
# total :总物理内存大小,单位为字节。
# available :可供分配的内存大小,单位为字节,这个值是系统认为可用的内存,包括缓存和缓冲区占用的内存。
# percent :已使用内存的百分比。
# used :已使用的内存大小,单位为字节。
# free :空闲的内存大小,单位为字节。
# active :当前正在使用或最近使用的内存,单位为字节。
# inactive :标记为未使用的内存,单位为字节。
# buffers :缓存数据,如文件系统元数据,单位为字节。
# cached :缓存数据,单位为字节。
# shared :可由多个进程共享的内存,单位为字节。
# slab :用于内核数据结构的内存,单位为字节。
# 获取系统内存信息。使用 psutil 库获取系统的虚拟内存信息。
mem = psutil.virtual_memory()
# 判断是否缓存。 如果所需的内存小于可用内存,则 success 为 True ,表示可以缓存到RAM。
success = mem_required < mem.available # to cache or not to cache, that is the question
# 处理不可缓存的情况。如果 success 为 False ,则设置 self.cache 为 None 并记录日志。
if not success:
self.cache = None
LOGGER.info(
f"{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images " # {self.prefix}{mem_required / gb:.1f}缓存图像所需的 GB RAM。
f"with {int(safety_margin * 100)}% safety margin but only " # 具有 {int(safety_margin * 100)}% 安全裕度,但只有。
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, not caching images ⚠️" # {mem.available / gb:.1f}/{mem.total / gb:.1f}GB 可用,不缓存图像 ⚠️。
)
# 返回结果。返回 success 值,指示是否可以缓存图像到RAM。
return success
# check_cache_ram 方法通过估计数据集的内存需求并与系统可用内存进行比较,来决定是否将图像缓存到RAM中。这种方法有助于避免内存溢出,并确保数据加载过程的高效性。通过考虑安全边际,它还提供了一定程度的灵活性,以应对内存需求的波动。
# 这段代码定义了一个名为 set_rectangle 的方法,它是 BaseDataset 类的一部分。这个方法的作用是为 YOLO 目标检测设置边界框的形状为矩形,并根据图像的宽高比来确定训练时每个批次的图像尺寸。
def set_rectangle(self):
# 将 YOLO 检测的边界框形状设置为矩形。
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
# 计算批次索引。
# np.arange(self.ni) 生成一个从0到 self.ni - 1 的索引数组。
# np.floor(...) / self.batch_size 计算每个索引对应的批次索引。
# .astype(int) 将结果转换为整数类型。
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
# 计算批次数量。 bi[-1] 是最后一个批次的索引, +1 得到总的批次数量。
nb = bi[-1] + 1 # number of batches
# 提取图像尺寸并计算宽高比。
# 从标签中提取图像的原始尺寸(宽和高)。
s = np.array([x.pop("shape") for x in self.labels]) # hw
# 计算每个图像的宽高比。
ar = s[:, 0] / s[:, 1] # aspect ratio
# 根据宽高比排序。
# ar.argsort() 获取宽高比的排序索引。
irect = ar.argsort()
# 根据排序索引重新排列图像文件路径、标签和宽高比。
self.im_files = [self.im_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
ar = ar[irect]
# Set training image shapes
# 设置训练图像形状。
# 这段代码是 set_rectangle 方法中的一部分,用于确定每个训练批次的图像尺寸,这些尺寸将被用于 YOLO 目标检测的训练过程中。
# 初始化形状列表。
# shapes 是一个列表,用于存储每个批次的图像形状(宽和高)。
# [1, 1] 是初始形状,表示如果没有任何图像的宽高比(aspect ratio, ar)超过1,那么所有批次的图像都将被设置为正方形(1x1)。
# nb 是批次的数量。
shapes = [[1, 1]] * nb
# 遍历每个批次。循环遍历每个批次。
for i in range(nb):
# 提取批次的宽高比。 ari 是当前批次的图像宽高比数组。
ari = ar[bi == i]
# 计算最小和最大宽高比。 mini 是当前批次中最小的宽高比。 maxi 是当前批次中最大的宽高比。
mini, maxi = ari.min(), ari.max()
# 根据宽高比设置形状。
# 如果 maxi 小于1,意味着该批次中所有图像都是高宽比(更窄),因此设置形状为 [maxi, 1] ,即高度被拉伸至1,宽度按比例缩小。
if maxi < 1:
shapes[i] = [maxi, 1]
# 如果 mini 大于1,意味着该批次中所有图像都是宽高比(更宽),因此设置形状为 [1, 1 / mini] ,即宽度被压缩至1,高度按比例增加。
elif mini > 1:
shapes[i] = [1, 1 / mini]
# 这段代码的目的是根据每个批次中图像的宽高比来动态调整图像的形状,以确保在训练过程中能够适应不同尺寸的图像。通过这种方式,可以提高训练的灵活性和效率,特别是在处理具有不同宽高比的图像时。这种方法有助于保持图像的宽高比,避免不必要的形变,这对于某些对宽高比敏感的任务(如目标检测)是非常重要的。
# 计算批次尺寸。根据 形状列表 、 图像尺寸 、 步长 和 填充 计算每个批次的实际尺寸。 np.ceil 向上取整,确保尺寸是整数。 .astype(int) 将结果转换为整数类型。
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
# 设置批次索引。将批次索引 bi 存储为类的成员变量。
self.batch = bi # batch index of image
# set_rectangle 方法通过考虑图像的宽高比来设置 YOLO 目标检测的矩形边界框,并确定每个训练批次的图像尺寸。这种方法有助于优化训练过程中的内存使用和计算效率。通过根据图像尺寸动态调整批次尺寸,可以确保不同尺寸的图像都能被有效处理。
# 这段代码定义了一个名为 __getitem__ 的方法,它是 BaseDataset 类的一部分,并且是 Python 中特殊方法的特殊用法。这个方法使得 BaseDataset 类的实例可以像普通的列表或数组一样,通过索引来访问数据集中的元素。
# 1.self : 类实例的引用。
# 2.index : 要访问的数据集中图像的索引。
def __getitem__(self, index):
# 返回给定索引的转换标签信息。
"""Returns transformed label information for given index."""
# 获取图像和标签 : get_image_and_label(index) 。 调用 get_image_and_label 方法,根据提供的 index 索引获取对应的图像和标签数据。
# 应用变换 : self.transforms(...) 。将获取到的图像和标签数据传递给 self.transforms ,这是一个在类的初始化过程中定义的变换函数或变换序列。 这些变换包括图像的缩放、裁剪、归一化等操作,以及标签的相应转换。
# 返回变换后的数据。返回经过变换后的图像和标签数据。
return self.transforms(self.get_image_and_label(index))
# __getitem__ 方法是 PyTorch Dataset 类的核心方法之一,它允许数据集与 PyTorch 的 DataLoader 配合使用,实现数据的批量处理、打乱和多线程加载。这个方法确保了当索引访问数据集时,能够返回正确变换后的数据,这对于构建高效的数据管道至关重要。
# 示例 :
# 假设你有一个 BaseDataset 类的实例 dataset ,并且想要获取索引为 5 的图像和标签 :
# image, label = dataset[5]
# 这将调用 dataset 的 __getitem__ 方法,获取索引为 5 的图像和标签,然后返回它们。这些数据可以被进一步用于模型的训练或评估。
# 这段代码定义了一个名为 get_image_and_label 的方法,它是 BaseDataset 类的一部分。这个方法的作用是从数据集中获取指定索引 index 的图像和对应的标签信息,并进行必要的处理。
# 参数。
# 1.self : 类实例的引用。
# 2.index : 要访问的数据集中图像的索引。
def get_image_and_label(self, index):
# 从数据集中获取并返回标签信息。
"""Get and return label information from the dataset."""
# 深拷贝标签信息。
# 使用 deepcopy 对索引 index 对应的标签信息进行深拷贝,以确保标签信息在后续处理中被安全地修改,而不会影响原始数据。
# 这一步骤是必要的,因为浅拷贝可能会导致修改拷贝数据时不小心修改原始数据的问题,这一点在 Ultralytics 的 GitHub 仓库的 Pull Request #1948 中被提及和修复。
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
# 移除不必要的标签信息。从标签中移除 shape 键,因为 shape 是用于矩形训练的,如果不需要则移除。
label.pop("shape", None) # shape is for rect, remove it
# 加载图像并更新标签。调用 load_image 方法加载索引 index 对应的 图像 ,并获取 原始尺寸 和 调整后的尺寸 。 更新标签信息中的图像数据和尺寸。
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
# 计算缩放比例。计算 调整后的尺寸 与 原始尺寸 的 比例 ,用于评估。
label["ratio_pad"] = (
label["resized_shape"][0] / label["ori_shape"][0],
label["resized_shape"][1] / label["ori_shape"][1],
) # for evaluation
# 处理矩形训练。如果启用矩形训练,更新标签信息中的 rect_shape 键,以存储当前批次的图像形状。
if self.rect:
label["rect_shape"] = self.batch_shapes[self.batch[index]]
# 更新标签信息。调用 update_labels_info 方法进一步更新标签信息,并返回更新后的标签。
return self.update_labels_info(label)
# get_image_and_label 方法负责从数据集中获取图像和标签,进行必要的处理,并返回更新后的标签信息。这个方法是数据预处理流程的关键部分,确保了图像和标签在训练过程中的正确性和一致性。通过深拷贝和仔细处理标签信息,它有助于避免数据共享和修改时的潜在问题。
# 这段代码定义了一个名为 __len__ 的方法,它是 BaseDataset 类的一部分,并且是一个 Python 特殊方法,用于返回对象的长度。在数据集对象的上下文中, __len__ 方法返回数据集中的标签列表的长度,即数据集中图像的数量。
# 方法解释。
# 1.self : 类实例的引用
def __len__(self):
# 返回数据集的标签列表的长度。
"""Returns the length of the labels list for the dataset."""
# self.labels : 一个列表,包含了数据集中所有图像的标签信息。
# 调用内置的 len 函数来获取 self.labels 列表的长度。
return len(self.labels)
# 这个方法使得 BaseDataset 类的实例可以像普通的列表或数组一样,使用 len() 函数来获取数据集的大小。
# 在 PyTorch 中, __len__ 方法是 Dataset 类所必需的,因为它被 DataLoader 用来确定数据集中有多少个样本,从而可以正确地进行批处理和数据迭代。
# 示例 :
# 假设你有一个 BaseDataset 类的实例 dataset :
# dataset = BaseDataset(...)
# size_of_dataset = len(dataset)
# 这将调用 dataset 的 __len__ 方法,返回数据集中的图像数量。这个数量可以用于确定训练周期的数量、数据加载的批次大小等。
# 这段代码定义了一个名为 update_labels_info 的方法,它是 BaseDataset 类的一部分。这个方法的目的是提供一个接口,允许用户根据自己的需要自定义标签信息的格式。在当前的实现中,这个方法仅仅是返回了传入的标签信息,没有进行任何修改。
# 方法解释。
# 1.self : 类实例的引用。
# 2.label : 一个字典,包含了当前图像的标签信息。
def update_labels_info(self, label):
# 在此自定义您的标签格式。
"""Custom your label format here."""
# 直接返回传入的 label 字典,不做任何处理。
return label
# 这个方法作为一个占位符,提示用户在这里实现自定义的标签信息更新逻辑。
# 在实际应用中,用户可能需要根据特定的需求修改标签信息,例如添加新的标签字段、修改标签的格式或者执行某些数据增强操作。
# update_labels_info 方法提供了一个灵活的接口,允许用户根据自己的需求来调整和扩展标签信息,以适应不同的数据处理和模型训练需求。
# 这段代码定义了一个名为 build_transforms 的方法,它是 BaseDataset 类的一部分。这个方法的目的是构建和返回一组图像变换,这些变换将在数据加载时应用于图像和标签。然而,当前的实现只是抛出了一个 NotImplementedError 异常,表明这个方法需要在子类中具体实现。
# 方法解释。
# 1.self : 类实例的引用。
# 2.hyp : 一个可选参数,通常包含超参数或配置信息,用于定制变换。
def build_transforms(self, hyp=None):
# 用户可以在此处自定义增强。
"""
Users can customize augmentations here.
Example:
```python
if self.augment:
# Training transforms
return Compose([])
else:
# Val transforms
return Compose([])
```
"""
# 这行代码表明 build_transforms 方法是一个抽象方法,需要在继承 BaseDataset 的子类中具体实现。
raise NotImplementedError
# 这个方法作为一个模板,提示子类在这里实现具体的变换逻辑。
# 在深度学习和计算机视觉任务中,图像变换是数据预处理的重要部分,可以包括缩放、裁剪、旋转、归一化等多种操作。
# build_transforms 方法提供了一个框架,允许用户根据自己的需求来定制图像变换,以适应不同的数据处理和模型训练需求。
# 这段代码定义了一个名为 get_labels 的方法,它是 BaseDataset 类的一部分。这个方法的目的是获取数据集中所有图像的标签信息。然而,当前的实现只是抛出了一个 NotImplementedError 异常,这表明这个方法需要在子类中具体实现。
# 方法解释。
# 1.self : 类实例的引用。
def get_labels(self):
# 用户可以在此处自定义自己的格式。
"""
Users can customize their own format here.
Note:
Ensure output is a dictionary with the following keys:
```python
dict(
im_file=im_file,
shape=shape, # format: (height, width)
cls=cls,
bboxes=bboxes, # xywh
segments=segments, # xy
keypoints=keypoints, # xy
normalized=True, # or False
bbox_format="xyxy", # or xywh, ltwh
)
```
"""
# 这行代码表明 get_labels 方法是一个抽象方法,需要在继承 BaseDataset 的子类中实现具体的逻辑。
raise NotImplementedError
# 这个方法作为一个模板,提示子类在这里实现获取标签信息的具体逻辑。
# 在机器学习和深度学习中,标签信息是训练模型时必需的,它们可以是类别标签、边界框坐标、分割掩码等。
# get_labels 方法提供了一个框架,允许用户根据自己的数据源和格式来定制标签信息的加载逻辑,以适应不同的数据处理和模型训练需求。
原文地址:https://blog.csdn.net/m0_58169876/article/details/144303453
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!