需求:实现batch内正负1:1采样比例,验证这种采样会不会影响模型的最终精度
探索:搜索良久,发现没有比较直接的实现,需要自己重写一下DataLoader中的sampler
1:确定一下DataLoader的定义
2:确认一下DataLoader, Sampler, Dataset三者的关系
链接:https://zhuanlan.zhihu.com/p/76893455
Sampler提供indicies
Dataset根据indicies提供data
DataLoader将上面两个组合起来,提供最终的batch训练数据
3:注意事项:自定义sampler后,shuffle不能指定(默认即可)
实现:可以参考文末的链接中的demo,也可以参考本文中的实战例子
import os
import cv2
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
from torchvision import transforms, utils
def img_noise(img_data):
'''
添加高斯噪声,均值为0,方差为0.001
'''
image = np.array(img_data)
image = np.array(image/255, dtype=float)
noise = np.random.normal(0, 0.00001 ** 0.5, image.shape)
out = image + noise
out = np.clip(out, 0, 1.0)
out = np.uint8(out*255)
return out
def img_add_stripe(img_data):
image = np.array(img_data)
h,w,c = image.shape
a = np.random.random()
if a<= 0.5:
for i in range(0,h-1,11):
stripe_data = np.uint8(np.ones([1, w, 3])*150)
image[i,:,:] = stripe_data
if (i+2)<=(h-1):
image[i+2,:,:] = stripe_data
else:
for i in range(0,w-1,11):
stripe_data = np.uint8(np.ones([h, 3])*150)
image[:,i,:] = stripe_data
if (i+2)<=(w-1):
image[:,i+2,:] = stripe_data
return image
def img_gamma(img, para):
img1 = np.power(img/255, para) * 255
img1 = img1.astype(np.uint8)
return img1
def data_augmentation(img, label):
b = torch.rand((1,1)).item()
if b>=0.7:
img = img_noise(img)
if b >= 0.2 and b <= 0.5:
img = img_gamma(img, 1)
# if label == 0:
# c = np.random.random()
# if c>=0.9:
# img = img_add_stripe(img)
if label == 0:
p = torch.randint(1,100, (1,1))
if p.item()<8:
img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
return img
def default_loader(path):
im = cv2.imread(path)
if im is None:
print("None:", path)
if im.shape[0] != 112 or im.shape[1] !=112:
im = cv2.resize(im,(112, 112), interpolation=cv2.INTER_NEAREST)
im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
# im = Image.fromarray(im.astype(np.uint8))
if im is None:
return None
else:
return im
# define Dataset. Assume each line in your .txt file is [name/tab/label], for example:0001.jpg 1
class MyDatasets(Dataset):
def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader, data_augmentation=data_augmentation):
lines = []
self.img_name_pos = []
self.img_name_neg = []
self.img_label_pos = []
self.img_label_neg = []
with open(txt_path) as input_file:
lines = input_file.readlines()
for line in lines:
if int(line.strip().split(' ')[-1])==1:
self.img_name_pos.append(os.path.join(img_path, line.strip().split(' ')[0]))
self.img_label_pos.append(1)
else:
self.img_name_neg.append(os.path.join(img_path, line.strip().split(' ')[0]))
self.img_label_neg.append(0)
value = len(self.img_label_pos) -len(self.img_label_neg)
if value>=0:
self.img_name_neg += random.sample(self.img_name_neg, value)
self.img_label_neg += [0]*value
else:
self.img_name_pos += random.sample(self.img_name_pos, -value)
self.img_label_pos += [1]*(-value)
self.img_name = self.img_name_pos + self.img_name_neg
self.img_label = self.img_label_pos + self.img_label_neg
self.data_augmentation = data_augmentation
self.data_transforms = data_transforms
self.dataset = dataset
self.loader = loader
def __len__(self):
return len(self.img_name)
def __getitem__(self, item):
img_name = self.img_name[item]
label = self.img_label[item]
img = self.loader(img_name)
if self.data_augmentation is not None:
img = data_augmentation(img, label)
if self.data_transforms is not None:
try:
img = Image.fromarray(img.astype(np.uint8))
img = self.data_transforms(img)
except:
print("Cannot transform image: {}".format(img_name))
return img, label
class MySampler(Sampler):
def __init__(self, dataset):
halfway_point = int(len(dataset)/2)
self.pos_indices = list(range(halfway_point))
self.neg_indices = list(range(halfway_point, len(dataset)))
def __iter__(self):
random.shuffle(self.pos_indices)
random.shuffle(self.neg_indices)
shuffle_list = []
new_list = []
for x,y in zip(self.pos_indices, self.neg_indices):
new_list.append(x)
new_list.append(y)
print(self.pos_indices)
print(self.neg_indices)
print(new_list)
return iter(new_list)
def __len__(self):
return len(self.first_half_indices) + len(self.second_half_indices)
# load datasets
def load_mydata(img_path_default, txt_path_default):
trans_list = [transforms.ColorJitter(brightness=0.5), transforms.ColorJitter(contrast=0.5),
transforms.ColorJitter(saturation=0.5),
# transforms.ColorJitter(saturation=0.5), transforms.ColorJitter(hue=0.5),
transforms.RandomRotation(5, resample=Image.BILINEAR, expand=False, center=(56, 56))]
transform = transforms.RandomChoice(trans_list)
transform = transforms.RandomApply([transform], p=0.2)
data_transforms = { 'train':
transforms.Compose([
transform,
transforms.RandomCrop(96),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
'test':
transforms.Compose([
transforms.CenterCrop(96),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
}
image_datasets = {x: MyDatasets(img_path=img_path_default,
txt_path=(txt_path_default + '/' + x + '.txt'),
data_transforms=data_transforms[x],
dataset=x) for x in ['train', 'test']}
our_sampler = {x:MySampler(image_datasets[x]) for x in ['train', 'test']}
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
sampler=our_sampler[x],
batch_size=2,
num_workers = 2,
) for x in ['train', 'test']}
return dataloders['train'], dataloders['test']
list_train = r'/home/mntsde/lilai/imgs_ir'
list_test = r'/home/mntsde/lilai/imgs_ir'
train_loader, test_loader = load_mydata(list_train, list_test)
for epoch in range( 0, 10):
print("********************************")
print("epoch: ", epoch)
print("********************************")
for i, data in enumerate(train_loader):
print(i, data[1])
print("--------------------------------")
for i, data in enumerate(test_loader):
print(i, data[1])
参考:
(1)https://www.scottcondron.com/jupyter/visualisation/audio/2020/12/02/dataloaders-samplers-collate.html#SequentialSampler
(2)https://blog.csdn.net/u010087338/article/details/117927204
(3)https://github.com/ufoym/imbalanced-dataset-sampler