def seq_gen(len_seq, data_seq, offset, n_frame, n_route, day_slot, C_0=1):
'''
Generate data in the form of standard sequence unit.
:param len_seq: int, the length of target date sequence.
:param data_seq: np.ndarray, source data / time-series.
:param offset: int, the starting index of different dataset type.
:param n_frame: int, the number of frame within a standard sequence unit,
which contains n_his = 12 and n_pred = 9 (3 /15 min, 6 /30 min & 9 /45 min).
:param n_route: int, the number of routes in the graph.
:param day_slot: int, the number of time slots per day, controlled by the time window (5 min as default).
:param C_0: int, the size of input channel.
:return: np.ndarray, [len_seq, n_frame, n_route, C_0].
'''
n_slot = day_slot - n_frame + 1
tmp_seq = np.zeros((len_seq * n_slot, n_frame, n_route, C_0))
for i in range(len_seq):
for j in range(n_slot):
sta = (i + offset) * day_slot + j
end = sta + n_frame
tmp_seq[i * n_slot + j, :, :, :] = np.reshape(data_seq[sta:end, :], [n_frame, n_route, C_0])
return tmp_seq
n_train, n_val, n_test = 34, 5, 5
def data_gen(file_path, data_config, n_route, n_frame=21, day_slot=288):
'''
Source file load and dataset generation.
:param file_path: str, the file path of data source.
:param data_config: tuple, the configs of dataset in train, validation(验证), test.
:param n_route: int, the number of routes in the graph(向量个数).
:param n_frame: int, the number of frame within a standard sequence unit,
which contains n_his = 12 and n_pred = 9 (3 /15 min, 6 /30 min & 9 /45 min).
:param day_slot: int, the number of time slots(时间段) per day(每天), controlled by the time window (5 min as default).
:return: dict, dataset that contains training, validation and test with stats(统计值).
'''
n_train, n_val, n_test = data_config
# generate training, validation and test data
try:
# 不能确定正确执行的代码
data_seq = pd.read_csv(file_path, header=None).values
except FileNotFoundError:
print(f'ERROR: input file was not found in {file_path}.')
seq_train = seq_gen(n_train, data_seq, 0, n_frame, n_route, day_slot)
seq_val = seq_gen(n_val, data_seq, n_train, n_frame, n_route, day_slot)
seq_test = seq_gen(n_test, data_seq, n_train + n_val, n_frame, n_route, day_slot)
# x_stats: dict, the stats for the train dataset, including the value of mean and standard deviation.
x_stats = {'mean': np.mean(seq_train), 'std': np.std(seq_train)}
# x_train, x_val, x_test: np.array, [sample_size, n_frame, n_route, channel_size].
x_train = z_score(seq_train, x_stats['mean'], x_stats['std'])
x_val = z_score(seq_val, x_stats['mean'], x_stats['std'])
x_test = z_score(seq_test, x_stats['mean'], x_stats['std'])
x_data = {'train': x_train, 'val': x_val, 'test': x_test}
dataset = Dataset(x_data, x_stats)
return dataset