自学内容网 自学内容网

【深度学习实战(11)】搭建自己的dataset和dataloader

一、dataset和dataloader要点说明

在我们搭建自己的网络时,往往需要定义自己的datasetdataloader,将图像和标签数据送入模型。
(1)在我们定义dataset时,需要继承torch.utils.data.dataset,再重写三个方法:

  • init方法,主要用来定义数据的预处理
  • getitem方法,数据增强;返回数据的item和label
  • len方法,返回数据数量

(2)在我们定义dataloader时,需要考虑下面几个参数:

  • dataset :使用哪个数据集
  • batch_size:将数据集拆成一组多少个进行训练
  • shuffle:是否需要打乱数据
  • num_workers:几个mini_batch并行计算,一般<=你的电脑cpu数目
  • collect_fn:数据打包方式

(3)通过迭代的方式,按批次,获取dataloader中的数据

(4)关系图

在这里插入图片描述

二、核心代码框架

import os
import cv2
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader


# -------------------------------------------------------------#
#   自定义dataset需要继承torch.utils.data.dataset,
#   再重写def __init__,def __len__,def __getitem__三个方法
# -------------------------------------------------------------#
class YourDataset(Dataset):
    def __init__(self,  root_path):
        super(YourDataset, self).__init__()
        self.root_path = root_path
        #-------------------------------------------------------------------------#
        #   获取样本名,以jpg原始图片为参考,修改后缀名为json,png,获取json,png标签文件路径
        #-------------------------------------------------------------------------#
        self.sample_names = []
        jpg_path = os.path.join(os.path.join(self.root_path, "images"),)
        for file in os.listdir(jpg_path):
            if file.endswith(".jpg"):
                self.sample_names.append(os.path.splitext(file)[0]) # 去掉.json

    def __len__(self):
        #----------------------#
        #   返回数据数量
        #----------------------#
        return len(self.sample_names)

    def __getitem__(self, index):
        name = self.sample_names[index]

        # ----------------------#
        #   读取图像
        # ----------------------#
        img_path = os.path.join(os.path.join(self.root_path, "images"), name + '.jpg')
        image = cv2.imread(img_path)
        # ----------------------#
        #   读取标签
        # ----------------------#
        label_path = os.path.join(os.path.join(self.root_path, "jsons"), name + '.json')
        with open(label_path) as label_file:
            points = self.get_data_from_json(label_file)
        #----------------------#
        #   图像数据增强
        #----------------------#
        image = self.random_color(image)
        #----------------------#
        #   标签归一化
        #----------------------#
        labels = self.convert_labels(points)
        return image,  labels

# -------------------------------------#
#   图片和标签格式转换后,按批次(batch)打包
# -------------------------------------#
def dataloader_collate_fn(batch):
    images = []
    labels = []
    for img, label in batch:
        images.append(transforms.ToTensor()(img))
        labels.append(label)
    return images, labels


if __name__ == '__main__':
    # -------------------------------------#
    #   构建dataset
    # -------------------------------------#
    path = './data/train'
    train_dataset = YourDataset(path)

    # -------------------------------------#
    #   构建Dataloader
    # -------------------------------------#
    dataset = train_dataset
    batch_size = 32
    shuffle = True
    num_workers = 0
    collate_fn = dataloader_collate_fn
    sampler = None
    train_gen = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,drop_last=True, collate_fn=collate_fn, sampler=sampler)
    # ---------------------------------------------#
    #   通过迭代的方式,一批一批读取训练集中的图像和标签数据
    # ---------------------------------------------#
    for iter, batch in enumerate(train_gen):
        images,  labels = batch

原文地址:https://blog.csdn.net/m0_51579041/article/details/137988404

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!