20210715:pytorch DataLoader 自定义 sampler

需求:实现batch内正负1:1采样比例,验证这种采样会不会影响模型的最终精度

探索:搜索良久,发现没有比较直接的实现,需要自己重写一下DataLoader中的sampler

1:确定一下DataLoader的定义

 2:确认一下DataLoader, Sampler, Dataset三者的关系

        链接:https://zhuanlan.zhihu.com/p/76893455

                Sampler提供indicies

                Dataset根据indicies提供data

                DataLoader将上面两个组合起来,提供最终的batch训练数据

3:注意事项:自定义sampler后,shuffle不能指定(默认即可)

实现:可以参考文末的链接中的demo,也可以参考本文中的实战例子

import os
import cv2
import random
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
from torchvision import transforms, utils



def img_noise(img_data):
    ''' 
        添加高斯噪声,均值为0,方差为0.001
    '''
    image = np.array(img_data)
    image = np.array(image/255, dtype=float)
    noise = np.random.normal(0, 0.00001 ** 0.5, image.shape)
    out = image + noise
    out = np.clip(out, 0, 1.0)
    out = np.uint8(out*255)
    return out


def img_add_stripe(img_data):
    image = np.array(img_data)
    h,w,c = image.shape
    a = np.random.random()

    if a<= 0.5:
        for i in range(0,h-1,11):
            stripe_data = np.uint8(np.ones([1, w, 3])*150)
            image[i,:,:] = stripe_data
            if (i+2)<=(h-1):
                image[i+2,:,:] = stripe_data
    else:
        for i in range(0,w-1,11):
            stripe_data = np.uint8(np.ones([h, 3])*150)
            image[:,i,:] = stripe_data
            if (i+2)<=(w-1):
                image[:,i+2,:] = stripe_data
    return image


def img_gamma(img, para):
    img1 = np.power(img/255, para) * 255
    img1 = img1.astype(np.uint8)
    return img1


def data_augmentation(img, label):
    b = torch.rand((1,1)).item()
    if b>=0.7:
        img = img_noise(img)
    if b >= 0.2 and b <= 0.5:
        img = img_gamma(img, 1)
    # if label == 0:
    #     c = np.random.random()
    #     if c>=0.9:
    #         img = img_add_stripe(img)
    if label == 0:
        p = torch.randint(1,100, (1,1))
        if p.item()<8:
            img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
            img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
    return img


def default_loader(path):
    im = cv2.imread(path)
    if im is None:
        print("None:", path)
    if im.shape[0] != 112 or im.shape[1] !=112:
        im = cv2.resize(im,(112, 112), interpolation=cv2.INTER_NEAREST)
    im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
    # im = Image.fromarray(im.astype(np.uint8))

    if im is None:
        return None
    else:
        return im
   
   
# define Dataset. Assume each line in your .txt file is [name/tab/label], for example:0001.jpg 1
class MyDatasets(Dataset):
    def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader, data_augmentation=data_augmentation):
        lines = []
        self.img_name_pos = []
        self.img_name_neg = []
        self.img_label_pos = []
        self.img_label_neg = []
        with open(txt_path) as input_file:
            lines = input_file.readlines()
        for line in lines:
            if int(line.strip().split(' ')[-1])==1:
                self.img_name_pos.append(os.path.join(img_path, line.strip().split(' ')[0]))
                self.img_label_pos.append(1)
            else:
                self.img_name_neg.append(os.path.join(img_path, line.strip().split(' ')[0]))
                self.img_label_neg.append(0)
        
        value = len(self.img_label_pos) -len(self.img_label_neg)  
        if value>=0:
           self.img_name_neg += random.sample(self.img_name_neg, value)
           self.img_label_neg += [0]*value
        else:
            self.img_name_pos += random.sample(self.img_name_pos, -value)
            self.img_label_pos += [1]*(-value)
        
        self.img_name = self.img_name_pos + self.img_name_neg
        self.img_label = self.img_label_pos + self.img_label_neg

        self.data_augmentation = data_augmentation
        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader
 

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, item):
        img_name = self.img_name[item]
        label = self.img_label[item]
        img = self.loader(img_name)       

        if self.data_augmentation is not None:
            img = data_augmentation(img, label)     
            
        if self.data_transforms is not None:
            try:
                img = Image.fromarray(img.astype(np.uint8))
                img = self.data_transforms(img)
            except:
                print("Cannot transform image: {}".format(img_name))
        return img, label

 
class MySampler(Sampler):
    def __init__(self, dataset):
        halfway_point = int(len(dataset)/2)
        self.pos_indices = list(range(halfway_point))
        self.neg_indices = list(range(halfway_point, len(dataset)))
        
    def __iter__(self):
        random.shuffle(self.pos_indices)
        random.shuffle(self.neg_indices)
        shuffle_list = []

        new_list = []
        for x,y in zip(self.pos_indices, self.neg_indices):
            new_list.append(x)
            new_list.append(y)
            
        print(self.pos_indices)
        print(self.neg_indices)
        print(new_list)
        return iter(new_list)
    
    def __len__(self):
        return len(self.first_half_indices) + len(self.second_half_indices)
         

# load datasets
def load_mydata(img_path_default, txt_path_default):
    trans_list = [transforms.ColorJitter(brightness=0.5), transforms.ColorJitter(contrast=0.5),
                  transforms.ColorJitter(saturation=0.5), 
                #   transforms.ColorJitter(saturation=0.5), transforms.ColorJitter(hue=0.5),
                  transforms.RandomRotation(5, resample=Image.BILINEAR, expand=False, center=(56, 56))]
    transform = transforms.RandomChoice(trans_list)
    transform = transforms.RandomApply([transform], p=0.2)
    
    data_transforms = { 'train':
                            transforms.Compose([
                                            transform,
                                            transforms.RandomCrop(96),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                            ]),
                        'test':
                            transforms.Compose([
                                            transforms.CenterCrop(96),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                            ])
                        }

    image_datasets = {x: MyDatasets(img_path=img_path_default,
                                    txt_path=(txt_path_default + '/' + x + '.txt'),
                                    data_transforms=data_transforms[x],
                                    dataset=x) for x in ['train', 'test']}


    
    our_sampler = {x:MySampler(image_datasets[x]) for x in ['train', 'test']} 
    dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 sampler=our_sampler[x],
                                                 batch_size=2,
                                                 num_workers = 2,
                                                 ) for x in ['train', 'test']}

    return dataloders['train'], dataloders['test']


list_train = r'/home/mntsde/lilai/imgs_ir'
list_test = r'/home/mntsde/lilai/imgs_ir'
train_loader, test_loader = load_mydata(list_train, list_test)
for epoch in range( 0, 10):
    print("********************************")
    print("epoch: ", epoch)
    print("********************************")
    for i, data in enumerate(train_loader):
        print(i, data[1])
    print("--------------------------------")    
    for i, data in enumerate(test_loader):
        print(i, data[1])

参考:

(1)https://www.scottcondron.com/jupyter/visualisation/audio/2020/12/02/dataloaders-samplers-collate.html#SequentialSampler 

(2)https://blog.csdn.net/u010087338/article/details/117927204

(3)https://github.com/ufoym/imbalanced-dataset-sampler

你可能感兴趣的:(深度学习)