论文笔记--Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting

Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting (AAAI 19)

Summary

作者提出ASTGCN的主要由三个独立组件组成,分别对交通流的三种节奏特性(近期依赖、日周期依赖和周周期依赖)进行建模。每个组件包含两个主要部分:1)有效捕获交通数据中动态时空相关性的时空注意机制;2)时空卷积,即同时使用图卷积来捕获空间模式和通用标准卷积来捕获时间特征。三个组件的结果相融合得到最终预测结果。

Problem Definition

交通预测问题最大挑战还是如何有效提取数据的时空相关性。如下图

线条颜色越深,影响越大。从图(a)表示的是空间依赖的关系,不同的地点对A的影响是不同的,即使是同一个地点随着时间的推移对A的影响也是不同的。在时间维度下图(b),不同位置的历史观测结果对A未来不同时段的交通状态有不同的影响。综上所述,公路网交通数据相关性在空间维度和时间维度上均表现出较强的动态性。

问题定义

将交通网络定义为一个无向图表示为 G = ( V , E , A ) G=(V,E,A) G=(V,E,A),V表示节点列表,E是边集, A ∈ R N × N A\in \mathbb{R}^{N\times N} ARN×N是邻接矩阵。定义 X = ( X 1 , X 2 , … , X τ ) T ∈ R N × F × τ \mathcal{X}=\left(\mathbf{X}_{1}, \mathbf{X}_{2}, \ldots, \mathbf{X}_{\tau}\right)^{T} \in \mathbb{R}^{N \times F \times \tau} X=(X1,X2,,Xτ)TRN×F×τ表示τ时间内所有节点的所有特征值。利用交通网络中所有节点在过去 τ \tau τ时间内的各种历史测度,预测未来交通流序列 ( y 1 , y 2 , . . . , y N ) t ∈ R N × T p (y^1,y^2,...,y^N)^t \in \mathbb{R}^{N\times T_p} (y1,y2,...,yN)tRN×Tp

Method

ASTGCN算法框架

输入

X h = ( X t 0 − T h + 1 , X t 0 − T h + 2 , … , X t 0 ) ∈ R N × F × T h \mathcal{X}_{h}=\left(\mathbf{X}_{t_{0}-T_{h}+1}, \mathbf{X}_{t_{0}-T_{h}+2}, \ldots, \mathbf{X}_{t_{0}}\right) \in \mathbb{R}^{N \times F \times T_{h}} Xh=(Xt0Th+1,Xt0Th+2,,Xt0)RN×F×Th表示的是最近时间段交通信息,长度为Th。从直观上看,交通拥挤的形成和扩散是渐进的。因此,过去的交通流必然会对未来的交通流产生影响。


X d = ( X t 0 − ( T d / T p ) ∗ q + 1 , … , X t 0 − ( T d / T p ) ∗ q + T p X t 0 − ( T d / T p − 1 ) ∗ q + 1 , … , X t 0 − ( T d / T p − 1 ) ∗ q + T p , ⋯ X t 0 − q + 1 , … , X t 0 − q + T p ) ∈ R N × F × T d \begin{aligned} &\mathcal{X}_{d}=\left(\mathbf{X}_{t_{0}-\left(T_{d} / T_{p}\right) * q+1}, \ldots, \mathbf{X}_{t_{0}-\left(T_{d} / T_{p}\right) * q+T_{p}}\right. \\ &\mathbf{X}_{t_{0}-\left(T_{d} / T_{p}-1\right) * q+1}, \ldots, \mathbf{X}_{t_{0}-\left(T_{d} / T_{p}-1\right) * q+T_{p}}, \cdots \\ &\left.\mathbf{X}_{t_{0}-q+1}, \ldots, \mathbf{X}_{t_{0}-q+T_{p}}\right) \in \mathbb{R}^{N \times F \times T_{d}} \end{aligned} Xd=(Xt0(Td/Tp)q+1,,Xt0(Td/Tp)q+TpXt0(Td/Tp1)q+1,,Xt0(Td/Tp1)q+Tp,Xt0q+1,,Xt0q+Tp)RN×F×Td
表示日周期时间段交通信息,长度为Td。公式中q表示一天采集的时间步长度。由于人的日常规律,交通数据可能会呈现重复的模式,例如每天的早晨高峰。日周期数据的目的是对交通数据的日周期性进行建模。


X w = ( X t 0 − 7 ∗ ( T w / T p ) ∗ q + 1 , … , X t 0 − 7 ∗ ( T w / T p ) ∗ q + T p X t 0 − 7 ∗ ( T w / T p − 1 ) ∗ q + 1 , … , X t 0 − 7 ∗ ( T w / T p − 1 ) ∗ q + T p , … X t 0 − 7 ∗ q + 1 , … , X t 0 − 7 ∗ q + T p ) ∈ R F × N × T w \begin{aligned} &\mathcal{X}_{w}=\left(\mathbf{X}_{t_{0}-7 *\left(T_{w} / T_{p}\right) * q+1}, \ldots, \mathbf{X}_{t_{0}-7 *\left(T_{w} / T_{p}\right) * q+T_{p}}\right. \\ &\mathbf{X}_{t_{0}-7 *\left(T_{w} / T_{p}-1\right) * q+1}, \ldots, \mathbf{X}_{t_{0}-7 *\left(T_{w} / T_{p}-1\right) * q+T_{p}}, \ldots \\ &\left.\mathbf{X}_{t_{0}-7 * q+1}, \ldots, \mathbf{X}_{t_{0}-7 * q+T_{p}}\right) \in \mathbb{R}^{F \times N \times T_{w}} \end{aligned} Xw=(Xt07(Tw/Tp)q+1,,Xt07(Tw/Tp)q+TpXt07(Tw/Tp1)q+1,,Xt07(Tw/Tp1)q+Tp,Xt07q+1,,Xt07q+Tp)RF×N×Tw
表示周周期时间段交通信息,长度为Tw,其中7表示一周7天。通常情况下,周一的交通模式与历史上周一的交通模式有一定的相似性,但可能与周末的交通模式有很大的不同。所以每周周期数据被设计用来捕获流量数据中的每周周期特征。

输入数据可视化表示如下

时空注意力模块

①空间注意力

首先通过 X h ( r − 1 ) = ( X 1 , X 2 , … X T r − 1 ) ∈ R N × C r − 1 × T r − 1 \boldsymbol{X}_{h}^{(r-1)}=\left(\mathbf{X}_{1}, \mathbf{X}_{2}, \ldots \mathbf{X}_{T_{r-1}}\right) \in \mathbb{R}^{N \times C_{r-1} \times T_{r-1}} Xh(r1)=(X1,X2,XTr1)RN×Cr1×Tr1计算出注意力矩阵S, S i , j S_{i,j} Si,j表示节点i与节点j的相关强度。然后通过softmax使节点注意权值之和为1。
S = V s ⋅ σ ( ( X h ( r − 1 ) W 1 ) W 2 ( W 3 X h ( r − 1 ) ) T + b s ) S i , j ′ = exp ⁡ ( S i , j ) ∑ j = 1 N exp ⁡ ( S i , j ) \begin{gathered} \mathbf{S}=\mathbf{V}_{s} \cdot \sigma\left(\left(\boldsymbol{X}_{h}^{(r-1)} \mathbf{W}_{1}\right) \mathbf{W}_{2}\left(\mathbf{W}_{3} \mathcal{X}_{h}^{(r-1)}\right)^{T}+\mathbf{b}_{s}\right) \\ \mathbf{S}_{i, j}^{\prime}=\frac{\exp \left(\mathbf{S}_{i, j}\right)}{\sum_{j=1}^{N} \exp \left(\mathbf{S}_{i, j}\right)} \end{gathered} S=Vsσ((Xh(r1)W1)W2(W3Xh(r1))T+bs)Si,j=j=1Nexp(Si,j)exp(Si,j)
其中 V s , b s ∈ R N × N , W 1 ∈ R T r − 1 , W 2 ∈ R C r − 1 × T r − 1 , W 3 ∈ R C r − 1 \mathbf{V}_{s}, \mathbf{b}_{s} \in \mathbb{R}^{N \times N}, \mathbf{W}_{1} \in \mathbb{R}^{ {T}_{r-1}}, \mathbf{W}_{2} \in \mathbb{R}^{C_{r-1} \times T_{r-1}}, \mathbf{W}_{3} \in \mathbb{R}^{C_{r-1}} Vs,bsRN×N,W1RTr1,W2RCr1×Tr1,W3RCr1使可学习参数。

然后注意力矩阵S’在图卷积部分将与邻接矩阵A共同调节节点间的影响权重。

②时间注意力

计算时间维度的注意力系数

E = V e ⋅ σ ( ( ( X h ( r − 1 ) ) T U 1 ) U 2 ( U 3 X h ( r − 1 ) ) + b e ) E i , j ′ = exp ⁡ ( E i , j ) ∑ j = 1 T r − 1 exp ⁡ ( E i , j ) \begin{gathered} \mathbf{E}=\mathbf{V}_{e} \cdot \sigma\left(\left(\left(\mathcal{X}_{h}^{(r-1)}\right)^{T} \mathbf{U}_{1}\right) \mathbf{U}_{2}\left(\mathbf{U}_{3} \mathcal{X}_{h}^{(r-1)}\right)+\mathbf{b}_{e}\right) \\ \mathbf{E}_{i, j}^{\prime}=\frac{\exp \left(\mathbf{E}_{i, j}\right)}{\sum_{j=1}^{T_{r-1}} \exp \left(\mathbf{E}_{i, j}\right)} \end{gathered} E=Veσ(((Xh(r1))TU1)U2(U3Xh(r1))+be)Ei,j=j=1Tr1exp(Ei,j)exp(Ei,j)
其中$\mathbf{V}{e{2}} \mathbf{b}{e} \in \mathbb{R}^{T{r-1} \times T_{r-1}}, \mathbf{U}{1} \in \mathbb{R}^{N}, \mathbf{U}{2} \in \mathbb{R}^{C_{r-1} \times N}, \mathbf{U}{3} \in \mathbb{R}^{C{r-1}}
$是可学习参数。

对于时间注意力块,作者直接将归一化的时间注意矩阵应用于输入,计算公式如下

X ^ h ( r − 1 ) = ( X ^ 1 , X ^ 2 , … , X ^ T r − 1 ) = ( X 1 , X 2 , … , X T r − 1 ) E ′ \hat{\boldsymbol{X}}_{h}^{(r-1)}=\left(\hat{\mathbf{X}}_{1}, \hat{\mathbf{X}}_{2}, \ldots, \hat{\mathbf{X}}_{T_{r-1}}\right)=\left(\mathbf{X}_{1}, \mathbf{X}_{2}, \ldots, \mathbf{X}_{T_{r-1}}\right) \mathbf{E}^{\prime} X^h(r1)=(X^1,X^2,,X^Tr1)=(X1,X2,,XTr1)E
时空卷积模块

①空间维卷积

采用的是谱域方法(具体是Cheby-conv方法改进得到)。Cheby-conv计算公式如下

为了动态调整节点之间的相关性,对每一项的 T k ( L ~ ) T_k(\tilde{L}) Tk(L~)与空间注意力矩阵 S ′ ∈ R N × N S'\in \mathbb{R}^{N\times N} SRN×N进行哈达玛乘积。

具体公式如下

②时间维卷积

*表示标准卷积,此处应该是1D-conv,将时间点的前后数据也一起融合了一下,得到了整个模块的最终输出。

最后对三种输入得到的三种输出进行融合,公式如下

Experiments

数据使用两个加州高速数据PeMSD4和PeMSD8。

参数设置

T h = 24 , T d = 12 , T w = 24 T_h=24,T_d=12,T_w=24 Th=24,Td=12,Tw=24切比雪夫多项式K={1,2,3},预测时间步长 T p = 12 T_p=12 Tp=12

实验结果如下

MSTGCN是未使用注意力机制的模型。

下图是各种方法在预测区间增大下的影响。

作者挑选了包含10个点的子图,并显示训练集中节点之间的平均空间注意矩阵。如下,最后一行,我们可以知道第9个点的车流与第3个点和第8个点上的车流是密切相关的。他们三个点在空间上也是相互接近的,很合理。

创新点

不仅仅使用相近时间的历史数据来预测,还考虑了同一天的同一时刻,同一周的时刻的影响来辅助预测。还有就是使用注意力直接学习时间空间相关性。

全部评论

相关推荐

点赞 收藏 评论
分享
正在热议
# 牛客帮帮团来啦!有问必答 #
1151203次浏览 17148人参与
# 通信和硬件还有转码的必要吗 #
11194次浏览 101人参与
# 不去互联网可以去金融科技 #
20335次浏览 255人参与
# 和牛牛一起刷题打卡 #
18899次浏览 1635人参与
# 实习与准备秋招该如何平衡 #
203348次浏览 3625人参与
# 大厂无回复,继续等待还是奔赴小厂 #
4970次浏览 30人参与
# OPPO开奖 #
19192次浏览 267人参与
# 通信硬件薪资爆料 #
265877次浏览 2484人参与
# 国企是理工四大天坑的最好选择吗 #
2220次浏览 34人参与
# 互联网公司评价 #
97672次浏览 1280人参与
# 简历无回复,你会继续海投还是优化再投? #
25034次浏览 354人参与
# 0offer是寒冬太冷还是我太菜 #
454821次浏览 5124人参与
# 国企和大厂硬件兄弟怎么选? #
53896次浏览 1012人参与
# 参加过提前批的机械人,你们还参加秋招么 #
14636次浏览 349人参与
# 硬件人的简历怎么写 #
82284次浏览 852人参与
# 面试被问第一学历差时该怎么回答 #
19393次浏览 213人参与
# 你见过最离谱的招聘要求是什么? #
28057次浏览 248人参与
# 学历对求职的影响 #
161224次浏览 1804人参与
# 你收到了团子的OC了吗 #
538675次浏览 6386人参与
# 你已经投递多少份简历了 #
344169次浏览 4963人参与
# 实习生应该准时下班吗 #
96965次浏览 722人参与
# 听劝,我这个简历该怎么改? #
63517次浏览 622人参与
牛客网
牛客企业服务