PyTorch深度学习入门笔记(二)PyTorch加载数据初认识

一些简单函数的使用

获取数据集的数据

# Kyrie Irving
# !/9462...
from torch.utils.data import Dataset
import cv2
from PIL import Image
import os

# img = cv2.imread('E:\\CodeCodeCodeCode\\Python-data\\hymenoptera_data\\train\\ants\\5650366_e22b7e1065.jpg', 0)
# cv2.imshow('a', img)
# cv2.waitKey(0)

# img = Image.open('E:\\CodeCodeCodeCode\\AI\\Pytorch-study\\hymenoptera_data\\train\\ants\\5650366_e22b7e1065.jpg')
# img.show()

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        # get relative address of ants pictures
        self.img_path = os.listdir(self.path)

    def __getitem__(self, item):
        img_name = self.img_path[item]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

root_dir = "../hymenoptera_data/train"
ants_label_dir ="ants"
bees_label_dir ="bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset
img, label = train_dataset[0] # __getitem__
print(label)
img.show()

你可能感兴趣的:(pytorch,机器学习,人工智能,pytorch,深度学习,python)