GAN的损失函数和二元交叉熵损失的对应及代码
以下解释为GPT生成
这里有个问题,使用二元交叉熵,的时候生成器的损失如何体现
看代码
import torch
import torch.nn as nn
import torch.optim as optim
# 设置设备为GPU或CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义生成器 (Generator)
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size),
nn.Tanh() # 输出范围在 [-1, 1] 之间
)
def forward(self, x):
return self.model(x)
# 定义判别器 (Discriminator)
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size),
nn.Sigmoid() # 输出概率 [0, 1]
)
def forward(self, x):
return self.model(x)
# 超参数设置
input_size = 100 # 生成器输入噪声向量的维度
hidden_size = 128
output_size = 1 # 判别器的输出是一个概率
data_size = 784 # 假设输入数据的维度(例如,MNIST 图片 28x28 展开成 784 维向量)
# 初始化生成器和判别器
G = Generator(input_size, hidden_size, data_size).to(device)
D = Discriminator(data_size, hidden_size, output_size).to(device)
# 定义损失函数(二元交叉熵损失)
criterion = nn.BCELoss()
# 优化器
lr = 0.0002
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)
# 生成随机噪声的函数(用于生成器)
def generate_noise(batch_size, input_size):
return torch.randn(batch_size, input_size).to(device)
# 假设我们有一个简单的生成数据和真实数据的函数
def get_real_data(batch_size):
# 这里我们用随机生成的假数据来模拟真实数据
return torch.randn(batch_size, data_size).to(device)
# 训练步骤
epochs = 1000
batch_size = 64
for epoch in range(epochs):
# 训练判别器
real_data = get_real_data(batch_size) # 获取真实数据
noise = generate_noise(batch_size, input_size) # 生成噪声
fake_data = G(noise) # 生成数据
# 判别器的目标:正确区分真实数据和生成数据
real_labels = torch.ones(batch_size, 1).to(device) # 真实数据的标签为1
fake_labels = torch.zeros(batch_size, 1).to(device) # 生成数据的标签为0
# 判别器对真实数据的损失
outputs_real = D(real_data)
D_loss_real = criterion(outputs_real, real_labels)
# 判别器对生成数据的损失
outputs_fake = D(fake_data.detach()) # 对生成数据的判别(生成数据不传递梯度给生成器)
D_loss_fake = criterion(outputs_fake, fake_labels)
# 判别器总损失
D_loss = D_loss_real + D_loss_fake
# 更新判别器
optimizer_D.zero_grad()
D_loss.backward()
optimizer_D.step()
# 训练生成器
noise = generate_noise(batch_size, input_size) # 生成新的噪声
fake_data = G(noise) # 生成新的假数据
# 生成器的目标:欺骗判别器,让判别器认为生成的数据是真实的
outputs_fake = D(fake_data)
G_loss = criterion(outputs_fake, real_labels) # 生成器希望生成的数据被判为真实数据,因此标签设为1
# 更新生成器
optimizer_G.zero_grad()
G_loss.backward()
optimizer_G.step()
# 每隔一段时间打印损失
if epoch % 100 == 0:
print(f"Epoch [{epoch}/{epochs}] | D Loss: {D_loss.item():.4f} | G Loss: {G_loss.item():.4f}")
原文地址:https://blog.csdn.net/zfhsfdhdfajhsr/article/details/142315114
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!