pytorch学习之Dataset类

一、

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):  # 提供一个全局变量
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self,idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label
    def __len__(self):
        return len(self.img_path)

root_dir = r"Dataset_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)

pytorch学习之Dataset类_第1张图片

 

例子1:
from PIL import Image
path = "d:a"                   #获取图片的地址
img = Image.open(path)
img.size                         #尺寸
Img.show()                #显示数据,此处Image.show(Img)错误

例子2:import os
想获取图片的地址
1.获取所有图片的地址的列表list
2.通过相应的索引获取图片的地址
dir_path = "dataset/train/ants"
import os
img_path_list = os.listdir(dir_path)

pytorch学习之Dataset类_第2张图片

想获取所有图片的地址
import os

你可能感兴趣的:(笔记,python,开发语言,后端)