Stablediffusion模型diffusesr格式和ckpt格式相互转换

  参考资料:

  diffusers的源码 [github]

  因为小博客可能看的人很少,所以我写的啰嗦一点,想直接看如何互相转换的朋友可以直接转到文末的代码段。

  当你在学习Stablediffusion这个开源的t2i模型时,不可避免地会碰到两种模型权重的存储格式,即diffusers格式和ckpt格式:

Stablediffusion模型diffusesr格式和ckpt格式相互转换_第1张图片

  如上图所示,这是一个hugging face的仓库,仓库里有文件夹和.ckpt文件.safetensors文件。这个截图同时包含了两种格式的权重,那么哪些属于diffusers格式?哪些属于ckpt格式?

  答:明摆着的,后缀为.ckpt的文件就是ckpt格式。而diffusers格式其实包含:feature_extractor, scheduler, text_encoder, tokenizer, unet, vae这些文件夹以及model_index.json这个文件。大的二进制文件主要位于text_encoder,unet和vae文件夹下。

  safetensors是ckpt转换得到的,防止别有用心之人在ckpt文件中加入恶意代码。safetensor和ckpt文件都能直接用于AUTOMATIC111这个为T2I模型开发的WebUI上,而diffusers不行。而diffusers提供的一些代码example又非常有借鉴意义,但一旦使用,其存储类型是一系列目录,这就催生了两种各格式相互转化的需求。

  怎么转化?使用diffusers官方提供的转化脚本即可,这些转换脚本在diffuser源代码仓库的scripts文件夹下:主要涉及两个文件:convert_original_stable_diffusion_to_diffusers.py和convert_diffusers_to_original_stable_diffusion.py。

  diffusers to ckpts

  使用convert_diffusers_to_original_stable_diffusion.py脚本,一个典型的使用场景是你训练了一个模型,然后就想把这个模型转换成.ckpt文件放到webui上。使用范例如下:

python convert_diffusers_to_original_stable_diffusion.py --model_path model_dir --checkpoint_path path_to_ckpt.ckpt

  还有两个参数--half和--use_safetensors,如果加上就会把数值存为fp16以及把ckpt存为safetensors

  ckpts to diffusers

  使用convert_original_stable_diffusion_to_diffusers.py脚本,一个典型的使用场景是把ckpt文件(自己训练的或者是SD官方发布的)解包成diffusers的目录形式上传hugging face或者是自己使用。需要注意的是diffuser的目录存储形式实际上提供了比ckpt文件更多的(实际上多得多)信息。为什么?因为ckpt是pytorch使用的存储权重的二进制文件,你在load_state_dict的时候需要先初始化model,然后再load权重,但是ckpts文件中没有这些信息(或很少),就只能靠脚本来推断到底用了哪些参数初始化的model。所以你会发现diffusers to ckpts的脚本很简单,ckpts to diffusers的脚本很复杂。下面提供一个范例:

python convert_original_stable_diffusion_to_diffusers.py --checkpoint_path path_to_ckpt.ckpt --dump_path model_dir --image_size 512 --prediction_type epsilon

  在使用v2-base或v2.1-base的时候,一定要加上--image_size 512和--prediction_type epsilon这两个参数,不然脚本就会推断错模型的类型(ckpt中的信息太少)。由于脚本已经写的很完备了,大多数情况下只需要--checkpoint_path和--dump_path两个参数就能正常完成转换。

  最后,diffusers这个代码仓库日新月异,每天都在高强度更新,所以读者需要多git pull安装最新版本,这两个转换脚本什么时候发生变化也不能保证滴。

你可能感兴趣的:(stable,diffusion,深度学习,人工智能,机器学习,pytorch)