《Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting》代码
原始论文代码见上方链接,本文仅记录不懂之处
参考链接
【1】Python解析命令行读取参数–argparse模块使用方法
【2】python3中argparse模块
功能
步骤
# 实例化解析器对象,description参数可以用于插入描述脚本用途的信息,可以为空
parser = argparse.ArgumentParser()
# 添加参数
# 添加--config标签
# required参数表示--config参数是必须的,并且类型为str,输入其他的会报错
# help参数描述--config参数的用途或意义
parser.add_argument("--config", type=str, help="configuration file path", required=True)
parser.add_argument("--force", type=str, default=False, help="remove params dir", required=False)
# 将变量以“标签-值”的字典形式存入args字典
args = parser.parse_args()
参考链接
【1】python -ConfigParser模块讲解
【2】Python 之ConfigParser模块
功能
# read configuration
# 初始化实例
config = configparser.ConfigParser()
print('Read configuration file: %s' % (args.config))
# 读取配置文件
config.read(args.config)
data_config = config['Data']
training_config = config['Training']
# Data相关
adj_filename = data_config['adj_filename']
graph_signal_matrix_filename = data_config['graph_signal_matrix_filename']
num_of_vertices = int(data_config['num_of_vertices'])
points_per_hour = int(data_config['points_per_hour'])
num_for_predict = int(data_config['num_for_predict'])
# Training相关
model_name = training_config['model_name']
ctx = training_config['ctx']
optimizer = training_config['optimizer']
learning_rate = float(training_config['learning_rate'])
epochs = int(training_config['epochs'])
batch_size = int(training_config['batch_size'])
num_of_weeks = int(training_config['num_of_weeks'])
num_of_days = int(training_config['num_of_days'])
num_of_hours = int(training_config['num_of_hours'])
merge = bool(int(training_config['merge']))
参考链接
【1】MXNet模型参数的访问、初始化和共享
【2】MXNET:深度学习计算-模型参数
# 非首次对模型初始化需要指定 force_reinit
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
# 如对第一个隐藏层的权重使用Xavier 初始化方法
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
# 通过继承init.Initializer,重载构造函数
class MyInit(mx.init.Initializer):
xavier = mx.init.Xavier()
uniform = mx.init.Uniform() # 均匀分布
def _init_weight(self, name, data):
if len(data.shape) < 2:
self.uniform._init_weight(name, data)
print('Init', name, data.shape, 'with Uniform')
else:
self.xavier._init_weight(name, data)
print('Init', name, data.shape, 'with Xavier')
参考链接
【1】numpy中的transpose函数使用方法
transpose()
首先,将三维数组想成一本书(x, y, z),x-页数,y-行,z-列
其次,(x, y, z) = (0, 1, 2)
把每次的transpose() 看作是轴的互换,对应更新索引顺序,并添加到新位置就行了
感觉自己会忘,所以具体转换步骤详细说明了一下,作图不方便 手写如下
代码示例:
import numpy as np
arr = np.arange(24).reshape((2,3,4))
>>>结果
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
vc = arr.transpose(1,0,2) # X轴(0)与Y轴(1)发生变换
print(vc)
>>>结果
[[[ 0 1 2 3]
[12 13 14 15]]
[[ 4 5 6 7]
[16 17 18 19]]
[[ 8 9 10 11]
[20 21 22 23]]]
vc = arr.transpose(0,2,1) # 表示Y轴(1)与Z轴(2)发生轴变换
print(vc)
[[[ 0 4 8]
[ 1 5 9]
[ 2 6 10]
[ 3 7 11]]
[[12 16 20]
[13 17 21]
[14 18 22]
[15 19 23]]]
vc = arr.transpose(2,1,0) # 表示X轴(0)与Z轴(2)发生轴变换
print(vc)
[[[ 0 12]
[ 4 16]
[ 8 20]]
[[ 1 13]
[ 5 17]
[ 9 21]]
[[ 2 14]
[ 6 18]
[10 22]]
[[ 3 15]
[ 7 19]
[11 23]]]
自取地址
或者这个