深度学习系列41:多模态Dalle-min生成图像

1. dalle-min模型介绍

参考https://huggingface.co/flax-community/dalle-mini,可以用这个版本进行探索和学习。
dalle模型包括:

  1. 一个基于BART的编码器,将文本token转为图像token
  2. 一个基于VQGAN模型的编解码器,将图像token和图片之间互相转换

首先要训练VAGAN模型。开源的模型对于人脸重构效果不佳,期待有人做优化训练;此外还需要一个预训练好的BART模型。
训练模型包括如下几个部分:
1)将图片用VQGAN的编码器转为图像token
2)将文字用BART的编码器转为文字token
3)两者拼接后用BART的解码器转为图像toke
4)与第一步的图像token计算交叉熵,进行优化

深度学习系列41:多模态Dalle-min生成图像_第1张图片

2. 推理过程

使用训练好的BART模型将文字转为图片token,然后用训练的VQGAN模型解码器生成图片,然后用CLIP模型挑出最优的K张图片。
深度学习系列41:多模态Dalle-min生成图像_第2张图片

3. 如何使用

这里是在线测试地址:https://huggingface.co/spaces/dalle-mini/dalle-mini
这里是git地址:https://github.com/borisdayma/dalle-mini
首先要安装库:

pip install -q dalle-mini jax
pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

然后执行下面的代码

import jax
import jax.numpy as jnp
# dalle模型,资源不够的话可以用下面那个
DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"  
# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" 
DALLE_COMMIT_ID = None

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)

# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)

from dalle_mini import DalleBartProcessor
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
prompts = ["sunset over a lake in the mountains", "the Eiffel tower landing on the moon"]
tokenized_prompts = processor(prompts)
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")
# generate images
images = []
n_predictions = 8
for i in trange(n_predictions):
    # generate images
    encoded_images = model.generate(**tokenized_prompts,params=params,condition_scale=10.0)
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = vqgan.decode_code(encoded_images, params=vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        display(img)
        print()

到这里为止已经可以生成一系列图片了
深度学习系列41:多模态Dalle-min生成图像_第3张图片

接下来用clip来评分:

# CLIP model
CLIP_REPO = "openai/clip-vit-base-patch32"
CLIP_COMMIT_ID = None

# Load CLIP
clip, clip_params = FlaxCLIPModel.from_pretrained(
    CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)
clip_params = replicate(clip_params)

# score images
@partial(jax.pmap, axis_name="batch")
def p_clip(inputs, params):
    logits = clip(params=params, **inputs).logits_per_image
    return logits

from flax.training.common_utils import shard

# get clip scores
clip_inputs = clip_processor(
    text=prompts * jax.device_count(),
    images=images,
    return_tensors="np",
    padding="max_length",
    max_length=77,
    truncation=True,
).data
logits = p_clip(shard(clip_inputs), clip_params)

# organize scores per prompt
p = len(prompts)
logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()
#logits = rearrange(logits, '1 b p -> p b')

for i, prompt in enumerate(prompts):
    print(f"Prompt: {prompt}\n")
    for idx in logits[i].argsort()[::-1]:
        display(images[idx*p+i])
        print(f"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\n")
    print()

你可能感兴趣的:(深度学习系列,深度学习,python,计算机视觉)