BLIP:Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generat 论文的两个贡献如下:
An MED can operate either as a unimodal encoder, or an image-grounded text encoder, or an image-grounded text decoder.
We finetune a pre-trained MED into two modules: a captioner to produce synthetic captions given web images, and a filter to remove noisy captions from both the original web texts and the synthetic texts.
定义了一个处理训练集的类,继承PyTorch中用于处理数据集的基类Dataset,通常情况下,自定义的Dataset类需要实现两个方法:__ len__
和__ getitem__
:
class coco_karpathy_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''省略部分代码'''
# 给每个图像进行编号,编号方式:
# image_id:n
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
# 之前用函数加载了annotation文件:
# self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
# self.annotation是一个数组,数组中的每个元素是一个dict,如:
# [{"caption": "A woman wearing a net on her head cutting a cake. ",
# "image": "val2014/COCO_val2014_000000522418.jpg", "image_id": "coco_522418"},
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
# Image是一个Python图像处理库,常用于图像的加载、处理和保存操作。
image = Image.open(image_path).convert('RGB')
# 对图像对变换
image = self.transform(image)
# prompt + 对caption进行预处理后 得到新的caption
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
# 返回transform后的图形、处理后的caption、图像对应的编号
return image, caption, self.img_ids[ann['image_id']]
附上pre_caption
函数代码:
def pre_caption(caption,max_words=50):
# 把这些符号:.!\"()*#:;~ 替换为空格,并且将caption全部转换为小写字母
caption = re.sub(
r"([.!\"()*#:;~])",
' ',
caption.lower(),
)
# 将连续出现两个或更多空格的地方替换为单个空格
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
# 去掉caption末尾的换行符
caption = caption.rstrip('\n')
# 去掉caption 两边的空格
caption = caption.strip(' ')
#truncate caption
caption_words = caption.split(' ')
if len(caption_words)>max_words: # 如果超过了max_words,就只取前max_words个单词
caption = ' '.join(caption_words[:max_words])
return caption
# 定义 normalize
# transforms.Normalize()函数接受两个参数,分别是均值(mean)和标准差(std)
# 均值(mean)和标准差(std) 这些参数是根据训练数据集的特征计算得出的。
# 分别对应三个通道(R、G、B)
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# 对训练集进行的transform
transform_train = transforms.Compose([
# 根据给定的 image_size 进行scale,以及使用BICUBIC插值方法进行图像的插值填充
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
# 随机水平翻转
transforms.RandomHorizontalFlip(),
# 自定义的 RandomAugment 函数,下面会做记录
# Identity(无操作)、AutoContrast(自动对比度调整)、Brightness(亮度调整)、
# Sharpness(锐度调整)、Equalize(直方图均衡化)、ShearX(X轴方向的错切变换)、
# ShearY(Y轴方向的错切变换)、TranslateX(X轴方向的平移变换)、
# TranslateY(Y轴方向的平移变换)、Rotate(旋转变换)
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
# 将图像数据转换为PyTorch张量的格式
transforms.ToTensor(),
normalize,
])
当使用BICUBIC插值方法进行图像插值填充时,原始图像上的像素值被用于计算新图像上每个像素的值。通过计算原始图像中像素的加权平均值,BICUBIC插值可以提供更平滑和连续的图像结果。
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
# 是否是PIL格式的图像
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
# 从augs这个数组中随机选择N个存储在 sampled_ops 列表中
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
# 将PIL图像对象转换为NumPy数组形式
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
# 根据概率判断是否应用当前的增强操作
if np.random.random() > prob:
continue
args = arg_dict[name](level)
# 这个 *args 包括 上一行代码得到的(level, replace_value)
img = func_dict[name](img, *args)
return img
__call__函数是Python中的特殊方法(special method),用于使对象可以像函数一样被调用,当调用该实例时,会自动执行__call__方法,并按照其中的逻辑进行
有很多ops操作,只选择一个记录TranslateX
:
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
func_dict = {
'''省略部分代码'''
'TranslateX': translate_x_func,
'''省略部分代码'''
}
def translate_x_func(img, offset, fill=(0, 0, 0)):
# offset:水平平移的偏移量,表示图像将向右平移的像素数。
# fill:边界填充的颜色,默认为(0, 0, 0),表示黑色填充
'''
same output as PIL.Image.transform
'''
# 这个img已经是numpy数组了
H, W = img.shape[0], img.shape[1]
# 平移矩阵M
M = np.float32([[1, 0, -offset], [0, 1, 0]])
# 对输入图像进行仿射变换,将平移矩阵M应用于图像
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
arg_dict = {
'''省略部分代码'''
'TranslateX': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'''省略部分代码'''
}
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level): # 将level转换为一组用于平移操作的参数
# 将传入的level除以MAX_LEVEL,然后乘以translate_const,得到一个平移的具体数值
level = (level / MAX_LEVEL) * float(translate_const)
# 以50%的概率将平移的数值取反,实现随机选择正向或负向平移
if np.random.random() > 0.5: level = -level
return (level, replace_value)
# 返回 level_to_args 这个函数
return level_to_args
val和test的annotation也是list,list中每个元素都是dict,包含两个键值,一个image,一个caption,其中caption是list,如下:
{"image": "val2014/COCO_val2014_000000184613.jpg",
"caption": ["A child holding a flowered umbrella and petting a yak.",
"A young man holding an umbrella next to a herd of cattle.",
"a young boy barefoot holding an umbrella touching the horn of a cow",
"A young boy with an umbrella who is touching the horn of a cow.",
"A boy holding an umbrella while standing next to livestock."]}
class coco_karpathy_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''省略部分代码'''
self.text = []
# 保存每一张图片的路径的list
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
# ann就是一个dict,包含"image"和 "caption",img_id 就是索引 index
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.img2txt[img_id] = []
# 一个图片对应多个caption
for i, caption in enumerate(ann['caption']):
# 对caption做预处理之后,把新的caption 放入text数组中
self.text.append(pre_caption(caption,max_words))
# txt_id是每一张图片对应的多个caption的index,这些txt_id放在一个list中:
# {0 : [0, 1, 2,3,4]}
self.img2txt[img_id].append(txt_id)
# {0:0} {1:0} {2:0} {3:0} 表示txt_id到img_id的映射,
# 多个text可以映射到同一张图片
self.txt2img[txt_id] = img_id
txt_id += 1
'''__len__和 __getitem__的代码省略,和训练集的类一样'''
test和val数据集的transform相对于train的简单很多:
transform_test = transforms.Compose([
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
normalize,
])
调用这两个实例就能得到三个数据集:
elif dataset=='retrieval_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
以上,完成了自定义数据集,接下来则需要做数据集的loader,也就是可迭代的数据加载器
torch.utils.data.DataLoader
是PyTorch中用于数据加载的类。它提供了一种方便的方式来迭代和批量处理数据。
DataLoader的主要作用是将自定义的数据集包装成一个可迭代的数据加载器,以便于在训练或测试过程中以批量的方式加载和处理数据。
使用DataLoader可以实现以下功能:
使用DataLoader需要指定以下参数:
除了上述的参数,还有:
pin_memory
:通常情况下,在使用GPU进行训练时,如果主机内存足够,建议将pin_memory设置为True,以提高数据加载到GPU的速度。但如果遇到内存不足的情况,可以将pin_memory设置为False,以节省内存资源。
sampler
:用于指定数据加载的顺序和采样方式。sampler参数可以接受以下几种类型的取值:
(1)SequentialSampler:顺序采样器,按照数据集的顺序依次采样数据,不进行随机打乱。
(2)RandomSampler:随机采样器,在每个时期(epoch)中随机打乱数据,并按照打乱后的顺序进行采样。
(3)SubsetRandomSampler:子集随机采样器,从给定的索引列表中随机采样数据,适用于对数据集的子集进行采样。
(4)WeightedRandomSampler:加权随机采样器,根据给定的样本权重进行采样,用于处理类别不平衡的数据集。
(5)自定义采样器:用户可以自定义采样器类,继承自Sampler,实现自己的数据采样逻辑。
drop_last
:如果数据集的样本数量无法被批次大小整除,并且drop_last参数设置为True,则最后一个不完整的批次将被丢弃。这通常在训练过程中用于确保每个批次的大小保持一致,以提高训练的效率。
ps:本论文采用的是Pytorch提供的DistributedSampler作为分布式训练的采样器,而如果不是分布式训练,则把sampler设置成了None
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
else:
samplers = [None, None, None]
调用create_loader
函数:
train_loader, val_loader, test_loader =
create_loader([train_dataset, val_dataset, test_dataset],samplers,
batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
# 工作线程数的列表
num_workers=[4,4,4],
is_trains=[True, False, False],
#数据集的collate函数列表,用于对每个批次的样本进行处理和组合。如果不需要特定的处理逻辑,可以设置为None
collate_fns=[None,None,None])
create_loader
函数:
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
'''也许是我看的代码比较少的原因,看到这样做loader真的感觉很高效,代码简洁、清晰、好看,
使用zip就可以依次把三个数据集的loader做好,灵活使用if来判断,可以共用代码,并且传入的参数
也很特别,不是单独一个,而是包含3个元素的list,这样正好对应三个数据集'''
loaders = [] # 用来保存三个数据集的loader
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
if is_train:
# 如果sampler 是 None,也就是非分布式训练,则随机打乱
# 否在,在分布式训练下,不需随机打乱
shuffle = (sampler is None)
# 训练集会把 最后一个不完整的批次丢掉
drop_last = True
else:
# 在val 和 test 数据集,既不随机打乱数据,也不会丢弃最后一个不完整的批次
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
# 把做好的loader加到list中
loaders.append(loader)
return loaders
create_sampler()
函数用于创建分布式训练中的采样器(sampler):
# num_tasks:总任务数,即分布式训练中的进程数
# global_rank:当前进程的全局排名
def create_sampler(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset,shuffle in zip(datasets,shuffles):
# 遍历datasets和shuffles列表,对每个数据集创建一个分布式采样器,
# 并将其添加到samplers列表中
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
samplers.append(sampler)
return samplers
分布式采样器使用torch.utils.data.DistributedSampler类进行创建,需要指定数据集、总任务数、当前进程的全局排名和是否进行洗牌
如果是要进行分布式训练,则需要获得总进程数以及进程排名,最后调用create_sampler函数:
if args.distributed:
# 获得分布式训练环境中的总进程数
num_tasks = utils.get_world_size()
# 获取当前进程在分布式训练环境中的排名
# 这样可以了解当前进程在整个分布式训练中的位置和角色,以便进行相应的操作和通信。
global_rank = utils.get_rank()
# 对训练集做sampler,验证集和测试集不需要
samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
检查环境以及获得进程数:
def is_dist_avail_and_initialized():
# 检查当前环境是否支持分布式训练
if not dist.is_available():
return False
# 检查是否已经初始化了分布式训练环境
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
# 如果环境不支持或者未初始化,则默认进程数为1
# 表示当前环境中只有一个进程
return 1
# 获取分布式训练环境中的总进程数,并返回该值
return dist.get_world_size()