自学内容网 自学内容网

训练常用API函数及方法

Tensor转Image

output = img.data.squeeze().float().clamp_(0, 1).numpy()
if output.ndim == 3:
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # 1.BGR 2。HWC 
output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8

维度转置

#交换第二维度和第三维度
array.transpose(0, 2, 1)

训练模型

def train(model):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for input, label in tqdm(train_loader):
            output = model(input)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch {epoch + 1} loss: {epoch_loss:.4f}')

颜色转换

bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)

自定义数据集

    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path  # 图像文件夹路径
        self.transform = transform  # 可选的转换
        self.image_files = os.listdir(folder_path)  # 获取文件夹中所有文件

    def __len__(self):
        return len(self.image_files)  # 返回图像数量

    def __getitem__(self, idx):
        img_name = os.path.join(self.folder_path, self.image_files[idx])  # 构建图像路径
        image = Image.open(img_name)  # 加载图像
        label = 0  # 这里可以根据需求定义标签(例如,如果是单一类别,标签可以是0)

        if self.transform:
            image = self.transform(image)  # 应用转换

        return image, label  # 返回图像和标签

图表显示

import matplotlib.pyplot as plt
# 创建10x5的子图布局
fig, axes = plt.subplots(1, num_images, figsize=(20, 2))
for i in range(num_images):
    # GBR
    axes[i].imshow(np.transpose(generated_images[i], (1, 2, 0)))
    axes[i].axis('off')
plt.show()

原文地址:https://blog.csdn.net/weixin_51277037/article/details/142501386

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!