问题描述
这是我第一次与GAN合作,而且我面临着一个问题,那就是鉴别器一再超越表现器.我正在尝试从此文章,我正在查看此稍有不同的实现方式帮助我.
It's the first time I'm working with GANs and I am facing an issue regarding the Discriminator repeatedly outperforming the Generator. I am trying to reproduce the PA
model from this article and I'm looking at this slightly different implementation to help me out.
我已经阅读了很多有关GAN的工作方式的论文,并且还遵循了一些教程来更好地理解它们.此外,我已经阅读了有关如何克服主要不稳定因素的文章,但找不到解决这种现象的方法.
I have read quite a lot of papers on how GANs work and also followed some tutorials to understand them better. Moreover, I've read articles on how to overcome the major instabilities, but I can't find a way to overcome this behavior.
在我的环境中,我正在使用 PyTorch
和 BCELoss()
.遵循 DCGAN PyTorch教程之后,我正在使用以下训练循环:
In my environment, I'm using PyTorch
and BCELoss()
. Following the DCGAN PyTorch tutorial, I'm using the following training loop:
criterion = nn.BCELoss()
train_d = False
# Discriminator true
optim_d.zero_grad()
disc_train_real = target.to(device)
batch_size = disc_train_real.size(0)
label = torch.full((batch_size,), 1, device=device).cuda()
output_d = discriminator(disc_train_real).view(-1)
loss_d_real = criterion(output_d, label).cuda()
if lossT:
loss_d_real *= 2
if loss_d_real.item() > 0.3:
loss_d_real.backward()
train_d = True
D_x = output_d.mean().item()
# Discriminator false
output_g = generator(image)
output_d = discriminator(output_g.detach()).view(-1)
label.fill_(0)
loss_d_fake = criterion(output_d, label).cuda()
D_G_z1 = output_d.mean().item()
if lossT:
loss_d_fake *= 2
loss_d = loss_d_real + loss_d_fake
if loss_d_fake.item() > 0.3:
loss_d_fake.backward()
train_d = True
if train_d:
optim_d.step()
# Generator
label.fill_(1)
output_d = discriminator(output_g).view(-1)
loss_g = criterion(output_d, label).cuda()
D_G_z2 = output_d.mean().item()
if lossT:
loss_g *= 2
loss_g.backward()
optim_g.step()
,经过一段时间的解决,一切似乎都很好:
and, after a period of settlement, everything seems to work fine:
Epoch 1/5 - Step: 1900/9338 Loss G: 3.057388 Loss D: 0.214545 D(x): 0.940985 D(G(z)): 0.114064 / 0.114064
Time for the last step: 51.55 s Epoch ETA: 01:04:13
Epoch 1/5 - Step: 2000/9338 Loss G: 2.984724 Loss D: 0.222931 D(x): 0.879338 D(G(z)): 0.159163 / 0.159163
Time for the last step: 52.68 s Epoch ETA: 01:03:24
Epoch 1/5 - Step: 2100/9338 Loss G: 2.824713 Loss D: 0.241953 D(x): 0.905837 D(G(z)): 0.110231 / 0.110231
Time for the last step: 50.91 s Epoch ETA: 01:02:29
Epoch 1/5 - Step: 2200/9338 Loss G: 2.807455 Loss D: 0.252808 D(x): 0.908131 D(G(z)): 0.218515 / 0.218515
Time for the last step: 51.72 s Epoch ETA: 01:01:37
Epoch 1/5 - Step: 2300/9338 Loss G: 2.470529 Loss D: 0.569696 D(x): 0.620966 D(G(z)): 0.512615 / 0.350175
Time for the last step: 51.96 s Epoch ETA: 01:00:46
Epoch 1/5 - Step: 2400/9338 Loss G: 2.148863 Loss D: 1.071563 D(x): 0.809529 D(G(z)): 0.114487 / 0.114487
Time for the last step: 51.59 s Epoch ETA: 00:59:53
Epoch 1/5 - Step: 2500/9338 Loss G: 2.016863 Loss D: 0.904711 D(x): 0.621433 D(G(z)): 0.440721 / 0.435932
Time for the last step: 52.03 s Epoch ETA: 00:59:02
Epoch 1/5 - Step: 2600/9338 Loss G: 2.495639 Loss D: 0.949308 D(x): 0.671085 D(G(z)): 0.557924 / 0.420826
Time for the last step: 52.66 s Epoch ETA: 00:58:12
Epoch 1/5 - Step: 2700/9338 Loss G: 2.519842 Loss D: 0.798667 D(x): 0.775738 D(G(z)): 0.246357 / 0.265839
Time for the last step: 51.20 s Epoch ETA: 00:57:19
Epoch 1/5 - Step: 2800/9338 Loss G: 2.545630 Loss D: 0.756449 D(x): 0.895455 D(G(z)): 0.403628 / 0.301851
Time for the last step: 51.88 s Epoch ETA: 00:56:27
Epoch 1/5 - Step: 2900/9338 Loss G: 2.458109 Loss D: 0.653513 D(x): 0.820105 D(G(z)): 0.379199 / 0.103250
Time for the last step: 53.50 s Epoch ETA: 00:55:39
Epoch 1/5 - Step: 3000/9338 Loss G: 2.030103 Loss D: 0.948208 D(x): 0.445385 D(G(z)): 0.303225 / 0.263652
Time for the last step: 51.57 s Epoch ETA: 00:54:47
Epoch 1/5 - Step: 3100/9338 Loss G: 1.721604 Loss D: 0.949721 D(x): 0.365646 D(G(z)): 0.090072 / 0.232912
Time for the last step: 52.19 s Epoch ETA: 00:53:55
Epoch 1/5 - Step: 3200/9338 Loss G: 1.438854 Loss D: 1.142182 D(x): 0.768163 D(G(z)): 0.321164 / 0.237878
Time for the last step: 50.79 s Epoch ETA: 00:53:01
Epoch 1/5 - Step: 3300/9338 Loss G: 1.924418 Loss D: 0.923860 D(x): 0.729981 D(G(z)): 0.354812 / 0.318090
Time for the last step: 52.59 s Epoch ETA: 00:52:11
,即,生成器上的梯度较高,并在一段时间后开始减小,与此同时,鉴别器上的梯度上升.至于损失,发生器下降而鉴别器上升.如果与本教程相比,我想这可以接受.
that is, the gradients on the Generator are higher and start to decrease after a while, and in the meanwhile the gradients on the Discriminator rise up. As for the losses, the Generator goes down while the Discriminator goes up. If compared to the tutorial, I guess this can be acceptable.
这是我的第一个问题:我注意到在教程中(通常)随着 D_G_z1
的增加, D_G_z2
的减少(反之亦然),而在我的示例中,这种情况发生的次数要少得多.只是巧合还是我做错了什么?
Here's my first question: I've noticed that on the tutorial (usually) as D_G_z1
rises, D_G_z2
decreases (and viceversa), while in my example this happens a lot less. Is it just a coincidence or am I doing something wrong?
鉴于此,我让培训过程继续进行,但现在我注意到了这一点:
Given that, I've let the training procedure go on, but now I'm noticing this:
Epoch 3/5 - Step: 1100/9338 Loss G: 4.071329 Loss D: 0.031608 D(x): 0.999969 D(G(z)): 0.024329 / 0.024329
Time for the last step: 51.41 s Epoch ETA: 01:11:24
Epoch 3/5 - Step: 1200/9338 Loss G: 3.883331 Loss D: 0.036354 D(x): 0.999993 D(G(z)): 0.043874 / 0.043874
Time for the last step: 51.63 s Epoch ETA: 01:10:29
Epoch 3/5 - Step: 1300/9338 Loss G: 3.468963 Loss D: 0.054542 D(x): 0.999972 D(G(z)): 0.050145 / 0.050145
Time for the last step: 52.47 s Epoch ETA: 01:09:40
Epoch 3/5 - Step: 1400/9338 Loss G: 3.504971 Loss D: 0.053683 D(x): 0.999972 D(G(z)): 0.052180 / 0.052180
Time for the last step: 50.75 s Epoch ETA: 01:08:41
Epoch 3/5 - Step: 1500/9338 Loss G: 3.437765 Loss D: 0.056286 D(x): 0.999941 D(G(z)): 0.058839 / 0.058839
Time for the last step: 52.20 s Epoch ETA: 01:07:50
Epoch 3/5 - Step: 1600/9338 Loss G: 3.369209 Loss D: 0.062133 D(x): 0.955688 D(G(z)): 0.058773 / 0.058773
Time for the last step: 51.05 s Epoch ETA: 01:06:54
Epoch 3/5 - Step: 1700/9338 Loss G: 3.290109 Loss D: 0.065704 D(x): 0.999975 D(G(z)): 0.056583 / 0.056583
Time for the last step: 51.27 s Epoch ETA: 01:06:00
Epoch 3/5 - Step: 1800/9338 Loss G: 3.286248 Loss D: 0.067969 D(x): 0.993238 D(G(z)): 0.063815 / 0.063815
Time for the last step: 52.28 s Epoch ETA: 01:05:09
Epoch 3/5 - Step: 1900/9338 Loss G: 3.263996 Loss D: 0.065335 D(x): 0.980270 D(G(z)): 0.037717 / 0.037717
Time for the last step: 51.59 s Epoch ETA: 01:04:16
Epoch 3/5 - Step: 2000/9338 Loss G: 3.293503 Loss D: 0.065291 D(x): 0.999873 D(G(z)): 0.070188 / 0.070188
Time for the last step: 51.85 s Epoch ETA: 01:03:25
Epoch 3/5 - Step: 2100/9338 Loss G: 3.184164 Loss D: 0.070931 D(x): 0.999971 D(G(z)): 0.059657 / 0.059657
Time for the last step: 52.14 s Epoch ETA: 01:02:34
Epoch 3/5 - Step: 2200/9338 Loss G: 3.116310 Loss D: 0.080597 D(x): 0.999850 D(G(z)): 0.074931 / 0.074931
Time for the last step: 51.85 s Epoch ETA: 01:01:42
Epoch 3/5 - Step: 2300/9338 Loss G: 3.142180 Loss D: 0.073999 D(x): 0.995546 D(G(z)): 0.054752 / 0.054752
Time for the last step: 51.76 s Epoch ETA: 01:00:50
Epoch 3/5 - Step: 2400/9338 Loss G: 3.185711 Loss D: 0.072601 D(x): 0.999992 D(G(z)): 0.076053 / 0.076053
Time for the last step: 50.53 s Epoch ETA: 00:59:54
Epoch 3/5 - Step: 2500/9338 Loss G: 3.027437 Loss D: 0.083906 D(x): 0.997390 D(G(z)): 0.082501 / 0.082501
Time for the last step: 52.06 s Epoch ETA: 00:59:03
Epoch 3/5 - Step: 2600/9338 Loss G: 3.052374 Loss D: 0.085030 D(x): 0.999924 D(G(z)): 0.073295 / 0.073295
Time for the last step: 52.37 s Epoch ETA: 00:58:12
不仅 D(x)
再次增加并固定为几乎一个,而且 D_G_z1
和 D_G_z2
始终显示相同价值.此外,从损失的角度看,歧视者的表现似乎明显好于发生器.这种行为一直持续到下一个时期,直到下一个时期,直到训练结束.
not only D(x)
has increased again and it's stuck to almost one, but also both D_G_z1
and D_G_z2
always show the same value. Moreover, looking at the losses it seems pretty clear that the Discriminator has outperformed the Generator. This behavior has gone on and on for the rest of the epoch and for all the next one, until the end of the training.
因此,我的第二个问题:这正常吗?如果没有,我在程序中做错了什么?如何获得更稳定的培训?
Hence my second question: is this normal? If not, what am I doing wrong within the procedure? How can I achieve a more stable training?
我尝试按照建议使用 MSELoss()
训练网络,这是输出:
I've tried to train the network using the MSELoss()
as suggested and this is the output:
Epoch 1/1 - Step: 100/9338 Loss G: 0.800785 Loss D: 0.404525 D(x): 0.844653 D(G(z)): 0.030439 / 0.016316
Time for the last step: 55.22 s Epoch ETA: 01:25:01
Epoch 1/1 - Step: 200/9338 Loss G: 1.196659 Loss D: 0.014051 D(x): 0.999970 D(G(z)): 0.006543 / 0.006500
Time for the last step: 51.41 s Epoch ETA: 01:21:11
Epoch 1/1 - Step: 300/9338 Loss G: 1.197319 Loss D: 0.000806 D(x): 0.999431 D(G(z)): 0.004821 / 0.004724
Time for the last step: 51.79 s Epoch ETA: 01:19:32
Epoch 1/1 - Step: 400/9338 Loss G: 1.198960 Loss D: 0.000720 D(x): 0.999612 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.47 s Epoch ETA: 01:18:09
Epoch 1/1 - Step: 500/9338 Loss G: 1.212810 Loss D: 0.000021 D(x): 0.999938 D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.18 s Epoch ETA: 01:17:11
Epoch 1/1 - Step: 600/9338 Loss G: 1.216168 Loss D: 0.000000 D(x): 0.999945 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.24 s Epoch ETA: 01:16:02
Epoch 1/1 - Step: 700/9338 Loss G: 1.212301 Loss D: 0.000000 D(x): 0.999970 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.61 s Epoch ETA: 01:15:02
Epoch 1/1 - Step: 800/9338 Loss G: 1.214397 Loss D: 0.000005 D(x): 0.999973 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.58 s Epoch ETA: 01:14:04
Epoch 1/1 - Step: 900/9338 Loss G: 1.212016 Loss D: 0.000003 D(x): 0.999932 D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.20 s Epoch ETA: 01:13:13
Epoch 1/1 - Step: 1000/9338 Loss G: 1.215162 Loss D: 0.000000 D(x): 0.999988 D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.28 s Epoch ETA: 01:12:23
Epoch 1/1 - Step: 1100/9338 Loss G: 1.216291 Loss D: 0.000000 D(x): 0.999983 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.78 s Epoch ETA: 01:11:28
Epoch 1/1 - Step: 1200/9338 Loss G: 1.215526 Loss D: 0.000000 D(x): 0.999978 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.88 s Epoch ETA: 01:10:35
可以看出,情况变得更糟.此外,请再次阅读 EnhanceNet纸,第4.2.4节(训练)指出,所使用的对抗损失函数是 BCELoss()
,因为我希望能解决我在 MSELoss()
中遇到的消失梯度问题.
As can be seen, the situation gets even worse. Moreover, reading the EnhanceNet paper all over again, Section 4.2.4 (Adversarial Training) states that the adversarial loss function used is a BCELoss()
, as I would expect to solve the vanishing gradients problem that I get with MSELoss()
.
推荐答案
解释GAN的损失有点荒唐,因为实际的损失值
Interpreting GAN Losses are a bit of a black art because the actual loss values
问题1:(根据我的经验),鉴别器/发电机优势之间摆动的频率主要基于以下几个因素:学习率和批量大小,这将影响传播的损失.所使用的特定损失指标将影响D& A的方差.G网络训练.EnhanceNet论文(用于基线)和本教程也使用均方误差损失-您正在使用二进制交叉熵损失,这将改变网络的收敛速度.我不是专家,所以这里有一个很好的链接到Rohan Varma的文章,该文章解释了损失函数之间的区别.奇怪的是,看看您更改丢失功能时网络的行为是否有所不同-试试看并更新问题?
Question 1: The frequency of swinging between a discriminator/generator dominance will vary based on a few factors primarily (in my experience): learning rates and batch sizes which will impact the propagated loss. The particular loss metrics used will impact variance in how the D & G networks train. The EnhanceNet paper (for baseline) and the tutorial use a Mean Squared Error loss too - you're using a Binary Cross Entropy loss which will change the rate at which the networks converge. I'm no expert so here's a pretty good link to Rohan Varma's article that explains the difference between loss functions. Would be curious to see if your network behaves differently when you change the loss function - try it and update the question?
问题2:随着时间的流逝,D损失和G损失都应该稳定在一个值上,但是很难判断他们是否已经在强大的绩效上趋于一致或是否趋于一致.由于诸如模式崩溃/梯度递减等原因,它们已经收敛(乔纳森·许(Jonathan Hui)对训练GAN的问题的解释).我发现最好的方法是实际检查生成的图像的横截面,然后目视检查输出,或者对生成的图像集使用某种感知指标(SSIM,PSNR,PIQ等).
Question 2: Over time both the D and G losses should settle to a value, however it's somewhat difficult to tell whether they've converged on strong performance or whether they've converged due to something like mode collapse/diminishing gradients (Jonathan Hui's explanation on problems in training GANs). The best way I've found is to actually inspect a cross section of the generated images and either visually inspect the output or use some kind of perceptual metrics (SSIM, PSNR, PIQ, etc.) across the generated image set.
一些其他有用的线索,可能对发现ans有用:
Some other useful leads that you might find useful in finding an ans:
这篇文章在解释GAN损失方面有两个相当不错的指标.
This post has a couple of reasonably good pointers on interpreting GAN Losses.
伊恩·古德费洛(Ian Goodfellow)的 NIPS2016教程对于如何平衡D&G训练.
Ian Goodfellow's NIPS2016 tutorial also has some solid ideas on how to balance D & G training.
这篇关于如何在GAN中平衡生成器和鉴别器的性能?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!