做项目要用到coco数据集,于是找了一些开源的数据加载程序,惭愧的是我clone以后就忘了是哪来的了
COCO下载地址:https://cocodataset.org/#download
以2014的val为例,40504张图片,每幅图对应5句话,共202654句话
从这里下载COCO API: https://github.com/cocodataset/cocoapi,下载后放在train.py同个路径下,这是from pycocotools.coco import COCO 必要的,这个是对COCO json解析的api。
用于加载数据
import nltk
import os
import torch
import torch.utils.data as data
from vocabulary import Vocabulary
from PIL import Image
from pycocotools.coco import COCO
import numpy as np
from tqdm import tqdm
import random
import json
def get_loader(transform,
mode='train',
batch_size=1,
vocab_threshold=None, #最小词汇数门槛
vocab_file='./vocab.pkl', #词向量文件
start_word="", #句子开始特殊词
end_word="", #句子结束特殊词
unk_word="", #表示未知词的特殊词
vocab_from_file=True, #若为true,加载存在词向量文件,若为false,表示不存在词向量文件,需要创建
num_workers=0, #选了其他数字都不行,到头来还是0没问题,其实我也没搞明白
cocoapi_loc='D:/项目学习/pytorch/image captioning/data'): #数据所在文件夹
#训练模式
if mode == 'train':
if vocab_from_file==True: assert os.path.exists(vocab_file), "vocab_file does not exist. Change vocab_from_file to False to create vocab_file."
img_folder = os.path.join(cocoapi_loc, data_dir+'train2014/') #训练数据文件夹
annotations_file = os.path.join(cocoapi_loc, data_dir+'annotations/captions_train2014.json') #训练数据json
#测试模式
if mode == 'test':
assert batch_size==1, "在测试模式下batch_size必须为1"
assert os.path.exists(vocab_file), "没有vocab.pkl词向量文件,你得先在train模式下生成"
assert vocab_from_file==True, "vocab_from_file必须为True"
img_folder = os.path.join(cocoapi_loc,data_dir+'test2017/') #测试数据文件夹
annotations_file = os.path.join(cocoapi_loc, data_dir+'annotations/captions_test2017.json') #测试数据json
dataset = CoCoDataset(transform=transform,
mode=mode,
batch_size=batch_size,
vocab_threshold=vocab_threshold,
vocab_file=vocab_file,
start_word=start_word,
end_word=end_word,
unk_word=unk_word,
annotations_file=annotations_file,
vocab_from_file=vocab_from_file,
img_folder=img_folder)
if mode == 'train':
#随机选择数据集annotations的长度(注:COCO数据集的annotations长度大多为9~12)
indices = dataset.get_train_indices()
# 根据上述选择的长度,采样该长度annotations所对应的图片
#(注:一幅图片对应五句话,所以可能多次采样到同一幅图,但喂入模型的不一定会是一样的句子)
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
data_loader = data.DataLoader(dataset=dataset,
num_workers=num_workers,
batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
batch_size=dataset.batch_size,
drop_last=False))
else:
data_loader = data.DataLoader(dataset=dataset,
batch_size=dataset.batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader
class CoCoDataset(data.Dataset):
def __init__(self, transform, mode, batch_size, vocab_threshold, vocab_file, start_word,
end_word, unk_word, annotations_file, vocab_from_file, img_folder):
self.transform = transform
self.mode = mode
self.batch_size = batch_size
#字向量的创建或者加载
self.vocab = Vocabulary(vocab_threshold, vocab_file, start_word,
end_word, unk_word, annotations_file, vocab_from_file)
self.img_folder = img_folder
if self.mode == 'train':
self.coco = COCO(annotations_file)
#经过pycocotools.coco的处理后,self.coco.anns为字典形式
self.ids = list(self.coco.anns.keys())
print('选择句子长度')
all_tokens = [nltk.tokenize.word_tokenize(str(self.coco.anns[self.ids[index]]['caption']).lower()) for index in tqdm(np.arange(len(self.ids)))]
self.caption_lengths = [len(token) for token in all_tokens]
else:
test_info = json.loads(open(annotations_file).read())
self.paths = [item['file_name'] for item in test_info['images']]
def __getitem__(self, index):
if self.mode == 'train':
ann_id = self.ids[index]
caption = self.coco.anns[ann_id]['caption']
img_id = self.coco.anns[ann_id]['image_id']
path = self.coco.loadImgs(img_id)[0]['file_name']
image = Image.open(os.path.join(self.img_folder, path)).convert('RGB')
image = self.transform(image)
tokens = nltk.tokenize.word_tokenize(str(caption).lower())
caption = []
caption.append(self.vocab(self.vocab.start_word))
caption.extend([self.vocab(token) for token in tokens])
caption.append(self.vocab(self.vocab.end_word))
caption = torch.Tensor(caption).long()
return image, caption
else:
path = self.paths[index]
PIL_image = Image.open(os.path.join(self.img_folder, path)).convert('RGB')
orig_image = np.array(PIL_image)
image = self.transform(PIL_image)
return orig_image, image
#随机选择一个字幕长度
def get_train_indices(self):
sel_length = np.random.choice(self.caption_lengths)
all_indices = np.where([self.caption_lengths[i] == sel_length for i in np.arange(len(self.caption_lengths))])[0]
indices = list(np.random.choice(all_indices, size=self.batch_size))
return indices
def __len__(self):
if self.mode == 'train':
return len(self.ids)
else:
return len(self.paths)
用于加载或创建词向量
import nltk
import pickle
import os.path
from pycocotools.coco import COCO
from collections import Counter
class Vocabulary(object):def __init__(self,
vocab_threshold,
vocab_file='./vocab.pkl',
start_word="",
end_word="",
unk_word="",
annotations_file='D:/项目学习/pytorch/image captioning/coco/annotations/captions_train2014.json',
vocab_from_file=False):
self.vocab_threshold = vocab_threshold
self.vocab_file = vocab_file
self.start_word = start_word
self.end_word = end_word
self.unk_word = unk_word
self.annotations_file = annotations_file
self.vocab_from_file = vocab_from_file
self.get_vocab()
def get_vocab(self):
#从文件加载词汇表或从头开始构建词汇表
if os.path.exists(self.vocab_file) & self.vocab_from_file:
with open(self.vocab_file, 'rb') as f:
vocab = pickle.load(f)
self.word2idx = vocab.word2idx
self.idx2word = vocab.idx2word
print('词汇已从vocab.pkl中加载')
else:
self.build_vocab()
#保存vocab.pkl
with open(self.vocab_file, 'wb') as f:
pickle.dump(self, f)
def build_vocab(self):
#将tokens转化为整数并填入字典(反之亦然)
self.init_vocab()
self.add_word(self.start_word)
self.add_word(self.end_word)
self.add_word(self.unk_word)
self.add_captions()
def init_vocab(self):
#初始化字典,用于存储tokens转化成的整数
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
#向词汇表添加token
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def add_captions(self):
#循环训练句子并将所有达到或超过阈值的标记添加到词汇表中
coco = COCO(self.annotations_file)
counter = Counter()
ids = coco.anns.keys()
for i, id in enumerate(ids):
caption = str(coco.anns[id]['caption'])
tokens = nltk.tokenize.word_tokenize(caption.lower())
counter.update(tokens)
if i % 100000 == 0:
print("[%d/%d] Tokenizing captions..." % (i, len(ids)))
words = [word for word, cnt in counter.items() if cnt >= self.vocab_threshold]
for i, word in enumerate(words):
self.add_word(word)
def __call__(self, word):
if not word in self.word2idx:
return self.word2idx[self.unk_word]
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
train.py里面的调用
import torch
from torchvision import transforms
transform_train = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
data_loader = get_loader(transform=transform_train,
mode='train',
batch_size=batch_size,
vocab_threshold=vocab_threshold,
vocab_from_file=vocab_from_file)