pytorch实现交叉验证
一般的交叉验证是对神经网络回归分类的代码
我这里是针对图像分类来的,对于目标检测这些的话,把对应读取数据的函数修改一下就行了
实现交叉验证的Dataset
import torch
import torch.nn as nn
from torch.utils.data.dataset import *
from PIL import Image
from torch.nn import functional as F
import random
class KZDataset(Dataset):
def __init__(self, txt_path=None, ki=0, K=5, typ='train', transform=None, rand=False):
'''
txt_path: 所有数据的路径,我的形式为(单张图片路径 类别\n)
img1.png 0
...
img100.png 1
ki:当前是第几折,从0开始,范围为[0, K)
K:总的折数
typ:用于区分训练集与验证集
transform:对图片的数据增强
rand:是否随机
'''
self.all_data_info = self.get_img_info(txt_path)
if rand:
random.seed(1)
random.shuffle(self.all_data_info)
leng = len(self.all_data_info)
every_z_len = leng // K
if typ == 'val':
self.data_info = self.all_data_info[every_z_len * ki : every_z_len * (ki+1)]
elif typ == 'train':
self.data_info = self.all_data_info[: every_z_len * ki] + self.all_data_info[every_z_len * (ki+1) :]
self.transform = transform
def __getitem__(self, index):
img_pth, label = self.data_info[index]
img = Image.open(img_pth).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(txt_path):
data_info = []
data = open(txt_path, 'r')
data_lines = data.readlines()
for data_line in data_lines:
data_line = data_line.split()
img_pth = data_line[0]
label = int(data_line[1])
data_info.append((img_pth, label))
return data_info
运用KZDataset
这里我只写调用的伪代码,按自己需求改
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
transfm = transforms.Compose([
transforms.Resize((384, 384)),
trainsforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)])
for ki in range(K):
trainset = KZDataset(txt_path='data.txt', ki=ki, K=K, typ='train', transform=transfm, rand=True)
valset = KZDataset(txt_path='data.txt', ki=ki, K=K, typ='val', transform=transfm, rand=True)
train_loader = DataLoader(
dataset=trainset,
batch_size=batchs,
shuffle=True)
val_loader = DataLoader(
dataset=valset,
batch_size=batchs,
)
for epoch in range(epoches):
for i, (inputs, labels) in enumerate(train_loader):
pass
'''
训练过程
'''
for i, (inputs, labels) in enumerate(val_loader):
pass
'''
验证过程
'''