import torch
import torch.nn as nn
from models import EncoderDecoder
from data_utils import DataOrderScaner
import os, h5py
import constants
def evaluate(src, model, max_length):
"""
evaluate one source sequence
"""
m0, m1 = model
#将模型参数model解构为m0(编码器-解码器模型)和m1(线性层加Softmax层)
length = len(src)
src = torch.LongTensor(src)
#获取源序列的长度,并将其转换为LongTensor
## (seq_len, 1)
src = src.view(-1, 1)
#(1,1)
length = torch.LongTensor([[length]])
encoder_hn, H = m0.encoder(src, length)
#获取encoder的隐藏状态和输出
h = m0.encoder_hn2decoder_h0(encoder_hn)
#将encoder的隐藏状态转化为decoder的隐藏状态
## running the decoder step by step with BOS as input
input = torch.LongTensor([[constants.BOS]])
trg = []
#存储预测的目标序列
for _ in range(max_length):
## `h` is updated for next iteration
o, h = m0.decoder(input, h, H)
#使用当前的输入和隐藏状态来获取解码器的输出o和更新后的隐藏状态h
o = o.view(-1, o.size(2)) ## => (1, hidden_size)
o = m1(o) ## => (1, vocab_size)
#通过m1获得词汇表大小维度上的对数概率分布
## the most likely word
_, word_id = o.data.topk(1)
word_id = word_id[0][0]
#选取概率最高的词汇作为预测的词
if word_id == constants.EOS:
break
#如果预测的词是结束符(EOS),则结束序列生成
trg.append(word_id)
## update `input` for next iteration
input = torch.LongTensor([[word_id]])
return trg
def evaluator(args):
"""
do evaluation interactively
"""
m0 = EncoderDecoder(args.vocab_size, args.embedding_size,
args.hidden_size, args.num_layers,
args.dropout, args.bidirectional)
m1 = nn.Sequential(nn.Linear(args.hidden_size, args.vocab_size),
nn.LogSoftmax())
'''
创建EncoderDecoder模型实例m0
创建线性层+LogSoftmax层
'''
if os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
m0.load_state_dict(checkpoint["m0"])
m1.load_state_dict(checkpoint["m1"])
#将checkpoint中的模型参数导入
while True:
try:
print("> ", end="")
src = input()
# 从用户那里获取输入【一个sequence】
src = [int(x) for x in src.split()]
#将输入字符串按空格分割,并将其转换为整数列表
trg = evaluate(src, (m0, m1), args.max_length)
#调用evaluate 函数,生成sequence 对应的最大概率单元格id 序列
print(" ".join(map(str, trg)))
except KeyboardInterrupt:
#捕获任何由KeyboardInterrupt(例如用户按下Ctrl+C)引发的异常
break
else:
print("=> no checkpoint found at '{}'".format(args.checkpoint))
3
'''
从文件中读取轨迹数据,然后使用encoder-decoder模型,将结果向量表征保存在HDF5文件中
'''
def t2vec(args):
"read source sequences from trj.t and write the tensor into file trj.h5"
m0 = EncoderDecoder(args.vocab_size, args.embedding_size,
args.hidden_size, args.num_layers,
args.dropout, args.bidirectional)
'''
初始化EncoderDecoder(和evaluate不同,这里不用初始化m1)
'''
if os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
m0.load_state_dict(checkpoint["m0"])
#加载已经训练好的checkpoint
if torch.cuda.is_available():
m0.cuda()
m0.eval()
vecs = []
scaner = DataOrderScaner(os.path.join(args.data, "{}-trj.t".format(args.prefix)), args.t2vec_batch)
scaner.load()
#从源文件中扫描和加载数据,并以batch的方式返回数据
i = 0
while True:
if i % 100 == 0:
print("{}: Encoding {} trjs...".format(i, args.t2vec_batch))
i = i + 1
src, lengths, invp = scaner.getbatch()
#不重复地获取一个batch的数据
#在getbatch操作中哦,对当前批次的 数据进行pad和重新排序的操作
#invp是为了获取排序前的序列顺序
if src is None: break
#src 为None,表示所有轨迹都已经scan了一遍
if torch.cuda.is_available():
src, lengths, invp = src.cuda(), lengths.cuda(), invp.cuda()
h, _ = m0.encoder(src, lengths)
#使用encoder编码这个,得到相应的hidden state
## (num_layers, batch, hidden_size * num_directions)
h = m0.encoder_hn2decoder_h0(h)
#将encoder的输出hidden state,转化为decoder的输入hidden state 的格式
## (batch, num_layers, hidden_size * num_directions)
h = h.transpose(0, 1).contiguous()
## (batch, *)
#h = h.view(h.size(0), -1)
vecs.append(h[invp].cpu().data)
#通过invp,转换成原来这个batch中序列的顺序
## (num_seqs, num_layers, hidden_size * num_directions)
vecs = torch.cat(vecs)
## (num_layers, num_seqs, hidden_size * num_directions)
vecs = vecs.transpose(0, 1).contiguous()
path = os.path.join(args.data, "{}-trj.h5".format(args.prefix))
#创建对应的hdf5文件,存储encoder之后各层的hidden state
print("=> saving vectors into {}".format(path))
with h5py.File(path, "w") as f:
for i in range(m0.num_layers):
f["layer"+str(i+1)] = vecs[i].squeeze(0).numpy()
#hdf5每一层存储的是对应层 encoder的hidden state
else:
print("=> no checkpoint found at '{}'".format(args.checkpoint))