STGCN复现第二弹:解读main.py

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0" #设置当前使用的GPU设备仅为0号设备  设备名称为'/gpu:0'
from os.path import join as pjoin

import tensorflow as tf
tf.reset_default_graph() #函数用于清除默认图形堆栈并重置全局默认图形。

#动态申请显存

config = tf.ConfigProto()  #tf.ConfigProto()主要的作用是配置tf.Session的运算方式,比如gpu运算或者cpu运算
config.gpu_options.allow_growth = True #当使用GPU时候,Tensorflow运行自动慢慢达到最大GPU的内存
tf.Session(config=config)
from utils.math_graph import *
from data_loader.data_utils import *
from models.trainer import model_train
from models.tester import model_test

import argparse

创建一个 ArgumentParser 对象:

parser = argparse.ArgumentParser()	 #ArgumentParser对象包含将命令行解析成 Python 数据类型所需的全部信息。

通过调用 add_argument() 方法给 ArgumentParser对象添加程序所需的参数信息:

#ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])
  • name or flags - 一个命名或者一个选项字符串的列表,例如 foo 或 -f, --foo。
  • action - 当参数在命令行中出现时使用的动作基本类型。
  • nargs - 命令行参数应当消耗的数目。
  • const - 被一些 action 和 nargs 选择所需求的常数。
  • default - 当参数未在命令行中出现时使用的值。
  • type - 命令行参数应当被转换成的类型。
  • choices - 可用的参数的容器。
  • required - 此命令行选项是否可省略 (仅选项可用)。
  • help - 一个此选项作用的简单描述。
  • metavar - 在使用方法消息中使用的参数值示例。
  • dest - 被添加到 parse_args() 所返回对象上的属性名。
parser.add_argument('--n_route', type=int, default=228)# 
parser.add_argument('--n_his', type=int, default=12)
parser.add_argument('--n_pred', type=int, default=9)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--epoch', type=int, default=50)
parser.add_argument('--save', type=int, default=10)
parser.add_argument('--ks', type=int, default=3)
parser.add_argument('--kt', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--opt', type=str, default='RMSProp')
parser.add_argument('--graph', type=str, default='default')
parser.add_argument('--inf_mode', type=str, default='merge')

通过 parse_args() 方法解析参数:

args = parser.parse_args()  #args是dict类型,包含参数信息

字符串前面加f表示格式化字符串,加f后可以在字符串里面使用用花括号括起来的变量和表达式,如果字符串里面没有表达式,那么前面加不加f输出应该都一样.python 3.6后可用

print(f'Training configs: {args}')
n, n_his, n_pred = args.n_route, args.n_his, args.n_pred
Ks, Kt = args.ks, args.kt
# blocks: settings of channel(通道) size in st_conv_blocks / bottleneck design
blocks = [[1, 32, 64], [64, 32, 128]]

# Load wighted adjacency matrix W
if args.graph == 'default':
    W = weight_matrix(pjoin('./dataset', f'PeMSD7_W_{n}.csv'))  #pjoin()拼接地址字符串
else:
    # load customized(自定义) graph weight matrix
    W = weight_matrix(pjoin('./dataset', args.graph))

# Calculate graph kernel
L = scaled_laplacian(W)

scaled_laplacian()函数详解

# Alternative(供替代的选择) approximation(近似) method(方法): 1st approx - first_approx(W, n).
Lk = cheb_poly_approx(L, Ks, n)

STGCN复现第二弹:解读main.py_第1张图片

tf.add_to_collection(name='graph_kernel', value=tf.cast(tf.constant(Lk), tf.float32)) 
’‘’
tf.cast():数据类型转换
tf.Constant()创建一个常量tensor,按照给出value来赋值,可以用shape来指定其形状。value可以是一个数,也可以是一个list。
‘’‘

tf.cast()函数详解:数据类型转换
tf.Constant()函数详解

# Data Preprocessing
data_file = f'PeMSD7_V_{n}.csv'
n_train, n_val, n_test = 34, 5, 5
PeMS = data_gen(pjoin('./dataset', data_file), (n_train, n_val, n_test), n, n_his + n_pred)
print(f'>> Loading dataset with Mean: {PeMS.mean:.2f}, STD: {PeMS.std:.2f}')
if __name__ == '__main__':
    model_train(PeMS, blocks, args)
    model_test(PeMS, PeMS.get_len('test'), n_his, n_pred, args.inf_mode)

你可能感兴趣的:(复现STGCN)