Waymo数据集分为两部分:motion 和 perception,其中motion数据集的主要用途是Sim Agents,Motion Prediction,Interaction Prediction,和 Occupancy and Flow Prediction。官方主页:Waymo Motion Dataset
Waymo_motion 原始数据集保存在谷歌云盘中waymo_open_dataset_motion_v_1_1_0 ,需要自行下载。
其目录结构如下:
其中 scenario 文件夹下的数据与 tf_example 文件夹下数据基本相同(数据保存格式不同,tf_example 文件夹下为 tensor 格式),occupancy_flow_challenge文件夹下数据没有使用过,后面使用到了再做补充。
scenario 下普通数据集每个场景共9秒91帧(历史10帧,当前1帧,未来80帧),在轨迹预测或其他任务中一般使用1s历史数据预测8s未来数据。这9s的数据是由下面20s的数据切分得到的,测试集只包含1s的历史数据(防作弊),后续1.2.0版本的数据集中加入了1s的历史雷达点云数据。
training_20s 数据集共包含70506个场景,每个场景20秒199帧数据,通常根据实验用途自己进行切分。
# 新建 conda 环境
conda create -n waymo python=3.8
conda activate waymo
# 安装 tensorflow (需要注意tensorflow与cuda版本对应,我这里使用cuda 11.0版本)
pip install tensorflow==2.4.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
#安装 pytorch (pytorch-lightning 是否安装自行选择)
pip install pytorch-lightning==1.4.0
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
# 测试gpu是否可用(返回 gpu 信息和 True 则表示安装正确)
python
>>> import tensorflow as tf
>>> tf.config.list_physical_devices('GPU')
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
>>> import torch
>>> torch.cuda.is_available()
True
>>> quit()
# 安装 waymo-open-dataset 需要与tensorflow版本对应
pip install waymo-open-dataset-tf-2-4-0
# 其他库缺啥安装啥
import os, glob, pickle
import tensorflow as tf
from waymo_open_dataset.protos import scenario_pb2
raw_data_path = 'training_20s/'
process_data_path = 'training_20s_process/'
raw_data = glob.glob(os.path.join(raw_data_path, '*.tfrecord*'))
raw_data.sort()
for data_file in raw_data:
dataset = tf.data.TFRecordDataset(data_file, compression_type='')
for cnt, data in enumerate(dataset):
info = {}
scenario = scenario_pb2.Scenario()
scenario.ParseFromString(bytearray(data.numpy()))
# 一下几行代码可以将每个scenario的所有信息保存到txt文本文件中查看
# with open('./one_scenario.txt', 'w+') as f:
# f.write(str(scenario))
# f.close()
# print("--------------------- save successfully! ---------------------")
# quit()
print(type(scenario))
info['scenario_id'] = scenario.scenario_id
info['timestamps_seconds'] = list(scenario.timestamps_seconds) # list of int of shape (91)
info['current_time_index'] = scenario.current_time_index # int, 10
info['sdc_track_index'] = scenario.sdc_track_index # int
info['objects_of_interest'] = list(scenario.objects_of_interest) # list, could be empty list
info['tracks_to_predict'] = {
'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict],
'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict]
} # for training: suggestion of objects to train on, for val/test: need to be predicted
info['tracks'] = list(scenario.tracks)
info['dynamic_map_states'] = list(scenario.dynamic_map_states)
output_file = os.path.join(process_data_path, f'sample_{scenario.scenario_id}.pkl')
with open(output_file, 'wb') as f:
pickle.dump(info, f)
每个场景包含哪些内容?我们这里将每一份数据(一个完整的20s数据)姑且称为一个 scenario,官方好像称为 segment 。
scenario.scenario_id: 场景唯一id
scenario.timestamps_seconds: 从0开始的每个场景的时间步
scenario.current_time_index: 当前帧的时间索引(前面为历史,后面为预测)
scenario.tracks: 包含每个对象的轨迹
- id:每个对象的唯一id
- object_type:每个对象的类型
- states:当前 id 对应对象在每一帧时的状态信息
- center_x: 几何中心x坐标
- center_y: 几何中心y坐标
- center_z: 几何中心z坐标
- length: 几何长度
- width: 几何宽度
- height: 几何高度
- heading: 车头角度
- velocity_x: 横向速度
- velocity_y: 纵向速度
- valid: 该帧是否可以观测到对象
scenario.dynamic_map_states: 跨时间步长的交通信息状态
- lane_states: 包含给定时间步长的交通信号状态集及其控制的车道 id
scenario.map_features:地图数据
- lane centers: 车道中心线
- lane boundaries: 车道边界
- road boundaries: 道路边界
- crosswalks: 人行横道
- speed bumps: 减速带位置
- stop signs: 停车点位置
scenario.sdc_track_index: 自动驾驶车辆(主车)在场景中的索引
scenario.objects_of_interest: 可能对研究训练有用的行为的对象
scenario.tracks_to_predict: 指示必须预测哪些对象,仅在训练和验证集中提供
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from waymo_open_dataset.protos import scenario_pb2
plt.figure(figsize=(30,30))
plt.rcParams['axes.facecolor']='grey'
# 选择交互车辆
data_path = '/data1/zyp/drl/PiH/MTR_PiH/data/training_20s/validation'
file_list = os.listdir(data_path)
for cnt_file, file in enumerate(file_list):
file_path = os.path.join(data_path, file)
dataset = tf.data.TFRecordDataset(file_path, compression_type='')
for cnt_scenar, data in enumerate(dataset):
scenario = scenario_pb2.Scenario()
scenario.ParseFromString(bytearray(data.numpy()))
print(scenario.scenario_id)
# 画地图线
for i in range(len(scenario.map_features)):
# 车道线
if str(scenario.map_features[i].lane) != '':
line_x = [z.x for z in scenario.map_features[i].lane.polyline]
line_y = [z.y for z in scenario.map_features[i].lane.polyline]
plt.scatter(line_x, line_y, c='g', s=5)
# plt.text(line_x[0], line_y[0], str(scenario.map_features[i].id), fontdict={'family': 'serif', 'size': 20, 'color': 'black'})
# 边界线
if str(scenario.map_features[i].road_edge) != '':
road_edge_x = [polyline.x for polyline in scenario.map_features[i].road_edge.polyline]
road_edge_y = [polyline.y for polyline in scenario.map_features[i].road_edge.polyline]
plt.scatter(road_edge_x, road_edge_y)
# plt.text(road_edge_x[0], road_edge_y[0], scenario.map_features[i].road_edge.type, fontdict={'family': 'serif', 'size': 20, 'color': 'black'})
if scenario.map_features[i].road_edge.type == 2:
plt.scatter(road_edge_x, road_edge_y, c='k')
elif scenario.map_features[i].road_edge.type == 3:
plt.scatter(road_edge_x, road_edge_y, c='purple')
print(scenario.map_features[i].road_edge)
else:
plt.scatter(road_edge_x, road_edge_y, c='k')
# 道路边界线
if str(scenario.map_features[i].road_line) != '':
road_line_x = [j.x for j in scenario.map_features[i].road_line.polyline]
road_line_y = [j.y for j in scenario.map_features[i].road_line.polyline]
if scenario.map_features[i].road_line.type == 7: # 双实黄线
plt.plot(road_line_x, road_line_y, c='y')
elif scenario.map_features[i].road_line.type == 8: # 双虚实黄线
plt.plot(road_line_x, road_line_y, c='y')
elif scenario.map_features[i].road_line.type == 6: # 单实黄线
plt.plot(road_line_x, road_line_y, c='y')
elif scenario.map_features[i].road_line.type == 1: # 单虚白线
for i in range(int(len(road_line_x)/7)):
plt.plot(road_line_x[i*7:5+i*7], road_line_y[i*7:5+i*7], color='w')
elif scenario.map_features[i].road_line.type == 2: # 单实白线
plt.plot(road_line_x, road_line_y, c='w')
else:
plt.plot(road_line_x, road_line_y, c='w')
# 画车及轨迹
for i in range(len(scenario.tracks)):
if i==scenario.sdc_track_index:
traj_x = [center.center_x for center in scenario.tracks[i].states if center.center_x != 0.0]
traj_y = [center.center_y for center in scenario.tracks[i].states if center.center_y != 0.0]
head = [center.heading for center in scenario.tracks[i].states if center.center_y != 0.0]
plt.scatter(traj_x[0], traj_y[0], s=140, c='r', marker='s')
# plt.imshow(img1,extent=[traj_x[0]-3, traj_x[0]+3,traj_y[0]-1.5, traj_y[0]+1.5])
plt.scatter(traj_x, traj_y, s=14, c='r')
else:
traj_x = [center.center_x for center in scenario.tracks[i].states if center.center_x != 0.0]
traj_y = [center.center_y for center in scenario.tracks[i].states if center.center_y != 0.0]
head = [center.heading for center in scenario.tracks[i].states if center.center_y != 0.0]
plt.scatter(traj_x[0], traj_y[0], s=140, c='k', marker='s')
# plt.imshow(img1,extent=[traj_x[0]-3, traj_x[0]+3,traj_y[0]-1.5, traj_y[0]+1.5])
plt.scatter(traj_x, traj_y, s=14, c='b')
break
break