pytorch下的interator.init_epoch()函数

先来看一下:出现init_epoch该函数的代码段如下:

from torchtext import data

def mt_iterator(opt, train=True):
    DE = data.Field(eos_token=EOS, lower=True, preprocessing=(lambda x: x[::-1]) if opt.reverse else None)
    EN = data.Field(init_token=EOS, eos_token=EOS, lower=True)
    train_data, val_data, test_data = datasets.TranslationDataset.splits(path=opt.data, train='train',
                                                                         validation='valid', test='test',
                                                                         exts=('.input', '.output'), 
                                                                         fields=(DE, EN))
    if train:
        iterator = data.BucketIterator.splits(
            (train_data, val_data), batch_size=opt.batch_size, device=0 if opt.cuda else -1
        )
        iterator[0].repeat = False
    else:
        iterator = data.Iterator(  test_data, batch_size=opt.batch_size,
            device=0 if opt.cuda else -1, train=False, shuffle=False, sort=False
        )  
    return iterator, EN.vocab

iterator, vocab_en = mt_iterator(opt)
train_iter = iterator[0]

train_iter.init_epoch()

由于init_epoch函数出自torchtext.data的iter,查看此源代码。

Iterator:迭代器函数,用来加载数据集中的批次的数据。

init_epoch()函数在最下方,其作用是为每次epoch创建一个batch生成器。train时的每一次epoch初始化一次。

class Iterator(object):
  """Defines an iterator that loads batches of data from a Dataset.
  Attributes:
  batch_size_fn: Function of three arguments (new example to add, current
  count of examples in the batch, and current effective batch size)
  that returns the new effective batch size resulting from adding
  that example to a batch. This is useful for dynamic batching, where
  this function would add to the current effective batch size the
  number of tokens in the new example.
  sort_key: 用于排序的键,可将相似长度的放一起,减少padding
  train: 是否这个iterator表示训练集.
  repeat: Whether to repeat the iterator for multiple epochs.【默认train】
  shuffle: 每两次epoch间是否打乱【默认train】
  sort: 是否按照self.sort_key排序.【默认not train】
  sort_within_batch: Whether to sort (in descending order according to
  self.sort_key) within each batch. 默认参照self.sort.
  If self.sort is True and this is False, the batch is left in the
  original (ascending) sorted order.
  device: Use -1 for CPU ,默认GPU
  """
   
  def __init__(self, dataset, batch_size, sort_key=None,device=None,
  batch_size_fn=None,train=True,
  repeat=None,shuffle=None,sort=None,
  sort_within_batch=None):
  self.batch_size, self.train, self.dataset= batch_size, train, dataset
  self.batch_size_fn = batch_size_fn
  self.iterations = 0
  self.repeat = train if repeat is None else repeat
  self.shuffle = train if shuffle is None else shuffle
  self.sort = not train if sort is None else sort
  if sort_within_batch is None:
  self.sort_within_batch = self.sort
  else:
  self.sort_within_batch = sort_within_batch
  if sort_key is None:
  self.sort_key = dataset.sort_key
  else:
  self.sort_key = sort_key
  self.device = device
   
  self.random_shuffler = RandomShuffler()
   
  # For state loading/saving only
  self._iterations_this_epoch= 0
  self._random_state_this_epoch= None
  self._restored_from_state = False
   
  @classmethod
  def splits(cls, datasets, batch_sizes=None,**kwargs):
  """Create Iterator objects for multiple splits of a dataset.
 
  Arguments:
  datasets: Tuple of Dataset objects corresponding to the splits. The
  first such object should be the train set.
  batch_sizes: Tuple of batch sizes to use for the different splits,
  or None to use the same batch_size for all splits.
  Remaining keyword arguments: Passed to the constructor of the
  iterator class being used.
  """
  if batch_sizes is None:
  batch_sizes = [kwargs.pop('batch_size')]* len(datasets)
  ret = []
  for i in range(len(datasets)):
  train = i == 0
  ret.append(cls(
  datasets[i], batch_size=batch_sizes[i],train=train,**kwargs))
  return tuple(ret)
   
  def data(self):
  """Return the examples in the dataset in order, sorted, or shuffled."""
  if self.sort:
  xs = sorted(self.dataset,key=self.sort_key)
  elif self.shuffle:
  xs = [self.dataset[i]for i in self.random_shuffler(range(len(self.dataset)))]
  else:
  xs = self.dataset
  return xs
   
init_epoch()函数
def init_epoch(self):

"""Set up the batch generator for a new epoch."""每轮重新迭代



if self._restored_from_state:如果 ,则获取保存的参数

self.random_shuffler.random_state= self._random_state_this_epoch

else:否则 默认为第一次迭代,则自行初始化

self._random_state_this_epoch= self.random_shuffler.random_state



self.create_batches()创建第0次迭代的batch数据



if self._restored_from_state:如果是恢复状态,则将此标志位归于默认值

self._restored_from_state= False

else:默认为否

self._iterations_this_epoch= 0 说明是刚开始,是第零次迭代



if not self.repeat:迭代器在多轮之间不重复,则

self.iterations = 0

 
  def create_batches(self):
  self.batches = batch(self.data(),self.batch_size, self.batch_size_fn)
   
  @property
  def epoch(self):
  return self.iterations / len(self)




你可能感兴趣的:(随记)