在完成了caffe的配置后,以及安装完依赖库cython, opencv, pyyaml, easydict
这里首先记录一下easydict的错误(再次强调,这个库一定要装低版本!),在通过pip install easydict后可解决
File "./tools/generate_tsv.py", line 221, in assert cfg.TEST.HAS_RPN
Assertion error: cfg.TEST.HAS_RPN == False
下面进入正题,genome features的提取
with open('./data/visualgenome/image_data.json') as f:
for item in json.load(f):
image_id = int(item['image_id'])
# filepath = os.path.join('./data/VGdata/', item['url'].split('rak248/')[-1])
filepath = os.path.join('./data/VGdata/', str(image_id)+'.jpg') # 这里可直接用作者的那句话
# print(filepath, os.path.exits(filepath))
split.append((filepath,image_id))
python ./tools/generate_tsv.py --gpu 0 --cfg experiments/cfgs/faster_rcnn_end2end_resnet.yml --def ./models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt --out /home/share/bierone/genome_resnet101_faster_rcnn_genome.tsv --net data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel --split genome
验证:提取完tsv文件后,难免需要进行检查,验证box的位置是否合理,这里附上本人的代码show.py(难免有疏漏之处,希望大家不吝指出):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# set display defaults
# plt.rcParams['figure.figsize'] = (10, 10) # large images
# plt.rcParams['image.interpolation'] = 'nearest' # don't interpolate: show square pixels
# plt.rcParams['image.cmap'] = 'gray' # use grayscale output rather than a (potentially misleading) color heatmap
import numpy as np
import cv2, base64
import csv, sys
csv.field_size_limit(sys.maxsize)
FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features']
infile = '/home/share/lyb/genome_resnet101_faster_rcnn_genome.tsv'
data_root = '/home/lyb/bottom-up-attention/data/VGdata/'
def get_detections_from_tsv(nums=5):
in_data = {}
with open(infile, "r") as tsv_in_file:
reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES)
for i, item in enumerate(reader):
item['image_id'] = (item['image_id'])
item['image_h'] = int(item['image_h'])
item['image_w'] = int(item['image_w'])
item['num_boxes'] = int(item['num_boxes'])
for field in ['boxes', 'features']:
item[field] = np.frombuffer(base64.decodestring(item[field].encode('utf8')),
dtype=np.float32).reshape((item['num_boxes'],-1))
# show_features(item['boxes'])
in_data[i] = item
if i > nums:
break
return in_data
def show_features(ax, boxes, objects='aa', attrs='bb'):
for i in range(boxes.shape[0]):
bbox = boxes[i]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=2, alpha=0.8)
)
# plt.gca().text(bbox[0], bbox[1] - 2,
# '%s' % (cls),
# bbox=dict(facecolor='blue', alpha=0.5),
# fontsize=12, color='white')
plt.axis('off')
# plt.tight_layout()
plt.draw()
if __name__ == '__main__':
in_data = get_detections_from_tsv()
for key,item in in_data.items():
# print(item)
im_file = data_root + item['image_id'] + '.jpg'
im = cv2.imread(im_file)
# im = im[:, :, (2, 1, 0)] # RGB reverse channels
fig, ax = plt.subplots(figsize=(20, 20))
ax.imshow(im)
show_features(ax,item['boxes'])
# rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
plt.savefig('demo/'+item['image_id']+'.jpg')
总结:这里作者的代码写的比较复杂,我只针对部分做了仔细查看,就不附上分析了。其实整个提取过程并不复杂,时间花费主要在配置环境上。