Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction


本博文是对郑宇老师团队所提出的STResNet网络的一个略微扩充说明。本人自己在看完这篇论文的时候,感觉就一个字‘懵’。你说不懂吧,好像又明白点,你说懂吧又感觉有好多细节还是不清楚。好在该论文开放了源代码。经过对源代码的一番剖析,总算是弄懂之前不明白的一些细节。不过该源码是基于Keras实现的,由于本人之前一直使用Tensorflow,所以又对其利用tf进行了重构,代码整体上看起也来更加简洁,地址见文末。

1.背景

这是2017年发表在AAAI上论文,其研究目的是对某个地方下一时刻车辆进出流量进行预测。作者说到,按照这样的思想利用论文中提出的模型同样还可以对某个区域的人流量,外卖订单量,快递收发量进行预测等等。具体可以参见郑宇老师在CFF上做的报告,里面也详细阐述了本论文的核心思想及其可拓展的相关问题。

所谓车流量预测,指的是利用历史数据对某个区域下一时间点的进/出车流量进行预测,也就是论文中所指的In-flow和Out-flow. 同时对于In-flow和Out-flow的统计定义如下所示:

Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction_第1张图片

其中图p0063中公式所表示的含义如下图所示:

Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction_第2张图片

上图为一个 4 × 4 4\times4 4×4的区域,代表的是在某时间片 t t t时两个车辆的运行轨迹,则 x t i n , 3 , 2 = 1 , x t o u t , 3 , 2 = 2 x_t^{in,3,2}=1,x_t^{out,3,2}=2 xtin,3,2=1,xtout,3,2=2。其统计规则如下:
t t t时间时,车辆A的移动轨迹(蓝色)历经了4个区域( g 1 , g 2 , g 3 , g 4 g_1,g_2,g_3,g_4 g1,g2,g3,g4);车辆B的移动轨迹历经了3个区域。
计算in-flow:对于A来说( i = 2 , 3 , 4 i=2,3,4 i=2,3,4)此时有: g 2 ∉ ( 3 , 2 ) , g 3 ∈ ( 3 , 2 ) g_2\notin(3,2),g_3\in(3,2) g2/(3,2),g3(3,2);对于B来说( i = 2 , 3 i=2,3 i=2,3),但此时不满足公式,所以 x t i n , 3 , 2 = 1 x_t^{in,3,2}=1 xtin,3,2=1
计算out-flow:对于A来说( i = 2 , 3 , 4 i=2,3,4 i=2,3,4)此时有: g 3 ∈ ( 3 , 2 ) , g 4 ∉ ( 3 , 2 ) g_3\in(3,2),g_4\notin(3,2) g3(3,2),g4/(3,2);对于B来说( i = 2 , 3 i=2,3 i=2,3)此时有: g 1 ∈ ( 3 , 2 ) , g 2 ∉ ( 3 , 2 ) g_1\in(3,2),g_2\notin(3,2) g1(3,2),g2/(3,2),所以 x t i n , 3 , 2 = 2 x_t^{in,3,2}=2 xtin,3,2=2

不过这都不需要你来统计,论文中所提供的数据集都已统计好了,明白这个意思就好。

2.论文介绍

2.1 数据预处理

在理解模型前我们首先来看看喂给网络的数据都长什么样,这样有助于理解。

Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction_第3张图片

如图p0067所示,最原始的数据集已经将整个北京市划分成了一个 32 × 32 32\times32 32×32的小区域,并且也已经统计出了每个小区域每隔半小时(一个时间片)的进出流量,即已经表示成了 [ 2 , 32 , 32 ] [2,32,32] [2,32,32]的格式。同时论文在实现时候,采用的是用当前时刻的前3个时间片来模拟邻近性(Closeness),用当前时刻前一天的相同时刻的一个时间片来模拟周期性(Period),用当前时刻前一周的相同时刻的一个时间片来模拟趋势性(Trend),即代码中的len_closeness=3,len_period=1,len_trend=1作为三个超参数。也就是用这三个部分来预测 t i t_i ti时刻的流量。

Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction_第4张图片

同时除了车流量数据之外,论文中还引入了其它额外的气象等数据,分别是:time_feature,holiday_feature,meteorol_feature最终将这三个部分拼成一个向量meta_feature.

Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction_第5张图片

对于每个时间片来说:
time_feature有8维度,前面7个维度为one-hot形式,最后以为表示当天是否为工作日;例如图p0069中的含义为,该时间片对应为星期四且为工作日。
holiday_feature有1个维度,0表示时间片所在的当天为工作日,1表示假期。
meteorol_feature有19个维度,前面17个也为one-hot形式,表示天气类型中的一种,后面两个维度分别表示风速和温度

最后将这个三个向量拼接成了一个28维度的向量。也就是说,现在我们已经知道了整个网络输入数据的形式了。对于数据预处理的这部分,直接调用下面函数即可获取:

    X_train, Y_train, X_test, Y_test, mmn, external_dim, timestamp_train, timestamp_test = \
        load_data(len_closeness=3, len_period=1, len_trend=1, len_test=4*7* 48)

2.2 网络构建

首先定义了网络的输入部分,笔者将其分成了5个placeholder,其含义如变量名;然后接着就是定义网络的部分,即Closeness,Period,Trend这三个部分和天气模块;最后就是评估和训练模块。分别在下面这几个方法中被定义:

def _build_placeholder(self):
def _build_stresnet(self, ):
def evaluate(self, mmn, x, y):
def train(self, x, y):

其它的部分在代码中都有详细的注释。

3.论文结果

笔者在用tersorflow重构的时候考虑到BN对实验结果影响不大就没有加上。但是即便如此,其它配置下效果依旧不如论文中的结果。笔记将参数初始化的方式等多个细节都同论文作者的代码进行过对比,都保持了一致,最终其它几个参数配置的结果如下:

Config epoches RMSE
L2-E 6082 22.32

后面的结果会不定期放在github上: https://github.com/TolicWang/DeepST

更多内容欢迎扫描关注公众号月来客栈!
在这里插入图片描述

你可能感兴趣的:(深度学习相关)