代码笔记《Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting》

目录

  • 1. train.py
    • 1.1 argparse模块
    • 1.2 ConfigParser模块
    • 1.3 初始化模型参数
    • 1.4 transpose()
  • 2. data_preparation.py
    • 2.1 数据归一化
  • 代码地址⭐,见万老师的个人主页

《Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting》代码

原始论文代码见上方链接,本文仅记录不懂之处

1. train.py

1.1 argparse模块

参考链接
【1】Python解析命令行读取参数–argparse模块使用方法
【2】python3中argparse模块

功能

  • argparse 是python自带的命令行参数解析包,可以用来方便地读取命令行参数。当代码需要频繁修改参数时,使用这个工具可以将代码和参数分离开来,让代码更简洁,适用范围更广。
  • 多数情况下,脚本可能需要很多参数,而且每次参数类型/用途各不相同,此时在参数前添加标签,表明参数的类型/用途便十分有用,该功能可以利用argparse 实现。

步骤

  • 导入模块 import argparse
  • 实例化解析器对象 parser = argparse.ArgumentParser()
  • 添加参数 parser.add_argument()
  • 获取参数集合 args = parser.parse_args()
# 实例化解析器对象,description参数可以用于插入描述脚本用途的信息,可以为空
parser = argparse.ArgumentParser()

# 添加参数
# 添加--config标签
# required参数表示--config参数是必须的,并且类型为str,输入其他的会报错
# help参数描述--config参数的用途或意义
parser.add_argument("--config", type=str, help="configuration file path", required=True)
parser.add_argument("--force", type=str, default=False, help="remove params dir", required=False)

# 将变量以“标签-值”的字典形式存入args字典
args = parser.parse_args()

1.2 ConfigParser模块

参考链接
【1】python -ConfigParser模块讲解
【2】Python 之ConfigParser模块

功能

  • ConfigParser 是用来读取配置文件的包
  • 配置文件格式与windows ini文件类似,包含1/多个节(section),每个节可以有多个参数(键值对)
# read configuration
# 初始化实例
config = configparser.ConfigParser()
print('Read configuration file: %s' % (args.config))
# 读取配置文件
config.read(args.config)
data_config = config['Data']
training_config = config['Training']

# Data相关
adj_filename = data_config['adj_filename']
graph_signal_matrix_filename = data_config['graph_signal_matrix_filename']
num_of_vertices = int(data_config['num_of_vertices'])
points_per_hour = int(data_config['points_per_hour'])
num_for_predict = int(data_config['num_for_predict'])

# Training相关
model_name = training_config['model_name']
ctx = training_config['ctx']
optimizer = training_config['optimizer']
learning_rate = float(training_config['learning_rate'])
epochs = int(training_config['epochs'])
batch_size = int(training_config['batch_size'])
num_of_weeks = int(training_config['num_of_weeks'])
num_of_days = int(training_config['num_of_days'])
num_of_hours = int(training_config['num_of_hours'])
merge = bool(int(training_config['merge']))

1.3 初始化模型参数

参考链接
【1】MXNet模型参数的访问、初始化和共享
【2】MXNET:深度学习计算-模型参数

  • MXNet 的 init 模块里提供了多种预设的初始化方法,如将权重参数初始化为均值=0,标准差=0.01的正态分布随机数,举例:
# 非首次对模型初始化需要指定 force_reinit
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)

# 如对第一个隐藏层的权重使用Xavier 初始化方法
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
  • 自定义初始化方法
# 通过继承init.Initializer,重载构造函数
class MyInit(mx.init.Initializer):
    xavier = mx.init.Xavier()
    uniform = mx.init.Uniform() # 均匀分布

    def _init_weight(self, name, data):
        if len(data.shape) < 2:
            self.uniform._init_weight(name, data)
            print('Init', name, data.shape, 'with Uniform')
        else:
            self.xavier._init_weight(name, data)
            print('Init', name, data.shape, 'with Xavier')

1.4 transpose()

参考链接
【1】numpy中的transpose函数使用方法

transpose()

  • 二维矩阵的transpose() 比较好理解,就是转置操作
  • 三维张量中,有三个维度x/y/z,分为用0/1/2表示,transpose() 相当于做轴变换,看了几篇博客 感觉还是不太清晰,虽然本篇文章做的变换是最简单的一种情况,但是为了搞清楚函数用法,下文将以三维举例,说说我的理解方法。

首先,将三维数组想成一本书(x, y, z),x-页数,y-行,z-列
其次,(x, y, z) = (0, 1, 2)
把每次的transpose() 看作是轴的互换,对应更新索引顺序,并添加到新位置就行了

感觉自己会忘,所以具体转换步骤详细说明了一下,作图不方便 手写如下
代码笔记《Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting》_第1张图片
代码示例:

import numpy as np
arr = np.arange(24).reshape((2,3,4))
>>>结果
[[[ 0  1  2  3]  
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
  
vc = arr.transpose(1,0,2)  # X轴(0)与Y轴(1)发生变换
print(vc)
>>>结果
[[[ 0  1  2  3]
  [12 13 14 15]]

 [[ 4  5  6  7]
  [16 17 18 19]]

 [[ 8  9 10 11]
  [20 21 22 23]]]

vc = arr.transpose(0,2,1)  # 表示Y轴(1)与Z轴(2)发生轴变换
print(vc)
[[[ 0  4  8]
  [ 1  5  9]
  [ 2  6 10]
  [ 3  7 11]]

 [[12 16 20]
  [13 17 21]
  [14 18 22]
  [15 19 23]]]

vc = arr.transpose(2,1,0)  # 表示X轴(0)与Z轴(2)发生轴变换
print(vc)
[[[ 0 12]
  [ 4 16]
  [ 8 20]]

 [[ 1 13]
  [ 5 17]
  [ 9 21]]

 [[ 2 14]
  [ 6 18]
  [10 22]]

 [[ 3 15]
  [ 7 19]
  [11 23]]]

代码笔记《Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting》_第2张图片

2. data_preparation.py

2.1 数据归一化

代码地址⭐,见万老师的个人主页

自取地址
或者这个

你可能感兴趣的:(实例)