因为是二进制文件,所以需要自己转换成图片、txt标签
#调用一些和操作系统相关的函数
import os
#输入输出相关
from skimage import io
#dataset相关
import torchvision.datasets.mnist as mnist
#路径
root="/home/s/PycharmProjects/untitled/fashion-mnist/data/fashion"
#读取二进制文件,这里不知道是不是必须使用mnist读
train_set = (
mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),#路径拼接,split()是分割路径与文件名,和这个正好相反
mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
)
test_set = (
mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
)
#打印test_set类型
print(type(test_set))
>>>out:
#打印test_set中元素个数
print(len(test_set))
>>>out:2
#打印第元素类型,都是tensor
print(type(test_set[0]))
print(type(test_set[1]))
>>>out:
>>>out:
#打印元素形状,可以第一个元素是所有照片的tensor,第二个元素是所有标签的tensor.这里用test_set[0].shape是一样的
print("test set[0] :",test_set[0].size())
print("test set[1] :",test_set[1].size())
>>>out:('test set[0] :', (10000, 28, 28))
>>>out:('test set[1] :', (10000,))
#取出一个图片看一下,这两种都可以,就是看一下这个tensor的形状
a = test_set[0]
print(a[0].shape)
print(test_set[0][0].shape)
>>>out:(28, 28)
>>>out:(28, 28)
#定义一个tensor转图片的函数
def convert_to_img(train=True):
if(train):
#创建一个train.txt文件,用来保存标签
f=open(root+'train.txt','w')#python中并没有这种路径表示方式,这个不对
data_path=root+'/train/'
#如果不存在这个路径,就创建文件夹
if(not os.path.exists(data_path)):
os.makedirs(data_path)
#zip打包成元组,train_set本来不就是元组么?
for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
img_path=data_path+str(i)+'.jpg'
#tensor与numpy格式转换tensor_img = torch.from_numpy(numpy_img)
io.imsave(img_path,img.numpy())
a=str(label)
a = a.rstrip(')')
a = a.strip('tensor(')#这里如果不进行字符串的处理,会输出“tensor(9)”而不是“9”
f.write(img_path+' '+ a +'\n')
f.close()
else:
f = open(root + 'test.txt', 'w')
data_path = root + '/test/'
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
img_path = data_path+ str(i) + '.jpg'
io.imsave(img_path, img.numpy())
a=str(label)
a = a.rstrip(')')
a = a.strip('tensor(')
f.write(img_path + ' ' + a + '\n')
f.close()
convert_to_img(True)
convert_to_img(False)
主要是以torch.utils.data.Dataset为基类进行编写:
__init__
__getitem__
__len__
这几个函数也都是大同小异,可以添加一些自己需要的返回值
import torch
from torch.utils.data import Dataset
from PIL import Image
#以torch.utils.data.Dataset为基类创建MyDataset
class MyDataset(Dataset):
#stpe1:初始化
def __init__(self, txt, transform=None, target_transform=None,):
fh = open(txt, 'r')#打开标签文件
imgs = []#创建列表,装东西
for line in fh:#遍历标签文件每行
line = line.rstrip()#删除字符串末尾的空格
words = line.split()#通过空格分割字符串,变成列表
imgs.append((words[0],int(words[1])))#把图片名words[0],标签int(words[1])放到imgs里
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):#检索函数
fn, label = self.imgs[index]#读取文件名、标签
img = Image.open(fn).convert('RGB')#通过PIL.Image读取图片
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
官方文档里关于torch.utils.data.Dataset这个基类的说明
用到的一些函数
1)Python rstrip() 删除 string 字符串末尾的指定字符(默认为空格)从后边删
str.rstrip([chars])
str = " this is string example....wow!!! ";
print str.rstrip();
str = "88888888this is string example....wow!!!8888888";
print str.rstrip('8');
out:
>>> this is string example....wow!!!
>>>88888888this is string example....wow!!!
2)Python strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。两头删
str.strip([chars])
str = "00000003210Runoob01230000000";
print str.strip( '0' );
str2 = " Runoob ";
print str2.strip();
out:
3210Runoob0123
Runoob
3)split() 通过指定分隔符对字符串进行切片,如果参数 num 有指定值,则分隔 num+1 个子字符串。变成了列表
str = "Line1-abcdef \nLine2-abc \nLine4-abcd";
print str.split( ); # 以空格为分隔符,包含 \n
print str.split(' ', 1 ); # 以空格为分隔符,分隔成两个
out:
>>>['Line1-abcdef', 'Line2-abc', 'Line4-abcd']
>>>['Line1-abcdef', '\nLine2-abc \nLine4-abcd']
1.对图像进行预处理
from torchvision import transforms as transforms
trans_form = transforms.Compose([
transforms.Resize(96), # 缩放到 96 * 96 大小
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])
通过 torchvision.transforms模块对数据进行预处理
torchvision.transforms.
Compose
(transforms)可以把许多transforms合在一起
transforms.ToTensor()是必须做的,这里也可以不用官方给的,自己写data_tf函数
具体有那些transforms:https://pytorch.org/docs/stable/torchvision/transforms.html?highlight=transforms
2.dataset、dataloader加载
train_data=MyDataset(txt='train.txt', transform=trans_form)
test_data=MyDataset(txt='test.txt', transform=trans_form)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
dataloader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它
迭代器返回的是一个list[imgs,labels]
imgs是64张28X28X3的图片组成的一个tensor
labels是64个标签的tensor
3.划分训练接、验证集、测试集
https://cloud.tencent.com/developer/article/1435013
train_data=MyDataset(txt='train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt='test.txt', transform=transforms.ToTensor())
print('train:', len(train_data), 'test:', len(test_data))
train_data, val_data = torch.utils.data.random_split(train_data, [55000, 5000])
print('train:', len(train_data), 'validation:', len(val_data))
>>>('train:', 60000, 'test:', 10000)
>>>('train:', 55000, 'validation:', 5000)
pytorch 0.4.1版本以上才支持random_split函数
参考链接:
https://www.runoob.com/python/att-string-split.html
https://www.cnblogs.com/denny402/p/7520063.html
https://blog.csdn.net/Teeyohuang/article/details/79587125
https://pytorch.org/docs/stable/data.html
https://blog.csdn.net/TH_NUM/article/details/80877687