初学者体验扩散模型

可以在这个网址下载代码,里面有很多现有的,比如文字生成图像,图像生成图像

https://github.com/huggingface/diffusers

因为扩散模型训练起来很慢,不一定每个人都可以训练出来,所以他们提供了现成的模型,可以直接调用,就很爽。下面这个网址就是所有的模型汇总的,不仅仅局限于扩散模型。下面我来演示在服务器上用自己的数据训练模型。

Models - Hugging Face

1.下载模型源码

可以直接进入第一个链接去下载,也可以在服务器上输入如下命令:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

下包之前,最后自己手动下载torch,指定版本,不然就是最新版。

#这两个是不同版本的torch,对应不同版本的cuda
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html

pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 -f https://download.pytorch.org/whl/torch_stable.html

这里还不能下载其他包,按照你要做的扩散模型来下。

2.下载安装包

以图像生成图像为例:

初学者体验扩散模型_第1张图片

进入example,unconditional_image_generation,里面只有三个文件,就是图像中的后三个。前面那两个是我自己建的。

下载这里的requirements中的包,并在unconditional_image_generation中导入自己的数据集。如果没有自己的数据集,可以用该网站自带的。

数据集要求格式如下:

data_dir/xxx.png
data_dir/xxy.png
data_dir/[...]/xxz.png

3.修改参数

进入train_unconditional.py,找到main函数,这些看自己情况修改。参数有很多。

  --train_data_dir="imgs" \#数据集
  --resolution=64 \数据集的size大小,代码会把你的数据集里所有的图像压缩成这个大小,
#而且也是生成图像的大小
  --output_dir="ddpm-ema-flowers-64" \模型位置
  --train_batch_size=16 \
  --num_epochs=100 \

4.配置训练环境

这个扩散模型还需要额外的修改下环境配置,如下所示

accelerate config

初学者体验扩散模型_第2张图片

你可以照我这么来弄,也可以按照选项来。

5.训练 

这个起码训练10h+,弄个nohup。

nohup accelerate launch train_unconditional.py > ./output.log 2>&1 &

6.使用model

新建一个generate.py,改一下model_id,就可以用了

# !pip install diffusers
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
import os
model_id = "ddpm-model-64"
#生成的图像放的位置
img_path = 'results'+'/'+model_id+'-img'
if not os.path.exists(img_path): os.mkdir(img_path)
device = "cuda"

# load model and scheduler
ddpm = DDPMPipeline.from_pretrained(model_id)  # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
ddpm.to(device)
for i in range(100):
    # run pipeline in inference (sample random noise and denoise)
    image = ddpm().images[0]
    # save image
    #不修改格式
    #image.save(os.path.join(img_path,f'{i}.png'))
    #改成单通道
    image.convert('L').save(os.path.join(img_path,f'{i}.png'))
    #看看跑到哪里了
    if i%10==0:print(f"i={i}")

你可能感兴趣的:(python,深度学习,图像处理)