训练常用API函数及方法
Tensor转Image
output = img.data.squeeze().float().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # 1.BGR 2。HWC
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
维度转置
#交换第二维度和第三维度
array.transpose(0, 2, 1)
训练模型
def train(model):
for epoch in range(num_epochs):
running_loss = 0.0
for input, label in tqdm(train_loader):
output = model(input)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(train_loader)
print(f'Epoch {epoch + 1} loss: {epoch_loss:.4f}')
颜色转换
bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
自定义数据集
def __init__(self, folder_path, transform=None):
self.folder_path = folder_path # 图像文件夹路径
self.transform = transform # 可选的转换
self.image_files = os.listdir(folder_path) # 获取文件夹中所有文件
def __len__(self):
return len(self.image_files) # 返回图像数量
def __getitem__(self, idx):
img_name = os.path.join(self.folder_path, self.image_files[idx]) # 构建图像路径
image = Image.open(img_name) # 加载图像
label = 0 # 这里可以根据需求定义标签(例如,如果是单一类别,标签可以是0)
if self.transform:
image = self.transform(image) # 应用转换
return image, label # 返回图像和标签
图表显示
import matplotlib.pyplot as plt
# 创建10x5的子图布局
fig, axes = plt.subplots(1, num_images, figsize=(20, 2))
for i in range(num_images):
# GBR
axes[i].imshow(np.transpose(generated_images[i], (1, 2, 0)))
axes[i].axis('off')
plt.show()
原文地址:https://blog.csdn.net/weixin_51277037/article/details/142501386
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!