Pytorch + DataSet + DataLoader实现k折交叉验证

 第一步:继承DataSet并创建自己的PaddyDataSet---一般是重写__init__()方法,__get_item()__方法,补充__len()__方法。

# PaddyDataSet
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np
from torch.utils.data import DataLoader
import numpy as np
import torchvision.transforms as transforms

paddy_labels = {'bacterial_leaf_blight':0,'bacterial_leaf_streak':1,'bacterial_panicle_blight':2,'blast':3,'brown_spot':4,
              'dead_heart':5, 'downy_mildew':6, 'hispa':7, 'normal':8, 'tungro':9}

class PaddyDataSet(Dataset):
    def __init__(self, data_dir,transform=None):
        """
        数据集
        
        """
        self.label_name={'bacterial_leaf_blight':0,'bacterial_leaf_streak':1,'bacterial_panicle_blight':2,'blast':3,'brown_spot':4,
              'dead_heart':5, 'downy_mildew':6, 'hispa':7, 'normal':8, 'tungro':9}
        
        # data_info 存储所有图片路径和标签, 在DataLoader中通过index读取样本
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform
        self.temp = np.zeros((640,480))
    
    def __getitem__(self,index):
        path_img,label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')

        # print(img.size)
        if img.size == self.temp.shape:
            img = img.resize((480,640))
            # print(img.size)
        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.data_info)
    
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'),img_names))
                
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    # print(sub_dir)
                    label = paddy_labels[sub_dir]
                    data_info.append((path_img, int(label)))
                   
        return data_info
                

第二步:加载数据集和模型

# 读取模型
import torchvision
from torch import nn
import torch.optim as optim
import torch.nn.functional as F


# 加载resnet50预训练模型
ResNet = torchvision.models.resnet50(pretrained=True)
# 修改模型最后一层
ResNet.add_module("add_linear", nn.Linear(1000,10))
# 使用Adam优化
optimizer = optim.Adam(ResNet.parameters(), lr=0.0001)

# 读取数据集后再进行划分
data_dir = "./DataSet/train_images"
data = PaddyDataSet(data_dir=data_dir,transform=transforms.Compose([transforms.ToTensor()]))


第三步:划分训练集和验证集,并进行模型的训练和验证

def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()



# # 先划分成 十份

from sklearn.model_selection import KFold

kf = KFold(n_splits=10, shuffle=True, random_state=0)
# train_fold = []

# for epoch in range(20):
epoch = 1
for train_index, val_index in kf.split(data):
    # print(train_index, test_index)

    train_fold = torch.utils.data.dataset.Subset(data, train_index)
    val_fold = torch.utils.data.dataset.Subset(data, val_index)    

    # 打包成DataLoader类型 用于 训练
    train_loader = DataLoader(dataset=train_fold, batch_size=16, shuffle=True)
    val_loader = DataLoader(dataset=val_fold, batch_size=16, shuffle=True)
    train_size = len(train_loader)
    val_size = len(val_loader)
    ResNet.train()
    train_loss = 0
    train_correct = 0
    batch_num = 1
    # 开始进行训练
    print("训练开始.....")
    for batch in train_loader:
        # if(batch_num % 20 == 0):
        print("epoch: {} --- >>> batch_num: {}".format(epoch, batch_num))
        batch_num += 1
        images, labels = batch
        preds = ResNet(images)
        loss = F.cross_entropy(preds, labels)
        optimizer.zero_grad()
        loss.backward()

        train_loss += loss.item()
        train_correct += get_num_correct(preds, labels)
    # 验证
    print("验证开始.....")
    ResNet.eval()
    val_loss = 0
    val_correct = 0
    with torch.no_grad():
        for val_data in val_loader:
            images, labels = val_data
            outputs = ResNet(images)
            loss = F.cross_entropy(outputs, labels)
            val_loss += loss.item()
            val_correct += get_num_correct(outputs, labels)
    print("epoch {}: train_loss: {}   train_acccuracy: {}%    val_loss: {}    val_accuracy: {}% ".format(epoch, train_loss,
                                                                                                     train_correct/train_size * 100,
                                                                                                     val_loss, val_correct/val_size * 100
                                                                                                     ))   
    epoch += 1             

你可能感兴趣的:(pytorch,深度学习,人工智能)