多重感知机MLP:Mnist

文章目录

    • 网络结构
    • 代码
      • common_utils.py
      • network.py
      • provider.py
      • train.py
      • test.py
      • visual.py
    • 实验
      • 训练结果
      • 测试结果
      • 可视化

网络结构

输入 过程 输出
28*28 Flatten 784
784 Linear 300
300 Linear 100
100 Linear 10

代码

文件结构:
多重感知机MLP:Mnist_第1张图片

common_utils.py

用来输出日志文件

# common_utils.py
import logging


def create_logger(log_file=None, rank=0, log_level=logging.INFO):
    logger = logging.getLogger(__name__)
    logger.setLevel(log_level if rank == 0 else 'ERROR')
    formatter = logging.Formatter('[%(asctime)s  %(filename)s %(lineno)d '
                                  '%(levelname)5s]  %(message)s')
    console = logging.StreamHandler()
    console.setLevel(log_level if rank == 0 else 'ERROR')
    console.setFormatter(formatter)
    logger.addHandler(console)
    if log_file is not None:
        file_handler = logging.FileHandler(filename=log_file)
        file_handler.setLevel(log_level if rank == 0 else 'ERROR')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    return logger

network.py

设计MLP结构,包含训练函数train_model和评估函数eval_model

# network.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import provider

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        x = self.softmax(x)
        return x

    def train_model(self, args):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.parameters(), lr=args.lr)
        scheduler = StepLR(optimizer, step_size=3, gamma=0.1)  # 学习率调度器
        train_loader = provider.GetLoader(batch_size=args.batch_size, loadType='train')
        test_loader = provider.GetLoader(batch_size=args.batch_size, loadType='test')

        best_accuracy = 0.0
        for epoch in range(args.epochs):
            self.train()
            running_loss = 0.0
            correct = 0
            total = 0

            for images, labels in train_loader:
                images = images.to(device)
                labels = labels.to(device)

                # 前向传播
                outputs = self(images)
                loss = criterion(outputs, labels)

                # 反向传播和优化
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # 统计准确率
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                running_loss += loss.item()

            train_loss = running_loss / len(train_loader)
            train_accuracy = correct / total

            # 在测试集上评估模型
            self.eval()
            test_loss = 0.0
            correct = 0
            total = 0

            with torch.no_grad():
                for images, labels in test_loader:
                    images = images.to(device)
                    labels = labels.to(device)

                    outputs = self(images)
                    loss = criterion(outputs, labels)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                    test_loss += loss.item()

            test_loss = test_loss / len(test_loader)
            test_accuracy = correct / total

            # 更新学习率
            scheduler.step()

            # 保存在验证集上表现最好的模型
            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                torch.save({
                    'model_state_dict': self.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
                    'best_accuracy': best_accuracy,
                }, 'best_model.pth')

            # 打印训练过程中的损失和准确率
            args.logger.info(f"Epoch [{epoch+1}/{args.epochs}] - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}")

        # 保存最后一个epoch的模型
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'best_accuracy': best_accuracy,
        }, 'final_model.pth')

    def eval_model(self, dataloader):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)
        self.eval()
        total = 0
        correct = 0

        with torch.no_grad():
            for images, labels in dataloader:
                images = images.to(device)
                labels = labels.to(device)

                outputs = self(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total
        return accuracy

provider.py

包含数据读取函数GetLoader和数据可视化函数visualize_loader

#  provider.py
from sklearn.preprocessing import MinMaxScaler
import torch
import torchvision
import matplotlib.pyplot as plt




def visualize_loader(loader,model=None): 
    # batch=[32*1*28*28,32]
    for batch in loader:
        break
    fig, axes = plt.subplots(4, 8, figsize=(20, 10))
    imgs=batch[0]
    labels=batch[1].numpy()
    if model==None:
        imgName='train.png'
        predicted=labels
    else:
        imgName = 'test.png'
        outputs = model(imgs)
        _, predicted = torch.max(outputs.data, 1)
        predicted = predicted.numpy()
    imgs=imgs.squeeze().numpy()
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i])
        ax.set_title(predicted[i],color='black' if predicted[i]==labels[i] else 'red')
        ax.axis('off')
    plt.tight_layout()
    plt.show()
    plt.savefig(imgName)


# loader.shape=1875*[32*1*28*28,32]
def GetLoader(path='data',batch_size=32,loadType='train'):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])
    transfer=MinMaxScaler(feature_range=(0, 255)) 
    dataset = torchvision.datasets.MNIST(root=path, train=loadType=='train',transform=transform,download =False)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return loader

train.py

训练模型

# train.py
import argparse
import datetime
import common_utils
import os
import network
import provider
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--batch_size', type=int, default=32, required=False, help='batch size for training')
    parser.add_argument('--epochs', type=int, default=7, required=False, help='number of epochs to train for')
    parser.add_argument('--lr', type=float, default=0.01, required=False, help='learning rate')
    
    log_file = 'output/'+ ('log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
    logger = common_utils.create_logger(log_file)
    parser.add_argument('--logger', type=type(logger), default=logger, help='logger')
    
    args = parser.parse_args()

    return args

def main():
    args = parse_config()
    # log to file
    args.logger.info('**********************Start logging**********************')
    for key, val in vars(args).items():
        args.logger.info('{:16} {}'.format(key, val))

    args.logger.info('**********************Start training ********************')
    model = network.MLP()
    model.train_model(args)
    args.logger.info('**********************End training **********************')

    # Evaluate the trained model
    args.logger.info('**********************Start eval ************************')
    test_loader = provider.GetLoader(batch_size=args.batch_size, loadType='test')
    test_accuracy = model.eval_model(test_loader)
    args.logger.info(f'Test Accuracy: {test_accuracy:.4f}')
    args.logger.info('**********************End eval **************************')
    args.logger.info('**********************End *******************************\n')


if __name__ == '__main__':
    main()

test.py

测试模型

# test.py
import argparse
import datetime
import common_utils
import os
import network
import provider
import torch
import torch.nn as nn
import torch.optim as optim


def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--batch_size', type=int, default=32, required=False, help='batch size for training')
    parser.add_argument('--checkpoint', type=str, default='best_model.pth', help='checkpoint to start from')
    log_file = 'output/'+ ('log_test_%s.txt' % datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
    logger = common_utils.create_logger(log_file)
    parser.add_argument('--logger', type=type(logger), default=logger, help='checkpoint to start from')
    args = parser.parse_args()
    return args

def main():
    args= parse_config()
    args.logger.info('**********************Start logging**********************')
    for key, val in vars(args).items():
        args.logger.info('{:16} {}'.format(key, val))
    args.logger.info('**********************Start testing **********************')
    test(args)
    args.logger.info('**********************End testing ************************\n\n')

    
def test(args): 
    checkpoint = torch.load(args.checkpoint)
    model=network.MLP()
    model.load_state_dict(checkpoint['model_state_dict'])
    args.logger.info(model)
    test_loader=provider.GetLoader(batch_size=args.batch_size,loadType='test')  
    test_accuracy = model.eval_model(test_loader)
    args.logger.info(f'Test Accuracy: {test_accuracy:.4f}')


if __name__ == '__main__':
    main()


visual.py

可视化代码

# visual.py
import provider
import network
import torch


train_loader=provider.GetLoader(loadType='train')
provider.visualize_loader(train_loader)

test_loader=provider.GetLoader(loadType='test')
checkpoint = torch.load('best_model.pth')
model=network.MLP()
model.load_state_dict(checkpoint['model_state_dict'])
provider.visualize_loader(test_loader,model)


实验

训练结果

[2023-07-22 10:45:31,237  train.py 30  INFO]  **********************Start logging**********************
[2023-07-22 10:45:31,237  train.py 32  INFO]  batch_size       32
[2023-07-22 10:45:31,237  train.py 32  INFO]  epochs           7
[2023-07-22 10:45:31,237  train.py 32  INFO]  lr               0.01
[2023-07-22 10:45:31,237  train.py 32  INFO]  logger           
[2023-07-22 10:45:31,237  train.py 34  INFO]  **********************Start training ********************
[2023-07-22 10:45:46,963  network.py 106  INFO]  Epoch [1/7] - Train Loss: 0.5768, Train Accuracy: 0.8446, Test Accuracy: 0.9037
[2023-07-22 10:45:59,299  network.py 106  INFO]  Epoch [2/7] - Train Loss: 0.5059, Train Accuracy: 0.8759, Test Accuracy: 0.9299
[2023-07-22 10:46:11,687  network.py 106  INFO]  Epoch [3/7] - Train Loss: 0.4536, Train Accuracy: 0.8884, Test Accuracy: 0.9198
[2023-07-22 10:46:24,010  network.py 106  INFO]  Epoch [4/7] - Train Loss: 0.3161, Train Accuracy: 0.9196, Test Accuracy: 0.9502
[2023-07-22 10:46:36,307  network.py 106  INFO]  Epoch [5/7] - Train Loss: 0.2497, Train Accuracy: 0.9350, Test Accuracy: 0.9528
[2023-07-22 10:46:48,712  network.py 106  INFO]  Epoch [6/7] - Train Loss: 0.2280, Train Accuracy: 0.9395, Test Accuracy: 0.9549
[2023-07-22 10:47:01,138  network.py 106  INFO]  Epoch [7/7] - Train Loss: 0.2078, Train Accuracy: 0.9443, Test Accuracy: 0.9573
[2023-07-22 10:47:01,155  train.py 37  INFO]  **********************End training **********************
[2023-07-22 10:47:01,155  train.py 40  INFO]  **********************Start eval ************************
[2023-07-22 10:47:02,492  train.py 43  INFO]  Test Accuracy: 0.9573
[2023-07-22 10:47:02,493  train.py 44  INFO]  **********************End eval **************************
[2023-07-22 10:47:02,493  train.py 45  INFO]  **********************End *******************************

测试结果

[2023-07-22 10:50:46,173  test.py 24  INFO]  **********************Start logging**********************
[2023-07-22 10:50:46,173  test.py 26  INFO]  batch_size       32
[2023-07-22 10:50:46,173  test.py 26  INFO]  checkpoint       best_model.pth
[2023-07-22 10:50:46,173  test.py 26  INFO]  logger           
[2023-07-22 10:50:46,173  test.py 27  INFO]  **********************Start testing **********************
[2023-07-22 10:50:49,084  test.py 36  INFO]  MLP(
(flatten): Flatten(start_dim=1, end_dim=-1)
(fc1): Linear(in_features=784, out_features=300, bias=True)
(fc2): Linear(in_features=300, out_features=100, bias=True)
(fc3): Linear(in_features=100, out_features=10, bias=True)
(relu): ReLU()
(softmax): LogSoftmax(dim=1)
(dropout): Dropout(p=0.2, inplace=False)
)
[2023-07-22 10:50:50,970  test.py 39  INFO]  Test Accuracy: 0.9573
[2023-07-22 10:50:50,970  test.py 29  INFO]  **********************End testing ************************

可视化

测试结果
多重感知机MLP:Mnist_第2张图片

你可能感兴趣的:(深度学习,神经网络,机器学习)