书接上回。
上回讲完了class NeuralNetwork()的构造函数__init__(),接下来该是class NeuralNetwork()的run函数了。为什么该是run()函数?看下边代码:
def test_train():
visual_loss.init()
visual_acc.init()
nn = NeuralNetwork(epoch = 50, lr = 0.0001, eval_every_N_steps = 200)
nn.run(test = True)
if __name__ == '__main__':
test_train()
豁然开朗了吧。network.py的主函数调用了test_train(),而test_train()中最主要的就是NeuralNetwork()和run()。前一个函数实际上完成了数据的准备和预处理,其已经在上一篇文章中讲解完了。本篇文章详细讲解run()函数,其完成了最为核心的训练以及预测功能。代码如下:
def run(self, test = False):
'''
Train model with mini-batch stochastic gradient descent.
'''
total_step = 0
for n in range(self.epoch):
for mini_batch in corpus.load_train():
nabla_b = [np.zeros(b.shape) for b in self.biases]
nabla_w = [np.zeros(w.shape) for w in self.weights]
total_cost = 0.0
for x, y_ in mini_batch:
# here scale the input's word ids with 0.001 for x to make sure the Z-vector can pass sigmoid fn
delta_nabla_b, delta_nabla_w, cost = self.back_propagation( \
np.reshape(x, (self.input_layer_size, 1)) * 0.001, \
np.reshape(y_, (self.output_layer_size, 1)))
nabla_b = [ nb+mnb for nb, mnb in zip(nabla_b, delta_nabla_b)]
nabla_w = [ nw+mnw for nw, mnw in zip(nabla_w, delta_nabla_w)]
total_cost += cost
self.weights = [ w - (self.lr * w_)/len(mini_batch) for w, w_ in zip(self.weights, nabla_w)]
self.biases = [ b - (self.lr * b_)/len(mini_batch) for b, b_ in zip(self.biases, nabla_b)]
total_step += 1
print("Epoch %s, total step %d, cost %f" % (n, total_step, total_cost/len(mini_batch)))
visual_loss.plot(total_step, total_cost/len(mini_batch))
if (total_step % self.eval_every_N_steps ) == 0 and test:
accuracy = self.evaluate()
print("Epoch %s, total step %d, accuracy %s" % (n, total_step, accuracy))
visual_acc.plot(total_step, accuracy)
千万不要小看这区区十几行代码,实际上展开来的内容是相当多的!前边几行无需多说,很容易理解。从for n in range(self.epoch):这行开始讲起。
其实这行也不难理解,for循环表示一共进行epoch轮训练的每一轮进行的处理。
接下来是for mini_batch in corpus.load_train():。先说corpus.load_train()。可以参见上一篇文章中的corpus.load_test()的分析。实际上最终调用了data.py中的load_train()。代码如下:
def load_train(batch_size = 100, question_max_length = 20, utterance_max_length = 99):
'''
load train data
'''
return __resolve_input_data(_train_data, batch_size, question_max_length, utterance_max_length)
先来看__resolve_input_data()中的_train_data。与上一篇文章中的_test_data类似,它也在data.py中:_train_data = insuranceqa.load_pairs_train()。_train_data保存的是insuranceqa.load_pairs_train()的返回值。参考前一篇文章,insuranceqa.load_pairs_train()最终返回的是XXX\Anaconda3\Lib\site-packages\pairs\insuranceqa_data\iqa.train.json.gz文件的内容。
再来看__resolve_input_data(),其源码如下:
def __resolve_input_data(data, batch_size, question_max_length = 20, utterance_max_length = 99):
'''
resolve input data
'''
batch_iter = BatchIter(data = data, batch_size = batch_size)
for mini_batch in batch_iter.next():
result = []
for o in mini_batch:
x = pack_question_n_utterance(o['question'], o['utterance'], question_max_length, utterance_max_length)
y_ = o['label']
assert len(x) == utterance_max_length + question_max_length + 1, "Wrong length afer padding"
assert VOCAB_GO_ID in x, " must be in input x"
assert len(y_) == 2, "desired output."
result.append([x, y_])
if len(result) > 0:
# print('data in batch:%d' % len(mini_batch))
yield result
else:
raise StopIteration
这个函数中一上来又引出了一个类:BatchIter。BatchIter类在data.py中定义和
实现,代码如下:
class BatchIter():
'''
Load data with mini-batch
'''
def __init__(self, data = None, batch_size = 100):
assert data is not None, "data should not be None."
self.batch_size = batch_size
self.data = data
def next(self):
random.shuffle(self.data)
index = 0
total_num = len(self.data)
while index <= total_num:
yield self.data[index:index + self.batch_size]
index += self.batch_size
__resolve_input_data()中首先调用了BatchIter的构造函数,其实就是传入了2个值,batch_size和data。batch_size由于调用的时候没有指明,因此是默认的100。而data就是_train_data,即XXX\Anaconda3\Lib\site-packages\pairs\insuranceqa_data\iqa.train.json.gz文件的原始内容。
__resolve_input_data()接下来调用了BatchIter的next()函数,该函数首先调用random.shuffle()将self.data即_train_data中的元素打乱。之后每次迭代返回batch_size的数据,直到返回了全部的数据为止。但对于这段代码,个人认为有一些bug,如果index没有超过total_num,但index加上batch_size超过了total_num,应该会导致数组溢出。
接下来,__resolve_input_data()中调用了pack_question_n_utterance(o['question'], o['utterance'], question_max_length, utterance_max_length),并将其返回值赋给了x。pack_question_n_utterance()这个函数在前一篇文章中已经作了详细分析,在这里调用的作用是将训练集_train_data的mini_batch个样本中每个样本的问题、分隔符和答复组合在一起,并依次添加到result列表中。同时,将mini_batch个样本中每个样本的标签(正例或负例)赋给y_,并与x的每个值组成列表,一并依次添加到result列表中。最终将result列表返回。
__resolve_input_data()返回了,意味着corpus.load_train()也返回了。回到run函数中的for mini_batch in corpus.load_train():,之前我们就是从这里走来的,现在走了一大圈,还是要回到起始的地方,不过这一次我们是把所有的未知都探明了才回来的。corpus.load_train()返回的是_train_data(也即XXX\Anaconda3\Lib\site-packages\pairs\insuranceqa_data\iqa.train.json.gz文件的原始内容)的打乱顺序后的一个batch(batch_size,mini_batch),那么for循环中的mini_batch就是这个batch中的每一项。
下一篇文章我将详细解析for mini_batch in corpus.load_train():循环体中的内容。