记录Pytorch使用过程中的经验,避免再次踩坑,保持更新ing…
官方API:class torch.nn.CrossEntropyLoss(weight=None, size_average=True)
调用时参数:
input : 包含每个类的得分,2-D tensor,shape为 batch*n
target: 大小为 n 的 1—D tensor,包含类别的索引(0到 n-1)
坑点:①在输入前不需要进行Softmax计算,CrossEntropyLoss会自动对输入进行softmax计算,仅需要输入原始数据。②在计算loss时,Label不需要转换成one-hot形式。
调用时参数形状(shape):
Input: (N,C) C 是类别的数量
Target: (N) N是mini-batch的大小,0 <= targets[i] <= C-1
一开始不知道这有什么用,即使没有设置requires_grad = False,它也只是累积梯度,没有后向传播,参数不会更新,大不了占用内存。在模型验证阶段的时候用with torch.no_grad():
不就好了吗?
但其实不然,因为类似于Batch Normalization的操作,它是不需要计算梯度就可以自动更新参数,所以必须指明:model.eval()
,否则它就会一直自动更新自身的γ和β参数,以及在测试的时候还继续计算输入样本的均值和方差。这样会导致模型对不同Batch中的同一样本Batch Norm后的结果不一样。
关于Batch Normalization的理论知识可以看我另一篇博文:
Batch Normalization 批归一化是什么? 有什么用?
还有Dropout层,如果你不指定为测试模型model.eval()
,那么在预测的时候也会启用Dropout,具体可以看我另外一篇博文:
PyTorch的Dropout
一般在卷积层转换到全连接层的时候,需要将特征维度合并,变成[batch, features]的2维,我以前都是直接根据最后一层卷积的输出大小用torch.view(-1, 77512)来改变Tensor(假设最后一层卷积输出为7x7x512channels的特征维度)。但是今天发现一个新操作,比较简单,就是torch.flatten(x, 1)
.
官方API的解释如下:
>>> t = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
这样对于我们维度转换就轻松很多了,假设最后一层卷积输出为X,我们只需要进行torch.flatten(x,1)操作之后就可以输入到全连接层了。
在transforms.Compose()组合操作里面,所有的PIL Image操作都要在Tensor操作前面,否则会报错。因为执行transforms.ToTensor()之后,数据会从[H,W,C]变为[C,H,W],并且取值范围为[0.0,1.0]之间。
在官方教程里经常会看到transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
,这是什么意思呢?我们自己在处理数据的时候应该怎么办呢?
官方:input[channel] = (input[channel] - mean[channel]) / std[channel]
首先我们要做标准化处理主要就是因为数据的标准化能够加速网络的训练,一定程度上避免梯度消失和梯度爆炸,除此之外,我们还需要保持训练集、验证集、测试集标准化数据的一致性,所以一些数据集官方往往会给出标准化系数,例如mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]
就是ImageNet数据集给出的标准化系数分别表示RGB三个通道上计算得到的均值和方差值。
如果我们用到了ImageNet数据集上的预训练模型,那么被迁移的网络最好也使用ImageNet的标准化系数。 但如果我们只是在自己的数据上从头训练的话,可以设置为mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
,这样数据会落在[-1, 1]上,但最好也算一下自己数据集的标准化系数,来用作标准化处理。
在用到多GPU训练模型的时候,我们经常会用到model = nn.DataParallel(model)
这个函数来加速并行。但在模型加载的时候一不小心就会报错RuntimeError: Error(s) in loading state_dict for xxx: Missing key(s) in state_dict: ...
。
这主要因为在模型加载的时候,模型没有处于并行化状态。在model.load_state_load(torch.load('XXXX.pth'))
之前加上model=nn.DataParallel(model)
即可。