在生成对抗网络(GAN)中,生成器(G)和判别器(D)通常是两个独立的神经网络,它们之间会有梯度传播的互动。下面是一个简单的GAN的PyTorch实现,用于生成一维数据,以展示何时应该使用detach()。
import torch
import torch.nn as nn
import torch.optim as optim
# 生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 1)
)
def forward(self, x):
return self.model(x)
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(1, 50),
nn.ReLU(),
nn.Linear(50, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 实例化生成器和判别器
G = Generator()
D = Discriminator()
# 定义优化器和损失函数
optimizer_G = optim.Adam(G.parameters(), lr=0.001)
optimizer_D = optim.Adam(D.parameters(), lr=0.001)
loss_func = nn.BCELoss()
# 训练循环
for epoch in range(1000):
# 训练判别器
D.zero_grad()
real_data = torch.randn(100, 1) # 真实数据
real_labels = torch.ones(100, 1) # 真实标签
fake_data = G(torch.randn(100, 10)).detach() # 使用detach(), 因为我们不想在这一步更新生成器
fake_labels = torch.zeros(100, 1) # 假的标签
real_loss = loss_func(D(real_data), real_labels)
# real_loss = loss_func(D(real_data.detach), real_labels)
fake_loss = loss_func(D(fake_data), fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
# 训练生成器
G.zero_grad()
noise_data = torch.randn(100, 10) # 噪声数据
fake_data = G(noise_data) # 没有使用detach(), 因为我们想在这一步更新生成器
g_loss = loss_func(D(fake_data), torch.ones(100, 1))
g_loss.backward()
optimizer_G.step()
在这个例子中:
detach()
来中断梯度传播到生成器(G)。这是因为在这一步中,我们仅关心优化判别器,而不希望更新生成器的参数。detach()
,因为我们需要通过反向传播的梯度来更新生成器的参数。注意:在训练判别器时,不使用real_loss = loss_func(D(real_data.detach), real_labels)
, 也就是这里不需要对real_data进行detach操作。
而且即使对real_data
进行.detach()
操作实际上应该不会有明显影响,原因在于real_data
并不是通过模型参数生成的,也不是一个需要优化的变量。.detach()
方法主要用于将一个张量从当前计算图中分离出来,阻止反向传播过程中对其计算梯度。但在本例中,real_data
本身就没有与需要优化的模型参数有直接关系,也不是由其他需要优化的变量通过一些运算得到的。
注意: 在训练判别器时,使用fake_data = G(torch.randn(100, 10)).detach()
, 注意是因为这个fake_data
是由生成器G
生成的, 为了保证分开训练判别器和生成器,即在训练判别器的时候,不对生成器的参数进行更新,这里就要把G
生成的数据进行detach
操作
在训练生成器时, 也用到了判别器,用判别器去判别生成器生成的内容,希望判别器能把G
生成的内容当做真的,这样就说明G
的生成的内容可以以假乱真
fake_data = G(noise_data) # 没有使用detach(), 因为我们想在这一步更新生成器
g_loss = loss_func(D(fake_data), torch.ones(100, 1))
g_loss.backward()
optimizer_G.step()
上面没有对传进D
的fake_data进行detach,是因为下面的代码只有g_loss_backward(),也就是只对G进行参数更新,当然这里也不能对fake_data进行detach,如果detach了,就无法更新G的参数了