使用NVlabs-StyleGAN训练自己的数据集
开源链接
训练自己的数据集
数据处理:转换成tfrecords
在git clone仓库代码后,需要将自己的数据集转换成tfrecords,是stylegan提供了dataset_tool.py来进行这一步,命令行代码为
python dataset_tool.py create_from_images /root/autodl-tmp/stylegan/datasets/custom_dataset /root/autodl-tmp/HE_data/trainA
# /root/autodl-tmp/stylegan/datasets/custom_dataset换成你的目标存储文件夹,/root/autodl-tmp/HE_data copy/trainA替换为你的数据集所在文件夹
此时可能会遇到版本不兼容的问题,根据错误信息修改即可,我进行的处理包括:
- 降级 protobuf 包到兼容的版本:使用 protobuf 版本 3.20.x 或更低。
pip install protobuf==3.20.3
- 升级 TensorBoard版本
pip install --upgrade tensorboard
成功后预计显示读入图片的数量,例如
Added 846 images.
修改train.py文件
修改train.py的第37行datasets部分
# Dataset.
desc += '-custom_dataset'; dataset = EasyDict(tfrecord_dir='custom_dataset', resolution=256); train.mirror_augment = False
#desc += '-ffhq'; dataset = EasyDict(tfrecord_dir='ffhq');
以及46-50行的gpu部分
# Number of GPUs.
desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}
#desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
#desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
#desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}
训练
随后即可开始训练
python train.py
适配的numpy版本为
pip install numpy==1.19.5
开始训练会显示
Setting up snapshot image grid...
2024-05-28 15:25:52.533438: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2024-05-28 15:25:53.760595: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2024-05-28 15:25:58.213745: W tensorflow/stream_executor/cuda/redzone_allocator.cc:312] Internal: ptxas exited with non-zero error code 65280, output: ptxas fatal : Value 'sm_89' is not defined for option 'gpu-name'
Relying on driver to perform ptx compilation. This message will be only logged once.
Setting up run dir...
WARNING:tensorflow:From /root/autodl-tmp/stylegan/training/training_loop.py:202: The name tf.summary.FileWriter is deprecated. Please use tf.compat.v1.summary.FileWriter instead.
Training...
从 Google Drive 下载文件失败的解决
在训练过程中出现了从 Google Drive 下载文件失败的错误。程序尝试下载 inception_v3_features.pkl 文件时,因网络连接超时而失败。因此我本地下载了该文件并重启训练,具体步骤为
- 在浏览器中访问:https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn 下载文件 inception_v3_features.pkl放置在datasets目录下。
- 修改代码以使用本地文件
# 修改前 inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl
修改后,使用本地文件路径
inception = misc.load_pkl('/path/to/inception_v3_features.pkl') # 修改为你的本地文件路径
tensorboard 版本问题
pip install tensorboard==1.14.0
预测
python generate.py --outdir=result/ --trunc=1 --seeds=85,297,849 --network=./runs/00002-face_ori-mirror-paper1024-batch2-resumeffhq1024/network-snapshot-000600.pkl