由于大规模模型的端到端的训练,视觉与语言的预训练模型的成本越来越高。本文提出了BLIP-2,这是一种通用的有效的预训练策略,它从现成的冷冻预训练图像编码器与大型的语言模型中引导视觉语言预训练。BLIP-2通过一个轻量级的查询transformer弥补了模态差距,该transformer分为两个阶段进行预训练:第一个阶段从冷冻的图像编码器中引导视觉语言representation learning。第二阶段从一个固定的语言模型中引导视觉到语言的生成学习。
优点:BLIP-2 achieves state-of-the-art performance on various vision-language tasks, despite having significantly fewer trainable parameters than existing methods. For example, our model outperforms Flamingo80B by 8.7% on zero-shot VQAv2 with 54x fewer trainable parameters. We also demonstrate the model’s emerging capabilities of zero-shot image-to-text generation that can follow natural language instructions
Overview of BLIP-2 two-stage pre-training strategy
Overview of Q-Former and the first stage of vision-language representation learning in BLIP-2
pip install salesforce-lavis
import torch
from PIL import Image
import requests
from lavis.models import load_model_and_preprocess
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
display(raw_image.resize((596, 437)))
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# we associate a model with its preprocessors to make it easier for inference.
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device
# dict_keys(['train', 'eval'])
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
model.generate({"image": image})
输出 :# ‘singapore’
## 由于核采样的不确定性,你可能会得到不同的标题。
model.generate({"image": image}, use_nucleus_sampling=True, num_captions=3)
Ask the model to explain its answer.
model.generate({"image": image, "prompt": "Question: which city is this? Answer:"})
"image": image,
"prompt": "Question: which city is this? Answer: singapore. Question: why?"})
[‘it has a statue of a merlion’]
"image": image,
"prompt": "Question: which city is this? Answer: singapore. Question: why?"})
# 'it has a statue of a merlion'
context = [
("which city is this?", "singapore"),
("why?", "it has a statue of a merlion"),
question = "where is the name merlion coming from?"
template = "Question: {} Answer: {}."
prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + question + " Answer:"
# generate model's response
model.generate({"image": image,"prompt": prompt})
# 'merlion is a portmanteau of mermaid and lion'
BLIP2 models are very large to load, so I use some techniques such as init_empty_weights.
And in order to submit within 9 hours, a beam width of beam search in decoder is reduced to 3.
# locally downloaded salesforce-lavis
!pip install salesforce-lavis --no-index --find-links=file:///kaggle/input/lavis-pip/
# in order to load local weights files, modified version of salesforce-lavis is required. so firstly uninstall.
!pip uninstall -y salesforce-lavis
# and install modified salesforce-lavis
!pip install salesforce-lavis --no-index --find-links=file:///kaggle/input/lavis-mod-wheel/salesforce_lavis-1.0.0.dev1-py3-none-any.whl
import os
import gc
import cv2
import sys
import torch
import numpy as np
import torch.nn as nn
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
from PIL import Image
from lavis.models import load_model, load_preprocess, load_model_and_preprocess
from lavis.processors import load_processor
from lavis.models.blip2_models.blip2_opt import Blip2OPT
from typing import Dict
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from accelerate import init_empty_weights
from sentence_transformers import SentenceTransformer, models
# these helper functions are based on the following repository.
# https://github.com/FrancescoSaverioZuppichini/Loading-huge-PyTorch-models-with-linear-memory-consumption/blob/main/README.md
def get_keys_to_submodule(model: nn.Module) -> Dict[str, nn.Module]:
keys_to_submodule = {}
for submodule_name, submodule in model.named_modules():
for param_name, param in submodule.named_parameters():
splitted_param_name = param_name.split('.')
is_leaf_param = len(splitted_param_name) == 1
if is_leaf_param:
if submodule_name != '':
key = f"{submodule_name}.{param_name}"
key = param_name
keys_to_submodule[key] = submodule
return keys_to_submodule
def load_state_dict_with_low_memory(model: nn.Module, state_dict: Dict[str, torch.Tensor]):
keys_to_submodule = get_keys_to_submodule(model)
for key, submodule in keys_to_submodule.items():
val = state_dict.get(key)
if val is not None:
param_name = key.split('.')[-1]
param_dtype = getattr(submodule, param_name).dtype
val = val.to(param_dtype)
new_val = torch.nn.Parameter(val, requires_grad=False)
setattr(submodule, param_name, new_val)
comp_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with init_empty_weights():
my_model = Blip2OPT(opt_model="facebook/opt-2.7b")
class DictWrapper:
def __init__(self, d):
self.dict = d
def __getattr__(self, name):
return self.dict[name]
def get(self, name, default_val=None):
return self.dict.get(name, default_val)
dict_tr = {
"name": "blip_image_train",
"image_size": 224
dict_ev = {
"name": "blip_image_eval",
"image_size": 224
dict_t = {
"name": "blip_caption"
config = {
vis_processors = load_preprocess(config)[0]
load_state_dict_with_low_memory(my_model, torch.load("/kaggle/input/blip2-pretrained-opt27b-sdpth/blip2_pretrained_opt2.7b_sd.pth"))
images = os.listdir(comp_path / 'images')
pred_prompt_list = []
for image_name in images:
image = Image.open(comp_path / 'images' / image_name).convert('RGB')
image = vis_processors["eval"](image).unsqueeze(0).to(device)
#产生标题(num_beans = 3)将可能的标题都给产出
pred_prompt = my_model.generate({"image": image}, num_beams=3)
del my_model
st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')
prompt_embeddings = st_model.encode(pred_prompt_list, batch_size=256).flatten()
imgIds = [i.split('.')[0] for i in images]
eIds = list(range(EMBEDDING_LENGTH))
imgId_eId = [
'_'.join(map(str, i)) for i in zip(
np.repeat(imgIds, EMBEDDING_LENGTH),
np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]
submission = pd.DataFrame(
