批量训练pytorch练习

利用pytorch框架训练,有一个重要步骤就是批量训练
要用到torch.nn.data模块

import torch.nn.data as Data

创建一个TensorDataset对象存放训练数据x,y

mydataset=Data.TensorDataset(x,y)

创建一个DataLoader对象加载数据,设置dataset,batch_size,shuffle,num_work等参数

 BATCH_SIZE=5 #设置批量训练的数量,超参数用大写字母表示
 data_loader=Data.Dataloder(dataset=mydataset, #数据集
    						batch_size=BATCH_SIZE, #批量训练数量
    						shuffle=True,#(打乱数据顺序)
    					   num_workers=2  #(2个线程)

数据集和数据加载器都写好后,可以开始训练了

for epoch in range(num_epoch):
	for step ,(batch_x,batch_y) in enumerate(data_loader):
		#开始训练
		#可视化,打印batch train 的数据
		print("epoch:{}\t step :{} \n batch_x:{} batch_y:{}".format(epoch ,step, batch_x,batch_y

下面贴入全部完整代码

import torch 
import torch.utils.data as Data
import numpy as np

torch.manual_seed(1)

x=torch.linspace(1,10,10)


y=torch.linspace(10,1,10)

BATCH_SIZE=6
num_epoches=3
data_set=Data.TensorDataset(x,y)

data_loader=Data.DataLoader(dataset=data_set,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             num_workers=2
                             )

def show_batch():
    for epoch in range(num_epoches):
        for step ,(batch_x,batch_y) in enumerate(data_loader):
            print("epoch:{} step: {} \nbatch_x:{} batch_y : {}\n".format(epoch,
                  step,
                  batch_x,
                  batch_y))
            
import import_test    

if __name__=="__main__":
    print("当前运行的脚本名",__name__)
    
    show_batch()

你可能感兴趣的:(莫烦pytorch学习)