我正在运行基于DCGAN的GAN,并正在尝试WGAN,但是对于如何训练WGAN感到有些困惑。
在官方的Wasserstein GAN PyTorch implementation中,每一次生成器训练都将歧视者/批评者训练为(cc)次(通常5次)。
这是否意味着批评者/区分者对Diters
批次或整个数据集Diters
进行训练?如果我没记错的话,官方实现建议对鉴别器/批评者进行整个数据集Diters
次的训练,但是WGAN的其他实现(在PyTorch和TensorFlow等中)则相反。
哪个是对的? The WGAN paper(至少对我而言)表示它是Diters
个批次。整个数据集的训练显然要慢几个数量级。
提前致谢!
最佳答案
正确的做法是将迭代视为批处理。
在原始的paper中,对于批注者/鉴别器的每次迭代,他们正在采样一批实际数据大小为m
的实数据和一批先前大小为m
的样本p(z)
以对其进行处理。在对批注者进行Diters
迭代训练之后,他们会训练生成器,该生成器也将从对p(z)
的一批先前样本进行采样开始。
因此,每个迭代都在批量处理。
在official implementation中,这也正在发生。可能令人困惑的是,他们使用变量名称niter
表示训练模型的时期数。尽管他们使用不同的方案在162 -166行设置Diters
:
# train the discriminator Diters times
if gen_iterations < 25 or gen_iterations % 500 == 0:
Diters = 100
else:
Diters = opt.Diters
如本文中所述,它们正在对批注者进行
Diters
培训。关于python-3.x - Wasserstein GAN评论家培训模棱两可,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53401431/