使用NVlabs-StyleGAN训练自己的数据集

开源链接

NVlabs/stylegan

训练自己的数据集

数据处理:转换成tfrecords

在git clone仓库代码后,需要将自己的数据集转换成tfrecords,是stylegan提供了dataset_tool.py来进行这一步,命令行代码为

python dataset_tool.py create_from_images /root/autodl-tmp/stylegan/datasets/custom_dataset  /root/autodl-tmp/HE_data/trainA
# /root/autodl-tmp/stylegan/datasets/custom_dataset换成你的目标存储文件夹,/root/autodl-tmp/HE_data copy/trainA替换为你的数据集所在文件夹

此时可能会遇到版本不兼容的问题,根据错误信息修改即可,我进行的处理包括:

  • 降级 protobuf 包到兼容的版本:使用 protobuf 版本 3.20.x 或更低。
    pip install protobuf==3.20.3
    
  • 升级 TensorBoard版本
    pip install --upgrade tensorboard
    

成功后预计显示读入图片的数量,例如

Added 846 images.

修改train.py文件

修改train.py的第37行datasets部分

# Dataset.
    desc += '-custom_dataset';     dataset = EasyDict(tfrecord_dir='custom_dataset', resolution=256);                 train.mirror_augment = False
    #desc += '-ffhq';     dataset = EasyDict(tfrecord_dir='ffhq'); 

以及46-50行的gpu部分

    # Number of GPUs.
    desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}
    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

训练

随后即可开始训练

python train.py

适配的numpy版本为

pip install numpy==1.19.5

开始训练会显示

Setting up snapshot image grid...
2024-05-28 15:25:52.533438: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2024-05-28 15:25:53.760595: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2024-05-28 15:25:58.213745: W tensorflow/stream_executor/cuda/redzone_allocator.cc:312] Internal: ptxas exited with non-zero error code 65280, output: ptxas fatal   : Value 'sm_89' is not defined for option 'gpu-name'

Relying on driver to perform ptx compilation. This message will be only logged once.
Setting up run dir...
WARNING:tensorflow:From /root/autodl-tmp/stylegan/training/training_loop.py:202: The name tf.summary.FileWriter is deprecated. Please use tf.compat.v1.summary.FileWriter instead.

Training...

从 Google Drive 下载文件失败的解决

在训练过程中出现了从 Google Drive 下载文件失败的错误。程序尝试下载 inception_v3_features.pkl 文件时,因网络连接超时而失败。因此我本地下载了该文件并重启训练,具体步骤为

  • 在浏览器中访问:https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn 下载文件 inception_v3_features.pkl放置在datasets目录下。
  • 修改代码以使用本地文件
    # 修改前
    inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl
    

    修改后,使用本地文件路径

    inception = misc.load_pkl('/path/to/inception_v3_features.pkl') # 修改为你的本地文件路径
    

tensorboard 版本问题

pip install tensorboard==1.14.0

预测

python generate.py --outdir=result/ --trunc=1 --seeds=85,297,849 --network=./runs/00002-face_ori-mirror-paper1024-batch2-resumeffhq1024/network-snapshot-000600.pkl
© 2024 - 2025 Sihan Gao. All rights reserved.