做深度学习:1 改网络模型的结构; 2 改 数据加载过程的代码
无非就是这两种,目前来说,网络模型结构的改动 还是比较清晰!
数据加载过程,主要分为以下三个步骤:
1 数据下载,
2 数据载入, 也就是图片的处理 。对应代码:
那些最常见的数据集,比如 mnist之类的,都是直接调用datasets.数据名()既可!
module_train = import_module('data.' + args.data_train.lower())
trainset = getattr(module_train, args.data_train)(args)
self.loader_train = MSDataLoader(
args,
trainset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=not args.cpu
)
比如这里的 trainset 就算数据载入,一般我们要改数据载入的过程。就是 改这个trainset里面的东西。
从哪里可以看出来是,真正的数据载入呢?
import torch.utils.data as data
要是导入了torch工具里面的data,并且有类继承了data.Dataset既可以知道,这里是真的数据载入的地方。
这个Dataset类,就是一个抽象类,一般继承了的类都会重写函数。
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
functions: Dict[str, Callable] = {}
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
def __getattr__(self, attribute_name):
if attribute_name in Dataset.functions:
function = functools.partial(Dataset.functions[attribute_name], self)
return function
else:
raise AttributeError
3 数据装载:打包送给模型 其实,这个过程主要就算一个打包过程,把这个数据送入到网络模型。一般就会设置以下,batchsize ,shuffle是否打乱顺序啥的。
就比如上面的代码
另外,ArbSR里面的代码:
打包用的是这个MSDataLoader(这是自己写的一个类,来打包)。一般来说,就是用torch库自带的data.DataLoader类。
class MSDataLoader(DataLoader):
def __init__(
self, args, dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None,
collate_fn=_utils.collate.default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(MSDataLoader, self).__init__(
dataset, batch_size=batch_size, shuffle=shuffle,
sampler=sampler, batch_sampler=batch_sampler,
num_workers=args.n_threads, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
self.scale = args.scale
def __iter__(self):
return _MSDataLoaderIter(self)
然后,这个就是ArbSR里面的打包代码,这里的iter是一个迭代器,这个执行完了以后,它run的是
trainer代码里面的,for batch了:
for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
lr, hr = self.prepare(lr, hr)
scale = hr.size(2) / lr.size(2)
scale2 = hr.size(3) / lr.size(3)
timer_data.hold()
self.optimizer.zero_grad()
# inference
self.model.get_model().set_scale(scale, scale2)
sr = self.model(lr)
# loss function
loss = self.loss(sr, hr)