Iterators
对torchtext的batch实现的修改算法原理
Batching matters a ton for speed. We want to have very evenly divided batches, with absolutely minimal padding. To do this we have to hack a bit around the default torchtext batching. This code patches their default batching to make sure we search over enough sentences to find tight batches.
这里是对torchtext中默认的batching操作进行的优化修改。
参考:https://towardsdatascience.com/how-to-use-torchtext-for-neural-machine-translation-plus-hack-to-make-it-5x-faster-77f3884d95
Torchtext本身已经很好了,并且sort_key使得dataset中的数据排序,这样batching后序列长度相近的会被放在同一个batch中,可以很大程度上降低padding的个数。
但是下面代码又进行了优化:根据每个batch中序列的最大长度,动态更改batch_size,使得可以更好的利用计算资源。
举个例子:
假设你的RAM每个iteration可以处理1500个tokens, batch_size = 20, 那么只有当batch中的序列长度为sequence length = 1500 / 20 = 75时,才可以将计算资源利用完全。
现实中,每个batch的sequence length的显然是在变化的,那么如果希望尽量多的利用计算资源,就需要可以动态调整当前的batch_size.
Transformer中的MyIterator重载了data.Iterator中的create_batches函数:
1 class MyIterator(data.Iterator): 2 def create_batches(self): 3 if self.train: 4 def pool(d, random_shuffler): 5 for p in data.batch(d, self.batch_size * 100): 6 p_batch = data.batch( 7 sorted(p, key=self.sort_key), 8 self.batch_size, self.batch_size_fn) 9 for b in random_shuffler(list(p_batch)): 10 yield b 11 self.batches = pool(self.data(), self.random_shuffler) 12 13 else: 14 self.batches = [] 15 for b in data.batch(self.data(), self.batch_size, 16 self.batch_size_fn): 17 self.batches.append(sorted(b, key=self.sort_key)) 18 19 def rebatch(pad_idx, batch): 20 "Fix order in torchtext to match ours" 21 src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1) 22 return Batch(src, trg, pad_idx)
pool函数
其中pool函数的功能与https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py中定义的class BucketIterator(Iterator)的pool函数功能类似。
1. 将原始的data分成大小为 100 * batch_size的一些chunks => (以上迭代 p 即为 每个chunk)
2. 在每个chunk中根据 sort_key 对examples进行排序,并对每个chunk按照batch_size分成100个batch =>
( p_batch = data.batch( sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn) )
3. 将这些chunks进行shuffle => (random_shuffler(list(p_batch)))
4. 在每个chunk中再把examples分成 大小为 batch_size 的 100 个 batch => (以上 b 即为每个 batch)
5. 生成器每次 yield一个batch => (yield b)