Stable Diffusion 是用 LAION-5B 的子集(图像大小为512*512)训练的扩散模型。此模型冻结 CLIP 的 ViT-L/14 文本编码器建模 prompt text。模型包含 860M UNet 和123M 文本编码器,可运行在具有至少10GB VRAM 的 GPU 上。
HF主页:https://huggingface.co/CompVis/stable-diffusion
Colab:https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb
diffusers官方文档:https://huggingface.co/docs/diffusers
接下来实战一下本地部署。
conda create -n diffenv python=3.8
conda activate diffenv
pip install diffusers==0.4.0
pip install transformers scipy ftfy
# pip install "ipywidgets>=7,<8" 这个是colab用于交互输入的控件
如果后面执行代码时报错 RuntimeError: CUDA error: no kernel image is available for execution on the device
,说明cuda版本和pytorch版本问题,根据机器的 cuda 版本重新装一下:
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 -f https://download.pytorch.org/whl/torch_stable.html
获取模型:
首先得同意模型的使用协议。
如果用官方 colab,需要输入 huggingface 的 access token 来联网校验你是否同意了协议。如果不想输入的话,就执行以下命令先把模型权重等文件下载到本地:
git lfs install
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
这样加载模型时直接 DiffusionPipeline.from_pretrained("./MODEL_PATH/stable-diffusion-v1-4")
,就不用加 use_auth_token=AUTH_TOKEN
参数了。
如果要确保高精度(占显存也高),删除 revision="fp16"
和 torch_dtype=torch.float16
。
import torch, os
from diffusers import StableDiffusionPipeline
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", # 本地地址也行
revision="fp16", # 如果不想用半精度,删掉这行和下面一行
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
默认长宽都是512像素,可以指定 pipe(height=512, width=768)
来控制尺寸。需要注意的是:
prompt = "a photograph of an astronaut swimming in the river"
image = pipe(prompt).images[0] # PIL格式 (https://pillow.readthedocs.io/en/stable/)
image.save(f"astronaut_rides_horse.png")
image
刚才 3.1 部分生成的每次都不一样,若需非随机生成,则指定随机种子,pipe()
中传入 generator
参数指定 generator。
import torch
generator = torch.Generator("cuda").manual_seed(1024)
image = pipe(prompt, generator=generator).images[0]
image
使用 num_inference_steps
参数更改推理 steps。通常步数越多,结果越好,推理越慢。Stable Diffusion 比较强,只需相对较少的步骤效果就不错,因此建议使用默认值50。如图把 num_inference_steps
设成 100,随机种子保持不变,貌似效果差距并不大。
import torch
generator = torch.Generator("cuda").manual_seed(1024)
image = pipe(prompt, num_inference_steps=100, generator=generator).images[0]
image
写个做图片拼接的函数:
from PIL import Image
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
一次性生成 3 幅图,此时 prompt 为 list 而不是 str。
num_images = 3
prompt = ["a traditional Chinese painting of a squirrel eating a banana"] * num_images
images = pipe(prompt).images
grid = image_grid(images, rows=1, cols=3)
grid
依赖环境如下:
Package Version
----------------------------- ----------------------------
absl-py 1.2.0
aeppl 0.0.33
aesara 2.7.9
aiohttp 3.8.3
aiosignal 1.2.0
alabaster 0.7.12
albumentations 1.2.1
altair 4.2.0
appdirs 1.4.4
arviz 0.12.1
astor 0.8.1
astropy 4.3.1
astunparse 1.6.3
async-timeout 4.0.2
asynctest 0.13.0
atari-py 0.2.9
atomicwrites 1.4.1
attrs 22.1.0
audioread 3.0.0
autograd 1.5
Babel 2.10.3
backcall 0.2.0
beautifulsoup4 4.6.3
bleach 5.0.1
blis 0.7.8
bokeh 2.3.3
branca 0.5.0
bs4 0.0.1
CacheControl 0.12.11
cached-property 1.5.2
cachetools 4.2.4
catalogue 2.0.8
certifi 2022.9.24
cffi 1.15.1
cftime 1.6.2
chardet 3.0.4
charset-normalizer 2.1.1
click 7.1.2
clikit 0.6.2
cloudpickle 1.5.0
cmake 3.22.6
cmdstanpy 1.0.7
colorcet 3.0.1
colorlover 0.3.0
community 1.0.0b1
confection 0.0.2
cons 0.4.5
contextlib2 0.5.5
convertdate 2.4.0
crashtest 0.3.1
crcmod 1.7
cufflinks 0.17.3
cupy-cuda11x 11.0.0
cvxopt 1.3.0
cvxpy 1.2.1
cycler 0.11.0
cymem 2.0.6
Cython 0.29.32
daft 0.0.4
dask 2022.2.0
datascience 0.17.5
debugpy 1.0.0
decorator 4.4.2
defusedxml 0.7.1
descartes 1.1.0
diffusers 0.4.0
dill 0.3.5.1
distributed 2022.2.0
dlib 19.24.0
dm-tree 0.1.7
docutils 0.17.1
dopamine-rl 1.0.5
earthengine-api 0.1.326
easydict 1.10
ecos 2.0.10
editdistance 0.5.3
en-core-web-sm 3.4.0
entrypoints 0.4
ephem 4.1.3
et-xmlfile 1.1.0
etils 0.8.0
etuples 0.3.8
fa2 0.3.5
fastai 2.7.9
fastcore 1.5.27
fastdownload 0.0.7
fastdtw 0.3.4
fastjsonschema 2.16.2
fastprogress 1.0.3
fastrlock 0.8
feather-format 0.4.1
filelock 3.8.0
firebase-admin 4.4.0
fix-yahoo-finance 0.0.22
Flask 1.1.4
flatbuffers 22.9.24
folium 0.12.1.post1
frozenlist 1.3.1
fsspec 2022.8.2
ftfy 6.1.1
future 0.16.0
gast 0.5.3
GDAL 2.2.2
gdown 4.4.0
gensim 3.6.0
geographiclib 1.52
geopy 1.17.0
gin-config 0.5.0
glob2 0.7
google 2.0.3
google-api-core 1.31.6
google-api-python-client 1.12.11
google-auth 1.35.0
google-auth-httplib2 0.0.4
google-auth-oauthlib 0.4.6
google-cloud-bigquery 1.21.0
google-cloud-bigquery-storage 1.1.2
google-cloud-core 1.0.3
google-cloud-datastore 1.8.0
google-cloud-firestore 1.7.0
google-cloud-language 1.2.0
google-cloud-storage 1.18.1
google-cloud-translate 1.5.0
google-colab 1.0.0
google-pasta 0.2.0
google-resumable-media 0.4.1
googleapis-common-protos 1.56.4
googledrivedownloader 0.4
graphviz 0.10.1
greenlet 1.1.3
grpcio 1.49.1
gspread 3.4.2
gspread-dataframe 3.0.8
gym 0.25.2
gym-notices 0.0.8
h5py 3.1.0
HeapDict 1.0.1
hijri-converter 2.2.4
holidays 0.16
holoviews 1.14.9
html5lib 1.0.1
httpimport 0.5.18
httplib2 0.17.4
httplib2shim 0.0.3
httpstan 4.6.1
huggingface-hub 0.10.0
humanize 0.5.1
hyperopt 0.1.2
idna 2.10
imageio 2.9.0
imagesize 1.4.1
imbalanced-learn 0.8.1
imblearn 0.0
imgaug 0.4.0
importlib-metadata 5.0.0
importlib-resources 5.9.0
imutils 0.5.4
inflect 2.1.0
intel-openmp 2022.2.0
intervaltree 2.1.0
ipykernel 5.3.4
ipython 7.9.0
ipython-genutils 0.2.0
ipython-sql 0.3.9
ipywidgets 7.7.1
itsdangerous 1.1.0
jax 0.3.21
jaxlib 0.3.20+cuda11.cudnn805
jedi 0.18.1
jieba 0.42.1
Jinja2 2.11.3
joblib 1.2.0
jpeg4py 0.1.4
jsonschema 4.3.3
jupyter-client 6.1.12
jupyter-console 6.1.0
jupyter-core 4.11.1
jupyterlab-widgets 3.0.3
kaggle 1.5.12
kapre 0.3.7
keras 2.8.0
Keras-Preprocessing 1.1.2
keras-vis 0.4.1
kiwisolver 1.4.4
korean-lunar-calendar 0.3.1
langcodes 3.3.0
libclang 14.0.6
librosa 0.8.1
lightgbm 2.2.3
llvmlite 0.39.1
lmdb 0.99
locket 1.0.0
logical-unification 0.4.5
LunarCalendar 0.0.9
lxml 4.9.1
Markdown 3.4.1
MarkupSafe 2.0.1
marshmallow 3.18.0
matplotlib 3.2.2
matplotlib-venn 0.11.7
miniKanren 1.0.3
missingno 0.5.1
mistune 0.8.4
mizani 0.7.3
mkl 2019.0
mlxtend 0.14.0
more-itertools 8.14.0
moviepy 0.2.3.5
mpmath 1.2.1
msgpack 1.0.4
multidict 6.0.2
multipledispatch 0.6.0
multitasking 0.0.11
murmurhash 1.0.8
music21 5.5.0
natsort 5.5.0
nbconvert 5.6.1
nbformat 5.6.1
netCDF4 1.6.1
networkx 2.6.3
nibabel 3.0.2
nltk 3.7
notebook 5.3.1
numba 0.56.2
numexpr 2.8.3
numpy 1.21.6
oauth2client 4.1.3
oauthlib 3.2.1
okgrade 0.4.3
opencv-contrib-python 4.6.0.66
opencv-python 4.6.0.66
opencv-python-headless 4.6.0.66
openpyxl 3.0.10
opt-einsum 3.3.0
osqp 0.6.2.post0
packaging 21.3
palettable 3.3.0
pandas 1.3.5
pandas-datareader 0.9.0
pandas-gbq 0.13.3
pandas-profiling 1.4.1
pandocfilters 1.5.0
panel 0.12.1
param 1.12.2
parso 0.8.3
partd 1.3.0
pastel 0.2.1
pathlib 1.0.1
pathy 0.6.2
patsy 0.5.2
pep517 0.13.0
pexpect 4.8.0
pickleshare 0.7.5
Pillow 7.1.2
pip 21.1.3
pip-tools 6.2.0
plotly 5.5.0
plotnine 0.8.0
pluggy 0.7.1
pooch 1.6.0
portpicker 1.3.9
prefetch-generator 1.0.1
preshed 3.0.7
prettytable 3.4.1
progressbar2 3.38.0
promise 2.3
prompt-toolkit 2.0.10
prophet 1.1.1
protobuf 3.17.3
psutil 5.4.8
psycopg2 2.9.3
ptyprocess 0.7.0
py 1.11.0
pyarrow 6.0.1
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycocotools 2.0.5
pycparser 2.21
pyct 0.4.8
pydantic 1.9.2
pydata-google-auth 1.4.0
pydot 1.3.0
pydot-ng 2.0.0
pydotplus 2.0.2
PyDrive 1.3.1
pyemd 0.5.1
pyerfa 2.0.0.1
Pygments 2.6.1
pygobject 3.26.1
pylev 1.4.0
pymc 4.1.4
PyMeeus 0.5.11
pymongo 4.2.0
pymystem3 0.2.0
PyOpenGL 3.1.6
pyparsing 3.0.9
pyrsistent 0.18.1
pysimdjson 3.2.0
pysndfile 1.3.8
PySocks 1.7.1
pystan 3.3.0
pytest 3.6.4
python-apt 0.0.0
python-chess 0.23.11
python-dateutil 2.8.2
python-louvain 0.16
python-slugify 6.1.2
python-utils 3.3.3
pytz 2022.4
pyviz-comms 2.2.1
PyWavelets 1.3.0
PyYAML 6.0
pyzmq 23.2.1
qdldl 0.1.5.post2
qudida 0.0.4
regex 2022.6.2
requests 2.23.0
requests-oauthlib 1.3.1
resampy 0.4.2
rpy2 3.4.5
rsa 4.9
scikit-image 0.18.3
scikit-learn 1.0.2
scipy 1.7.3
screen-resolution-extra 0.0.0
scs 3.2.0
seaborn 0.11.2
Send2Trash 1.8.0
setuptools 57.4.0
setuptools-git 1.2
Shapely 1.8.4
six 1.15.0
sklearn-pandas 1.8.0
smart-open 5.2.1
snowballstemmer 2.2.0
sortedcontainers 2.4.0
soundfile 0.11.0
spacy 3.4.1
spacy-legacy 3.0.10
spacy-loggers 1.0.3
Sphinx 1.8.6
sphinxcontrib-serializinghtml 1.1.5
sphinxcontrib-websupport 1.2.4
SQLAlchemy 1.4.41
sqlparse 0.4.3
srsly 2.4.4
statsmodels 0.12.2
sympy 1.7.1
tables 3.7.0
tabulate 0.8.10
tblib 1.7.0
tenacity 8.1.0
tensorboard 2.8.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorflow 2.8.2+zzzcolab20220929150707
tensorflow-datasets 4.6.0
tensorflow-estimator 2.8.0
tensorflow-gcs-config 2.8.0
tensorflow-hub 0.12.0
tensorflow-io-gcs-filesystem 0.27.0
tensorflow-metadata 1.10.0
tensorflow-probability 0.16.0
termcolor 2.0.1
terminado 0.13.3
testpath 0.6.0
text-unidecode 1.3
textblob 0.15.3
thinc 8.1.2
threadpoolctl 3.1.0
tifffile 2021.11.2
tokenizers 0.12.1
toml 0.10.2
tomli 2.0.1
toolz 0.12.0
torch 1.12.1+cu113
torchaudio 0.12.1+cu113
torchsummary 1.5.1
torchtext 0.13.1
torchvision 0.13.1+cu113
tornado 5.1.1
tqdm 4.64.1
traitlets 5.1.1
transformers 4.22.2
tweepy 3.10.0
typeguard 2.7.1
typer 0.4.2
typing-extensions 4.1.1
tzlocal 1.5.1
ujson 5.5.0
uritemplate 3.0.1
urllib3 1.24.3
vega-datasets 0.9.0
wasabi 0.10.1
wcwidth 0.2.5
webargs 8.2.0
webencodings 0.5.1
Werkzeug 1.0.1
wheel 0.37.1
widgetsnbextension 3.6.1
wordcloud 1.8.2.2
wrapt 1.14.1
xarray 0.20.2
xarray-einstats 0.2.2
xgboost 0.90
xkit 0.0.0
xlrd 1.1.0
xlwt 1.3.0
yarl 1.8.1
yellowbrick 1.5
zict 2.2.0
zipp 3.8.1