Dataloader自定义数据集

Dataloader自定义数据集Dataloader自定义数据集_第1张图片

一般情况下所有的数据集都存放在一个文件夹中,如train,不会进行分类存放。这时需要一个txt、xml或者json格式文件记录图像及其对应的结果标签,如下图所示,每一行对应一条数据,包含输入图像的路径名称,以及输出的结果(这里为类别),两者用空格间隔。
Dataloader自定义数据集_第2张图片
Dataloader自定义数据集_第3张图片

ann_file表示txt、xml文件路径,strip()去掉换行符等,readlines()一行一行读取数据(一般一行为一条数据)。
Dataloader自定义数据集_第4张图片
Dataloader自定义数据集_第5张图片
Dataloader自定义数据集_第6张图片
Dataloader自定义数据集_第7张图片
自定义Dataloader,其中构造函数init和getitem是必须要重写,init函数的主要目的是获取数据部分和标签部分并分别用list存储。getitem函数主要目的是随机返回tensor数据类型的某个增强后的数据及其标签,torch.form_numpy()将数据转为tensor。
数据增强在Dataloader进行增强,可通过构造函数传入设定的transform增强。
Dataloader自定义数据集_第8张图片
Dataloader自定义数据集_第9张图片
shuffle=True打乱数据。
Dataloader自定义数据集_第10张图片
iter(train_loader).next()每一次从train_loader中取一个batch的数据。
squeeze()压缩,如image的大小为batch_sizexCxHxW,通过next(),一次取一个batch数据的大小为1xCxHxW,squeeze()将维度压缩为CxHxW。
permute()将图片维度从CxHxW转为HxWxC,numpy()将数据从tensor转为array,其目的都是为了满足plt去显示图像。

参考:
https://ke.gupaoedu.cn

你可能感兴趣的:(pytorch,pytorch)