最近需要用到 DALL·E的推断功能,在现有开源代码基础上发现还有几个问题需要注意,谨以此篇博客记录之。
我用的源码主要是 https://github.com/borisdayma/dalle-mini 仓库中的Inference pipeline.ipynb 文件。
运行环境:Ubuntu服务器
⚠️注意:本博客仅涉及 DALL · E 推断,不涉及训练过程。
建议使用anaconda新建一个dalle环境,然后在该环境中进行相关配置,避免与环境中的其他库产生版本冲突。
使用下述命令新建名为dalle的环境:
conda create -n dalle python==3.8.0
在终端分别运行下述命令,安装所需的python库:
# 安装 dalle运行需要的依赖库(注意版本只能是0.3.25)# Required only for colab environments + GPU
pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 安装 dalle特定的库
pip install dalle-mini
# 安装 VQGAN
pip install -q git+https://github.com/patil-suraj/vqgan-jax.git
PS:如果由于网络连接问题无法通过pip
命令下载VQGAN,就采取Plan-B:将仓库 https://github.com/patil-suraj/vqgan-jax 下载到服务器并解压,然后使用cd
命令将当前目录到对应的仓库下载路径下,在终端运行python setup.py install
安装VQGAN即可。
由于网络连接问题,我采取「事先把模型下载到本地」的策略对模型进行直接调用,首先要明确的一点是,本项目中使用DALL · E 对图像进行编码,使用VQGAN对图像进行解码,所以我们需要分别下载DALL · E 和 VQGAN 两个模型。
DALL · E 模型下载地址:
mini版本:https://huggingface.co/dalle-mini/dalle-mini/tree/main
mega版本:https://huggingface.co/dalle-mini/dalle-mega/tree/main
VQGAN 模型下载地址:
https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main
下载完毕后,将模型部署到服务器,注意保存路径。
相较于ipynb文件,我个人更加喜欢操作py文件,所以对于给定的ipynb文件,首先使用命令jupyter nbconvert --to script Inference pipeline.ipynb
将其转为同名py文件,该文件的主要内容如下(不含CLIP排序部分),其中模型路径 DALLE_MODEL和VQGAN_REPO 已改为本地路径(就是第二步中两个模型的保存路径),可以看到文件的注释也比较详细。
# dalle-mini
DALLE_MODEL = "/newdata/SD/dalle-mini/dalle-mini"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "/newdata/SD/dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
import jax
import jax.numpy as jnp
# check how many devices are available
jax.local_device_count()
# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)
# Model parameters are replicated on each device for faster inference.
from flax.jax_utils import replicate
params = replicate(params)
vqgan_params = replicate(vqgan_params)
# Model functions are compiled and parallelized to take advantage of multiple devices.
from functools import partial
# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
return model.generate(
**tokenized_prompt,
prng_key=key,
params=params,
top_k=top_k,
top_p=top_p,
temperature=temperature,
condition_scale=condition_scale,
)
# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
return vqgan.decode_code(indices, params=params)
# Keys are passed to the model on each device to generate unique inference per device.
import random
# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
# ## Text Prompt
# Our model requires processing prompts.
from dalle_mini import DalleBartProcessor
# from transformers import AutoProcessor
processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID) # force_download=True, , local_only=True
# Let's define some text prompts
prompts = [
"sunset over a lake in the mountains",
"the Eiffel tower landing on the moon",
]
# print(prompts)
# Note: we could use the same prompt multiple times for faster inference.
tokenized_prompts = processor(prompts)
# Finally we replicate the prompts onto each device.
tokenized_prompt = replicate(tokenized_prompts)
# ## We generate images using dalle-mini model and decode them with the VQGAN.
# number of predictions per prompt
n_predictions = 8
# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0 # 越高,生成的图像越接近 prompt
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange
print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
# get a new key
key, subkey = jax.random.split(key) # jax.device_count()=1,returns the number of available jax devices
# generate images
encoded_images = p_generate(
tokenized_prompt,
shard_prng_key(subkey),
params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
decoded_images = p_decode(encoded_images, vqgan_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for idx, decoded_img in enumerate(decoded_images):
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
images.append(img)
...
使用命令 python /newdata/SD/inference_dalle-mini.py
运行程序。理想情况下就能够直接得到dalle生成的图像啦!
由于外部环境因素和一些不当操作,本人在运行该程序过程中还是遇到一些问题,主要有三个,在此将抱错信息与解决方法一并分享给大家。
...
requests.exceptions.ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /dalle-mini/dalle-mini/resolve/main/enwiki-words-frequency.txt (Caused by ConnectTimeoutError(, 'Connection to huggingface.co timed out. (connect timeout=10)'))" ), '(Request ID: 61b7c191-3fb8-4dfa-9025-e9acd4ee4d28)')
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/newdata/SD/inference_dalle-mini.py", line 84, in <module>
processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID) # force_download=True, , local_only=True
File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/utils.py", line 25, in from_pretrained
return super(PretrainedFromWandbMixin, cls).from_pretrained(
File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 62, in from_pretrained
return cls(tokenizer, config.normalize_text, config.max_text_length)
File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 21, in __init__
self.text_processor = TextNormalizer()
File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 215, in __init__
self._hashtag_processor = HashtagProcessor()
File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 25, in __init__
# wiki_word_frequency = hf_hub_download(
File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
return fn(*args, **kwargs)
File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1363, in hf_hub_download
raise LocalEntryNotFoundError(
huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.
顺着上面的报错信息,定位到/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py
文件的如下内容:
...
class HashtagProcessor:
# Adapted from wordninja library
# We use our wikipedia word count + a good heuristic to make it work
def __init__(self):
wiki_word_frequency = hf_hub_download(
"dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
)
self._word_cost = (
l.split()[0]
for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
)
...
于是问题的根源就在于,程序运行到这里时,没有找到本地的enwiki-words-frequency.txt
文件(经检查该文件其实是存在本地的,不知为何没有找到,很迷),于是尝试通过联网从huggingface官网下载,但由于网络状况欠佳,联网失败,于是报错。解决办法如下:
...
class HashtagProcessor:
# Adapted from wordninja library
# We use our wikipedia word count + a good heuristic to make it work
def __init__(self):
wiki_word_frequency = "/newdata/SD/dalle-mini/dalle-mini/enwiki-words-frequency.txt"
self._word_cost = (
l.split()[0]
for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
)
...
也就是将enwiki-words-frequency.txt
文件的本地路径直接赋值给wiki_word_frequency
变量,其余部份保持不变,问题解决。
FIx for "Couldn't invoke ptxas --version"
这个错误的产生是不同python库安装时带来的版本冲突导致的,DALLE-mini要求jax和jaxlib版本必须为0.3.25,但是通过pip imstall dalle-mini 命令安装后的jaxlib版本为0.4.13,但使用pip install jaxlib
的方式并不能找到0.3.25版本的jaxlib,而且会产生与flax、orbax-checkpoint等其他库的版本不兼容问题……在尝试多种方法合理降低jaxlib版本均失败后,发现答案就在ipynb中……也就是:pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
启示:要以官方说明文档为主,可以少走很多弯路!!!
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/newdata/SD/inference_dalle-mini.py", line 130, in <module>
decoded_images = p_decode(encoded_images, vqgan_params)
ValueError: pmap got inconsistent sizes for array axes to be mapped:
* most axes (101 of them) had size 512, e.g. axis 0 of argument params['decoder']['conv_in']['bias'] of type float32[512];
* some axes (71 of them) had size 3, e.g. axis 0 of argument params['decoder']['conv_in']['kernel'] of type float32[3,3,256,512];
* some axes (69 of them) had size 256, e.g. axis 0 of argument params['decoder']['up_1']['block_0']['norm1']['bias'] of type float32[256];
* some axes (67 of them) had size 128, e.g. axis 0 of argument params['decoder']['norm_out']['bias'] of type float32[128];
* some axes (35 of them) had size 1, e.g. axis 0 of argument indices of type int32[1,2,256];
* one axis had size 16384: axis 0 of argument params['quantize']['embedding']['embedding'] of type float32[16384,256]
后来发现,是因为之前调试的时候不小心把下面这行代码注释掉了……这个bug排得最辛苦,还挺无语的
vqgan_params = replicate(vqgan_params)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
2023-11-07 11:30:35.139851: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.257514: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.258648: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.628768: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
2023-11-07 11:30:35.628915: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 525.53.0 does not match DSO version 530.41.3 -- cannot find working devices in this configuration
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon']
0%| | 0/8 [00:00<?, ?it/s]
/root/anaconda3/envs/dalle/lib/python3.8/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype=float32. In future JAX releases this will result in an error.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
后记:第一次接触到基于jax框架编写的程序,还挺新鲜的,感觉和pytorch有一些不一样的地方。了解到jax是tensorflow的轻量级版本。上述博客内容中如果有个人理解不当之处,还望各位批评指正!
参考链接