找到的图片需要首先做相同尺寸的裁剪,归一化,否则会因为图片大小不同报错
RuntimeError: stack expects each tensor to be equal size,
but got [3, 667, 406] at entry 0 and [3, 600, 400] at entry 1
pytorch的torchvision.transforms
模块提供了许多用于图片变换/增强的函数。
transforms.Resize((600,600)),
因为主体要识别的图像一般在中心位置,所以使用CenterCrop
,这里设置为(400, 400)
transforms.CenterCrop((400,400)),
这里统一成torch.float64
方便神经网络计算,也可以统一成其他比如uint32等类型
transforms.ConvertImageDtype(torch.float64),
对于图片来说0~255
的范围有点大,并不利于模型梯度计算,我们应该进行归一化。pytorch当中也提供了归一化的函数torchvision.transforms.Normalize(mean,std)
,
[0.5,0.5,0.5]
的mean,std
来把数据归一化至[-1,1]
mean,std
来归一化至均值为0,标准差为1的正态分布,mean=[0.485, 0.456, 0.406]
,std=[0.229, 0.224, 0.225]
的归一化数据,这是在ImageNet
的几百万张图片数据计算得出的结果BN
等方法也具有很出色的归一化表现,我们也会使用到Juliuszh:详解深度学习中的Normalization,BN/LN/WN
Algernon:【基础算法】六问透彻理解BN(Batch Normalization)
我们这里使用简单的[0.5,0.5,0.5]
归一化方法,更新cls_dataset
,加入transform
操作 ,作为图片裁剪的预处理。
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
关于transforms
的操作大体分为裁剪/翻转和旋转/图像变换/transform自身操作,具体见余霆嵩:PyTorch 学习笔记(三):transforms的二十二个方法,这里不进行详细展开。
当数据集较小时,可以通过对已有图片做数据增强,利用之前提到的transforms
中的函数 ,也可以混合使用来根据已有数据创造新数据
self.data_enhancement = transforms.Compose([
transforms.RandomHorizontalFlip(p=1),
transforms.RandomRotation(30)
])
class cls_dataset(Dataset):
def __init__(self) -> None:
# initialization
def __getitem__(self, index):
# return data,label in set
def __len__(self):
# return the length of the dataset
import os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import h5py
from torchvision.io import read_image
train_pic_path = 'test-set'
test_pic_path = 'training-set'
def create_h5_file(file_name):
all_type = ['flower', 'bird']
h5df_file = h5py.File(file_name, "w") #file_name指向比如"train.hdf5"这种文件路径,但这句话之前file_name指向路径为空
#图片统一化处理
transform = transforms.Compose([
transforms.Resize((600, 600)),
transforms.CenterCrop((400, 400)),
transforms.ConvertImageDtype(torch.float64),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]
)
#数据增强
data_list = [] #建立一个保存图片张量的空列表
target_list = [] #建立一个保存图片标签的空列表
#遍历文件夹建立数据集
'''
文件夹组成
| —— train
| | —— flower
| | | —— 图片1
| | —— bird
| | —— | —— 图片2
| —— test
| | —— flower
| | —— bird
'''
dataset_kind = file_name.split('.')[0]
#先判断缺失的文件是训练集还是测试集
if dataset_kind == 'train':
pic_file_name = train_pic_path
else:
pic_file_name = test_pic_path
#再循环遍历文件夹
for file_name_dir, _, files in tqdm(os.walk(pic_file_name)):
target = file_name_dir.split('/')[-1]
if target in all_type:
for file in files:
pic = read_image(os.path.join(file_name_dir, file)) #以张量形式读取图片对象
pic = transform(pic) #预处理图片
pic = np.array(pic).astype(np.float64)
data_list.append(pic) #将pic对象添加到列表里
target_list.append(target.encode()) #将target编码后添加到列表里
h5df_file.create_dataset("image", data=data_list)
h5df_file.create_dataset("target", data=target_list)
h5df_file.close()
class h5py_dataset(Dataset):
def __init__(self, file_name) -> None:
super().__init__()
self.file_name = file_name #指向文件的路径名
#如果file_name指向的h5文件不存在,就新建一个
if not os.path.exists(file_name):
create_h5_file(file_name)
def __getitem__(self, index):
with h5py.File(self.file_name, 'r') as f:
if f['target'][index].decode() == 'bird': #如果在f文件的target列表中查找到index下标对应的标签是bird
target = torch.tensor(0)
else:
target = torch.tensor(1)
return f['image'][index], target
def __len__(self):
with h5py.File(self.file_name, 'r') as f:
return len(f['target'])
def h5py_loader():
train_file = 'train.hdf5'
test_file = 'test.hdf5'
train_dataset = h5py_dataset(train_file)
test_dataset = h5py_dataset(test_file)
train_data_loader = DataLoader(train_dataset, batch_size=4)
test_data_loader = DataLoader(test_dataset, batch_size=4)
return train_data_loader, test_data_loader
实例化set对象后利用torch.utils.data.DataLoader
卷积后,池化后尺寸计算公式:
(图像尺寸-卷积核尺寸 + 2*填充值)/步长+1
(图像尺寸-池化窗尺寸 + 2*填充值)/步长+1
参考文章
池化参数一般就是(2, 2)
中间的channel数量都是自己设定的,二的次方就行
kernelsize一般3或者5之类的
for _, data in enumerate(train_loader):
if isinstance(data, list):
image = data[0].type(torch.FloatTensor).to(device)
target = data[1].to(device)
elif isinstance(data, dict):
image = data['image'].type(torch.FloatTensor).to(device)
target = data['target'].to(device)
else:
print(type(data))
raise TypeError
for 循环中data的组成来源于构建set时,
h5df_file.create_dataset("image", data=data_list)
h5df_file.create_dataset("target", data=target_list)
写入了h5df文件中两个dataset,但在文件中是以嵌套列表形式保存,其中data[0]等价于引用image这个dataset,data[1]等价于引用target这个集合
投影概率放到网络里面