我正在运行基于DCGAN的GAN,并正在尝试WGAN,但是对于如何训练WGAN感到有些困惑。

在官方的Wasserstein GAN PyTorch implementation中,每一次生成器训练都将歧视者/批评者训练为(cc)次(通常5次)。

这是否意味着批评者/区分者对Diters批次或整个数据集Diters进行训练?如果我没记错的话,官方实现建议对鉴别器/批评者进行整个数据集Diters次的训练,但是WGAN的其他实现(在PyTorch和TensorFlow等中)则相反。

哪个是对的? The WGAN paper(至少对我而言)表示它是Diters个批次。整个数据集的训练显然要慢几个数量级。

提前致谢!

最佳答案

正确的做法是将迭代视为批处理。
在原始的paper中,对于批注者/鉴别器的每次迭代,他们正在采样一批实际数据大小为m的实数据和一批先前大小为m的样本p(z)以对其进行处理。在对批注者进行Diters迭代训练之后,他们会训练生成器,该生成器也将从对p(z)的一批先前样本进行采样开始。
因此,每个迭代都在批量处理。

official implementation中,这也正在发生。可能令人困惑的是,他们使用变量名称niter表示训练模型的时期数。尽管他们使用不同的方案在162 -166行设置Diters

# train the discriminator Diters times
    if gen_iterations < 25 or gen_iterations % 500 == 0:
        Diters = 100
    else:
        Diters = opt.Diters


如本文中所述,它们正在对批注者进行Diters培训。

关于python-3.x - Wasserstein GAN评论家培训模棱两可,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53401431/

10-12 22:10