【37】使用LSTM实现文本分类、图像分类、图像生成任务


如有错误,恳请指出。


文章目录

  • 1. 文本分类任务——基于IMDB的情感分析
    • 1.1 实现思路
    • 1.2 模型定义
    • 1.3 完整代码
  • 2. 图像分类任务——手写数字MNIST分类
    • 2.1 实现思路
    • 2.2 模型定义
    • 2.3 完整代码
  • 3. 图像生成任务——手写数字MNIST生成
    • 3.1 实现思路
    • 3.2 模型定义
    • 3.3 训练与预测伪代码
    • 3.4 完整代码

在上一篇文章中,使用了LSTM来预测时序信息,接下来就继续对LSTM进行一些拓展应用:

  • 1)使用LSTM网络来对文本分类
  • 2)使用LSTM网络来对图像分类
  • 3)使用LSTM网络来生成手写数字图像

我们可以把文本,图像统统看成是序列信息,就可以让LSTM使用。这里就随便写了几个小引用稍微玩一下,发掘一下lstm的用途。

ps:也反应了时序网络的功能强大,Transformer的多种可能性,多模态的可行性。


1. 文本分类任务——基于IMDB的情感分析

1.1 实现思路

在自然语言处理领域中,LSTM可以用来进行文本预测,大体上的思路是将batch个文本,每个文本的seq个词汇编码为hidden dim长度的特征向量,所以最后的特征维度是[seq, batch, hidden dim],就可以作为输入数据。提取LSTM最后的一个隐藏单元[batch, hidden dim],再使用一个全连接层就可以实现分类。

1.2 模型定义

class LSTM(nn.Module):

    def __init__(self, vocab_size, embedding_dim=100, hidden_dim=256):
        super(LSTM, self).__init__()

        # 构建一个查找表,对每个单词进行编码,编码为一个embedding_dim长度的特征向量
        # [0-10001] => [100]
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # 构建一个双向两层的LSTM时序网络
        # [100] => [256]
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
                           bidirectional=True, dropout=0.5)

        # 初始化隐藏单元与控制单元
        # self.cell = (torch.zeros([2 * 2, batch_size, hidden_dim]),
        #              torch.zeros([2 * 2, batch_size, hidden_dim]))

        # 构建一个分类器
        # [256*2] => [1]
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        """
        x:   [seq_len, b]
        out: [b, 1]
        """

        # [seq, b] => [seq, b, embedding_dim]: 为每一个单词编码长度为embedding_dim的特征向量
        embedding = self.dropout(self.embedding(x))

        # input: [seq, b, embedding_dim]
        # output: [seq, b, hid_dim*2]
        # hidden/h: [num_layers*2, b, hid_dim]
        # cell/c: [num_layers*2, b, hid_di]
        output, (hidden, cell) = self.lstm(embedding)
        # output, (hidden, cell) = self.lstm(embedding, self.cell)

        # [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
        # hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)

        # 个人觉得,直接用lstm训练好的最后一个输出最为结果可能更好
        hidden = output[-1]

        # [b, hid_dim*2] => [b, 1]
        hidden = self.dropout(hidden)
        out = self.fc(hidden)

        return out

1.3 完整代码

import torch
from torch import nn, optim
from torchtext import data, datasets
import numpy as np

print('GPU:', torch.cuda.is_available())

torch.manual_seed(123)

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

print('len of train data:', len(train_data))
print('len of test data:', len(test_data))

print(train_data.examples[15].text)
print(train_data.examples[15].label)

# word2vec, glove
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
LABEL.build_vocab(train_data)

batchsz = 30
embedding_dim = 100
hidden_dim = 256
device = torch.device('cuda:1')
train_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, test_data),
    batch_size=batchsz,
    device=device
)

class LSTM(nn.Module):

    def __init__(self, vocab_size, embedding_dim=100, hidden_dim=256):
        super(LSTM, self).__init__()

        # 构建一个查找表,对每个单词进行编码,编码为一个embedding_dim长度的特征向量
        # [0-10001] => [100]
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # 构建一个双向两层的LSTM时序网络
        # [100] => [256]
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
                           bidirectional=True, dropout=0.5)

        # 初始化隐藏单元与控制单元
        # self.cell = (torch.zeros([2 * 2, batch_size, hidden_dim]),
        #              torch.zeros([2 * 2, batch_size, hidden_dim]))

        # 构建一个分类器
        # [256*2] => [1]
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        """
        x:   [seq_len, b]
        out: [b, 1]
        """

        # [seq, b] => [seq, b, embedding_dim]: 为每一个单词编码长度为embedding_dim的特征向量
        embedding = self.dropout(self.embedding(x))

        # input: [seq, b, embedding_dim]
        # output: [seq, b, hid_dim*2]
        # hidden/h: [num_layers*2, b, hid_dim]
        # cell/c: [num_layers*2, b, hid_di]
        output, (hidden, cell) = self.lstm(embedding)
        # output, (hidden, cell) = self.lstm(embedding, self.cell)

        # [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
        # hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)

        # 个人觉得,直接用lstm训练好的最后一个输出最为结果可能更好
        hidden = output[-1]

        # [b, hid_dim*2] => [b, 1]
        hidden = self.dropout(hidden)
        out = self.fc(hidden)

        return out


# len(TEXT.vocab): 10002, embedding_dim:100, hidden_dim:256
rnn = LSTM(vocab_size=len(TEXT.vocab), embedding_dim=100, hidden_dim=256)
# rnn.cell = (torch.zeros([2 * 2, batchsz, hidden_dim]).to(device),
#             torch.zeros([2 * 2, batchsz, hidden_dim]).to(device))
# x = torch.rand([800, 30]).long().to(device)
# print("init rnn(x).shape: ", rnn(x).shape)


# 将LSTM模型中的embedding的词汇编码进行一个初始化
pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')

optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)


def binary_acc(preds, y):
    preds = torch.round(torch.sigmoid(preds))
    correct = torch.eq(preds, y).float()
    acc = correct.sum() / len(correct)
    return acc


def train(epoch, rnn, iterator, optimizer, criteon):
    avg_acc = []
    rnn.train()

    for i, batch in enumerate(iterator):

        # [seq, b] => [b, 1] => [b]
        pred = rnn(batch.text).squeeze(1)
        #
        loss = criteon(pred, batch.label)
        acc = binary_acc(pred, batch.label).item()
        avg_acc.append(acc)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 80 == 0:
            print("i:{}  acc:{}  loss:{}".format(i, acc, loss.item()))

    avg_acc = np.array(avg_acc).mean()
    print('[epoch:{}] avg acc:{}'.format(epoch, avg_acc))

    torch.save(rnn.state_dict(), "lstm_IMDB_{}.mdl".format(epoch))


def eval(epoch, rnn, iterator, criteon):
    avg_acc = []

    rnn.eval()

    with torch.no_grad():
        for batch in iterator:
            # [seq, b] -> [b, 1] -> [b]
            pred = rnn(batch.text).squeeze(1)

            # 可有可无
            # loss = criteon(pred, batch.label)

            acc = binary_acc(pred, batch.label).item()
            avg_acc.append(acc)

    avg_acc = np.array(avg_acc).mean()
    print('>>test avg_acc:', avg_acc)
    print('-' * 80)


for epoch in range(10):
    train(epoch, rnn, train_iterator, optimizer, criteon)
    eval(epoch, rnn, test_iterator, criteon)
  • 最后的输出结果:
GPU: True
len of train data: 25000
len of test data: 25000
['EUROPA', '(', 'ZENTROPA', ')', 'is', 'a', 'masterpiece', 'that', 'gives', 'the', 'viewer', 'the', 'excitement', 'that', 'must', 'have', 'come', 'with', 'the', 'birth', 'of', 'the', 'narrative', 'film', 'nearly', 'a', 'century', 'ago', '.', 'This', 'film', 'is', 'truly', 'unique', ',', 'and', 'a', 'work', 'of', 'genius', '.', 'The', 'camerawork', 'and', 'the', 'editing', 'are', 'brilliant', ',', 'and', 'combined', 'with', 'the', 'narrative', 'tropes', 'of', 'alienation', 'used', 'in', 'the', 'film', ',', 'creates', 'an', 'eerie', 'and', 'unforgettable', 'cinematic', 'experience., '/>, '/>The', 'participation', 'of', 'Barbara', 'Suwkowa', 'and', 'Eddie', 'Constantine', 'in', 'the', 'cast', 'are', 'two', 'guilty', 'pleasures', 'that', 'should', 'be', 'seen', 'and', 'enjoyed', '.', 'Max', 'Von', 'Sydow', 'provides', 'his', 'great', 'voice', 'as', 'the', 'narrator., '/>, '/>A', 'one', 'of', 'a', 'kind', 'movie', '!', 'Four', 'stars', '(', 'highest', 'rating', ')', '.']
pos
pretrained_embedding: torch.Size([10002, 100])
embedding layer inited.
i:0  acc:0.46666669845581055  loss:0.6949790120124817
i:80  acc:0.4333333671092987  loss:0.6981948018074036
i:160  acc:0.5333333611488342  loss:0.6852760910987854
i:240  acc:0.4333333671092987  loss:0.7037591934204102
i:320  acc:0.46666669845581055  loss:0.6954788565635681
i:400  acc:0.5  loss:0.6939224004745483
i:480  acc:0.40000003576278687  loss:0.6968072652816772
i:560  acc:0.46666669845581055  loss:0.6873948574066162
i:640  acc:0.6666666865348816  loss:0.6922653317451477
i:720  acc:0.4333333671092987  loss:0.6944605112075806
i:800  acc:0.5  loss:0.6933766007423401
[epoch:0] avg acc:0.4992806038482012
--------------------------------------------------------------------------------
>>test avg_acc: 0.5025180090370652
i:0  acc:0.4333333671092987  loss:0.7024383544921875
i:80  acc:0.36666667461395264  loss:0.6963495016098022
i:160  acc:0.6666666865348816  loss:0.6909230947494507
i:240  acc:0.5666667222976685  loss:0.6839739680290222
i:320  acc:0.40000003576278687  loss:0.6974653601646423
i:400  acc:0.5666667222976685  loss:0.6987729072570801
i:480  acc:0.6000000238418579  loss:0.6852691769599915
i:560  acc:0.6333333849906921  loss:0.6866227984428406
i:640  acc:0.46666669845581055  loss:0.6967718005180359
i:720  acc:0.40000003576278687  loss:0.6900659799575806
i:800  acc:0.5333333611488342  loss:0.6896941065788269
[epoch:1] avg acc:0.5023981108010815
--------------------------------------------------------------------------------
>>test avg_acc: 0.5060751622306
i:0  acc:0.5  loss:0.6841410398483276
i:80  acc:0.46666669845581055  loss:0.6943145990371704
i:160  acc:0.46666669845581055  loss:0.6860478520393372
i:240  acc:0.5333333611488342  loss:0.6897433400154114
i:320  acc:0.5666667222976685  loss:0.6863185167312622
i:400  acc:0.40000003576278687  loss:0.7020732164382935
i:480  acc:0.4333333671092987  loss:0.6965914964675903
i:560  acc:0.5666667222976685  loss:0.6850005388259888
i:640  acc:0.5  loss:0.6887240409851074
i:720  acc:0.4333333671092987  loss:0.6993931531906128
i:800  acc:0.40000003576278687  loss:0.6990763545036316
[epoch:2] avg acc:0.49676261795796367
--------------------------------------------------------------------------------
>>test avg_acc: 0.5240208123736888
i:0  acc:0.6000000238418579  loss:0.6799271702766418
i:80  acc:0.6000000238418579  loss:0.6885167956352234
i:160  acc:0.40000003576278687  loss:0.6957201957702637
i:240  acc:0.5333333611488342  loss:0.6934034824371338
i:320  acc:0.6000000238418579  loss:0.7242684960365295
i:400  acc:0.5333333611488342  loss:0.6945787668228149
i:480  acc:0.4333333671092987  loss:0.7100253701210022
i:560  acc:0.5666667222976685  loss:0.6828211545944214
i:640  acc:0.5666667222976685  loss:0.691743016242981
i:720  acc:0.46666669845581055  loss:0.7095474600791931
i:800  acc:0.5333333611488342  loss:0.6848457455635071
[epoch:3] avg acc:0.5040767665604036
--------------------------------------------------------------------------------
>>test avg_acc: 0.5380096217032245
i:0  acc:0.4333333671092987  loss:0.7026119828224182
i:80  acc:0.40000003576278687  loss:0.7223049402236938
i:160  acc:0.5  loss:0.6943424344062805
i:240  acc:0.6333333849906921  loss:0.6763809323310852
i:320  acc:0.6333333849906921  loss:0.6853126287460327
i:400  acc:0.40000003576278687  loss:0.7056844234466553
i:480  acc:0.4333333671092987  loss:0.7089643478393555
i:560  acc:0.5  loss:0.6923120021820068
i:640  acc:0.46666669845581055  loss:0.6989496946334839
i:720  acc:0.36666667461395264  loss:0.704035222530365
i:800  acc:0.5666667222976685  loss:0.6810514330863953
[epoch:4] avg acc:0.5044764476499969
--------------------------------------------------------------------------------
>>test avg_acc: 0.5951239331770573
i:0  acc:0.5333333611488342  loss:0.6941536068916321
i:80  acc:0.6000000238418579  loss:0.6859279870986938
i:160  acc:0.5333333611488342  loss:0.690399706363678
i:240  acc:0.5  loss:0.6989421844482422
i:320  acc:0.5333333611488342  loss:0.6894619464874268
i:400  acc:0.5  loss:0.6963599324226379
i:480  acc:0.4333333671092987  loss:0.698384165763855
i:560  acc:0.5  loss:0.6836084127426147
i:640  acc:0.5  loss:0.6970310807228088
i:720  acc:0.5666667222976685  loss:0.6944922208786011
i:800  acc:0.6000000238418579  loss:0.6284491419792175
[epoch:5] avg acc:0.5089128978675503
--------------------------------------------------------------------------------
>>test avg_acc: 0.5634292881373021
i:0  acc:0.5  loss:0.6751338243484497
i:80  acc:0.4333333671092987  loss:0.7002186179161072
i:160  acc:0.6333333849906921  loss:0.7004081010818481
i:240  acc:0.40000003576278687  loss:0.6799851059913635
i:320  acc:0.3333333432674408  loss:0.7024126052856445
i:400  acc:0.46666669845581055  loss:0.706433892250061
i:480  acc:0.4333333671092987  loss:0.689089298248291
i:560  acc:0.5333333611488342  loss:0.6774276494979858
i:640  acc:0.5333333611488342  loss:0.6818891167640686
i:720  acc:0.7000000476837158  loss:0.6483676433563232
i:800  acc:0.5333333611488342  loss:0.6753848791122437
[epoch:6] avg acc:0.5055555829898917
--------------------------------------------------------------------------------
>>test avg_acc: 0.6294964369145228
i:0  acc:0.6000000238418579  loss:0.688803493976593
i:80  acc:0.5666667222976685  loss:0.6731964349746704
i:160  acc:0.46666669845581055  loss:0.7247046828269958
i:240  acc:0.5666667222976685  loss:0.6654935479164124
i:320  acc:0.36666667461395264  loss:0.714630126953125
i:400  acc:0.40000003576278687  loss:0.6986389756202698
i:480  acc:0.5666667222976685  loss:0.6774072647094727
i:560  acc:0.30000001192092896  loss:0.7261555790901184
i:640  acc:0.6000000238418579  loss:0.6980276107788086
i:720  acc:0.4333333671092987  loss:0.6940374374389648
i:800  acc:0.46666669845581055  loss:0.7052587270736694
[epoch:7] avg acc:0.5096722914875268
--------------------------------------------------------------------------------
>>test avg_acc: 0.6216627034745056
i:0  acc:0.40000003576278687  loss:0.6944260001182556
i:80  acc:0.5  loss:0.6836961507797241
i:160  acc:0.5666667222976685  loss:0.6769872307777405
i:240  acc:0.4333333671092987  loss:0.6808486580848694
i:320  acc:0.6666666865348816  loss:0.6665003895759583
i:400  acc:0.7333333492279053  loss:0.6659902930259705
i:480  acc:0.5333333611488342  loss:0.6798592209815979
i:560  acc:0.6000000238418579  loss:0.67735356092453
i:640  acc:0.36666667461395264  loss:0.6860645413398743
i:720  acc:0.5  loss:0.6970022916793823
i:800  acc:0.5  loss:0.7039815783500671
[epoch:8] avg acc:0.5031575022055377
--------------------------------------------------------------------------------
>>test avg_acc: 0.5886491138300449
i:0  acc:0.6000000238418579  loss:0.6760574579238892
i:80  acc:0.5  loss:0.7213529944419861
i:160  acc:0.7000000476837158  loss:0.6760198473930359
i:240  acc:0.46666669845581055  loss:0.7069780826568604
i:320  acc:0.46666669845581055  loss:0.6837179660797119
i:400  acc:0.5333333611488342  loss:0.6772727966308594
i:480  acc:0.5666667222976685  loss:0.6746792197227478
i:560  acc:0.46666669845581055  loss:0.7091919779777527
i:640  acc:0.5333333611488342  loss:0.6950925588607788
i:720  acc:0.46666669845581055  loss:0.7464666366577148
i:800  acc:0.5666667222976685  loss:0.6813923120498657
[epoch:9] avg acc:0.5125100204913164
--------------------------------------------------------------------------------
>>test avg_acc: 0.5081135328117606

Process finished with exit code 0

可以看见,最高准确率是0.6294964369145228,进行训练之后就明显的有点过拟合了


2. 图像分类任务——手写数字MNIST分类

2.1 实现思路

对于图像来说,其信息维度都很熟悉了,[batch, c, h, w],那么其实也可以是将图片看成是序列的信息。比如这里我把每一行看成是一个词汇,那么剩下的c*h就可以看成是图像的一个编码信息,所以一张[c, h, w]的图像基于可以看成是[h,c*w]的序列信息的组合。也就是将一批的图像[batch, c, h, w],构建成[h, batch, c*w]这么一个时序信息,就可以丢到LSTM上面去处理了。

不过显然,这么用时序网络来处理图像是有问题的,因为会丢失图像的空域信息。接下来处理就是类似为文本分类了,提取lstm网络的最后一个隐藏单元输出[batch, c*w],就可以构建一个全连接层分类器来进行分类处理。

2.2 模型定义

# 模型搭建
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size=256, nclass=10):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=4)
        self.fc = nn.Linear(hidden_size, nclass)

    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> h b (c w)')
        output, (h, c) = self.lstm(x)
        out = self.fc(output[-1])

        return out

2.3 完整代码

import torch
import torch.nn as nn
import einops
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import optim


# 超参数设置
batch_size = 128
learning_rate = 1e-3
epochsize = 10


# 模型搭建
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size=256, nclass=10):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=4)
        self.fc = nn.Linear(hidden_size, nclass)

    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> h b (c w)')
        output, (h, c) = self.lstm(x)
        out = self.fc(output[-1])

        return out


# 数据集准备
traindata = datasets.MNIST('./dataset/mnist', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=False)
trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=True)

# 测试集下载
testdata = datasets.MNIST('./dataset/mnist', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=False)
testloader = DataLoader(testdata, batch_size=batch_size, shuffle=True)


# 构建模型优化器
device = torch.device('cpu')
model = LSTM(input_size=1*28).to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# 训练过程
def train(epoch, model, criteon, optimizer):
    model.train()
    for batchidx, (image, label) in enumerate(trainloader):
        image, label = image.to(device), label.to(device)
        category = model(image)

        # 计算损失
        loss = criteon(category, label)

        # 反向更新训练
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batchidx % 30 == 0:
            print("[{}/{}] loss:{}".format(batchidx, len(trainloader), loss.item()))

    # print(epoch, 'loss:', loss.item())


# 测试过程
def eval(epoch, model):
    model.eval()
    with torch.no_grad():
        total_connect = 0  # 总的正确个数
        total_num = 0  # 总的当前测试个数

        for (image, label) in testloader:
            image, label = image.to(device), label.to(device)
            category = model(image)

            pred = category.argmax(dim=1)
            # _, pred = category.max(dim=1)

            total_connect += torch.eq(pred, label).detach().float().sum().item()
            total_num += image.size(0)

        # 计算一次训练之后计算率
        acc = total_connect / total_num
        print('epoch:', epoch, 'test_acc:', acc)

        # 保存网络结构
        torch.save(model.state_dict(), 'lstm_mnist.mdl')


for epoch in range(epochsize):
    train(epoch, model, criteon, optimizer)
    eval(epoch, model)
  • 输出效果:
[0/469] loss:2.3025128841400146
[30/469] loss:0.9521487951278687
[60/469] loss:0.7575526237487793
[90/469] loss:0.4454754590988159
[120/469] loss:0.3473556637763977
[150/469] loss:0.2933858335018158
[180/469] loss:0.20550106465816498
[210/469] loss:0.179958313703537
[240/469] loss:0.22851471602916718
[270/469] loss:0.20810575783252716
[300/469] loss:0.09591241180896759
[330/469] loss:0.15426291525363922
[360/469] loss:0.13518500328063965
[390/469] loss:0.12554696202278137
[420/469] loss:0.1267729550600052
[450/469] loss:0.1595761477947235
epoch: 0 test_acc: 0.9718
[0/469] loss:0.09293190389871597
[30/469] loss:0.11132104694843292
[60/469] loss:0.09998508542776108
[90/469] loss:0.08452940732240677
[120/469] loss:0.08355244249105453
[150/469] loss:0.1536688357591629
[180/469] loss:0.04564380645751953
[210/469] loss:0.06441441178321838
[240/469] loss:0.09115603566169739
[270/469] loss:0.08784651011228561
[300/469] loss:0.03681361302733421
[330/469] loss:0.07171863317489624
[360/469] loss:0.018698034808039665
[390/469] loss:0.10782834142446518
[420/469] loss:0.08012165129184723
[450/469] loss:0.019021611660718918
epoch: 1 test_acc: 0.9795
[0/469] loss:0.0318123959004879
[30/469] loss:0.021020444110035896
[60/469] loss:0.020755767822265625
[90/469] loss:0.024244684725999832
[120/469] loss:0.017478983849287033
[150/469] loss:0.07532914727926254
[180/469] loss:0.015447300858795643
[210/469] loss:0.04213869571685791
[240/469] loss:0.03204599395394325
[270/469] loss:0.052437808364629745
[300/469] loss:0.08051534742116928
[330/469] loss:0.07634095102548599
[360/469] loss:0.08546453714370728
[390/469] loss:0.05848413705825806
[420/469] loss:0.07495103031396866
[450/469] loss:0.027593351900577545
epoch: 2 test_acc: 0.9839
[0/469] loss:0.07221744954586029
[30/469] loss:0.06388659030199051
[60/469] loss:0.03984460607171059
[90/469] loss:0.01048417016863823
[120/469] loss:0.005487970542162657
[150/469] loss:0.08362901955842972
[180/469] loss:0.014124182052910328
[210/469] loss:0.050419360399246216
[240/469] loss:0.023143572732806206
[270/469] loss:0.02464546263217926
[300/469] loss:0.031222017481923103
[330/469] loss:0.019614635035395622
[360/469] loss:0.0390796884894371
[390/469] loss:0.025058617815375328
[420/469] loss:0.021118540316820145
[450/469] loss:0.0783773735165596
epoch: 3 test_acc: 0.9853
[0/469] loss:0.08626232296228409
[30/469] loss:0.11992666125297546
[60/469] loss:0.00394081138074398
[90/469] loss:0.023075245320796967
[120/469] loss:0.04086756333708763
[150/469] loss:0.026740970090031624
[180/469] loss:0.020538218319416046
[210/469] loss:0.006089959293603897
[240/469] loss:0.018706560134887695
[270/469] loss:0.020065028220415115
[300/469] loss:0.011625857092440128
[330/469] loss:0.051783740520477295
[360/469] loss:0.02034861594438553
[390/469] loss:0.08164804428815842
[420/469] loss:0.012055248022079468
[450/469] loss:0.03791734576225281
epoch: 4 test_acc: 0.9812
[0/469] loss:0.021869802847504616
[30/469] loss:0.10099687427282333
[60/469] loss:0.034956786781549454
[90/469] loss:0.0021370809990912676
[120/469] loss:0.011838294565677643
[150/469] loss:0.004359148442745209
[180/469] loss:0.06643899530172348
[210/469] loss:0.06558196246623993
[240/469] loss:0.050053443759679794
[270/469] loss:0.03635893762111664
[300/469] loss:0.06303692609071732
[330/469] loss:0.02340596169233322
[360/469] loss:0.029642755165696144
[390/469] loss:0.03899005800485611
[420/469] loss:0.05554909259080887
[450/469] loss:0.03344707190990448
epoch: 5 test_acc: 0.9857
[0/469] loss:0.028613073751330376
[30/469] loss:0.00897591095417738
[60/469] loss:0.012265876866877079
[90/469] loss:0.010190272703766823
[120/469] loss:0.050869863480329514
[150/469] loss:0.001720290631055832
[180/469] loss:0.008847932331264019
[210/469] loss:0.003561723046004772
[240/469] loss:0.006472764071077108
[270/469] loss:0.029855765402317047
[300/469] loss:0.0309893861413002
[330/469] loss:0.024405676871538162
[360/469] loss:0.0010678835678845644
[390/469] loss:0.06153138354420662
[420/469] loss:0.04718944802880287
[450/469] loss:0.00825027097016573
epoch: 6 test_acc: 0.9871
[0/469] loss:0.003854188835248351
[30/469] loss:0.032898884266614914
[60/469] loss:0.02628408558666706
[90/469] loss:0.008407886140048504
[120/469] loss:0.03922632709145546
[150/469] loss:0.006640949752181768
[180/469] loss:0.002945204498246312
[210/469] loss:0.07272098958492279
[240/469] loss:0.007726807612925768
[270/469] loss:0.0024011172354221344
[300/469] loss:0.007132356055080891
[330/469] loss:0.012751876376569271
[360/469] loss:0.04916739836335182
[390/469] loss:0.005492747761309147
[420/469] loss:0.031216543167829514
[450/469] loss:0.002003708854317665
epoch: 7 test_acc: 0.9906
[0/469] loss:0.004876357968896627
[30/469] loss:0.003498057136312127
[60/469] loss:0.023187248036265373
[90/469] loss:0.0395420640707016
[120/469] loss:0.003954981919378042
[150/469] loss:0.0054273889400064945
[180/469] loss:0.014193476177752018
[210/469] loss:0.02274252474308014
[240/469] loss:0.011735322885215282
[270/469] loss:0.043604783713817596
[300/469] loss:0.00319128786213696
[330/469] loss:0.0030225743539631367
[360/469] loss:0.0269723292440176
[390/469] loss:0.051725342869758606
[420/469] loss:0.06526211649179459
[450/469] loss:0.059913888573646545
epoch: 8 test_acc: 0.9886
[0/469] loss:0.0034596475306898355
[30/469] loss:0.001289673033170402
[60/469] loss:0.00738430954515934
[90/469] loss:0.002652142196893692
[120/469] loss:0.031935252249240875
[150/469] loss:0.0020228298380970955
[180/469] loss:0.0021983913611620665
[210/469] loss:0.009495465084910393
[240/469] loss:0.003172997385263443
[270/469] loss:0.07073160260915756
[300/469] loss:0.005176837090402842
[330/469] loss:0.0320693738758564
[360/469] loss:0.02262166514992714
[390/469] loss:0.010343664325773716
[420/469] loss:0.014142563566565514
[450/469] loss:0.02190088853240013
epoch: 9 test_acc: 0.9875

3. 图像生成任务——手写数字MNIST生成

3.1 实现思路

在上面使用lstm对图像进行处理后,那么有一个想法就是。既然时序网络可以根据输入的一部分内容获取一个输出或者一段内容的更新输出,那么对图像来说,就可以给定一部分的图像序列来预测下一个像素点。

具体来时候,这里我使用了手写数字一半的像素点,以手写数字为例:[1, 28, 28],我就使用了一半[1, 14, 28]来对下一个像素值进行预测。那么预测了下一个点,又可以将新的像素点囊括在序列信息内,而排除原始序列信息的第一个像素点,从而进行下一个像素点的预测。周而反复,实现下半部分手写数字的全部识别。

在实现的过程中,一张图像就是一个训练的样本,将其看成的一个序列信息,显然这样又会丢失空域特征。利用前半部分序列信息来预测这段序列的下一个像素点的信息,直到最后预测完下半部分的全部像素点,再重新reshpe一下。

思路是可以实现的,但是会存在一些问题,就比如图像的空域信息会丢失,而且根据一段序列信息中,以手写数字为例,有很多个黑色的像素点,只有少量的白色像素点来主要构成数值。但是在训练的过程中,白色像素点的权重与大多数的黑色像素点的重要性显然是天差地别,我们应该更加注重白色像素点的信息,也就设计到权重的动态分配问题。

所以,对于手写数字这种黑白图像切像素不均衡的数据集,很有可能网络为了最小化损失值,从而预测的点全部是黑色像素。这样的话大部分是黑色误差都是0,而白色像素与黑色像素的误差也不会太多,从而完全不能生成任务东西。

3.2 模型定义

# 创建一个LSTM预测模型(根据一段序列预测一个数值)
class LSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=200):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=4)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = einops.rearrange(x, 'seq -> seq 1 1')
        output, (h, c) = self.lstm(x)
        output = einops.rearrange(output, 'seq b hidden -> seq (b hidden)')
        output = self.fc(output).view(-1)
        return output

这里本应该贴上训练与预测的伪代码的,但是删了…

3.3 训练与预测伪代码

#  训练伪代码
def train():
    image = [1, 28, 28].flatten()
    mid = len(image)//2
    model = LSTM()
    for i in range(mid):

        X = image[i : mid]
        pred = model(X)
        y = image[i+1 : mid+1]
        loss = MSE(pred, y)

        # 三部曲
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
# 可视化展示伪代码
def eval():
    image = [1, 28, 28].flatten()
    mid = len(image)//2
    # X = image[:mid]
    result = image[:mid]

    for i in range(mid):
        preddata = model(result[-14*28:])
        predpoint = preddata[-1]
        result.append(predpoint)

    image.show()
    result.show()

3.4 完整代码

import torch
import torch.nn as nn
import einops
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
import cv2
from torchvision.utils import save_image
import os


# 超参数设置
batch_size = 128
learning_rate = 1e-3
epochsize = 5
os.environ['CUDA_VISIBLE_DEVICES'] = "1"


# 创建一个LSTM预测模型(根据一段序列预测一个数值)
class LSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=200):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=4)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = einops.rearrange(x, 'seq -> seq 1 1')
        output, (h, c) = self.lstm(x)
        output = einops.rearrange(output, 'seq b hidden -> seq (b hidden)')
        output = self.fc(output).view(-1)
        return output


# 数据集准备
traindata = datasets.MNIST('./dataset/mnist', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=False)
testdata = datasets.MNIST('./dataset/mnist', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=False)

# 构建模型优化器
device = torch.device('cuda')
model = LSTM().to(device)
model.load_state_dict(torch.load('lstm_mnistgen_0.mdl'))
criteon = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# 训练过程
def train(epoch):
    image_id = 1
    # 依次获取每一张图像
    for image, label in iter(traindata):
        image = image.to(device)
        image_seq = image.flatten()
        mid = len(image_seq) // 2

        loss_list = []
        for i in range(mid):
            X = image_seq[i:mid]
            pred = model(X)
            y = image_seq[i+1:mid+1]
            loss = criteon(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_list.append(loss.item())

            if i % 50 == 0:
                print("[i/mid--{}/{}]: loss :{}".format(i, mid, loss.item()))

        loss_mean = np.array(loss_list).mean()
        print("[image{}]: loss mean:{}".format(image_id, loss_mean))
        print("-"*80)
        image_id += 1

        # 保存模型
        torch.save(model.state_dict(), "lstm_mnistgen_{}.mdl".format(epoch))


def eval(epoch, model, SEED=12345):

    # 随机选取一张图像
    np.random.seed(SEED)
    image_id = np.random.randint(0, len(testdata), 1)
    image, label = testdata[int(image_id)]
    image = image.to(device)
    image_seq = image.flatten()
    mid = len(image_seq) // 2

    # 根据图像的上半部分预测图像的下半部分
    image_pred = image_seq[:mid].tolist()
    for i in tqdm(range(mid), desc="pixel pred"):
        X = torch.FloatTensor(image_pred[-mid:]).to(device)
        pred_data = model(X)
        pred_pixel = pred_data[-1]
        image_pred.append(pred_pixel)

    image_pred = torch.tensor(image_pred).detach().view(1, 28, 28)
    print("[pred success]")
    print("image_pred.shape: ", image.shape)
    print("image_pred.shape: ", image_pred.shape)

    # 显示原图与预测图像
    # real_image = Image.fromarray(image.cpu().numpy(), mode='RGB')
    # gene_image = Image.fromarray(image_pred.numpy(), mode='RGB')
    # real_image.show()
    # gene_image.show()

    # 需要用opencv来显示
    # real_image = image.permute(1, 2, 0).cpu().numpy()
    # gene_image = image_pred.permute(1, 2, 0).numpy()
    # cv2.imshow("real_image", real_image)
    # cv2.imshow("gene_image", gene_image)

    # 保存图像
    save_image(image, "real_image_{}.jpg".format(epoch))
    save_image(image_pred, "gene_image_{}.jpg".format(epoch))
    print("save image success")


for epoch in range(epochsize):
    train(epoch)
    eval(epoch, model)


# if __name__ == '__main__':
#
#     model = LSTM().to(device)
#     model.state_dict(torch.load('lstm_mnistgen_0.mdl'))
#     eval(epoch=0, model=model, SEED=123)

参考资料:

这里的文本分类任务来自龙曲良的课程资料,其他的两个是自己搞着玩的,仅供参考。

补充:对LSTM的几个应用任务我上传到了资源了,见:使用RNN与LSTM实现的5个应用.

有钱钱有积分的大佬可以看看,或者见之前的两篇文章也足够了。

你可能感兴趣的:(深度学习,lstm,分类,深度学习)