自学内容网 自学内容网

深度学习:yolov3的使用--图像处理

定义了一个名为 ListDataset 的类,它继承自 PyTorch 的 Dataset 类,这个数据集从一个包含图像文件路径的列表中读取图像和对应的标签文件


class ListDataset(Dataset):
    def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
        with open(list_path, "r") as file:
            self.img_files = file.readlines()
        # 找到图片对应的label文件路径,将png和jpg格式变成txt格式
        self.label_files = [
            path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
            for path in self.img_files
        ]
        self.img_size = img_size#存储目标图像尺寸。
        self.max_objects = 100# 定义了图像中最大对象数量,默认为 100。
        self.augment = augment #是否进行数据增强
        self.multiscale = multiscale#是否多尺度训练
        self.normalized_labels = normalized_labels#是否标签归一化
        #定义了多尺度训练时图像尺寸的范围。
        self.min_size = self.img_size - 3 * 32
        self.max_size = self.img_size + 3 * 32
        self.batch_count = 0

调用函数

dataset = ListDataset(train_path, augment=True, multiscale=opt.multiscale_training)

定义了一个名为 collate_fn 的函数,它是 ListDataset 类的一个方法。这个函数用于将一个批次中的多个数据样本合并成一个批次

    def collate_fn(self, batch):
        # 解压批次数据
        paths, imgs, targets = list(zip(*batch))
        # Remove empty placeholder targets
        # 移除空目标
        targets = [boxes for boxes in targets if boxes is not None]
        # Add sample index to targets
        # 添加样本索引到目标
        for i, boxes in enumerate(targets):
            boxes[:, 0] = i
        # 拼接目标
        targets = torch.cat(targets, 0) #将targets按行进行拼接
        # Selects new image size every tenth batch
        # 多尺度训练
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
        # Resize images to input shape
        # 调整图像尺寸
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])
        self.batch_count += 1
        return paths, imgs, targets

调用函数

加载器可以批量加载数据集,并为训练过程提供数据

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,  #1个样本打包成一个batch进行加载
        shuffle=True,               #对数据进行随机打乱,
        num_workers=opt.n_cpu,      #用于指定子进程的数量,用于并行地加载数据。默认情况下,num_workers的值为0,表示没有使用子进程,所有数据都会在主进程中加载。当设置num_workers大于0时,DataLoader会创建指定数量的子进程,每个子进程都会负责加载一部分数据,然后主进程负责从这些子进程中获取数据。
                                    # 使用子进程可以加快数据的加载速度,因为每个子进程可以并行地加载一部分数据,从而充分利用多核CPU的计算能力。但是需要注意的是,使用子进程可能会导致数据的顺序被打乱,因此如果需要保持数据的原始顺序,应该将shuffle参数设置为False。
                                    # num_workers的值应该根据具体情况进行调整。如果数据集较大,可以考虑增加num_workers的值以充分利用计算机的资源。但是需要注意的是,如果num_workers的值过大,可能会导致内存消耗过大或者CPU负载过重,从而影响程序的性能。因此,需要根据实际情况进行调整。
        pin_memory=True,            #指定是否将加载进内存的数据的指针固定(pin),这个参数在某些情况下可以提高数据加载的速度。
                                    # 当设置pin_memory=True时,DataLoader会将加载进内存的数据的指针固定,即不进行移动操作。这样做的目的是为了提高数据传输的效率。因为当数据从磁盘或者网络等地方传输到内存中时,如果指针不固定,可能会导致数据在传输过程中被移动,从而需要重新读取,浪费了时间。而固定指针可以避免这种情况的发生,从而提高了数据传输的效率。
                                    # 需要注意的是,pin_memory参数的效果与操作系统和硬件的性能有关。在一些高性能的计算机上,固定指针可能并不会带来太大的性能提升。但是在一些内存带宽较小的计算机上,固定指针可能会显著提高数据加载的效率。因此,需要根据实际情况进行调整。
        collate_fn=dataset.collate_fn,
                                    # collate_fn是一个函数,用于对每个batch的数据进行合并。这个函数的输入是一个batch的数据,输出是一个合并后的数据。
                                    # collate_fn函数的主要作用是对每个batch的数据进行预处理,例如将不同数据类型的张量合并成一个张量,或者对序列数据进行padding操作等。这样可以使得每个batch的数据格式一致,便于模型进行训练。
                                    # 在默认情况下,collate_fn函数会将每个batch的数据按照第一个元素的张量形状进行合并。例如,如果一个batch的数据中第一个元素的张量形状是[
                                    # 3, 224, 224],那么collate_fn函数会将该batch的所有数据都调整为这个形状。
    )

创建了一个名为 optimizer 的优化器对象,用于在训练过程中更新模型的参数。

optimizer = torch.optim.Adam(model.parameters())

定义了一个名为 metrics 的列表,其中包含了在目标检测模型训练和评估过程中常用的一系列指标。

    metrics = [
        "grid_size",#表示模型输出的特征图的大小
        "loss",#损失值
        "x",
        "y",
        "w",
        "h",#表示目标检测中目标框的中心坐标(x, y)和宽高(w, h)
        "conf",#置信度
        "cls",#目标框中目标的类别预测的分数
        "cls_acc",#目标类别预测的正确率
        "recall50",
        "recall75",#表示在不同的置信度阈值(通常是 0.5 和 0.75)下的召回率。
        "precision",#精确度
        "conf_obj",
        "conf_noobj",#对象置信度和非对象置信度
    ]

pad_to_square 函数是一个用于将图像填充到正方形的函数

def pad_to_square(img, pad_value):
    c, h, w = img.shape
    # 计算高度和宽度之间的差异
    dim_diff = np.abs(h - w)
    # 计算高度和宽度差异的一半,以及剩余的部分
    pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2    #dim_diff - dim_diff // 2剩余部分,不一定能被2整除的空间
    # 根据高度和宽度的比较结果确定填充的方向和大小
    pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)  #填充方向(左,右,上,下)
    # 使用 PyTorch 的 F.pad 函数对图像进行填充
    img = F.pad(img, pad, "constant", value=pad_value)
    return img, pad

图像处理:获取图像文件的路径,将图像转换为 PyTorch 张量,并确保图像是 RGB 格式。获取图像的高度和宽度,使用 pad_to_square 函数将图像填充为正方形

    def __getitem__(self, index):

        # ---------
        #  Image图片的处理
        # ---------
        # 获取图像文件的路径,并将其拼接为绝对路径
        img_path = self.img_files[index % len(self.img_files)].rstrip()
        img_path = r'D:/KECHENG/pythonProject/.venv/Lib/PyTorch-YOLOv3/' + img_path
        #print (img_path)使用绝对路径,F:\人工智能学习\深度学习课件\代码\第7章yolo\PyTorch-YOLOv3\PyTorch-YOLOv3\data\coco
        # Extract image as PyTorch tensor
        # 转换为RGB格式,转换为PyTorch张量。
        img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))

        # Handle images with less than three channels
        if len(img.shape) != 3:#是为了防止你的图片中存在灰度图,
            img = img.unsqueeze(0)
            img = img.expand((3, img.shape[1:]))
        # 获取图像的高度和宽度
        _, h, w = img.shape
        h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
        # 当尺寸并不是标准的正方形,进行填充0。
        img, pad = pad_to_square(img, 0)
        _, padded_h, padded_w = img.shape

    

标签处理:获取标签文件的路径,将其转换为 PyTorch 张量。将标签中的归一化坐标转换回原始图像的坐标系,并根据填充调整坐标,将坐标重新归一化到填充后的图像尺寸。创建 targets 张量,用于存储更新后的标签信息。

# 获取标签文件的路径,并将其拼接为绝对路径。
        label_path = self.label_files[index % len(self.img_files)].rstrip()
        label_path = r'D:\KECHENG\pythonProject\.venv\Lib\PyTorch-YOLOv3/' + label_path
        #print (label_path)F:\人工智能学习\深度学习课件\代码\第7章yolo\PyTorch-YOLOv3\PyTorch-YOLOv3\data\coco

        targets = None
        if os.path.exists(label_path):
            # 读取标签文件内容,并将其转换为PyTorch张量
            boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
            # Extract coordinates for unpadded + unscaled image,
            # COCO数据集中的.txt文件每个字段的含义:
            # class_num:类别编号,从1开始。
            # box_cx:归一化后的中心横坐标,即像素坐标的cx除以图像宽度的结果。
            # box_cy:归一化后的中心纵坐标,即像素坐标的cy除以图像高度的结果。
            # box_w:归一化后的标注框宽度,即标注框宽度除以图像宽度的结果。
            # box_h:归一化后的标注框高度,即标注框高度除以图像高度的结果。
            x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)#还原回标注框的左上x值
            y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
            x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
            y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
            # Adjust for added padding
            # 根据添加的填充调整坐标
            x1 += pad[0]#图像填充0
            y1 += pad[2]
            x2 += pad[1]
            y2 += pad[3]
            # Returns (x, y, w, h)  填充后继续还原回原始的值
            # 将坐标转换回归一化值
            boxes[:, 1] = ((x1 + x2) / 2) / padded_w
            boxes[:, 2] = ((y1 + y2) / 2) / padded_h
            boxes[:, 3] *= w_factor / padded_w
            boxes[:, 4] *= h_factor / padded_h
            # 创建targets张量,用于存储更新后的标签信息
            targets = torch.zeros((len(boxes), 6))
            targets[:, 1:] = boxes

图像增强

 # 应用图像增强
        if self.augment:
            if np.random.random() < 0.5:
                img, targets = horisontal_flip(img, targets)

        return img_path, img, targets

调用函数

    for epoch in range(1,opt.epochs+1):
        model.train()
        # 记录当前轮次的开始时间
        start_time = time.time()
        # 从数据加载器 dataloader 中迭代获取批次数据。
        for batch_i, (_, imgs, targets) in enumerate(dataloader):
            # 计算批次完成的总数
            batches_done = len(dataloader) * epoch + batch_i
            # 数据移动到设备
            imgs = Variable(imgs.to(device))    #Variable类是PyTorch中的一个包装器,它将张量和它们的梯度信息封装在一起。当我们对一个张量进行操作时,PyTorch会自动地创建一个对应的Variable对象,其中包含了原始张量、梯度等信息。通过使用Variable,我们可以方便地进行自动微分和优化。
            targets = Variable(targets.to(device), requires_grad=False)
            print('imgs',imgs.shape)
            print('targets',targets.shape)
            

原文地址:https://blog.csdn.net/mohanyelong/article/details/143374431

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!