使用pytorch自定义dataset

使用pytorch自定义dataset


from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
file_train = [os.path.join("./dog_breed/train_path",i) for i in file ]

def train_transform(self, rgb): #训练集预处理
 
    do_flip = np.random.uniform(0.0, 1.0) > 0.5  # random horizontal flip
    transform = transforms.Compose([
        transforms.HorizontalFlip(do_flip),  # 0.5概率水平翻转
        transforms.CenterCrop((228, 304))  #中心裁剪,裁成(228,304)的大小
        transforms.ColorJitter(0.4, 0.4, 0.4)  # (亮度,对比度,饱和度)在1-0.4,1+0.4间抖动
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
)
    ])
    rgb_np = transform(rgb)
    rgb_np = np.asfarray(rgb_np, dtype='float') / 255  #对RGB归一化
        
   	return rgb_np

def default_loader(path):
    img_pil =  Image.open(path)
    img_pil = img_pil.resize((224,224))
    img_tensor = train_transform(img_pil)
    return img_tensor

#当然出来的时候已经全都变成了tensor
class trainset(Dataset):
	# 这个里面一般要初始化一个loader(可以对图片进行处理 比如换轴、打开),一个images_path的列表,一个target的列表
    def __init__(self, loader=default_loader):
        #定义好 image 的路径
        self.images = file_train
        self.target = number_train
        self.loader = loader
	# 这里吗就是在给你一个index的时候,你返回一个图片的tensor和target的tensor,使用了loader方法,经过 归一化,剪裁,类型转化,从图像变成tensor
    def __getitem__(self, index):
        fn = self.images[index]
        img = self.loader(fn)
        target = self.target[index]
        return img,target
	#return你所有数据的个数
    def __len__(self):
        return len(self.images)

你可能感兴趣的:(python,深度学习)