TinyImageNet数据集读取与计算均值和标准差

最近在使用TinyImageNet数据集做分类任务,网络上对这个数据集的使用不算多,故做一些记录。

读取

借鉴github,读取数据集
https://github.com/Manikvsin/TinyImagenet-pytorch/blob/master/tiny_image_net_torch.py

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models,utils,datasets,transforms
import numpy as np
import sys
import os
from PIL import Image
import numpy as np

class TinyImageNet(Dataset):
    def __init__(self, root, train=True, transform = None):
        self.Train = train
        self.root_dir = root
        self.transform = transform
        self.train_dir = os.path.join(self.root_dir, "train")
        self.val_dir = os.path.join(self.root_dir, "val")
        
        if (self.Train):
            self._create_class_idx_dict_train()
        else:
            self._create_class_idx_dict_val()
        
        
        self._make_dataset(self.Train)

        words_file = os.path.join(self.root_dir, "words.txt")
        wnids_file = os.path.join(self.root_dir, "wnids.txt")
        
        self.set_nids = set()
        
        with open(wnids_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                self.set_nids.add(entry.strip("\n"))
        
        self.class_to_label = {}
        with open(words_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                words = entry.split("\t")
                if words[0] in self.set_nids:
                    self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0]



    def _create_class_idx_dict_train(self):
        if sys.version_info >= (3,5):
            classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(train_dir,d))]
        classes = sorted(classes)   
        num_images = 0
        for root, dirs, files in os.walk(self.train_dir):
            for f in files:
                if f.endswith(".JPEG"):
                    num_images = num_images + 1

        self.len_dataset = num_images;

        self.tgt_idx_to_class = {i:classes[i] for i in range(len(classes))}
        self.class_to_tgt_idx = {classes[i]:i for i in range(len(classes))}

    def _create_class_idx_dict_val(self):
        val_image_dir = os.path.join(self.val_dir, "images")
        if sys.version_info >= (3,5):
            images = [d.name for d in os.scandir(val_image_dir) if d.is_file()]
        else:
            images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(train_dir,d))]
        val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt")
        self.val_img_to_class = {}
        set_of_classes = set()
        with open(val_annotations_file,'r') as fo:
            entry = fo.readlines()
            for data in entry:
                words = data.split("\t")
                self.val_img_to_class[words[0]] = words[1]
                set_of_classes.add(words[1])

        self.len_dataset = len(list(self.val_img_to_class.keys()))
        classes = sorted(list(set_of_classes))
        #self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))}
        self.class_to_tgt_idx = {classes[i]:i for i in range(len(classes))}
        self.tgt_idx_to_class = {i:classes[i] for i in range(len(classes))}

    def _make_dataset(self, Train=True):
        self.images = []
        if Train:
            img_root_dir = self.train_dir
            list_of_dirs = [target for target in self.class_to_tgt_idx.keys()]
        else:
            img_root_dir = self.val_dir
            list_of_dirs = ["images"]

        for tgt in list_of_dirs:
            dirs = os.path.join(img_root_dir, tgt)
            if not os.path.isdir(dirs):
                continue

            for root,_,files in sorted(os.walk(dirs)):
                for fname in sorted(files):
                    if (fname.endswith(".JPEG")):
                        path = os.path.join(root, fname)
                        if Train:
                            item = (path, self.class_to_tgt_idx[tgt])
                        else:
                            item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]])
                        self.images.append(item)
    
    def return_label(self, idx):
        return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx]


    def __len__(self): 
        return self.len_dataset


    def __getitem__(self, idx):
        img_path, tgt = self.images[idx]
        with open(img_path,'rb') as f:
            sample = Image.open(img_path)
            sample = sample.convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)
        
        return sample, tgt 

合并数据集

我想把数据集合成一个文件,存储成npy格式
训练集10万张图片,转化成npy文件,大小约为9.2G
本来最开始一重循环读取,结果时间越来越长,到1万张左右时,因为数组一直在内存中,拼接操作花费的时间直线上升。
所以权衡之下,500张图片做一次拼接。

import time
if __name__ == "__main__":
    #transform = transforms.Compose([transforms.ToTensor()])
    start = time.time()
    transform = None
    tynet =  TinyImageNet(root='../dataset/tiny-imagenet-200',transform=transform)
    for j in range(200):
        dst = np.zeros((1,64,64,3))
        tar = np.random.rand(1)
        for i in range(500):
            dat, target = tynet.__getitem__(i + j * 500)
            dat = np.array(dat)[None,:,:,:]
            dst = np.concatenate((dst,dat))
            tar = np.concatenate((tar,[target]))
            if (i+1) % 500 == 0:
                print(i)
                end = time.time()
                print(end-start)
        #pdb.set_trace()
        #print(tar)
        if j > 0:
            tmp_dat = np.load("data.npy")
            tmp_tar = np.load("target.npy")
            tmp_dat = np.concatenate((tmp_dat,dst[1:]))
            tmp_tar = np.concatenate((tmp_tar,tar[1:]))
            np.save("data.npy",tmp_dat)
            np.save("target.npy",tmp_tar)
            new_dst = np.load("data.npy")
        else:
            np.save("data.npy",np.array(dst[1:]))
            np.save("target.npy",np.array(tar[1:]))
            new_dst = np.load("data.npy")
        

        print(new_dst.shape)
        print("loaded.")

计算均值与方差

def getStat(data_loader):
    '''
    Compute mean and std for training data
    :return: (mean, std)
    '''
    print('Compute mean and std for training data.')
    print(len(data_loader))
    mean = torch.zeros(3)
    std = torch.zeros(3)

    for train_data,_ in data_loader:
        for d in range(3):
            mean[d] += train_data[:,d, :, :].mean()
            std[d] += train_data[:,d, :, :].std()
    mean = mean / len(data_loader)
    std = std / len(data_loader)
    return list(mean.numpy()), list(std.numpy())
if __name__ == "__main__":
    transform = transforms.Compose([transforms.ToTensor()])
    #transform = None
    tynet =  TinyImageNet(root='../dataset/tiny-imagenet-200',transform=transform)
    loader = DataLoader(tynet,batch_size = 100)
    print(getStat(loader))

最后结果,均值和标准差

batch = 100
([0.48024505, 0.4480726, 0.39754787], [0.2717199, 0.26526922, 0.27396977])

batch = 200
([0.48024592, 0.4480722, 0.39754784], [0.27227214, 0.26570824, 0.27469844])

batch = 250
([0.4802457, 0.44807217, 0.3975479], [0.27193516, 0.26546907, 0.27422717])

batch = 500
([0.48024562, 0.4480722, 0.3975478], [0.27201378, 0.26554194, 0.27431726])

如有错误,欢迎指正。

你可能感兴趣的:(日记,python,深度学习)