JHMDB是对HMDB数据集的二次标注,即 joint-annotated HMDB。
类别列表(21类)
每个视频最多只有一类目标行为,bbox只标了做目标行为的那几个人
JHMDB-GT.pkl
中,这个文件是一个字段,包括了6个key
labels
(list): List of the 21 labels,21类行为标签,即上一小节中提到的。gttubes
(dict): Dictionary that contains the ground truth tubes for each video.
walk/Panic_in_the_Streets_walk_u_cm_np1_ba_med_5
nframes
rows and 5 columns, each col is in format like
.nframes
(dict): Dictionary that contains the number of frames for each video, like 'walk/Panic_in_the_Streets_walk_u_cm_np1_ba_med_5': 16
.
train_videos
(list): A list with nsplits=1
elements, each one containing the list of training videos.
test_videos
(list): A list with nsplits=1
elements, each one containing the list of testing videos.
resolution
(dict): Dictionary that outputs a tuple (h,w) of the resolution for each video, like 'pour/Bartender_School_Students_Practice_pour_u_cm_np1_fr_med_1': (240, 320)
.
UCF101_24是UCF101数据集的子集,使用了一些不一样的标签。
类别信息(共24类)
数据下载以及其他相关可以参考 mmaction2数据准备文档,数据下载在上面的文档中有介绍了,反正就是一个压缩包,没啥好说的。
标签与1.2.数据准备以及标签详解
完全相同,这里就不多说了。
import argparse
import os
import cv2
import pickle
from collections import defaultdict
def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="ucf101_24")
parser.add_argument("--dataset_root_path", type=str,
default="/ssd01/data/ucf101_24")
parser.add_argument("--rgb-dir-name", type=str, default="rgb-images")
# JHMDB & UCF101_24
parser.add_argument("--pkl-filename", type=str, default="UCF101v2-GT.pkl")
parser.add_argument("--img-impl", type=str, default="%05d.jpg")
return parser.parse_args()
def _darknet_draw_bbox(bboxes,
labels,
scores,
img,
bboxes_color=(0, 255, 0),
bboxes_thickness=1,
text_color=(0, 255, 0),
text_thickness=2,
text_front_scale=0.5):
"""
bbox的形式是 xyxy,取值范围是像素的值
labels是标签名称
scores是置信度,[0, 1]的浮点数
"""
for idx, (bbox, label) in enumerate(zip(bboxes, labels)):
xmin, ymin, xmax, ymax = bbox
pt1 = (int(xmin), int(ymin)) # 左下
pt2 = (int(xmax), int(ymax)) # 右上
# 画bbox
cv2.rectangle(img, pt1, pt2, bboxes_color, bboxes_thickness)
# 写上对应的文字
cur_label = label
if scores is not None:
cur_label += " [" + str(round(scores[idx] * 100, 2)) + "]"
cv2.putText(
img=img,
text=cur_label,
org=(pt1[0], pt1[1] - 5),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=.5,
color=(0, 255, 0),
thickness=2,
)
return img
def _show_single_video(rgb_dir, tubes, nframes, label, img_impl):
draw_imgs = []
for i in range(nframes):
img = cv2.imread(os.path.join(rgb_dir, img_impl % (i+1)))
boxes = tubes[i+1]
draw_img = _darknet_draw_bbox(
boxes, [label]*len(boxes), None, img)
draw_imgs.append(draw_img)
cv2.imshow("demo", draw_img)
cv2.waitKey(100)
return draw_imgs
def _filter_samples(data_dict, args):
# TODO: filter input data with categories
return data_dict
def _show_spatiotemporal_datasets(args):
pkl_path = os.path.join(args.dataset_root_path, args.pkl_filename)
rgb_dir_path = os.path.join(args.dataset_root_path, args.rgb_dir_name)
with open(pkl_path, "rb") as fid:
cache = pickle.load(fid, encoding='bytes')
labels = [c.decode() for c in cache[b'labels']]
train_videos = [c.decode() for c in cache[b'train_videos'][0]]
test_videos = [c.decode() for c in cache[b'test_videos'][0]]
nframes = {k.decode(): v for k, v in cache[b'nframes'].items()}
resolution = {k.decode(): v for k, v in cache[b'resolution'].items()}
gttubes = {k.decode(): v for k, v in cache[b'gttubes'].items()}
# key - sample relative dir
# value - dict
# label - int
# nframes - int
# reoslution - tuple
# tubes - array
data_dict = defaultdict(dict)
for k in nframes:
data_dict[k]['nframes'] = nframes[k]
for k in resolution:
assert data_dict[k]['nframes'] is not None
data_dict[k]['resolution'] = resolution[k]
for k in gttubes:
assert data_dict[k]['resolution'] is not None
label = list(gttubes[k].keys())[0]
data_dict[k]['label'] = label
data_dict[k]['tubes'] = defaultdict(list)
for boxes in gttubes[k][label]:
for box in boxes:
data_dict[k]['tubes'][int(box[0])].append(box[1:])
for relative_path in data_dict.keys():
sample = data_dict[relative_path]
_show_single_video(
os.path.join(rgb_dir_path, relative_path),
sample['tubes'],
sample['nframes'],
labels[sample['label']],
args.img_impl,
)
def main(args):
if args.dataset in ["ucf101_24", "jhmdb"]:
show_dataset_fn = _show_spatiotemporal_datasets
else:
raise ValueError("Unknown dataset {}".format(args.dataset))
show_dataset_fn(args)
if __name__ == '__main__':
main(_parse_args())