Pytorch入门———自定义数据集

import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from skimage import io, transform
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from torch.utils.data import Dataset, DataLoader
import cv2
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
transform = transforms.Compose(
    [
    transforms.ToPILImage(),
     # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     # 图像在[-1,1]范围内归一化,image =(图像-平均值)/ std
     # 通用的统计值
     transforms.Resize((255,255)),
     transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
     #(0.5,0.5,0.5) 是 R G B 三个通道上的均值, 后面(0.5, 0.5, 0.5)是三个通道的标准差

     ])

def picCount(root_dir):
    count = 0
    for file in os.listdir(root_dir):  # file 表示的是文件名
        count = count + 1
    return count

# 数据集类
class TestDataset(Dataset):
    def __init__(self, root_dir, transform=transform):
        """
 Args:
 csv_file (string): Path to the csv file with annotations.
 root_dir (string): Directory with all the images.
 transform (callable, optional): Optional transform to be applied
 on a sample.
 """
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return picCount(self.root_dir)

    def __getitem__(self, idx):
        path_list = os.listdir(self.root_dir)
        #path_list.remove('.DS_Store')  # macos中的文件管理文件,默认隐藏,这里可以忽略
        #print(path_list)
        img_name = os.path.join(self.root_dir,
                                path_list[idx])
        image = io.imread(img_name)
        # plt.imshow(image)
        # plt.show()
        if self.transform:
            image = self.transform(image)
        sample = {'image': image, 'name': path_list[idx]}
        return sample
train_data = TestDataset(root_dir='./data/someTry/in',transform=transform)

train_loader = DataLoader(dataset=train_data, batch_size=5, shuffle=False ,num_workers=0)

# 输出图像的函数
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


#imshow(train_data.__getitem__(2)['image'])

dataiter = iter(train_loader)
print(dataiter.next())
image = dataiter.next()['image']
name = dataiter.next()['name']
# 显示图片
imshow(torchvision.utils.make_grid(image))
# 打印图片标签
print(' '.join('%5s' % name[j] for j in range(4)))

io读

Pytorch入门———自定义数据集_第1张图片

cv2读

 Pytorch入门———自定义数据集_第2张图片

 因为通道会被改变。

在使用io.imread时,进行transforms时,有时需要转换成PIL格式,即需要

ToPILImage()

你可能感兴趣的:(pytorch学习,pytorch,python,深度学习)