pytorch学习之:常用的数据集处理方法(Dataset)和数据采样策略 (Sampler)

文章目录

  • 数据集处理方法
    • 小批量数据 & 为数据添加随机噪声
    • 两个数据集 Dataset 合并
    • 切分子数据集
  • 数据集采样策略
    • 构造一个按照正态分布从数据集中采样

数据集处理方法

小批量数据 & 为数据添加随机噪声

  • 使用小部分的数据:在做实验的时候,有时候我们想用一小部分数据来先跑通代码,然后再上大量的数据
  • 为 Dataset 中的图片数据添加高斯噪声
"""
 @file: codes.py
 @Time    : 2023/1/12
 @Author  : Peinuan qin
 """
import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
import random
from torch.utils.data import Subset, Dataset

DATA_ROOT = "./data"
MEAN = (0.1307,)
STD = (0.3081,)

class MyDataset(Dataset):
    def __init__(self, dataset, ratio=0.2, add_noise=True):
        self.dataset = dataset
        self.add_noise = add_noise
        if ratio:
            random_indexs = random.sample(range(len(dataset)), int(ratio * len(dataset)))
            self.dataset = Subset(dataset, random_indexs)
            print(f"using a small dataset with ratio: {ratio}")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        # noise image as the encoder-decoder input, and the clean image as the groundtruth label
        if self.add_noise:
            return self.make_noise(self.dataset[item][0]), self.dataset[item][0]
        else:
            return self.dataset[item][0], self.dataset[item][0]

    def make_noise(self, x):
        """
        generate gaussian noise to make noised data for encoder
        :param x:
        :return:
        """
        noise = np.random.normal(0, 1, size=x.size())
        noise = torch.from_numpy(noise)
        x += noise
        return x



DATASET_RATIO = 0.2

trainset = MNIST(DATA_ROOT
                     , train=True
                     , transform=transforms.Compose([transforms.ToTensor()
                                                        , transforms.Normalize(MEAN, STD)])
                     , download=True)

valset = MNIST(DATA_ROOT
                 , train=False
                 , transform=transforms.Compose([transforms.ToTensor()
                                                    , transforms.Normalize(MEAN, STD)])
                 , download=False)

# only use 0.2 of the raw data for training and validation

train_set = MyDataset(trainset, ratio=DATASET_RATIO)
val_set = MyDataset(valset, ratio=DATASET_RATIO)

两个数据集 Dataset 合并

"""
 @file: codes.py
 @Time    : 2023/1/12
 @Author  : Peinuan qin
 """
from torchvision import transforms

DATA_ROOT = "./data"
MEAN = (0.1307,)
STD = (0.3081,)

from torchvision.datasets import MNIST
from torch.utils.data import ConcatDataset, Subset, random_split

trainset = MNIST(DATA_ROOT
                     , train=True
                     , transform=transforms.Compose([transforms.ToTensor()
                                                        , transforms.Normalize(MEAN, STD)])
                     , download=True)

valset = MNIST(DATA_ROOT
                 , train=False
                 , transform=transforms.Compose([transforms.ToTensor()
                                                    , transforms.Normalize(MEAN, STD)])
                 , download=False)


complete_set = ConcatDataset([trainset, valset])

切分子数据集

  • 通常使用 Subset 类来完成, 具体实现,可以参考第一段代码中有关 Subset 的部分
dataset = Subset(dataset, random_indexs)

数据集采样策略

构造一个按照正态分布从数据集中采样

"""
 @file: codes.py
 @Time    : 2023/1/12
 @Author  : Peinuan qin
 """
import random
from collections import Counter
from copy import deepcopy

import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
from torchvision.datasets import MNIST
from torch.utils.data import ConcatDataset



DATA_ROOT = "./data"
MEAN = (0.1307,)
STD = (0.3081,)
CLS_NUM = 10
BATCHSIZE=64
SPLIT_NUM = 2


class MyDataset(Dataset):
    def __init__(self, dataset, transform=None):
        super(MyDataset, self).__init__()
        self.dataset = dataset
        self.x, self.y = self.get_x_y()
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def get_x_y(self):
        x = []
        y = []
        for i in range(len(self.dataset)):
            x.append(self.dataset[i][0])
            y.append(self.dataset[i][1])
        return x, y

    def get_dict(self):
        dict = {}
        for i in tqdm(range(len(self.x))):
            if self.y[i] not in dict:
                dict[self.y[i]] = []
                dict[self.y[i]].append(self.x[i])
            else:
                dict[self.y[i]].append(self.x[i])
        return dict

    def get_y_lst(self):
        return self.y

    def plot_distribution(self):
        plt.hist(self.y)
        plt.show()

    def __getitem__(self, item):
        img = self.dataset[item][0]
        label = self.dataset[item][1]
        if self.transform:
            img = self.transform(img)

        return img, label


class ClassDict:
    def __init__(self, label, x_lst):
        self.label = label
        self.x_lst = x_lst
        self.dict = {i: x_lst[i] for i in range(len(x_lst))}
        self.copy_dict = deepcopy(self.dict)

    def sample(self, num):
        num = min(num, len(self.dict))
        sample_indexs = random.sample(list(self.dict.keys()), num)
        x_lst = [self.dict.pop(idx) for idx in sample_indexs]
        print(f"label: {self.label}, remaining samples: {len(self.dict)}")
        print(f"label: {self.label}, sampling lst length: {len(x_lst)}")
        return x_lst

    def remain(self):
        x_lst = [v for k, v in self.dict.items()]
        return x_lst



class NormalSampler:
    def __init__(self, class_dicts):
        self.class_dicts = class_dicts

    def sample(self, mean, std, num):
        label_float_lst = np.random.normal(mean, std, (num,))
        label_int_lst = list(map(lambda x: int(x), label_float_lst))
        label_count_dict = Counter(label_int_lst)
        print(label_count_dict)
        for k in dict(label_count_dict).keys():
            if k not in range(len(self.class_dicts)):
                label_count_dict.pop(k)

        all_x_lst = []
        all_y_lst = []

        for label, count in label_count_dict.items():
            class_dic = self.class_dicts[label]
            class_x_lst = class_dic.sample(count)
            class_y_lst = [label for _ in range(len(class_x_lst))]
            all_x_lst.extend(class_x_lst)
            all_y_lst.extend(class_y_lst)
        return all_x_lst, all_y_lst

    def remain(self):
        all_x_lst = []
        all_y_lst = []

        for i in range(len(self.class_dicts)):
            class_dic = self.class_dicts[i]
            label = class_dic.label
            class_x_lst = class_dic.remain()
            class_y_lst = [label for _ in range(len(class_x_lst))]
            all_x_lst.extend(class_x_lst)
            all_y_lst.extend(class_y_lst)

        return all_x_lst, all_y_lst



class SubDataset(Dataset):
    def __init__(self, x, y):
        super(SubDataset, self).__init__()
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def plot_distribution(self):
        plt.hist(self.y)
        plt.show()

    def __getitem__(self, item):
        return self.x[item], self.y[item]


trainset = MNIST(DATA_ROOT
                     , train=True
                     , transform=transforms.Compose([transforms.ToTensor()
                                                        , transforms.Normalize(MEAN, STD)])
                     , download=True)

valset = MNIST(DATA_ROOT
                 , train=False
                 , transform=transforms.Compose([transforms.ToTensor()
                                                    , transforms.Normalize(MEAN, STD)])
                 , download=False)


complete_set = ConcatDataset([trainset, valset])

complete_set = MyDataset(complete_set, transform=None)
# class_dicts 包含了 N 个 dict,买个 dict 存放了当前类中所有的 x 样本
classes_dict = complete_set.get_dict()
class_dicts = [ClassDict(i, classes_dict[i]) for i in range(CLS_NUM)]
# 构造正态分布取样器
normal_sampler = NormalSampler(class_dicts)
# 每个 split 中分的基础样本数量(最后一个split) 可能会多余 base_sample_size
basic_sample_size = len(complete_set) // SPLIT_NUM
subsets = []
for i in range(SPLIT_NUM):
    # 这时候可以用 sampler.sample() 方法来取样
    if i != SPLIT_NUM-1:
        x, y = normal_sampler.sample(CLS_NUM // SPLIT_NUM, 3, basic_sample_size)
        subset = SubDataset(x, y)
        # subset.plot_distribution()
    # 最后一次采样必须包含其他部分所有的样本
    else:
        x, y = normal_sampler.remain()
        subset = SubDataset(x, y)
        # subset.plot_distribution()
    subsets.append(subset)

    for i in range(len(subsets)):
       print("=" * 35)
       print(f"subset {i}")
       subset = MyDataset(subsets[i])
       # subset = subsets[i]
       print(f"subset length: {len(subset)}")
       x, y = subset.get_x_y()

       subset = MyDataset(subset, transforms.Compose(
                           [
                               transforms.RandomHorizontalFlip(),
                               transforms.ToTensor(),
                            ]
                           ))

	# 分层抽样,对每个 subset 保证数据的 train, val 是同分布的
       for train_idxs, val_idxs in StratifiedShuffleSplit(n_splits=1
                                                   , train_size=0.75
                                                   , test_size=0.25
                                                   , random_state=1024).split(x, y):

           train_sampler = SubsetRandomSampler(train_idxs)
           val_sampler = SubsetRandomSampler(val_idxs)
           fold_train_loader = DataLoader(subset
                                          , batch_size=BATCHSIZE
                                          # , shuffle=True
                                          , sampler=train_sampler
                                          , num_workers=4
                                          , pin_memory=True)

           fold_val_loader = DataLoader(subset
                                        , batch_size=BATCHSIZE
                                        # , shuffle=False
                                        , sampler=val_sampler
                                        , num_workers=4
                                        , pin_memory=True)

你可能感兴趣的:(日常学习,pytorch,学习,python)