pytorch设置batch

使用pytorch进行数据网络训练时,数据集可能有上万条数据,训练的话比较浪费时间,设置batch,一次训练一个batch_size的大小,既节省时间,又可以快速收敛。
使用前需要导入包:

from torch.utils.data import Dataset, DataLoader, TensorDataset

设置batch,需要将训练数据的输入属性和标签放入DataLoader中,见下:


def addbatch(data_train,data_test,batchsize):
    """
    设置batch
    :param data_train: 输入
    :param data_test: 标签
    :param batchsize: 一个batch大小
    :return: 设置好batch的数据集
    """
    data = TensorDataset(data_train,data_test)
    data_loader = DataLoader(data, batch_size=batchsize, shuffle=False)#shuffle是是否打乱数据集,可自行设置

    return data_loader

使用时调用即可:

#设置batch
    traindata=addbatch(traininput,trainlabel,1000)#1000为一个batch_size大小为1000,训练集为10000时一个epoch会训练10次。

进行神经网络训练用下面方法:

    for epoch in range(EPOCH):
        for step, data in enumerate(traindata):
            inputs, labels = data
            # 前向传播
            out = net(inputs)
            # 计算损失函数
            loss = loss_func(out, labels)
            # 清空上一轮的梯度
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

你可能感兴趣的:(python,pytorch,batch,深度学习)