RNN训练mnist数据

RNN训练mnist数据

此博文主要是为了用RNN 做图像分类,来了解pytorch的RNN用于训练图像时的用法。

mnist数据集中的图像:图像1由28*28个像素点组成如下图所示
RNN训练mnist数据_第1张图片
对于此图像我们可以将每张图像看作是长28的序列,序列中的每个元素的特征维度为28.

RNN的结构
RNN训练mnist数据_第2张图片
首先处理数据

import torch
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader

from torchvision import transforms as tfs
from torchvision.datasets import MNIST

定义数据

data_tf = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5], [0.5]) # 标准化
])

train_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=True, transform=data_tf)
test_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=False, transform=data_tf)

train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)

定义模型

class rnn_classify(nn.Module):
    def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
        super(rnn_classify, self).__init__()
        self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm
        self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果
        
    def forward(self, x):
        '''
        x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)
        '''
        x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)
        x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)
        out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)
        out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)
        out = self.classifier(out) # 得到分类结果
        return out

定义网络、损失函数、优化器

net = rnn_classify()
criterion = nn.CrossEntropyLoss()

optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)

定义训练函数

from datetime import datetime

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable


def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total


def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    if torch.cuda.is_available():
        net = net.cuda()
    prev_time = datetime.now()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            if torch.cuda.is_available():
                im = Variable(im.cuda())  # (bs, 3, h, w)
                label = Variable(label.cuda())  # (bs, h, w)
            else:
                im = Variable(im)
                label = Variable(label)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                if torch.cuda.is_available():
                    im = Variable(im.cuda(), volatile=True)
                    label = Variable(label.cuda(), volatile=True)
                else:
                    im = Variable(im, volatile=True)
                    label = Variable(label, volatile=True)
                output = net(im)
                loss = criterion(output, label)
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                % (epoch, train_loss / len(train_data),
                   train_acc / len(train_data), valid_loss / len(valid_data),
                   valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)

开始训练

train(net, train_data, test_data, 10, optimzier, criterion)

完整代码:

    import torch
    from torch.autograd import Variable
    from torch import nn
    from torch.utils.data import DataLoader
    from datetime import datetime   
    import torch.nn.functional as F
    from torchvision import transforms as tfs
    from torchvision.datasets import MNIST


# 定义数据

    data_tf = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize([0.5], [0.5]) # 标准化
    ])
    
    train_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=True, transform=data_tf)
    test_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=False, transform=data_tf)
    
    train_data = DataLoader(train_set, 64, True, num_workers=4)
    test_data = DataLoader(test_set, 128, False, num_workers=4)


# 定义模型

    class rnn_classify(nn.Module):
        def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
            super(rnn_classify, self).__init__()
            self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm
            self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果
            
        def forward(self, x):
            '''
            x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)
            '''
            x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)
            x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)
            out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)
            out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)
            out = self.classifier(out) # 得到分类结果
            return out

# 定义网络、损失函数、优化器

    net = rnn_classify()
    criterion = nn.CrossEntropyLoss()
    
    optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)

# 定义训练函数


    
    
    def get_acc(output, label):
        total = output.shape[0]
        _, pred_label = output.max(1)
        num_correct = (pred_label == label).sum().item()
        return num_correct / total
    
    
    def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
        if torch.cuda.is_available():
            net = net.cuda()
        prev_time = datetime.now()
        for epoch in range(num_epochs):
            train_loss = 0
            train_acc = 0
            net = net.train()
            for im, label in train_data:
                if torch.cuda.is_available():
                    im = Variable(im.cuda())  # (bs, 3, h, w)
                    label = Variable(label.cuda())  # (bs, h, w)
                else:
                    im = Variable(im)
                    label = Variable(label)
                # forward
                output = net(im)
                loss = criterion(output, label)
                # backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
                train_loss += loss.item()
                train_acc += get_acc(output, label)
    
            cur_time = datetime.now()
            h, remainder = divmod((cur_time - prev_time).seconds, 3600)
            m, s = divmod(remainder, 60)
            time_str = "Time %02d:%02d:%02d" % (h, m, s)
            if valid_data is not None:
                valid_loss = 0
                valid_acc = 0
                net = net.eval()
                for im, label in valid_data:
                    if torch.cuda.is_available():
                        im = Variable(im.cuda(), volatile=True)
                        label = Variable(label.cuda(), volatile=True)
                    else:
                        im = Variable(im, volatile=True)
                        label = Variable(label, volatile=True)
                    output = net(im)
                    loss = criterion(output, label)
                    valid_loss += loss.item()
                    valid_acc += get_acc(output, label)
                epoch_str = (
                    "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                    % (epoch, train_loss / len(train_data),
                       train_acc / len(train_data), valid_loss / len(valid_data),
                       valid_acc / len(valid_data)))
            else:
                epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                             (epoch, train_loss / len(train_data),
                              train_acc / len(train_data)))
            prev_time = cur_time
            print(epoch_str + time_str)

注:参考廖星宇大神的《深度学习之pytorch》

你可能感兴趣的:(pytorch,pytorch,rnn,mnist)