自学内容网 自学内容网

GAN的损失函数和二元交叉熵损失的对应及代码

以下解释为GPT生成

在这里插入图片描述
在这里插入图片描述在这里插入图片描述
这里有个问题,使用二元交叉熵,的时候生成器的损失如何体现
在这里插入图片描述
看代码

import torch
import torch.nn as nn
import torch.optim as optim

# 设置设备为GPU或CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义生成器 (Generator)
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()  # 输出范围在 [-1, 1] 之间
        )
    
    def forward(self, x):
        return self.model(x)

# 定义判别器 (Discriminator)
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()  # 输出概率 [0, 1]
        )
    
    def forward(self, x):
        return self.model(x)

# 超参数设置
input_size = 100  # 生成器输入噪声向量的维度
hidden_size = 128
output_size = 1   # 判别器的输出是一个概率
data_size = 784   # 假设输入数据的维度(例如,MNIST 图片 28x28 展开成 784 维向量)

# 初始化生成器和判别器
G = Generator(input_size, hidden_size, data_size).to(device)
D = Discriminator(data_size, hidden_size, output_size).to(device)

# 定义损失函数(二元交叉熵损失)
criterion = nn.BCELoss()

# 优化器
lr = 0.0002
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)

# 生成随机噪声的函数(用于生成器)
def generate_noise(batch_size, input_size):
    return torch.randn(batch_size, input_size).to(device)

# 假设我们有一个简单的生成数据和真实数据的函数
def get_real_data(batch_size):
    # 这里我们用随机生成的假数据来模拟真实数据
    return torch.randn(batch_size, data_size).to(device)

# 训练步骤
epochs = 1000
batch_size = 64

for epoch in range(epochs):
    # 训练判别器
    real_data = get_real_data(batch_size)  # 获取真实数据
    noise = generate_noise(batch_size, input_size)  # 生成噪声
    fake_data = G(noise)  # 生成数据

    # 判别器的目标:正确区分真实数据和生成数据
    real_labels = torch.ones(batch_size, 1).to(device)  # 真实数据的标签为1
    fake_labels = torch.zeros(batch_size, 1).to(device)  # 生成数据的标签为0

    # 判别器对真实数据的损失
    outputs_real = D(real_data)
    D_loss_real = criterion(outputs_real, real_labels)

    # 判别器对生成数据的损失
    outputs_fake = D(fake_data.detach())  # 对生成数据的判别(生成数据不传递梯度给生成器)
    D_loss_fake = criterion(outputs_fake, fake_labels)

    # 判别器总损失
    D_loss = D_loss_real + D_loss_fake

    # 更新判别器
    optimizer_D.zero_grad()
    D_loss.backward()
    optimizer_D.step()

    # 训练生成器
    noise = generate_noise(batch_size, input_size)  # 生成新的噪声
    fake_data = G(noise)  # 生成新的假数据

    # 生成器的目标:欺骗判别器,让判别器认为生成的数据是真实的
    outputs_fake = D(fake_data)
    G_loss = criterion(outputs_fake, real_labels)  # 生成器希望生成的数据被判为真实数据,因此标签设为1

    # 更新生成器
    optimizer_G.zero_grad()
    G_loss.backward()
    optimizer_G.step()

    # 每隔一段时间打印损失
    if epoch % 100 == 0:
        print(f"Epoch [{epoch}/{epochs}] | D Loss: {D_loss.item():.4f} | G Loss: {G_loss.item():.4f}")


原文地址:https://blog.csdn.net/zfhsfdhdfajhsr/article/details/142315114

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!