简介
生成对抗网络(Generative Adversarial Network, 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。 生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。
下载安装命令
## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle
## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
原始GAN与其他生成式模型相比,GAN这种竞争的方式不再要求一个假设的数据分布,即不需要formulate p(x),而是使用一种分布直接进行采样sampling,从而真正达到理论上可以完全逼近真实数据,这也是GAN最大的优势。然而,这种不需要预先建模的方法缺点是太过自由了,对于较大的图片,较多的 pixel的情形,基于简单 GAN 的方式就不太可控了。为了解决GAN太过自由这个问题,一个很自然的想法是给GAN加一些约束,于是便有了Conditional Generative Adversarial Nets(CGAN)
CGAN,条件生成对抗网络,一种带条件约束的GAN,在生成模型(D)和判别模型(G)的建模中均引入条件变量y(conditional variable y),使用额外信息y对模型增加条件,可以指导数据生成过程。这些条件变量y可以基于多种信息,例如类别标签,用于图像修复的部分数据,来自不同模态(modality)的数据。如果条件变量y是类别标签,可以看做CGAN 是把纯无监督的 GAN 变成有监督的模型的一种改进。这个简单直接的改进被证明非常有效,并广泛用于后续的相关工作中。网络结构如下图所示:
conditionalGAN训练19轮的模型预测效果如下图所示:
阅读本项目之前建议阅读原版论文https://arxiv.org/pdf/1411.1784.pdf ,优秀解读博客https://blog.csdn.net/stalbo/article/details/79359380
本项目采用minist数据集
# 代码结构
# ├── network.py # 定义基础生成网络和判别网络。
# ├── utility.py # 定义通用工具方法。
# └── c_gan.py # conditionalGAN训练脚本.
# └── infer.py # 预测脚本
# └── reader.py # 数据读取.
# 训练过程中,每隔固定的训练轮数,会取一个batch的数据进行测试,测试结果以图片的形式保存至--output选项指定的路径。
# 执行python c_gan.py --help可查看更多使用方式和参数详细说明。
#在GPU上训练CGAN,测试结果以图片的形式保存至--output选项指定的路径。
!python c_gan/c_gan.py --epoch=5 --output="./C_result" --use_gpu=True
2020-02-18 11:03:15,594-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts'] 2020-02-18 11:03:16,011-INFO: generated new fontManager ----------- Configuration Arguments ----------- batch_size: 128 epoch: 5 output: ./C_result run_ce: False use_gpu: 1 ------------------------------------------------ W0218 11:03:17.667068 93 device_context.cc:236] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0 W0218 11:03:17.672333 93 device_context.cc:244] device: 0, cuDNN Version: 7.3. Epoch ID=0 Batch ID=0 D-Loss=1.4754687547683716 DG-Loss=0.48697683215141296 gen=[0.020384092, -0.628501, 0.8264506, -0.059517004, 0.08393577] Batch_time_cost=0.07 Epoch ID=0 Batch ID=50 D-Loss=1.3325045108795166 DG-Loss=0.5968529582023621 gen=[-0.32408723, -0.99999803, 0.9999368, -0.8411597, 0.089144364] Batch_time_cost=0.07 Epoch ID=0 Batch ID=100 D-Loss=1.3283110857009888 DG-Loss=0.6360062956809998 gen=[-0.48128778, -1.0, 0.9999977, -0.9618741, -0.15321578] Batch_time_cost=0.06 Epoch ID=0 Batch ID=150 D-Loss=1.3380171060562134 DG-Loss=0.6329773664474487 gen=[-0.50825375, -0.9999999, 0.99999666, -0.96668917, -0.24650867] Batch_time_cost=0.06 Epoch ID=0 Batch ID=200 D-Loss=1.3771660327911377 DG-Loss=0.6287590861320496 gen=[-0.48687252, -0.99999964, 0.99999994, -0.96308786, -0.18785018] Batch_time_cost=0.06 Epoch ID=0 Batch ID=250 D-Loss=1.3531957864761353 DG-Loss=0.6571791768074036 gen=[-0.39359128, -0.9999998, 1.0, -0.9617758, 0.14621706] Batch_time_cost=0.07 Epoch ID=0 Batch ID=300 D-Loss=1.363043189048767 DG-Loss=0.6445130109786987 gen=[-0.39853948, -0.99999535, 0.9999999, -0.9558876, 0.10362165] Batch_time_cost=0.07 Epoch ID=0 Batch ID=350 D-Loss=1.3479816913604736 DG-Loss=0.6430233716964722 gen=[-0.6183855, -1.0, 0.9999999, -0.98799354, -0.5378779] Batch_time_cost=0.06 Epoch ID=0 Batch ID=400 D-Loss=1.345306634902954 DG-Loss=0.6640985012054443 gen=[-0.39159825, -0.99999994, 1.0, -0.96443295, 0.1392249] Batch_time_cost=0.06 Epoch ID=0 Batch ID=450 D-Loss=1.3394019603729248 DG-Loss=0.6445193886756897 gen=[-0.4453725, -1.0, 1.0, -0.97341233, -0.037463646] Batch_time_cost=0.06 Epoch ID=1 Batch ID=0 D-Loss=1.3356835842132568 DG-Loss=0.6343369483947754 gen=[-0.73340905, -1.0, 0.9999999, -0.9945303, -0.78381747] Batch_time_cost=0.07 Epoch ID=1 Batch ID=50 D-Loss=1.3558433055877686 DG-Loss=0.6355264186859131 gen=[-0.5320427, -1.0, 1.0, -0.9876131, -0.31378862] Batch_time_cost=0.07 Epoch ID=1 Batch ID=100 D-Loss=1.3402173519134521 DG-Loss=0.6352111101150513 gen=[-0.55672026, -1.0, 1.0, -0.9907406, -0.41422915] Batch_time_cost=0.06 Epoch ID=1 Batch ID=150 D-Loss=1.3480899333953857 DG-Loss=0.6375709772109985 gen=[-0.5827232, -1.0, 1.0, -0.9946148, -0.47664627] Batch_time_cost=0.06 Epoch ID=1 Batch ID=200 D-Loss=1.3259913921356201 DG-Loss=0.641555905342102 gen=[-0.73367566, -1.0, 0.99999887, -0.99813175, -0.7840915] Batch_time_cost=0.06 Epoch ID=1 Batch ID=250 D-Loss=1.3230714797973633 DG-Loss=0.6369626522064209 gen=[-0.49091578, -1.0, 1.0, -0.98903924, -0.18111996] Batch_time_cost=0.07 Epoch ID=1 Batch ID=300 D-Loss=1.3248735666275024 DG-Loss=0.627005398273468 gen=[-0.48370016, -1.0, 1.0, -0.98776495, -0.15307435] Batch_time_cost=0.07 Epoch ID=1 Batch ID=350 D-Loss=1.3656104803085327 DG-Loss=0.6305217742919922 gen=[-0.7182529, -1.0, 1.0, -0.9988163, -0.80599004] Batch_time_cost=0.06 Epoch ID=1 Batch ID=400 D-Loss=1.3329648971557617 DG-Loss=0.6695771217346191 gen=[-0.4926114, -1.0, 1.0, -0.9927192, -0.18072756] Batch_time_cost=0.06 Epoch ID=1 Batch ID=450 D-Loss=1.3494174480438232 DG-Loss=0.6599210500717163 gen=[-0.5725556, -1.0, 1.0, -0.99533355, -0.46978894] Batch_time_cost=0.06 Epoch ID=2 Batch ID=0 D-Loss=1.3190510272979736 DG-Loss=0.6534481048583984 gen=[-0.6627138, -1.0, 1.0, -0.9986317, -0.75535697] Batch_time_cost=0.07 Epoch ID=2 Batch ID=50 D-Loss=1.3075931072235107 DG-Loss=0.662345826625824 gen=[-0.5798778, -1.0, 1.0, -0.9970621, -0.46988058] Batch_time_cost=0.07 Epoch ID=2 Batch ID=100 D-Loss=1.2966625690460205 DG-Loss=0.6551341414451599 gen=[-0.64662725, -1.0, 1.0, -0.9989106, -0.69297016] Batch_time_cost=0.06 Epoch ID=2 Batch ID=150 D-Loss=1.344698429107666 DG-Loss=0.642096221446991 gen=[-0.6564131, -1.0, 1.0, -0.9989565, -0.7414377] Batch_time_cost=0.06 Epoch ID=2 Batch ID=200 D-Loss=1.2718031406402588 DG-Loss=0.6544758081436157 gen=[-0.7044133, -1.0, 1.0, -0.99934673, -0.805151] Batch_time_cost=0.06 Epoch ID=2 Batch ID=250 D-Loss=1.2980257272720337 DG-Loss=0.6246281862258911 gen=[-0.80004555, -1.0, 1.0, -0.9998431, -0.9489559] Batch_time_cost=0.07 Epoch ID=2 Batch ID=300 D-Loss=1.3250139951705933 DG-Loss=0.6461377143859863 gen=[-0.5818067, -1.0, 1.0, -0.9984216, -0.5537344] Batch_time_cost=0.06 Epoch ID=2 Batch ID=350 D-Loss=1.3336296081542969 DG-Loss=0.6458700895309448 gen=[-0.6561196, -1.0, 1.0, -0.9994945, -0.8055736] Batch_time_cost=0.07 Epoch ID=2 Batch ID=400 D-Loss=1.332587480545044 DG-Loss=0.6567844152450562 gen=[-0.6696944, -1.0, 1.0, -0.99964476, -0.8398167] Batch_time_cost=0.06 Epoch ID=2 Batch ID=450 D-Loss=1.3682767152786255 DG-Loss=0.649559736251831 gen=[-0.7100771, -1.0, 1.0, -0.99959046, -0.85634416] Batch_time_cost=0.06 Epoch ID=3 Batch ID=0 D-Loss=1.3280253410339355 DG-Loss=0.6489391326904297 gen=[-0.7317555, -1.0, 1.0, -0.9997766, -0.89817333] Batch_time_cost=0.07 Epoch ID=3 Batch ID=50 D-Loss=1.3578753471374512 DG-Loss=0.6147778034210205 gen=[-0.72396195, -1.0, 1.0, -0.9998474, -0.90283734] Batch_time_cost=0.06 Epoch ID=3 Batch ID=100 D-Loss=1.3406777381896973 DG-Loss=0.6485224962234497 gen=[-0.718904, -1.0, 1.0, -0.99990094, -0.9153805] Batch_time_cost=0.06 Epoch ID=3 Batch ID=150 D-Loss=1.312596082687378 DG-Loss=0.6498677730560303 gen=[-0.67478544, -1.0, 1.0, -0.9998216, -0.8468633] Batch_time_cost=0.06 Epoch ID=3 Batch ID=200 D-Loss=1.3386080265045166 DG-Loss=0.6432048082351685 gen=[-0.6547149, -1.0, 1.0, -0.99988073, -0.83443314] Batch_time_cost=0.06 Epoch ID=3 Batch ID=250 D-Loss=1.326845645904541 DG-Loss=0.6399492025375366 gen=[-0.6931246, -1.0, 1.0, -0.99992764, -0.881379] Batch_time_cost=0.06 Epoch ID=3 Batch ID=300 D-Loss=1.3271305561065674 DG-Loss=0.6585754156112671 gen=[-0.6919609, -1.0, 1.0, -0.999896, -0.8784841] Batch_time_cost=0.07 Epoch ID=3 Batch ID=350 D-Loss=1.3486320972442627 DG-Loss=0.6269784569740295 gen=[-0.6996336, -1.0, 1.0, -0.99992526, -0.8959064] Batch_time_cost=0.06 Epoch ID=3 Batch ID=400 D-Loss=1.342246651649475 DG-Loss=0.6233903169631958 gen=[-0.77165926, -1.0, 1.0, -0.99998367, -0.9691971] Batch_time_cost=0.06 Epoch ID=3 Batch ID=450 D-Loss=1.3426618576049805 DG-Loss=0.6478078365325928 gen=[-0.6975712, -1.0, 1.0, -0.9999601, -0.9236312] Batch_time_cost=0.06 Epoch ID=4 Batch ID=0 D-Loss=1.3098371028900146 DG-Loss=0.6660643815994263 gen=[-0.7207169, -1.0, 1.0, -0.9999414, -0.92297864] Batch_time_cost=0.07 Epoch ID=4 Batch ID=50 D-Loss=1.3217291831970215 DG-Loss=0.6328650712966919 gen=[-0.71822643, -1.0, 1.0, -0.99995583, -0.9326522] Batch_time_cost=0.07 Epoch ID=4 Batch ID=100 D-Loss=1.3669278621673584 DG-Loss=0.6450607776641846 gen=[-0.6679931, -1.0, 1.0, -0.99994075, -0.8517525] Batch_time_cost=0.06 Epoch ID=4 Batch ID=150 D-Loss=1.333882451057434 DG-Loss=0.6543253660202026 gen=[-0.6812095, -1.0, 1.0, -0.99996954, -0.91442144] Batch_time_cost=0.06 Epoch ID=4 Batch ID=200 D-Loss=1.3355824947357178 DG-Loss=0.6455910205841064 gen=[-0.7346105, -1.0, 1.0, -0.9999781, -0.9356306] Batch_time_cost=0.06 Epoch ID=4 Batch ID=250 D-Loss=1.3421473503112793 DG-Loss=0.6508429050445557 gen=[-0.7158679, -1.0, 1.0, -0.9999834, -0.9500096] Batch_time_cost=0.07 Epoch ID=4 Batch ID=300 D-Loss=1.3312251567840576 DG-Loss=0.6514350771903992 gen=[-0.6991075, -1.0, 1.0, -0.99997765, -0.9242321] Batch_time_cost=0.07 Epoch ID=4 Batch ID=350 D-Loss=1.351043701171875 DG-Loss=0.6400759220123291 gen=[-0.7066042, -1.0, 1.0, -0.99998295, -0.9380742] Batch_time_cost=0.06 Epoch ID=4 Batch ID=400 D-Loss=1.3609528541564941 DG-Loss=0.6274992823600769 gen=[-0.7641589, -1.0, 1.0, -0.99999577, -0.97579837] Batch_time_cost=0.06 Epoch ID=4 Batch ID=450 D-Loss=1.3415582180023193 DG-Loss=0.652595043182373 gen=[-0.7490909, -1.0, 1.0, -0.9999947, -0.97528887] Batch_time_cost=0.06
#使用固化后的模型(训练了5轮)进行预测,预测结果保存在output中,batch_size为生成图像个数
!python c_gan/infer.py --output="./infer_result" --use_gpu=True --batch_size=4
# 可视化预测效果
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import cv2
img= cv2.imread('infer_result/generated_image.png')
plt.imshow(img)
plt.show()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py:795: UserWarning: The current program is empty. warnings.warn(error_info) W0218 11:06:43.294374 178 device_context.cc:236] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0 W0218 11:06:43.299666 178 device_context.cc:244] device: 0, cuDNN Version: 7.3. condition: [[5.] [0.] [4.] [1.]]
点击链接,使用AI Studio一键上手实践项目吧:https://aistudio.baidu.com/aistudio/projectdetail/169443
下载安装命令
## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle
## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
>> 访问 PaddlePaddle 官网,了解更多相关内容。