pytorch 部分知识点

一、Dataloader使用
参数设置:
1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。
2、batch_size,根据具体情况设置即可。
3、shuffle,一般在训练数据中会采用。
4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。
5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。
6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。
7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。
8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。
9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

调用:
使用enumerate访问可遍历的数组对象:

for i,(input,target) in enumerate(trainloader):
    print(input,target)

root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples

二、训练阶段和测试阶段
model.eval(),pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大;在模型测试阶段使用

model.train() 让model变成训练模式,此时 dropout和batch normalization的操作在训练q起到防止网络过拟合的问题

总结: model.train() 和 model.eval() 一般在模型训练和评价的时候会加上这两句,主要是针对由于model 在训练时和评价时 Batch Normalization 和 Dropout 方法模式不同;因此,在使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval

  # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

optimizer.zero_grad()
即将梯度初始化为零(因为一个batch的loss关于weight的导数是所有sample的loss关于weight的导数的累加和)
outputs = net(inputs) 即前向传播求出预测的值
loss = criterion(outputs, labels) 求loss
loss.backward() 即反向传播求梯度
optimizer.step() 即更新所有参数

你可能感兴趣的:(python机器学习实战,笔记)