官方文档地址
项目地址
DataModules将数据与模型解耦分开,从而可以只关注模型本身而不用关注数据
自定义DataModules时需要继承LightningModule,并实现以下几个方法
def __init__(self): # 一般用来指定data_dir(数据目录),定义transform,定义默认的self.dims,方便后面对数据的使用
def prepare_data(self): # 下载数据,在该函数对不对数据进行任何操作
def setup(self,stage): # 加载之前下载好的数据,并分配到训练、验证和测试上,stage可以为'fit'或'test',若为'fit'只分配训练集,'test'分配测试机,None则都分配
def train_dataloader(self): # 返回训练集的dataloader
def val_dataloader(self): # 返回验证集的dataloader
def test_dataloader(self): # 返回测试集的dataloader
自定义cifar10的DataModule
class CIFAR10DataModule(LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
self.dims = (3, 32, 32)
self.num_classes = 10
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)
def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)
def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)
项目同时实现了多个自监督学习方法,定义了vision基类
import os
from abc import abstractmethod
from typing import Any, Callable, List, Optional, Union
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split
class VisionDataModule(LightningDataModule): # 继承LightningDataModule
EXTRA_ARGS: dict = {}
name: str = ""
#: Dataset class to use
dataset_cls: type # 用来下载数据以及数据的划分,比如后面要加载CIFAR10,dataset_cls就是CIFAR10,要加载自己的数据集的话需要相应的实现
#: A tuple describing the shape of the data
dims: tuple
def __init__(
self,
data_dir: Optional[str] = None,
val_split: Union[int, float] = 0.2,# 如果是int则是验证集数据的长度,如果是float是验证集占训练集的百分比
num_workers: int = 0,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(*args, **kwargs)
self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.val_split = val_split
self.num_workers = num_workers
self.normalize = normalize
self.batch_size = batch_size
self.seed = seed
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
def prepare_data(self, *args: Any, **kwargs: Any) -> None: # 下载数据
"""Saves files to data_dir."""
self.dataset_cls(self.data_dir, train=True, download=True)
self.dataset_cls(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None) -> None:
"""Creates train, val, and test dataset."""
if stage == "fit" or stage is None: # 分配训练集数据
train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
dataset_train = self.dataset_cls(self.data_dir, train=True, transform=train_transforms, **self.EXTRA_ARGS)
dataset_val = self.dataset_cls(self.data_dir, train=True, transform=val_transforms, **self.EXTRA_ARGS)
# Split 分割数据集
self.dataset_train = self._split_dataset(dataset_train)
self.dataset_val = self._split_dataset(dataset_val, train=False)
if stage == "test" or stage is None: # 分配测试数据
test_transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms
self.dataset_test = self.dataset_cls(
self.data_dir, train=False, transform=test_transforms, **self.EXTRA_ARGS
)
def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset:
"""Splits the dataset into train and validation set."""
len_dataset = len(dataset) # type: ignore[arg-type]
splits = self._get_splits(len_dataset) # 分割后的训练集和测试集数据
dataset_train, dataset_val = random_split(dataset, splits, generator=torch.Generator().manual_seed(self.seed))
if train:
return dataset_train
return dataset_val
def _get_splits(self, len_dataset: int) -> List[int]:
"""Computes split lengths for train and validation set."""
if isinstance(self.val_split, int):
train_len = len_dataset - self.val_split # 训练集长度
splits = [train_len, self.val_split] # 返回分割后的训练集和验证集
elif isinstance(self.val_split, float):
val_len = int(self.val_split * len_dataset)
train_len = len_dataset - val_len
splits = [train_len, val_len]
else:
raise ValueError(f"Unsupported type {type(self.val_split)}")
return splits
@abstractmethod
def default_transforms(self) -> Callable: # 子类实现
"""Default transform for the dataset."""
# dataloader
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
"""The train dataloader."""
return self._data_loader(self.dataset_train, shuffle=self.shuffle)
def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
"""The val dataloader."""
return self._data_loader(self.dataset_val)
def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
"""The test dataloader."""
return self._data_loader(self.dataset_test)
def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
实现CIFAR10DataModule,继承VisionDataModule
from typing import Any, Callable, Optional, Sequence, Union
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import CIFAR10
else: # pragma: no cover
warn_missing_pkg("torchvision")
CIFAR10 = None
class CIFAR10DataModule(VisionDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/
Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png
:width: 400
:alt: CIFAR-10
Specs:
- 10 classes (1 per class)
- Each image is (3 x 32 x 32)
Standard CIFAR10, train, val, test splits and transforms
Transforms::
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transforms.Normalize(
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
)
])
Example::
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule(PATH)
model = LitModel()
Trainer().fit(model, datamodule=dm)
Or you can set your own transforms
Example::
dm.train_transforms = ...
dm.test_transforms = ...
dm.val_transforms = ...
"""
name = "cifar10"
dataset_cls = CIFAR10 # 下载、处理数据
dims = (3, 32, 32)
def __init__(
self,
data_dir: Optional[str] = None,
val_split: Union[int, float] = 0.2,
num_workers: int = 0,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,# true则丢掉最后一个不满的batch
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__( # type: ignore[misc]
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
*args,
**kwargs,
)
@property
def num_samples(self) -> int:
train_len, _ = self._get_splits(len_dataset=50_000)
return train_len
@property
def num_classes(self) -> int:
"""
Return:
10
"""
return 10
def default_transforms(self) -> Callable: # 覆盖基类函数
if self.normalize:
cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
else:
cf10_transforms = transform_lib.Compose([transform_lib.ToTensor()])
return cf10_transforms
上面代码的CIFAR10是直接调用torch.vision里的,在项目里也实现了这个函数,要加载torch没有的数据集的话可以参考一下
class LightDataset(ABC, Dataset):
data: Tensor
targets: Tensor
normalize: tuple
dir_path: str
cache_folder_name: str
DATASET_NAME = "light"
def __len__(self) -> int:
return len(self.data)
@property
def cached_folder_path(self) -> str:
return os.path.join(self.dir_path, self.DATASET_NAME, self.cache_folder_name)
@staticmethod
def _prepare_subset(
full_data: Tensor,
full_targets: Tensor,
num_samples: int,
labels: Sequence,
) -> Tuple[Tensor, Tensor]:
"""Prepare a subset of a common dataset."""
classes = {d: 0 for d in labels}
indexes = []
for idx, target in enumerate(full_targets):
label = target.item()
if classes.get(label, float("inf")) >= num_samples:
continue
indexes.append(idx)
classes[label] += 1
if all(classes[k] >= num_samples for k in classes):
break
data = full_data[indexes]
targets = full_targets[indexes]
return data, targets
def _download_from_url(self, base_url: str, data_folder: str, file_name: str):
url = os.path.join(base_url, file_name)
logging.info(f"Downloading {url}")
fpath = os.path.join(data_folder, file_name)
try:
urllib.request.urlretrieve(url, fpath)
except HTTPError as err:
raise RuntimeError(f"Failed download from {url}") from err
class CIFAR10(LightDataset):
"""Customized `CIFAR10 `_ dataset for testing Pytorch Lightning
without the torchvision dependency.
Part of the code was copied from
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/
Args:
data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt``
and ``CIFAR10/processed/test.pt`` exist.
train: If ``True``, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
Examples:
>>> from torchvision import transforms
>>> from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
>>> cf10_transforms = transforms.Compose([transforms.ToTensor(), cifar10_normalization()])
>>> dataset = CIFAR10(download=True, transform=cf10_transforms, data_dir="datasets")
>>> len(dataset)
50000
>>> torch.bincount(dataset.targets)
tensor([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000])
>>> data, label = dataset[0]
>>> data.shape
torch.Size([3, 32, 32])
>>> label
6
Labels::
airplane: 0
automobile: 1
bird: 2
cat: 3
deer: 4
dog: 5
frog: 6
horse: 7
ship: 8
truck: 9
"""
BASE_URL = "https://www.cs.toronto.edu/~kriz/" # 下载地址
FILE_NAME = "cifar-10-python.tar.gz" # 目标文件
cache_folder_name = "complete" # 完成解压后pt文件存放目录
TRAIN_FILE_NAME = "training.pt" # 训练集
TEST_FILE_NAME = "test.pt" # 测试集
DATASET_NAME = "CIFAR10" # 根目录
labels = set(range(10))
relabel = False
def __init__(
self, data_dir: str = ".", train: bool = True, transform: Optional[Callable] = None, download: bool = True
):
super().__init__()
self.dir_path = data_dir
self.train = train # training set or test set
self.transform = transform
if not _PIL_AVAILABLE:
raise ImportError("You want to use PIL.Image for loading but it is not installed yet.")
os.makedirs(self.cached_folder_path, exist_ok=True)
self.prepare_data(download)
if not self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)):
raise RuntimeError("Dataset not found.")
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
self.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file))
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
img = self.data[idx].reshape(3, 32, 32)
target = int(self.targets[idx])
if self.transform is not None:
img = img.numpy().transpose((1, 2, 0)) # convert to HWC
img = self.transform(Image.fromarray(img))
if self.relabel:
target = list(self.labels).index(target)
return img, target
@classmethod
def _check_exists(cls, data_folder: str, file_names: Sequence[str]) -> bool:
if isinstance(file_names, str):
file_names = [file_names]
return all(os.path.isfile(os.path.join(data_folder, fname)) for fname in file_names)
def _unpickle(self, path_folder: str, file_name: str) -> Tuple[Tensor, Tensor]:
with open(os.path.join(path_folder, file_name), "rb") as fo:
pkl = pickle.load(fo, encoding="bytes")
return torch.tensor(pkl[b"data"]), torch.tensor(pkl[b"labels"])
def _extract_archive_save_torch(self, download_path):
# extract achieve
with tarfile.open(os.path.join(download_path, self.FILE_NAME), "r:gz") as tar:
tar.extractall(path=download_path)
# this is internal path in the archive
path_content = os.path.join(download_path, "cifar-10-batches-py")
# load Test and save as PT
torch.save(
self._unpickle(path_content, "test_batch"), os.path.join(self.cached_folder_path, self.TEST_FILE_NAME)
)
# load Train and save as PT
data, labels = [], []
for i in range(5):
fname = f"data_batch_{i + 1}"
_data, _labels = self._unpickle(path_content, fname)
data.append(_data)
labels.append(_labels)
# stash all to one
data = torch.cat(data, dim=0)
labels = torch.cat(labels, dim=0)
# and save as PT
torch.save((data, labels), os.path.join(self.cached_folder_path, self.TRAIN_FILE_NAME))
def prepare_data(self, download: bool):
if self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)):
return
base_path = os.path.join(self.dir_path, self.DATASET_NAME)
if download:
self.download(base_path)
self._extract_archive_save_torch(base_path)
def download(self, data_folder: str) -> None:
"""Download the data if it doesn't exist in cached_folder_path already."""
if self._check_exists(data_folder, self.FILE_NAME):
return
self._download_from_url(self.BASE_URL, data_folder, self.FILE_NAME)