pytorch应用

三大步骤:数据读取、网络构建、其他辅助
数据读取
常见的数据例如mnist就用torchvision的datasets方法来进行傻瓜式读取就行
对于分类问题,可以采用torchvision.datasets.ImageFolder读取image和label信息:

data_dir = '/data' 
image_datasets = {x: datasets.ImageFolder( 
os.path.join(data_dir, x), data_transforms[x]
), for x in ['train', 'val']}

torchvision.datasets.ImageFolder只是返回list,list是不能作为模型输入的,因此在PyTorch中需要用另一个类来封装list,那就是:torch.utils.data.DataLoader。torch.utils.data.DataLoader类可以将list类型的输入数据封装成Tensor数据格式,以备模型使用。注意,这里是对图像和标签分别封装成一个Tensor。

当你的数据不是按照一个类别一个文件夹这种方式存储时,你就要自定义一个类来读取数据,自定义的这个类必须继承自torch.utils.data.Dataset这个基类,最后同样用torch.utils.data.DataLoader封装成Tensor。

利用torchvision.model中的模型就可以满足条件(还有pretrain的参数),如果最后分类classes数目不相同那么可以提取前一层fc:

# coding=UTF-8
import torchvision.models as models
 
#调用模型
model = models.resnet50(pretrained=True)
#提取fc层中固定的参数
fc_features = model.fc.in_features
#修改类别为9
model.fc = nn.Linear(fc_features, 9)

一般我们都把分类的前一层FC叫features

对于自己读取数据要继承datasets并重写__len__()和__getitem__()两个方法
def getitem(self, index):可见需要index对所有的数据进行遍历读取,最终返回image&label也是所有数据的!

你可能感兴趣的:(pytorch)