pytorch学习笔记——batch生成器

背景

深度学习的训练数据往往很多,如果一次性训练所有的数据,不但会导致时间过长,而且训练次数不够,参数也不能得到很好的迭代。为此,将训练数据分成小的batch,一次batch迭代就可以完成一次参数更新,大大提高了训练速度。
Pytorch中有现成的batch生成器,但是为了底层原理的理解,最好自己能够写出这样的代码,就先从能看懂现成代码开始吧。

batch生成器函数

def data_iter(batch_size,features,labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        j = torch.LongTensor(indices[i:min(i+batch_size,num_examples)])
        yield features.index_select(0,j),labels.index_select(0,j)

这个函数名为batch_iter,参数有三:batch_size(batch的大小)、features(训练数据的特征,可以视为自变量)和labels(训练数据的标签,可以视为因变量)。
首先,num_examples获得features变量的长度,这个值就是训练数据的个数;
接着,利用range函数生成从0到training number-1(num_examples-1)的range,利用list函数将其转为列表,这样,indices就是一个包含从0到num_examples-1的list了;
然后利用random包的shuffle函数对indices的数进行洗牌(实际上就是打乱,然后随机排列);
接下来,range(0, num_examples, batch_size)是从0到num_examples-1,每隔batch_size步长产生一个数,即这里的i分别为0,10,20,…,990;
接着,indices[i:min(i+batch_size,num_examples)]是一个索引操作,表示取出indices这个list里从i到下一个变量值-1的子list,其中min(i+batch_size,num_examples)的作用是防止索引超出最大范围;这样,j就得到了一个值类型为long的tensor,值是indices里索引出来的子list;
最后,yield函数是一个生成器,在这里可以简单的看作是return;features.index_select(0,j)中,0表示的是dim,这个操作从features中索引出了所有行数为j的features,组成一个tensor;labels同理。
这样,调用这个函数之后就可以获得一个训练数据的batch了。

实例演示

import torch
import numpy as np
import random

num_inputs = 2
num_examples = 1000
true_w = [2,-3.4]
true_b = 4.2
features = torch.from_numpy(np.random.normal(0,1,(num_examples,num_inputs)))
labels = true_w[0]*features[:,0]+true_w[1]*features[:,1]+true_b
labels += torch.from_numpy(np.random.normal(0,0.01,size = labels.size()))

def data_iter(batch_size,features,labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        j = torch.LongTensor(indices[i:min(i+batch_size,num_examples)])
        yield features.index_select(0,j),labels.index_select(0,j)

batch_size = 10

for x,y in data_iter(batch_size,features,labels):
    print(x,y) #这里只演示出第一个batch
    break

输出结果:

tensor([[ 1.0289, -0.5676],
        [ 0.4811,  0.0651],
        [-0.7113, -0.7735],
        [ 0.5077,  1.5935],
        [ 0.5343,  0.8802],
        [-1.1659, -1.0234],
        [ 0.1249, -0.2690],
        [-1.9804,  0.9771],
        [-0.5953, -0.0802],
        [ 0.2558, -1.0796]], dtype=torch.float64) tensor([ 8.2047,  4.9386,  5.4247, -0.2031,  2.2704,  5.3475,  5.3715, -3.0973,
         3.2829,  8.3943], dtype=torch.float64)

你可能感兴趣的:(笔记,pytorch,学习,batch)