Datawhale零基础入门CV赛事(街景字符编码识别)-Task4 模型训练与验证

# -*- coding: utf-8 -*-

'''
@Time    : 2020/5/26 20:59
@Author  : HHNa
@FileName: 3.model_train_val.py
@Software: PyCharm
 
'''
import torch
import numpy as np
import torch.nn as nn
import glob, json
from dataset import SVHNDataset
from model import SVHN_Model1
import torchvision.transforms as transforms
from tensorboard_logger import Logger


def train(train_loader, model, criterion, optimizer):
    # 切换模型为训练模式
    model.train()
    train_loss = []

    for i, (input, target) in enumerate(train_loader):
        c0, c1, c2, c3, c4= model(input)
        target = target.long()
        loss = criterion(c0, target[:, 0]) + \
                criterion(c1, target[:, 1]) + \
                criterion(c2, target[:, 2]) + \
                criterion(c3, target[:, 3]) + \
                criterion(c4, target[:, 4])

        loss /= 5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
        if i % 100 ==0:
            print(loss.item())
    return np.mean(train_loss)

def validate(val_loader, model, criterion):
    # 切换模型为预测模型
    model.eval()
    val_loss = []

    # 不记录模型梯度信息
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            c0, c1, c2, c3, c4 = model(input)
            target = target.long()
            loss = criterion(c0, target[:, 0]) + \
                    criterion(c1, target[:, 1]) + \
                    criterion(c2, target[:, 2]) + \
                    criterion(c3, target[:, 3]) + \
                    criterion(c4, target[:, 4])
            loss /= 5
            val_loss.append(loss.item())
    return np.mean(val_loss)


if __name__ == "__main__":
    train_path = glob.glob('./data/train/mchar_train/*.png')
    train_path.sort()
    train_json = json.load(open('./data/train/mchar_train.json'))
    train_label = [train_json[x]['label'] for x in train_json]

    val_path = glob.glob('./data/val/mchar_val/*.png')
    val_path.sort()
    val_json = json.load(open('./data/val/mchar_val.json'))
    val_label = [train_json[x]['label'] for x in val_json]

    train_dataset = SVHNDataset(train_path, train_label,
                transforms.Compose([
                    transforms.Resize((64, 128)),
                    transforms.ColorJitter(0.3, 0.3, 0.2),
                    transforms.RandomRotation(5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ]))
    val_dataset = SVHNDataset(val_json, val_label,
                transforms.Compose([
                    transforms.Resize((64, 128)),
                    transforms.ColorJitter(0.3, 0.3, 0.2),
                    transforms.RandomRotation(5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ]))
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=10,
        shuffle=True,
        num_workers=10,
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=10,
        shuffle=False,
        num_workers=10,
    )

    model = SVHN_Model1()

    criterion = nn.CrossEntropyLoss(size_average=False)
    optimizer = torch.optim.Adam(model.parameters(), 0.001)
    best_loss = 1000.0
    losses = []
    for epoch in range(20):
        print('Epoch: ', epoch)

        train(train_loader, model, criterion, optimizer)
        val_loss = validate(val_loader, model, criterion)

        # 记录下验证集精度
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), './model.pt')

有点bug:

Datawhale零基础入门CV赛事(街景字符编码识别)-Task4 模型训练与验证_第1张图片

参考:https://github.com/datawhalechina/team-learning/blob/master/03%20%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89%E5%AE%9E%E8%B7%B5%EF%BC%88%E8%A1%97%E6%99%AF%E5%AD%97%E7%AC%A6%E7%BC%96%E7%A0%81%E8%AF%86%E5%88%AB%EF%BC%89/Datawhale%20%E9%9B%B6%E5%9F%BA%E7%A1%80%E5%85%A5%E9%97%A8CV%20-%20Task%2004%20%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83%E4%B8%8E%E9%AA%8C%E8%AF%81%20.md

你可能感兴趣的:(街景字符编码识别)