# batch_size的理解
# 有一个被我无数次忘了的前提,图片本身也是一个下面这样的行列,每个元素的数字表示颜色之类的。
# 比如上面的CIFAR10的数据集中每张图片的尺寸是3x32x32,即一个三通道(RGB)的,32x32的行列。
# 这一张图片就是一个数据组,里面包含着3x32x32=3072个数字
import torch
from torch import nn
from torch.utils.data import DataLoader
example_1 = torch.tensor([[[1, 2, 0, 3],
[0, 1, 2, 3],
[1, 2, 1, 0],
[5, 2, 3, 1]],
[[1, 4, 6, 3],
[0, 1, 5, 3],
[2, 2, 1, 0],
[5, 2, 2, 1]],
[[1, 4, 6, 3],
[0, 1, 5, 3],
[2, 2, 1, 0],
[5, 2, 2, 1]],
[[1, 4, 6, 3],
[0, 1, 5, 3],
[2, 2, 1, 0],
[5, 2, 2, 1]]], dtype=torch.float32) # 瞎编4个4x4行列,这里涉及到第一个概念,如果把example_1视为一个数据集的话,这个数据集里一共有4个raw_data(数据组),每个raw_data通道数为1,长宽都为4
loader = DataLoader(example_1, batch_size=2) # 在这里我们设定了batch_size=2,即每个batch里只有2个数据组,所以example_1会被分成两个batch(一句题外话,这里可以意识到,dataloader的作用其实就是把数据集分割成多个batch)
for data in loader: #被dataloader处理后,数据被分成了两个batch,
res = data # 我们让res等于这个batch,
print(res.shape) #然后看一下res的尺寸,结果输出为torch.Size([2, 4, 4]),完美符合预期。(两个数据组,每个都是4x4)
ep_1 = torch.reshape(example_1, (-1, 1, 2, 2)) # 现在我们对res这个batch,如果我们想把它再次分成很多个2x2的行列,通道数为1的数据组_1,那这个batch里总共有多少个数据组_1呢?(batch_size留了-1让他自己补全)
print(ep_1.shape) #torch.Size([16, 1, 2, 2]) 很容易理解,以2x2,1通道的尺寸切割之前那个batch,因为之前那个batch总共只有64个元素,所以需要一次拿16个数据组才能一次把所以元素拿完
# 那如果,还是对于这个batch,如果我想一次只拿一个数据组,这个数据组只有一个通道,且高度为1,那长度应该为多少呢
ep_2 = torch.reshape(example_1, (1, 1, 1, -1)) # 按要求设置
print(ep_2.shape) #torch.Size([1, 1, 1, 64])。即,如果想一次把数据库内的所有数据拿完,且要求按我想的来的话,长度必须要为64才可以一次全部拿完。
# 总结(以CIFAR10作为例子)
# 第一层,我们有一个数据集,里面有640张图片(640个数据组)
# 第二层,我们在dataloader时设置batch_size=64,于是把这640张图片分成了10个batch
# 第三层,现在聚焦在某一个具体的batch里,里面有64个数据组(图片),每张图片3通道,长宽为32x32
# 第四层,所以这个batch一共有64x3x32x32=198806个数字。
# 如果我想令这个batch内只有一个数据组,且这个数据组只有一通道,高度为1的话,那这个数据组的宽度则理所当然为198806个数字.