pytorch tips

一、数据导入部分
torch.utils.data.Dataset,这是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现。
torchvision.datasets.ImageFolder接口实现数据导入。torchvision.datasets.ImageFolder会返回一个列表(比如image_datasets[‘train’]或者image_datasets[‘val]),列表中的每个值都是一个tuple,每个tuple包含图像和标签信息。列表list是不能作为模型输入的,因此在PyTorch中需要用另一个类来封装list,那就是:
torch.utils.data.DataLoader,它可以将list类型的输入数据封装成Tensor数据格式,以备模型使用。这里是对图像和标签分别封装成一个Tensor。
data_transforms是一个字典。主要是进行一些图像预处理,比如resize、crop等。实现的时候采用的是torchvision.transforms模块,比如torchvision.transforms.Compose是用来管理所有transforms操作的,torchvision.transforms.RandomSizedCrop是做crop的。需要注意的是对于torchvision.transforms.RandomSizedCrop和**transforms.RandomHorizontalFlip()**等,输入对象都是PIL Image,也就是用python的PIL库读进来的图像内容,而transforms.Normalize([0.5, 0.5, 0.4], [0.2, 0.2, 0.5])的作用对象需要是一个Tensor,因此在transforms.Normalize([0.5, 0.5, 0.4], [0.2, 0.2, 0.5])之前有一个
**transforms.ToTensor()**就是用来生成Tensor的。另外transforms.Scale(256)其实就是resize操作,目前已经被transforms.Resize类取代了。
将Tensor数据类型封装成Variable数据类型后就可以作为模型的输入了,用torch.autograd.Variable将Tensor封装成模型真正可以用的Variable数据类型。 Variable可以看成是tensor的一种包装,其不仅包含了tensor的内容,还包含了梯度等信息。
二、模块导入
torchvision.models用来导入模块
torch.nn模块来定义网络的所有层,比如卷积、降采样、损失层等等
torch.optim模块定义优化函数
三、训练
在每个epoch开始时都要更新学习率:scheduler.step()
设置模型状态为训练状态:model.train(True)
先将网络中的所有梯度置0:optimizer.zero_grad()
网络的前向传播:outputs = model(inputs)
然后将输出的outputs和原来导入的labels作为loss函数的输入就可以得到损失了:loss = criterion(outputs, labels)
输出的outputs也是torch.autograd.Variable格式,得到输出后(网络的全连接层的输出)还希望能到到模型预测该样本属于哪个类别的信息,这里采用torch.max。torch.max()的第一个输入是tensor格式,所以用outputs.data而不是outputs作为输入;第二个参数1是代表dim的意思,也就是取每一行的最大值,其实就是我们常见的取概率最大的那个index;第三个参数loss也是torch.autograd.Variable格式。
_, preds = torch.max(outputs.data, 1)
计算得到loss后就要回传损失。要注意的是这是在训练的时候才会有的操作,测试时候只有forward过程。
loss.backward()
回传损失过程中会计算梯度,然后需要根据这些梯度更新参数,optimizer.step()就是用来更新参数的。**optimizer.step()**后,你就可以从optimizer.param_groups[0][‘params’]里面看到各个层的梯度和权值信息。
optimizer.step()
这样一个batch数据的训练就结束了!不断重复这样的训练过程。

需要注意的地方:
1.num_worker = 2 暂时不要用多线程提取数据,容易陷入死循环程序崩溃(原因未知)
2.running_loss的计算,loss.data[0]在新版本pytorch中不能使用了。loss计算出来是一个tensor格式的结果,需要用xxxx.item()来将tensor转化为int格式
3.correct(正确判断的样本数量)也是一个tensor格式的

你可能感兴趣的:(pytorch,pytorch,deep,learing)