源码解读:CSSRNN

源代码:CSSRNN github链接,IJCAI2017

模型解读

源码解读:CSSRNN_第1张图片

  • current embedding: [state_size, emb_dim],其中state_size为link或grid的数量,emb_dim为embedding的维度。
  • destination embedding: [state_size, emb_dim],同上;终点embedding为单独的一套。
  • neighbor embedding:其实是一套线性变换的系数,并不是embedding,只是用embedding来加速;w的shape为[hid_dim, state_size],b的shape为[state_size],其中hid_dim为LSTM的输出维度。CSSRNN 最后一层的实质就是一个带邻接表约束的Softmax层。所谓LPIRNN,相比CSSRNN增加了一个维度,shape为[hid_dim, state_size, adj_size]。可以理解为CSSRNN是按节点给系数,LPIRNN是按边给系数。

代码解读

ID Embedding
先创建边的embedding,shape=[state_size, emb_dim];embedding的本质就是全连接神经网络的系数矩阵W。


# 有pretrain
emb_ = tf.get_variable("embedding", dtype=tf.float64, initializer=pretrained_emb_)
# 无pretrain
emb_ = tf.get_variable("embedding", [state_size, emb_dim], dtype=tf.float64)

Input的编码是one-hot,对one-hot的输入构建全连接神经网络,等价于从embedding中根据id编号提取出one-hot即元素1所在的行来。这个功能类似tf.gather()方法,TensorFlow提供了tf.nn.embedding_lookup(),可以并行地从embedding中查表,得到输入Tensor(shape=[batch_size, time_steps, state_size])embedding后的Tensor(shape=[batch_size, time_steps, emb_dim])。

emb_inputs_ = tf.nn.embedding_lookup(emb_, input_label, name="emb_inputs") # [batch, time, emb_dim]

为了考虑终点的影响,可以用同样的方法对destination进行embedding,然后通过tf.concat拼接到one-hot embedding出来的Tensor里。

# 注意,终点单独做了一次embedding,与前面的emb不是一套
dest_emb_ = tf.get_variable("dest_emb", [state_size, emb_dim], dtype=tf.float64)
dest_inputs_ = tf.tile(tf.expand_dims(tf.nn.embedding_lookup(self.dest_emb_, dest_label_), 1), [1, self.max_t_, 1])  # [batch, time, dest_dim]
inputs_ = tf.concat(2, [emb_inputs_, dest_inputs_], "input_with_dest") # [batch, time, emb_dim + dest_dim]

RNN层:

cell = tf.keras.layers.LSTMCell(hidden_dim)
layer = tf.keras.layers.RNN(cell)
rnn_outputs = layer(emb_inputs_, return_sequences=True)  # [batch, time, hid_dim]

Softmax层:
根据RNN的输出计算损失,损失计算时要考虑邻接表的约束。

outputs_ = tf.reshape(rnn_outputs, ...) # [batch*time, hid_dim]

# 输出层的参数
wp_ = tf.get_variable("wp", [int(outputs_flat_.get_shape()[1]), config.state_size],
                          dtype=config.float_type)  # [hid_dim, state_size]
bp_ = tf.get_variable("bp", [config.state_size], dtype=config.float_type)  # [state_size]


adj_mat = ... # n_edge * n_neighbor, element is the id of edge
adj_mask = ... # n_edge * n_neighbor, element is 1 or 0, where 1 means it is an edge in adj_mat and 0 means a padding in adj_mat

input_flat_ = tf.reshape(input_, [-1])  # [batch*t]
target_flat_ = tf.reshape(target_, [-1, 1])  # [batch*t, 1]
sub_adj_mat_ = tf.nn.embedding_lookup(adj_mat_, input_flat_) # [batch*t, max_adj_num]
sub_adj_mask_ = tf.nn.embedding_lookup(adj_mask_, input_flat_)  # [batch*t, max_adj_num]
# first column is target_
target_and_sub_adj_mat_ = tf.concat(1, [target_flat_, sub_adj_mat_])  # [batch*t, max_adj_num+1]

outputs_3d_ = tf.expand_dims(outputs_, 1)  # [batch*max_seq_len, hid_dim] -> [batch*max_seq_len, 1, hid_dim]

sub_w_ = tf.nn.embedding_lookup(w_t_, target_and_sub_adj_mat_)  # [batch*max_seq_len, max_adj_num+1, hid_dim]
sub_b_ = tf.nn.embedding_lookup(b_, target_and_sub_adj_mat_)  # [batch*max_seq_len, max_adj_num+1] 
sub_w_flat_ = tf.reshape(sub_w_, [-1, int(sub_w_.get_shape()[2])])  # [batch*max_seq_len*max_adj_num+1, hid_dim]
sub_b_flat_ = tf.reshape(sub_b_, [-1]) # [batch*max_seq_len*max_adj_num+1]

outputs_tiled_ = tf.tile(outputs_3d_, [1, tf.shape(adj_mat_)[1] + 1, 1])  # [batch*max_seq_len, max+adj_num+1, hid_dim]
outputs_tiled_ = tf.reshape(outputs_tiled_, [-1, int(outputs_tiled_.get_shape()[2])])  # [batch*max_seq_len*max_adj_num+1, hid_dim]
target_logit_and_sub_logits_ = tf.reshape(tf.reduce_sum(tf.multiply(sub_w_flat_, outputs_tiled_), 1) + sub_b_flat_,
                                                          [-1, tf.shape(adj_mat_)[1] + 1])  # [batch*max_seq_len, max_adj_num+1]

# for numerical stability
scales_ = tf.reduce_max(target_logit_and_sub_logits_, 1)  # [batch*max_seq_len]
scaled_target_logit_and_sub_logits_ = tf.transpose(tf.subtract(tf.transpose(target_logit_and_sub_logits_), scales_))  # transpose for broadcasting [batch*max_seq_len, max_adj_num+1]

scaled_sub_logits_ = scaled_target_logit_and_sub_logits_[:, 1:]  # [batch*max_seq_len, max_adj_num]
exp_scaled_sub_logits_ = tf.exp(scaled_sub_logits_)  # [batch*max_seq_len, max_adj_num]
deno_ = tf.reduce_sum(tf.multiply(exp_scaled_sub_logits_, sub_adj_mask_), 1)  # [batch*max_seq_len]
#log_deno_ = tf.log(deno_)  # [batch*max_seq_len]
log_deno_ = tf.log(tf.clip_by_value(deno_,1e-8,tf.reduce_max(deno_))) #避免计算无意义
log_nume_ = tf.reshape(scaled_target_logit_and_sub_logits_[:, 0:1], [-1])  # [batch*max_seq_len]
loss_ = tf.subtract(log_deno_, log_nume_)  # [batch*t] since loss is -sum(log(softmax))

max_prediction_ = tf.one_hot(tf.argmax(exp_scaled_sub_logits_ * sub_adj_mask_, 1),
                           int(adj_mat_.get_shape()[1]),
                           dtype=tf.float32)  # [batch*max_seq_len, max_adj_num]

你可能感兴趣的:(深度学习tensorflow)