pytorch重写 Dataset

Pytorch 继承 Dataset 加载自己定义的数据

首先介绍自己的 Mydataset

import os
import glob
import csv
import random

from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

class Mydataset(Dataset):
    def __init__(self, root, resize, mode):
        super(Mydataset, self).__init__()
        self.root = root
        self.resize = resize

        self.name2label = {}  # 0,1,2 ...
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue

            self.name2label[name] = len(self.name2label.keys())
        print(self.name2label)
        self.images, self.labels = self.load_csv('imagess.csv')

        if mode == 'train':  # %60 = %0->%60
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # %20 = %60->%80
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # %20 = %80->%100
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            print(len(images), images)

            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                write = csv.writer(f)
                for img in images:
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    write.writerow([img, label])
                print('writen into csv file:', filename)

        # read csv
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)
        return images, labels

    def __len__(self):
        return len(self.images)



    def __getitem__(self, idx):
        # idx-[0->len(images)]
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)
        return img, label

    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x - mean) / std
        # x = x_hat * std + mean
        # x:[x,h,w]
        # mean: [3] -> [3, 1, 1]

        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x


def main():
    import visdom
    import time
    import torchvision

    viz = visdom.Visdom()

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor()
    ])

    tmp = torchvision.datasets.ImageFolder(root='dataset', transform=transform)
    loader = DataLoader(tmp, batch_size=32, shuffle=True)

    for x, y in loader:
        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
        time.sleep(10)
 
if __name__ == "__main__":
    main()

基于 resnet18 如何加载数据训练,首先完成一个 Flatten.py 的函数

import torch
import torch.nn as nn

import matplotlib.pyplot as plt


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2,3, i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title('{}: {}'.format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

完成 train_resnrt18.py 训练程序

import torch
import visdom
import torch.nn as nn
import torch.optim
from mydataset import Mydataset
from torch.utils.data import Dataset, DataLoader

from Flatten import Flatten
from torchvision.models.resnet import resnet18

batchsize = 32
learning_rate = 1e-5
epoches = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_db = Mydataset('datasets', 32, mode='train')
val_db = Mydataset('datasets', 32, mode='val')
test_db = Mydataset('datasets', 32, mode='test')


train_loader = DataLoader(train_db, batch_size=batchsize, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsize, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsize, num_workers=2)

# 训练模型

viz = visdom.Visdom()


def evaluate(model, loader):
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct/total


def main():
    model = resnet18(pretrained=True)  # 比较好的 model
    model = nn.Sequential(*list(model.children())[:-1],  # [b, 512, 1, 1] -> 接全连接层
                          Flatten(),  # [b, 512, 1, 1] -> [b, 512]
                          nn.Linear(512, 2)).to(device)  # 添加全连接层

    # x = torch.randn(2, 3, 224, 224)
    # print(model(x).shape)
    # 定义损失函数
    criterion = nn.CrossEntropyLoss()
    # 定义迭代参数的算法
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epoches):
        for step, (x, y) in enumerate(train_loader):
            viz.images(train_db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
            viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
            x, y = x.to(device), y.to(device)
            model.train()
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
        if epoch % 1 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                viz.line([val_acc], [global_step], win='val_acc', update='append')



    print("best acc:", best_acc, "best epoch:", best_epoch)
    torch.save(model.state_dict(), 'resnet18-circle25-50.pkl')


    print("loaded from ckpt!")
    test_acc = evaluate(model, test_loader)
    print("test acc:", test_acc)


if __name__ == "__main__":
    main()

使用 visdom 进行可视化,完成物体的识别.

你可能感兴趣的:(深度学习)