数据集的文件夹格式:
(数据集的文件路径:file_path = r’ .\Dataset ')
Dataset
|------train
|------------000000
|------------000001
… …
|------------000305|------test
|------------000000
|------------000001
… …
|------------000305
|------valid
|------------000000
|------------000001
… …
|------------000305
path = ["test", "train", "valid"]
# class306记录了索引对应的类别标签(str类型)
class306 = ClassToList(data+"\\test")
train_set = []
test_set = []
valid_set = []
file_path = r' .\Dataset '
the_path = file_path+path[i]
for class_path in os.listdir(the_path):
print("reading "+the_path+"\\"+class_path)
for img_path in os.listdir(the_path+"\\"+class_path):
img = Image.open(the_path+"\\"+class_path+"\\"+img_path).convert('RGB')
x.append(img_transforms(img)) # img_transforms(img)将PIL.Image读入的图片转为相应的格式
y.append(j)
j += 1
print("Have read "+the_path+"\\"+class_path)
print("size: ", len(x))
if i == 1:
train_set = MyDataSet(x, y, train_transforms)
elif i == 0:
test_set = MyDataSet(x, y, test_transforms)
elif i == 2:
valid_set = MyDataSet(x, y, test_transforms)
from sklearn import preprocessing
import torch
# 可以用ClassToList(data+"\\test")从文件中读出所有类别
labels = ['000000', '000001', '000002', ..., '000305']
le = preprocessing.LabelEncoder()
targets = le.fit_transform(labels)
# targets: array([0, 1, 2, ..., 305])
targets = torch.as_tensor(targets)
# targets: tensor([0, 1, 2, ..., 305])
(2)将字符串标签存在list中,用标签对应的索引替换字符串标签。
# 可以用ClassToList(data+"\\test")从文件中读出所有类别
class306 = ['000000', '000001', '000002', ..., '000305']
j = 0
for class_path in os.listdir(r'.\Dataset\test'):
for each_img in os.listdir(r'.\Dataset\test\\'+class_path):
targets.append(j) # 放入图片标签的索引
j += 1
targets = torch.tensor(targets) # 转为torch.Tensor类型
# 用于处理读入的图片格式
img_transforms = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=stdv),
])
ClassToList():
def ClassToList(file_path):
class306 = []
for class_path in os.listdir(file_path):
class306.append(class_path)
return class306
MyDataSet():
# 继承自 Dataset
class MyDataSet(torch.utils.data.Dataset):
def __init__(self,x,y,transform):
self.x = x
self.y = y
if not isinstance(y, torch.Tensor):
print("将y从", type(y),end='')
self.y = torch.tensor(y)
print("转化为", type(self.y))
else:
self.y = y
self.idx = list()
self.transform = transform
for item in x:
self.idx.append(item)
def __getitem__(self, index):
input_data = self.idx[index]
target = self.y[index]
return input_data, target
def __len__(self):
return len(self.idx)