首先是传入数据,这里我们不做赘述。

对于数据格式的裁剪,可以通过以下代码进行:

from glob import glob
from PIL import Image
import os
from tqdm import tqdm
from tqdm.std import trange

img_path = glob(r"C:\Users\Administrator\Desktop\Resize\*.png")
path_save = r'C:\Users\Administrator\Desktop\ReSize512'
a = range(0, len(img_path))
i = 0
for file in tqdm(img_path):
    name = os.path.join(path_save, "%d.jpeg" % a[i])
    im = Image.open(file).convert("RGB")
    im=im.resize((512, 512)) # 踩坑一:记得返回!
    print(im.format, im.size, im.mode)
    im.save(name, 'jpeg')
    i += 1

这些代码主要是通过glob模块抓取文件,然后通过PIL读取文件并转化成RGB格式,然后再resize后保存。

将数据打包成压缩包,进入服务器终端。

oss cp oss://stylegan3-main.zip ./hy-tmp/
 oss cp oss://ReSize512.zip ./hy-tmp/

解包文件

unzip -q stylegan3-main.zip
unzip -q ReSize512.zip -d DataSets/

进入到项目文件,可以输入

conda env create -f environment.yml

安装依赖环境,但是服务器已经配置好了部分环境,也可自己选择配置:

pip install pillow==8.3.1
pip install click==8.0
pip install scipy==1.7.1
pip install requests==2.26.0
pip install tqdm==4.62.2
pip install ninja==1.10.2
pip install matplotlib==3.4.2
pip install imageio==2.9.0
pip install imgui==1.3.0
pip install glfw==2.2.0
pip install pyopengl==3.1.5
pip install imageio-ffmpeg==0.4.3
pip install pyspng

此时处理我们的训练数据:

python dataset_tool.py --source=../ReSize512 --dest=../ReSize_Input.zip

删除文件夹

rm -rf 00000/
python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
    --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
    --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl

注意这里的参数,stylegan3-rR表示平移旋转等变性,t表示平移等变性,resume表示上次使用的网络,snap表示每隔多少批次进行输出,kimg表示训练批次,gamma表示R1正则参数,越大模型约有创意,记得把matrix设置为None

python train.py --outdir=../training-runs --cfg=stylegan3-t --data=../DataSets/ReSize_Input.zip --gpus=2 --batch=16 --gamma=10 --mirror=1 --kimg=25000 --snap=10 --matrics=None

进入训练后,在training-runs里可以得到训练过程信息:

在云服务器上运行StyleGAN3生成伪样本-LMLPHP

上面是生成的Fake影像,下面是网络权重,这跟我们的snap参数有关

在云服务器上运行StyleGAN3生成伪样本-LMLPHP

这个文件放到tensorboard里打开即可查看训练状态。

tensorboard --logdir path --port 8848

在云服务器上运行StyleGAN3生成伪样本-LMLPHP

在Roboflow网址上进行数据标注

在云服务器上运行StyleGAN3生成伪样本-LMLPHP

完成数据增广工作

05-15 19:50