定义
class MNIST(VisionDataset):
"""`MNIST `_ Dataset.
Args:
root (string): Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = [
'http://yann.lecun.com/exdb/mnist/',
'https://ossci-datasets.s3.amazonaws.com/mnist/',
]
resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
]
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
self.data, self.targets = self._load_data()
def _check_legacy_exist(self):
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False
return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)
def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)
@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self) -> bool:
return all(
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
for url, _ in self.resources
)
def download(self) -> None:
"""Download the MNIST data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
# download files
for filename, md5 in self.resources:
for mirror in self.mirrors:
url = "{}{}".format(mirror, filename)
try:
print("Downloading {}".format(url))
download_and_extract_archive(
url, download_root=self.raw_folder,
filename=filename,
md5=md5
)
except URLError as error:
print(
"Failed to download (trying next):\n{}".format(error)
)
continue
finally:
print()
break
else:
raise RuntimeError("Error downloading {}".format(filename))
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
FMNIST,KMNIST,QMNIST均可直接读取,在torchvision.datasets中
可通过下面的方式加载
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data/', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_test, shuffle=True)
class CIFAR10(VisionDataset):
"""`CIFAR10 `_ Dataset.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data: Any = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)
def _check_integrity(self) -> bool:
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
CIFAR100同理
pytorch1.8以后其余已定义的数据集有
需要自己下载完整数据集
from pathlib import Path
import json
from typing import Any, Tuple, Callable, Optional
import torch
import PIL.Image
from torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset
class Food101(VisionDataset):
"""`The Food-101 Data Set `_.
The Food-101 is a challenging data set of 101 food categories, with 101'000 images.
For each class, 250 manually reviewed test images are provided as well as 750 training images.
On purpose, the training images were not cleaned, and thus still contain some amount of noise.
This comes mostly in the form of intense colors and sometimes wrong labels. All images were
rescaled to have a maximum side length of 512 pixels.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
_MD5 = "85eeb15f3717b99a5da872d97d918f87"
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = Path(self.root) / "food-101"
self._meta_folder = self._base_folder / "meta"
self._images_folder = self._base_folder / "images"
self.class_names_str = ['Apple pie', 'Baby back ribs', 'Baklava', 'Beef carpaccio', 'Beef tartare', 'Beet salad', 'Beignets', 'Bibimbap', 'Bread pudding', 'Breakfast burrito', 'Bruschetta', 'Caesar salad', 'Cannoli', 'Caprese salad', 'Carrot cake', 'Ceviche', 'Cheesecake', 'Cheese plate', 'Chicken curry', 'Chicken quesadilla', 'Chicken wings', 'Chocolate cake', 'Chocolate mousse', 'Churros', 'Clam chowder', 'Club sandwich', 'Crab cakes', 'Creme brulee', 'Croque madame', 'Cup cakes', 'Deviled eggs', 'Donuts', 'Dumplings', 'Edamame', 'Eggs benedict', 'Escargots', 'Falafel', 'Filet mignon', 'Fish and chips', 'Foie gras', 'French fries', 'French onion soup', 'French toast', 'Fried calamari', 'Fried rice', 'Frozen yogurt', 'Garlic bread', 'Gnocchi', 'Greek salad', 'Grilled cheese sandwich', 'Grilled salmon', 'Guacamole', 'Gyoza', 'Hamburger', 'Hot and sour soup', 'Hot dog', 'Huevos rancheros', 'Hummus', 'Ice cream', 'Lasagna', 'Lobster bisque', 'Lobster roll sandwich', 'Macaroni and cheese', 'Macarons', 'Miso soup', 'Mussels', 'Nachos', 'Omelette', 'Onion rings', 'Oysters', 'Pad thai', 'Paella', 'Pancakes', 'Panna cotta', 'Peking duck', 'Pho', 'Pizza', 'Pork chop', 'Poutine', 'Prime rib', 'Pulled pork sandwich', 'Ramen', 'Ravioli', 'Red velvet cake', 'Risotto', 'Samosa', 'Sashimi', 'Scallops', 'Seaweed salad', 'Shrimp and grits', 'Spaghetti bolognese', 'Spaghetti carbonara', 'Spring rolls', 'Steak', 'Strawberry shortcake', 'Sushi', 'Tacos', 'Takoyaki', 'Tiramisu', 'Tuna tartare', 'Waffles']
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self._labels = []
self._image_files = []
with open(self._meta_folder / f"{split}.json") as f:
metadata = json.loads(f.read())
self.classes = sorted(metadata.keys())
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
for class_label, im_rel_paths in metadata.items():
self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
self._image_files += [
self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
]
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def extra_repr(self) -> str:
return f"split={self._split}"
def _check_exists(self) -> bool:
return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
def examine_count(counter, name = "train"):
print(f"in the {name} set")
for label in counter:
print(label, counter[label])
if __name__ == "__main__":
label_names = []
with open('debug/food101_labels.txt') as f:
for name in f:
label_names.append(name.strip())
print(label_names)
train_set = Food101(root = "/nobackup/dataset_myf", split = "train", download = True)
test_set = Food101(root = "/nobackup/dataset_myf", split = "test")
print(f"train set len {len(train_set)}")
print(f"test set len {len(test_set)}")
from collections import Counter
train_label_count = Counter(train_set._labels)
test_label_count = Counter(test_set._labels)
# examine_count(train_label_count, name = "train")
# examine_count(test_label_count, name = "test")
kwargs = {'num_workers': 4, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(train_set ,
batch_size=16, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(test_set,
batch_size=16, shuffle=False, **kwargs)
from pathlib import Path
from typing import Any, Tuple, Callable, Optional
import torch
import PIL.Image
from torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset
class Flowers102(VisionDataset):
"""`Oxford 102 Flower `_ Dataset.
.. warning::
This class needs `scipy `_ to load target files from `.mat` format.
Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
between 40 and 258 images.
The images have large scale, pose and light variations. In addition, there are categories that
have large variations within the category, and several very similar categories.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a
transformed version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
_file_dict = { # filename, md5
"image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
"label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
"setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
}
_splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
# https://gist.github.com/JosephKJ/94c7728ed1a8e0cd87fe6a029769cde1
label_names = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', 'hippeastrum ', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily']
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
self._base_folder = Path(self.root) / "flowers-102"
self._images_folder = self._base_folder / "jpg"
if download:
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
from scipy.io import loadmat
set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
image_ids = set_ids[self._splits_map[self._split]].tolist()
labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
image_id_to_label = dict(enumerate(labels["labels"].tolist(), 1))
self._labels = []
self._image_files = []
for image_id in image_ids:
self._labels.append(image_id_to_label[image_id])
self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
self.class_names_str = self.label_names
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def extra_repr(self) -> str:
return f"split={self._split}"
def _check_integrity(self):
if not (self._images_folder.exists() and self._images_folder.is_dir()):
return False
for id in ["label", "setid"]:
filename, md5 = self._file_dict[id]
if not check_integrity(str(self._base_folder / filename), md5):
return False
return True
def download(self):
if self._check_integrity():
return
download_and_extract_archive(
f"{self._download_url_prefix}{self._file_dict['image'][0]}",
str(self._base_folder),
md5=self._file_dict["image"][1],
)
for id in ["label", "setid"]:
filename, md5 = self._file_dict[id]
download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
def examine_count(counter, name = "train"):
print(f"in the {name} set")
for label in counter:
print(label, counter[label])
if __name__ == "__main__":
# label_names = []
# with open('debug/flowers102_labels.txt') as f:
# for name in f:
# label_names.append(name.strip()[1:-1])
# print(label_names)
train_set = Flowers102(root = "/nobackup/dataset_myf", split = "train", download = True)
val_set = Flowers102(root = "/nobackup/dataset_myf", split = "val")
test_set = Flowers102(root = "/nobackup/dataset_myf", split = "test")
from collections import Counter
train_label_count = Counter(train_set._labels)
val_label_count = Counter(val_set._labels)
test_label_count = Counter(test_set._labels)
examine_count(train_label_count, name = "train")
examine_count(val_label_count, name = "val")
examine_count(test_label_count, name = "test")
kwargs = {'num_workers': 4, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(train_set ,
batch_size=16, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(val_set,
batch_size=16, shuffle=False, **kwargs)
import pathlib
from typing import Callable, Optional, Any, Tuple
from PIL import Image
import torch
from torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset
class StanfordCars(VisionDataset):
"""`Stanford Cars `_ Dataset
The Cars dataset contains 16,185 images of 196 classes of cars. The data is
split into 8,144 training images and 8,041 testing images, where each class
has been split roughly in a 50-50 split
.. note::
This class needs `scipy `_ to load target files from `.mat` format.
Args:
root (string): Root directory of dataset
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again."""
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
try:
import scipy.io as sio
except ImportError:
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = pathlib.Path(root) / "stanford_cars"
devkit = self._base_folder / "devkit"
if self._split == "train":
self._annotations_mat_path = devkit / "cars_train_annos.mat"
self._images_base_path = self._base_folder / "cars_train"
else:
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
self._images_base_path = self._base_folder / "cars_test"
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self._samples = [
(
str(self._images_base_path / annotation["fname"]),
annotation["class"] - 1, # Original target mapping starts from 1, hence -1
)
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
]
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.class_names_str = self.classes
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
"""Returns pil_image and class_id for given index"""
image_path, target = self._samples[idx]
pil_image = Image.open(image_path).convert("RGB")
if self.transform is not None:
pil_image = self.transform(pil_image)
if self.target_transform is not None:
target = self.target_transform(target)
return pil_image, target
def download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
download_root=str(self._base_folder),
md5="c3b158d763b6e2245038c8ad08e45376",
)
if self._split == "train":
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
download_root=str(self._base_folder),
md5="065e5b463ae28d29e77c1b4b166cfe61",
)
else:
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
download_root=str(self._base_folder),
md5="4ce7ebf6a94d07f1952d94dd34c4d501",
)
download_url(
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
root=str(self._base_folder),
md5="b0a2b23655a3edd16d84508592a98d10",
)
def _check_exists(self) -> bool:
if not (self._base_folder / "devkit").is_dir():
return False
return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
def examine_count(counter, name = "train"):
print(f"in the {name} set")
for label in counter:
print(label, counter[label])
if __name__ == "__main__":
train_set = StanfordCars(root = "/nobackup/dataset_myf", split = "train", download = True)
test_set = StanfordCars(root = "/nobackup/dataset_myf", split = "test", download = True)
print(f"train set len {len(train_set)}")
print(f"test set len {len(test_set)}")
from collections import Counter
train_label_count = Counter([label for img, label in train_set._samples])
test_label_count = Counter([label for img, label in test_set._samples])
examine_count(train_label_count, name = "train")
examine_count(test_label_count, name = "test")
kwargs = {'num_workers': 4, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(train_set ,
batch_size=16, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(test_set,
batch_size=16, shuffle=False, **kwargs)
import numpy as np
# 读取数据
import matplotlib.image
import os
from PIL import Image
from torchvision import transforms
import torch
class CUB():
def __init__(self, root, is_train=True, data_len=None,transform=None, target_transform=None):
self.root = root
self.is_train = is_train
self.transform = transform
self.target_transform = target_transform
img_txt_file = open(os.path.join(self.root, 'images.txt'))
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
# 图片索引
img_name_list = []
for line in img_txt_file:
# 最后一个字符为换行符
img_name_list.append(line[:-1].split(' ')[-1])
# 标签索引,每个对应的标签减1,标签值从0开始
label_list = []
for line in label_txt_file:
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
# 设置训练集和测试集
train_test_list = []
for line in train_val_file:
train_test_list.append(int(line[:-1].split(' ')[-1]))
# zip压缩合并,将数据与标签(训练集还是测试集)对应压缩
# zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,
# 然后返回由这些元组组成的对象,这样做的好处是节约了不少的内存。
# 我们可以使用 list() 转换来输出列表
# 如果 i 为 1,那么设为训练集
# 1为训练集,0为测试集
# zip压缩合并,将数据与标签(训练集还是测试集)对应压缩
# 如果 i 为 1,那么设为训练集
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
train_label_list = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
test_label_list = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
if self.is_train:
# matplotlib.image.imread 图片读取出来为array类型,即numpy类型
self.train_img = [matplotlib.image.imread(os.path.join(self.root, 'images', train_file)) for train_file in
train_file_list[:data_len]]
# 读取训练集标签
self.train_label = train_label_list
if not self.is_train:
self.test_img = [matplotlib.image.imread(os.path.join(self.root, 'images', test_file)) for test_file in
test_file_list[:data_len]]
self.test_label = test_label_list
# 数据增强
def __getitem__(self,index):
# 训练集
if self.is_train:
img, target = self.train_img[index], self.train_label[index]
# 测试集
else:
img, target = self.test_img[index], self.test_label[index]
if len(img.shape) == 2:
# 灰度图像转为三通道
img = np.stack([img]*3,2)
# 转为 RGB 类型
img = Image.fromarray(img,mode='RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.is_train:
return len(self.train_label)
else:
return len(self.test_label)
if __name__ == '__main__':
'''
dataset = CUB(root='./CUB_200_2011')
for data in dataset:
print(data[0].size(),data[1])
'''
# 以pytorch中DataLoader的方式读取数据集
transform_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomCrop(224, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
])
dataset = CUB(root='../dataset/CUB_200_2011', is_train=True, transform=transform_train,)
print(len(dataset))
trainloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0,
drop_last=True)
print(len(trainloader))
import numpy as np
import os
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torchvision.datasets.utils import extract_archive
class Aircraft(VisionDataset):
"""`FGVC-Aircraft `_ Dataset.
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
class_type (string, optional): choose from ('variant', 'family', 'manufacturer').
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')
img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images')
def __init__(self, root, train=True, class_type='variant', transform=None,
target_transform=None, download=False):
super(Aircraft, self).__init__(root, transform=transform, target_transform=target_transform)
split = 'trainval' if train else 'test'
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()
(image_ids, targets, classes, class_to_idx) = self.find_classes()
samples = self.make_dataset(image_ids, targets)
self.loader = default_loader
self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.img_folder)) and \
os.path.exists(self.classes_file)
def download(self):
if self._check_exists():
return
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
print('Downloading %s...' % self.url)
tar_name = self.url.rpartition('/')[-1]
download_url(self.url, root=self.root, filename=tar_name)
tar_path = os.path.join(self.root, tar_name)
print('Extracting %s...' % tar_path)
extract_archive(tar_path)
print('Done!')
def find_classes(self):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
with open(self.classes_file, 'r') as f:
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return image_ids, targets, classes, class_to_idx
def make_dataset(self, image_ids, targets):
assert (len(image_ids) == len(targets))
images = []
for i in range(len(image_ids)):
item = (os.path.join(self.root, self.img_folder,
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
if __name__ == '__main__':
train_dataset = Aircraft('./aircraft', train=True, download=False)
test_dataset = Aircraft('./aircraft', train=False, download=False)
class PermutedMNISTDataLoader(torchvision.datasets.MNIST):
def __init__(self, source='data/mnist_data', train = True, shuffle_seed = None):
super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)
self.train = train
self.num_data = 0
if self.train:
self.permuted_train_data = torch.stack(
[img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
for img in self.train_data])
self.num_data = self.permuted_train_data.shape[0]
else:
self.permuted_test_data = torch.stack(
[img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
for img in self.test_data])
self.num_data = self.permuted_test_data.shape[0]
def __getitem__(self, index):
if self.train:
input, label = self.permuted_train_data[index], self.train_labels[index]
else:
input, label = self.permuted_test_data[index], self.test_labels[index]
return input, label
def getNumData(self):
return self.num_data
batch_size = 64
learning_rate = 1e-3
num_task = 10
criterion = torch.nn.CrossEntropyLoss()
cuda_available = False
if torch.cuda.is_available():
cuda_available = True
def permute_mnist():
train_loader = {}
test_loader = {}
train_data_num = 0
test_data_num = 0
for i in range(num_task):
shuffle_seed = np.arange(28*28)
np.random.shuffle(shuffle_seed)
train_PMNIST_DataLoader = PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed)
test_PMNIST_DataLoader = PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed)
train_data_num += train_PMNIST_DataLoader.getNumData()
test_data_num += test_PMNIST_DataLoader.getNumData()
train_loader[i] = torch.utils.data.DataLoader(
train_PMNIST_DataLoader,
batch_size=batch_size)
test_loader[i] = torch.utils.data.DataLoader(
test_PMNIST_DataLoader,
batch_size=batch_size)
return train_loader, test_loader, int(train_data_num/num_task), int(test_data_num/num_task)
train_loader, test_loader, train_data_num, test_data_num = permute_mnist()
import os
import os
import pandas as pd
import warnings
from torchvision.datasets import ImageFolder
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_arg
class TinyImageNet(VisionDataset):
"""`tiny-imageNet `_ Dataset.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'tiny-imagenet-200/'
url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
filename = 'tiny-imagenet-200.zip'
md5 = '90528d7ca1a48142e341f4ef8d21d0de'
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)
self.dataset_path = os.path.join(root, self.base_folder)
self.loader = default_loader
self.split = verify_str_arg(split, "split", ("train", "val",))
if self._check_integrity():
print('Files already downloaded and verified.')
elif download:
self._download()
else:
raise RuntimeError(
'Dataset not found. You can use download=True to download it.')
if not os.path.isdir(self.dataset_path):
print('Extracting...')
extract_archive(os.path.join(root, self.filename))
_, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt'))
self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx)
def _download(self):
print('Downloading...')
download_url(self.url, root=self.root, filename=self.filename)
print('Extracting...')
extract_archive(os.path.join(self.root, self.filename))
def _check_integrity(self):
return check_integrity(os.path.join(self.root, self.filename), self.md5)
def __getitem__(self, index):
img_path, target = self.data[index]
image = self.loader(img_path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def __len__(self):
return len(self.data)
def find_classes(class_file):
with open(class_file) as r:
classes = list(map(lambda s: s.strip(), r.readlines()))
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(root, base_folder, dirname, class_to_idx):
images = []
dir_path = os.path.join(root, base_folder, dirname)
if dirname == 'train':
for fname in sorted(os.listdir(dir_path)):
cls_fpath = os.path.join(dir_path, fname)
if os.path.isdir(cls_fpath):
cls_imgs_path = os.path.join(cls_fpath, 'images')
for imgname in sorted(os.listdir(cls_imgs_path)):
path = os.path.join(cls_imgs_path, imgname)
item = (path, class_to_idx[fname])
images.append(item)
else:
imgs_path = os.path.join(dir_path, 'images')
imgs_annotations = os.path.join(dir_path, 'val_annotations.txt')
with open(imgs_annotations) as r:
data_info = map(lambda s: s.split('\t'), r.readlines())
cls_map = {line_data[0]: line_data[1] for line_data in data_info}
for imgname in sorted(os.listdir(imgs_path)):
path = os.path.join(imgs_path, imgname)
item = (path, class_to_idx[cls_map[imgname]])
images.append(item)
return images
if __name__ == '__main__':
train_dataset = TinyImageNet('./tiny-imagenet', split='train', download=False)
test_dataset = TinyImageNet('./tiny-imagenet', split='val', download=False)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Yaoyao Liu
## NUS School of Computing
## Email: [email protected]
## Copyright (c) 2019
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import random
import numpy as np
from tqdm import trange
import imageio
class MiniImageNetDataLoader(object):
def __init__(self, shot_num, way_num, episode_test_sample_num, shuffle_images = False):
self.shot_num = shot_num
self.way_num = way_num
self.episode_test_sample_num = episode_test_sample_num
self.num_samples_per_class = episode_test_sample_num + shot_num
self.shuffle_images = shuffle_images
metatrain_folder = './processed_images/train'
metaval_folder = './processed_images/val'
metatest_folder = './processed_images/test'
npy_dir = './episode_filename_list/'
if not os.path.exists(npy_dir):
os.mkdir(npy_dir)
self.npy_base_dir = npy_dir + str(self.shot_num) + 'shot_' + str(self.way_num) + 'way_' + str(episode_test_sample_num) + 'shuffled_' + str(self.shuffle_images) + '/'
if not os.path.exists(self.npy_base_dir):
os.mkdir(self.npy_base_dir)
self.metatrain_folders = [os.path.join(metatrain_folder, label) \
for label in os.listdir(metatrain_folder) \
if os.path.isdir(os.path.join(metatrain_folder, label)) \
]
self.metaval_folders = [os.path.join(metaval_folder, label) \
for label in os.listdir(metaval_folder) \
if os.path.isdir(os.path.join(metaval_folder, label)) \
]
self.metatest_folders = [os.path.join(metatest_folder, label) \
for label in os.listdir(metatest_folder) \
if os.path.isdir(os.path.join(metatest_folder, label)) \
]
def get_images(self, paths, labels, nb_samples=None, shuffle=True):
if nb_samples is not None:
sampler = lambda x: random.sample(x, nb_samples)
else:
sampler = lambda x: x
images = [(i, os.path.join(path, image)) \
for i, path in zip(labels, paths) \
for image in sampler(os.listdir(path))]
if shuffle:
random.shuffle(images)
return images
def generate_data_list(self, phase='train', episode_num=None):
if phase=='train':
folders = self.metatrain_folders
if episode_num is None:
episode_num = 20000
if not os.path.exists(self.npy_base_dir+'/train_filenames.npy'):
print('Generating train filenames')
all_filenames = []
for _ in trange(episode_num):
sampled_character_folders = random.sample(folders, self.way_num)
random.shuffle(sampled_character_folders)
labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)
labels = [li[0] for li in labels_and_images]
filenames = [li[1] for li in labels_and_images]
all_filenames.extend(filenames)
np.save(self.npy_base_dir+'/train_labels.npy', labels)
np.save(self.npy_base_dir+'/train_filenames.npy', all_filenames)
print('Train filename and label lists are saved')
elif phase=='val':
folders = self.metaval_folders
if episode_num is None:
episode_num = 600
if not os.path.exists(self.npy_base_dir+'/val_filenames.npy'):
print('Generating val filenames')
all_filenames = []
for _ in trange(episode_num):
sampled_character_folders = random.sample(folders, self.way_num)
random.shuffle(sampled_character_folders)
labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)
labels = [li[0] for li in labels_and_images]
filenames = [li[1] for li in labels_and_images]
all_filenames.extend(filenames)
np.save(self.npy_base_dir+'/val_labels.npy', labels)
np.save(self.npy_base_dir+'/val_filenames.npy', all_filenames)
print('Val filename and label lists are saved')
elif phase=='test':
folders = self.metatest_folders
if episode_num is None:
episode_num = 600
if not os.path.exists(self.npy_base_dir+'/test_filenames.npy'):
print('Generating test filenames')
all_filenames = []
for _ in trange(episode_num):
sampled_character_folders = random.sample(folders, self.way_num)
random.shuffle(sampled_character_folders)
labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)
labels = [li[0] for li in labels_and_images]
filenames = [li[1] for li in labels_and_images]
all_filenames.extend(filenames)
np.save(self.npy_base_dir+'/test_labels.npy', labels)
np.save(self.npy_base_dir+'/test_filenames.npy', all_filenames)
print('Test filename and label lists are saved')
else:
print('Please select vaild phase')
def load_list(self, phase='train'):
if phase=='train':
self.train_filenames = np.load(self.npy_base_dir + 'train_filenames.npy').tolist()
self.train_labels = np.load(self.npy_base_dir + 'train_labels.npy').tolist()
elif phase=='val':
self.val_filenames = np.load(self.npy_base_dir + 'val_filenames.npy').tolist()
self.val_labels = np.load(self.npy_base_dir + 'val_labels.npy').tolist()
elif phase=='test':
self.test_filenames = np.load(self.npy_base_dir + 'test_filenames.npy').tolist()
self.test_labels = np.load(self.npy_base_dir + 'test_labels.npy').tolist()
elif phase=='all':
self.train_filenames = np.load(self.npy_base_dir + 'train_filenames.npy').tolist()
self.train_labels = np.load(self.npy_base_dir + 'train_labels.npy').tolist()
self.val_filenames = np.load(self.npy_base_dir + 'val_filenames.npy').tolist()
self.val_labels = np.load(self.npy_base_dir + 'val_labels.npy').tolist()
self.test_filenames = np.load(self.npy_base_dir + 'test_filenames.npy').tolist()
self.test_labels = np.load(self.npy_base_dir + 'test_labels.npy').tolist()
else:
print('Please select vaild phase')
def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True):
new_path_list = []
new_label_list = []
for k in range(batch_sample_num):
class_idxs = list(range(0, self.way_num))
random.shuffle(class_idxs)
for class_idx in class_idxs:
true_idx = class_idx*batch_sample_num + k
new_path_list.append(input_filename_list[true_idx])
new_label_list.append(input_label_list[true_idx])
img_list = []
for filepath in new_path_list:
this_img = imageio.imread(filepath)
this_img = this_img / 255.0
img_list.append(this_img)
if reshape_with_one:
img_array = np.array(img_list)
label_array = self.one_hot(np.array(new_label_list)).reshape([1, self.way_num*batch_sample_num, -1])
else:
img_array = np.array(img_list)
label_array = self.one_hot(np.array(new_label_list)).reshape([self.way_num*batch_sample_num, -1])
return img_array, label_array
def one_hot(self, inp):
n_class = inp.max() + 1
n_sample = inp.shape[0]
out = np.zeros((n_sample, n_class))
for idx in range(n_sample):
out[idx, inp[idx]] = 1
return out
def get_batch(self, phase='train', idx=0):
if phase=='train':
all_filenames = self.train_filenames
labels = self.train_labels
elif phase=='val':
all_filenames = self.val_filenames
labels = self.val_labels
elif phase=='test':
all_filenames = self.test_filenames
labels = self.test_labels
else:
print('Please select vaild phase')
one_episode_sample_num = self.num_samples_per_class*self.way_num
this_task_filenames = all_filenames[idx*one_episode_sample_num:(idx+1)*one_episode_sample_num]
epitr_sample_num = self.shot_num
epite_sample_num = self.episode_test_sample_num
this_task_tr_filenames = []
this_task_tr_labels = []
this_task_te_filenames = []
this_task_te_labels = []
for class_k in range(self.way_num):
this_class_filenames = this_task_filenames[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class]
this_class_label = labels[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class]
this_task_tr_filenames += this_class_filenames[0:epitr_sample_num]
this_task_tr_labels += this_class_label[0:epitr_sample_num]
this_task_te_filenames += this_class_filenames[epitr_sample_num:]
this_task_te_labels += this_class_label[epitr_sample_num:]
this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False)
this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False)
return this_inputa, this_labela, this_inputb, this_labelb
参考 CINIC10
import torchvision
import torchvision.transforms as transforms
cinic_directory = '/path/to/cinic/directory'
cinic_mean = [0.47889522, 0.47227842, 0.43047404]
cinic_std = [0.24205776, 0.23828046, 0.25874835]
cinic_train = torch.utils.data.DataLoader(
torchvision.datasets.ImageFolder(cinic_directory + '/train',
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=cinic_mean,std=cinic_std)])),
batch_size=128, shuffle=True)