Pytorch基础知识之pth文件与DataLoader数据加载器

        在PyTorch框架训练模型的时候,经常会看到.pth这样的文件,如果我们直接打开看是一些乱码,那这个文件是做什么的,保存了一些什么东西呢?实质其实是一个.pkl文件,想了解这个文件的可以参阅:Python基础知识汇总
既然是pkl文件,保存的也是key-value键值对,或说字典类型,我们来保存并显示看下:

#将字典保存到tony.pth文件
torch.save({'hi':123,'hello':'haha','name':'Tony'}, 'tony.pth')
#读取
t=torch.load("tony.pth")
for k,v in t.items():
    print(k,v)

'''
hi 123
hello haha
name Tony
'''

上面是很简单的一个字典类型,一般在模型训练中,使用的是有序字典。

我们找一个例子来看下,前面在训练WGAN的时候,会生成很多.pth文件,我选一个netG_epoch_24.pth文件查看下内容。

import torch
net2 = torch.load("netG_epoch_24.pth")
print(type(net2),len(net2))

for k,v in net2.items():
    print(k,type(v),v.size())

  25
main.initial:100-512:convt.weight torch.Size([100, 512, 4, 4])
main.initial:512:batchnorm.weight torch.Size([512])
main.initial:512:batchnorm.bias torch.Size([512])
main.initial:512:batchnorm.running_mean torch.Size([512])
main.initial:512:batchnorm.running_var torch.Size([512])
main.initial:512:batchnorm.num_batches_tracked torch.Size([])
main.pyramid:512-256:convt.weight torch.Size([512, 256, 4, 4])
main.pyramid:256:batchnorm.weight torch.Size([256])
main.pyramid:256:batchnorm.bias torch.Size([256])
main.pyramid:256:batchnorm.running_mean torch.Size([256])
main.pyramid:256:batchnorm.running_var torch.Size([256])
main.pyramid:256:batchnorm.num_batches_tracked torch.Size([])
main.pyramid:256-128:convt.weight torch.Size([256, 128, 4, 4])
main.pyramid:128:batchnorm.weight torch.Size([128])
main.pyramid:128:batchnorm.bias torch.Size([128])
main.pyramid:128:batchnorm.running_mean torch.Size([128])
main.pyramid:128:batchnorm.running_var torch.Size([128])
main.pyramid:128:batchnorm.num_batches_tracked torch.Size([])
main.pyramid:128-64:convt.weight torch.Size([128, 64, 4, 4])
main.pyramid:64:batchnorm.weight torch.Size([64])
main.pyramid:64:batchnorm.bias torch.Size([64])
main.pyramid:64:batchnorm.running_mean torch.Size([64])
main.pyramid:64:batchnorm.running_var torch.Size([64])
main.pyramid:64:batchnorm.num_batches_tracked torch.Size([])
main.final:64-3:convt.weight torch.Size([64, 3, 4, 4])

可以看出保存的键值类型是参数,有权重值、偏置、均值和方差等,值内容的类型是Tensor张量,整个使用的是有序的字典类型OrderedDict。 

DataLoader数据加载器的用法

DataLoader数据加载器,来自torch.utils.data模块,参数可以使用数据集与采样器,可以使用多进程来处理数据集(Linux)。将数据集装载进去训练的时候,会将数据分成多个小组(每次小组数量取决于批量大小),批量的进行迭代。

import torch
import torch.utils.data as udata

x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)
#使用TensorDataset将数据包装成Dataset类
dset = udata.TensorDataset(x, y)
#每次批量5个,为了更直观没有打乱,实际训练中一般都是打乱比较好,也就是shuffle=True
#loader = udata.DataLoader(dataset=dset, batch_size=5,shuffle=False,num_workers=0)
'''
epoch:0, step:0, batch_x:tensor([1., 2., 3., 4., 5.]), batch_y:tensor([10.,  9.,  8.,  7.,  6.])
epoch:0, step:1, batch_x:tensor([ 6.,  7.,  8.,  9., 10.]), batch_y:tensor([5., 4., 3., 2., 1.])
epoch:1, step:0, batch_x:tensor([1., 2., 3., 4., 5.]), batch_y:tensor([10.,  9.,  8.,  7.,  6.])
epoch:1, step:1, batch_x:tensor([ 6.,  7.,  8.,  9., 10.]), batch_y:tensor([5., 4., 3., 2., 1.])
epoch:2, step:0, batch_x:tensor([1., 2., 3., 4., 5.]), batch_y:tensor([10.,  9.,  8.,  7.,  6.])
epoch:2, step:1, batch_x:tensor([ 6.,  7.,  8.,  9., 10.]), batch_y:tensor([5., 4., 3., 2., 1.])
'''
#每次批量4个,没有被整除那就是10个数字除以4还剩余2个再迭代一次
#也可以指定drop_last=True,将剩余的删除掉,那剩余的就不会再迭代了
#loader = udata.DataLoader(dataset=dset, batch_size=4,shuffle=False,num_workers=0)
'''
epoch:0, step:0, batch_x:tensor([1., 2., 3., 4.]), batch_y:tensor([10.,  9.,  8.,  7.])
epoch:0, step:1, batch_x:tensor([5., 6., 7., 8.]), batch_y:tensor([6., 5., 4., 3.])
epoch:0, step:2, batch_x:tensor([ 9., 10.]), batch_y:tensor([2., 1.])
epoch:1, step:0, batch_x:tensor([1., 2., 3., 4.]), batch_y:tensor([10.,  9.,  8.,  7.])
epoch:1, step:1, batch_x:tensor([5., 6., 7., 8.]), batch_y:tensor([6., 5., 4., 3.])
epoch:1, step:2, batch_x:tensor([ 9., 10.]), batch_y:tensor([2., 1.])
epoch:2, step:0, batch_x:tensor([1., 2., 3., 4.]), batch_y:tensor([10.,  9.,  8.,  7.])
epoch:2, step:1, batch_x:tensor([5., 6., 7., 8.]), batch_y:tensor([6., 5., 4., 3.])
epoch:2, step:2, batch_x:tensor([ 9., 10.]), batch_y:tensor([2., 1.])
'''
indices=range(len(dset))
sub_rnd_sample=indices[:10]
#随机子采样类似于打乱,所以如果是子采样的话,就不要设定shuffle=True
subsampler = udata.sampler.SubsetRandomSampler(sub_rnd_sample)
loader = udata.DataLoader(dataset=dset, batch_size=4,sampler=subsampler)
'''
epoch:0, step:0, batch_x:tensor([ 8., 10.,  3.,  1.]), batch_y:tensor([ 3.,  1.,  8., 10.])
epoch:0, step:1, batch_x:tensor([6., 2., 9., 7.]), batch_y:tensor([5., 9., 2., 4.])
epoch:0, step:2, batch_x:tensor([4., 5.]), batch_y:tensor([7., 6.])
epoch:1, step:0, batch_x:tensor([7., 6., 8., 3.]), batch_y:tensor([4., 5., 3., 8.])
epoch:1, step:1, batch_x:tensor([5., 9., 2., 1.]), batch_y:tensor([ 6.,  2.,  9., 10.])
epoch:1, step:2, batch_x:tensor([10.,  4.]), batch_y:tensor([1., 7.])
epoch:2, step:0, batch_x:tensor([10.,  1.,  2.,  5.]), batch_y:tensor([ 1., 10.,  9.,  6.])
epoch:2, step:1, batch_x:tensor([8., 9., 6., 7.]), batch_y:tensor([3., 2., 5., 4.])
epoch:2, step:2, batch_x:tensor([3., 4.]), batch_y:tensor([8., 7.])
'''

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print("epoch:{}, step:{}, batch_x:{}, batch_y:{}".format(epoch,step, batch_x, batch_y))

更详细的一些参数可以查看其源码来了解

你可能感兴趣的:(深度学习框架(PyTorch),pytorch,pth文件,DataLoader,torchutils.data)