pytorch之将图片及str标签转为tensor

数据集的文件夹格式:
(数据集的文件路径:file_path = r’ .\Dataset ')

Dataset
|------train
|------------000000
|------------000001
… …
|------------000305

|------test
|------------000000
|------------000001
… …
|------------000305
|------valid
|------------000000
|------------000001
… …
|------------000305

path = ["test", "train",  "valid"]
# class306记录了索引对应的类别标签(str类型)
class306 = ClassToList(data+"\\test")
train_set = []
test_set = []
valid_set = []
file_path =  r' .\Dataset '
the_path = file_path+path[i]
for class_path in os.listdir(the_path):
    print("reading "+the_path+"\\"+class_path)
    for img_path in os.listdir(the_path+"\\"+class_path):
        img = Image.open(the_path+"\\"+class_path+"\\"+img_path).convert('RGB')
        x.append(img_transforms(img))  # img_transforms(img)将PIL.Image读入的图片转为相应的格式
        y.append(j)
    j += 1
    print("Have read "+the_path+"\\"+class_path)
    print("size: ", len(x))
if i == 1:
    train_set = MyDataSet(x, y, train_transforms) 
elif i == 0:
    test_set = MyDataSet(x, y, test_transforms)
elif i == 2:
    valid_set = MyDataSet(x, y, test_transforms)
  1. transforms.Resize()只能对PIL.Image.open()打开的图片修改大小。
    用io.imread读取图片时
    (img = io.imread(the_path+"\"+class_path+"\"+img_path)),
    不能使用transforms.Resize()修改读入的图片大小。而图片转为tensor后,tensor的大小需要一致,所以这里使用PIL.Image来读取图片。
  2. 图片标签为字符串类型,需转为torch.Tensor类型。然而,元素是字符串的list, tuple等不能直接转为torch.Tensor类型。解决方法:
    (1)使用sklearn中的preprocessing
from sklearn import preprocessing
import torch
# 可以用ClassToList(data+"\\test")从文件中读出所有类别
labels = ['000000', '000001', '000002', ..., '000305']
le = preprocessing.LabelEncoder()
targets = le.fit_transform(labels)
# targets: array([0, 1, 2, ..., 305])
targets = torch.as_tensor(targets)
# targets: tensor([0, 1, 2, ..., 305])

(2)将字符串标签存在list中,用标签对应的索引替换字符串标签。

# 可以用ClassToList(data+"\\test")从文件中读出所有类别
class306 = ['000000', '000001', '000002', ..., '000305']
j = 0
for class_path in os.listdir(r'.\Dataset\test'):
	for each_img in os.listdir(r'.\Dataset\test\\'+class_path):
		targets.append(j) # 放入图片标签的索引
	j += 1
targets = torch.tensor(targets)  # 转为torch.Tensor类型
  1. transforms.Resize()需放在transforms.ToTensor()前面,否则无效。图片需要在转为Tensor类型前修改大小。
# 用于处理读入的图片格式
img_transforms = transforms.Compose([
    transforms.Resize((size, size)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=stdv),
])

ClassToList():

def ClassToList(file_path):
    class306 = []
    for class_path in os.listdir(file_path):
        class306.append(class_path)
    return class306

MyDataSet():

# 继承自 Dataset
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self,x,y,transform):
        self.x = x
        self.y = y
        if not isinstance(y, torch.Tensor):
            print("将y从", type(y),end='')
            self.y = torch.tensor(y)
            print("转化为", type(self.y))
        else:
            self.y = y
        self.idx = list()
        self.transform = transform
        for item in x:
            self.idx.append(item)

    def __getitem__(self, index):
        input_data = self.idx[index]
        target = self.y[index]
        return input_data, target

    def __len__(self):
        return len(self.idx)

你可能感兴趣的:(python新手)