深度学习与Pytorch入门实战(十六)情感分类实战(基于IMDB数据集)

笔记摘抄

提前安装torchtext和scapy,运行下面语句(压缩包地址链接:https://pan.baidu.com/s/1_syic9B-SXKQvkvHlEf78w 提取码:ahh3):

pip install torchtext

pip install scapy

pip install 你的地址\en_core_web_md-2.2.5.tar.gz
  • 在torchtext中使用spacy时,由于field的默认属性是tokenizer_language='en'

  • 当使用 en_core_web_md 时要改 field.py文件中 创建的field属性为tokenizer_language='en_core_web_md',且data.Field()中的参数也要改为tokenizer_language='en_core_web_md'

1. 加载数据

1.1 分割训练集测试集

import numpy as np
import torch
from torch import nn, optim
from torchtext import data, datasets

# 为CPU设置随机种子
torch.manual_seed(123)

# 两个Field对象定义字段的处理方法(文本字段、标签字段)
TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_md')  # 分词
LABEL = data.LabelField(dtype=torch.float)

# IMDB共50000影评,包含正面和负面两个类别。数据被前面的Field处理
# 按照(TEXT, LABEL) 分割成 训练集,测试集
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

print('len of train data:', len(train_data))        # 25000
print('len of test data:', len(test_data))          # 25000

# torchtext.data.Example : 用来表示一个样本,数据+标签
print(train_data.examples[15].text)                 # 文本:句子的单词列表
print(train_data.examples[15].label)                # 标签: 积极
len of train data: 25000
len of test data: 25000
['Like', 'one', 'of', 'the', 'previous', 'commenters', 'said', ',', 'this', 'had', 'the', 'foundations', 'of', 'a', 'great', 'movie', 'but', 'something', 'happened', 'on', 'the', 'way', 'to', 'delivery', '.', 'Such', 'a', 'waste', 'because', 'Collette', "'s", 'performance', 'was', 'eerie', 'and', 'Williams', 'was', 'believable', '.', 'I', 'just', 'kept', 'waiting', 'for', 'it', 'to', 'get', 'better', '.', 'I', 'do', "n't", 'think', 'it', 'was', 'bad', 'editing', 'or', 'needed', 'another', 'director', ',', 'it', 'could', 'have', 'just', 'been', 'the', 'film', '.', 'It', 'came', 'across', 'as', 'a', 'Canadian', 'movie', ',', 'something', 'like', 'the', 'first', 'few', 'seasons', 'of', 'X', '-', 'Files', '.', 'Not', 'cheap', ',', 'just', 'hokey', '.', 'Also', ',', 'it', 'needed', 'a', 'little', 'more', 'suspense', '.', 'Something', 'that', 'makes', 'you', 'jump', 'off', 'your', 'seat', '.', 'The', 'movie', 'reached', 'that', 'moment', 'then', 'faded', 'away', ';', 'kind', 'of', 'like', 'a', 'false', 'climax', '.', 'I', 'can', 'see', 'how', 'being', 'too', 'suspenseful', 'would', 'have', 'taken', 'away', 'from', 'the', '"', 'reality', '"', 'of', 'the', 'story', 'but', 'I', 'thought', 'that', 'part', 'was', 'reached', 'when', 'Gabriel', 'was', 'in', 'the', 'hospital', 'looking', 'for', 'the', 'boy', '.', 'This', 'movie', 'needs', 'to', 'have', 'a', 'Director', "'s", 'cut', 'that', 'tries', 'to', 'fix', 'these', 'problems', '.']
pos
  • 当我们把句子传进模型的时候,是按照一个个batch传进去的,而且每个batch中的句子必须是相同的长度。

  • 为了确保句子的长度相同,TorchText会把短的句子 pad到和最长的句子 等长。

1.2 创建vocabulary

  • vocabulary把每个单词一一映射到一个数字。使用10k个单词来构建单词表(用max_size这个参数可以设定),所有其他的单词都用来表示。

  • 词典中应当有10002个单词,且有两个label,可以通过TEXT.vocabTEXT.label查询,可以直接用stoi(stringtoint) 或者 itos(inttostring) 来查看单词表。

TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
LABEL.build_vocab(train_data)

print(len(TEXT.vocab))             # 10002
print(TEXT.vocab.itos[:12])        # ['', '', 'the', ',', '.', 'and', 'a', 'of', 'to', 'is', 'in', 'I']
print(TEXT.vocab.stoi['and'])      # 5
print(LABEL.vocab.stoi)            # defaultdict(None, {'neg': 0, 'pos': 1})
['', '', 'the', ',', '.', 'and', 'a', 'of', 'to', 'is', 'in', 'I']
5
defaultdict(, {'neg': 0, 'pos': 1})

1.3 创建iteratiors

  • 每个iterator中各有两部分:词(.text)和标签(.label),其中 text 全部转换成数字了

  • BucketIterator会把长度差不多的句子放到同一个batch中,确保每个batch中不出现太多的padding。

  • 这里因为pad比较少,所以把 也当做了模型的输入进行训练。

  • 如果有GPU,还可以指定每个iteration返回的tensor 都在GPU上。

batchsz = 30
train_iterator, test_iterator = data.BucketIterator.splits(
                                (train_data, test_data),
                                batch_size = batchsz,
                               )

如果要使用gpu加速,改成:

batchsz = 30
device = torch.device('cuda')
train_iterator, test_iterator = data.BucketIterator.splits(
                                (train_data, test_data),
                                batch_size = batchsz,
                                device=device
                               )

2. 定义模型

class RNN(nn.Module):

  def __init__(self, vocab_size, embedding_dim, hidden_dim):
    super(RNN, self).__init__()

    # [0-10001] => [100]
    # 参数1:embedding个数(单词数), 参数2:embedding的维度(词向量维度)
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    # [100] => [256]
    # 双向LSTM,所以下面FC层使用 hidden_dim*2
    self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
                       bidirectional=True, dropout=0.5) 
    # [256*2] => [1]
    self.fc = nn.Linear(hidden_dim*2, 1)
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    """
    x: [seq_len, b] vs [b, 3, 28, 28]
    """
    # [seq_len, b, 1] => [seq_len, b, 100]
    embedding = self.dropout(self.embedding(x))

    # output: [seq, b, hid_dim*2]
    # hidden/h: [num_layers*2, b, hid_dim]
    # cell/c: [num_layers*2, b, hid_dim]
    output, (hidden, cell) = self.rnn(embedding)
    # [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
    # 双向,所以要把最后两个输出连接
    hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
    # [b, hid_dim*2] => [b, 1]
    hidden = self.dropout(hidden)
    out = self.fc(hidden)

    return out
  • 使用 预训练过的embedding 来替换随机初始化

  • Tip:.copy_() 这种 带着下划线的函数 均代表 替换inplace

rnn = RNN(len(TEXT.vocab), 100, 256)                          #词个数,词嵌入维度,输出维度

pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)    # torch.Size([10002, 100])

# 使用预训练过的embedding来替换随机初始化
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')
pretrained_embedding: torch.Size([10002, 100])
embedding layer inited.

3. 训练模型

  • 首先定义模型和损失函数。
optimizer = optim.Adam(rnn.parameters(), lr=1e-3)

# BCEWithLogitsLoss是针对二分类的CrossEntropy
criteon = nn.BCEWithLogitsLoss()

如果使用GPU加速,改成:

optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
# BCEWithLogitsLoss是针对二分类的CrossEntropy
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)
RNN(
  (embedding): Embedding(10002, 100)
  (rnn): LSTM(100, 256, num_layers=2, dropout=0.5, bidirectional=True)
  (fc): Linear(in_features=512, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
  • 定义一个函数用于计算准确率
def binary_acc(preds, y):

    preds = torch.round(torch.sigmoid(preds))
    correct = torch.eq(preds, y).float()
    acc = correct.sum() / len(correct)
    return acc
  • 定义一个训练函数
def train(rnn, iterator, optimizer, criteon):
    avg_acc = []
    rnn.train()   # 表示进入训练模式

    for i, batch in enumerate(iterator):
        # [seq, b] => [b, 1] => [b]
        # batch.text 就是上面forward函数的参数text,压缩维度是为了和batch.label维度一致
        pred = rnn(batch.text).squeeze(1)

        loss = criteon(pred, batch.label)
        # 计算每个batch的准确率
        acc = binary_acc(pred, batch.label).item()
        avg_acc.append(acc)

        optimizer.zero_grad()  # 清零梯度准备计算
        loss.backward()        # 反向传播
        optimizer.step()       # 更新训练参数

        if i % 10 == 0:
            print(i, acc)

    avg_acc = np.array(avg_acc).mean()
    print('avg acc:', avg_acc)

4. 评估模型

  • 定义一个评估函数,和训练函数高度重合

  • 区别是要把rnn.train()改为rnn.val(),不需要反向传播过程。

def evaluate(rnn, iterator, criteon):
    avg_acc = []
    rnn.eval()         # 表示进入测试模式

    with torch.no_grad():
        for batch in iterator:
            pred = rnn(batch.text).squeeze(1)      # [b, 1] => [b]
            loss = criteon(pred, batch.label)
            acc = binary_acc(pred, batch.label).item()
            avg_acc.append(acc)

    avg_acc = np.array(avg_acc).mean()

    print('test acc:', avg_acc)

5. 运行

for epoch in range(10):
    # 训练模型
    train(rnn, train_iterator, optimizer, criteon)
    # 评估模型
    evaluate(rnn, test_iterator, criteon)
view result
0 0.8666667342185974
10 0.9666666984558105
20 0.8000000715255737
30 0.8666667342185974
40 0.8666667342185974
50 0.8000000715255737
60 0.9333333969116211
70 0.7666667103767395
80 0.9000000357627869
90 0.8666667342185974
100 0.9000000357627869
110 0.7666667103767395
120 0.8000000715255737
130 0.9666666984558105
140 0.8666667342185974
150 0.9000000357627869
160 0.9000000357627869
170 0.9000000357627869
180 0.8000000715255737
190 0.8000000715255737
200 0.9333333969116211
210 0.9000000357627869
220 0.9333333969116211
230 0.8666667342185974
240 0.9000000357627869
250 0.7666667103767395
260 0.9333333969116211
270 0.9000000357627869
280 0.8000000715255737
290 0.8666667342185974
300 0.9333333969116211
310 0.7666667103767395
320 0.9000000357627869
330 0.9666666984558105
340 0.9666666984558105
350 0.8333333730697632
360 0.9000000357627869
370 0.8000000715255737
380 0.9000000357627869
390 0.8666667342185974
400 0.8333333730697632
410 0.9000000357627869
420 0.9333333969116211
430 0.8333333730697632
440 0.8666667342185974
450 0.8000000715255737
460 0.9333333969116211
470 0.8666667342185974
480 0.9333333969116211
490 0.9333333969116211
500 0.9000000357627869
510 0.8333333730697632
520 0.8666667342185974
530 0.9333333969116211
540 0.9333333969116211
550 0.7666667103767395
560 0.8333333730697632
570 0.9333333969116211
580 0.9000000357627869
590 0.9333333969116211
600 0.9000000357627869
610 0.8333333730697632
620 0.7333333492279053
630 0.8333333730697632
640 0.8333333730697632
650 0.9000000357627869
660 0.9333333969116211
670 0.8000000715255737
680 0.9000000357627869
690 0.9000000357627869
700 0.9000000357627869
710 0.9333333969116211
720 0.8000000715255737
730 0.9333333969116211
740 0.9666666984558105
750 0.9666666984558105
760 0.9333333969116211
770 0.8666667342185974
780 0.8666667342185974
790 0.8666667342185974
800 0.9666666984558105
810 0.9000000357627869
820 0.9000000357627869
830 0.9333333969116211
avg acc: 0.8855715916454078
test acc: 0.8775779855051201
0 0.9000000357627869
10 0.9666666984558105
20 0.9000000357627869
30 0.9000000357627869
40 0.9666666984558105
50 0.9666666984558105
60 0.7666667103767395
70 0.8666667342185974
80 0.9333333969116211
90 0.9000000357627869
100 0.9333333969116211
110 0.8666667342185974
120 0.9000000357627869
130 0.9000000357627869
140 0.8666667342185974
150 0.8333333730697632
160 0.8333333730697632
170 0.9333333969116211
180 0.8333333730697632
190 0.9000000357627869
200 0.8666667342185974
210 1.0
220 1.0
230 0.9666666984558105
240 0.9000000357627869
250 0.8000000715255737
260 0.9333333969116211
270 0.9666666984558105
280 0.9333333969116211
290 0.9666666984558105
300 0.9000000357627869
310 0.9333333969116211
320 0.9333333969116211
330 0.9666666984558105
340 0.9666666984558105
350 0.9666666984558105
360 0.9333333969116211
370 0.9666666984558105
380 0.8333333730697632
390 0.7333333492279053
400 0.9000000357627869
410 0.9000000357627869
420 0.8000000715255737
430 0.9333333969116211
440 0.8666667342185974
450 0.9333333969116211
460 0.8333333730697632
470 0.9333333969116211
480 0.9333333969116211
490 0.8000000715255737
500 0.9666666984558105
510 0.9000000357627869
520 1.0
530 0.9666666984558105
540 1.0
550 0.9333333969116211
560 0.9000000357627869
570 1.0
580 0.9000000357627869
590 0.9000000357627869
600 0.8666667342185974
610 0.8333333730697632
620 0.9000000357627869
630 0.9000000357627869
640 0.8666667342185974
650 0.9000000357627869
660 0.9666666984558105
670 0.9333333969116211
680 0.8666667342185974
690 0.9000000357627869
700 0.8666667342185974
710 0.9333333969116211
720 0.9666666984558105
730 0.9666666984558105
740 0.9666666984558105
750 0.9000000357627869
760 0.9000000357627869
770 0.9000000357627869
780 0.9333333969116211
790 0.9333333969116211
800 0.9333333969116211
810 0.8666667342185974
820 0.9000000357627869
830 0.9000000357627869
avg acc: 0.9071942910873633
test acc: 0.8886890964542361
0 0.9333333969116211
10 0.9333333969116211
20 0.9666666984558105
30 0.9333333969116211
40 0.9333333969116211
50 0.8666667342185974
60 1.0
70 0.8333333730697632
80 0.9666666984558105
90 0.9000000357627869
100 0.9666666984558105
110 0.9666666984558105
120 0.9333333969116211
130 0.9333333969116211
140 0.9000000357627869
150 0.9666666984558105
160 0.8666667342185974
170 0.9666666984558105
180 0.9666666984558105
190 0.9333333969116211
200 0.9333333969116211
210 0.8666667342185974
220 0.9000000357627869
230 0.8333333730697632
240 0.9333333969116211
250 0.8000000715255737
260 0.8666667342185974
270 0.9000000357627869
280 0.9000000357627869
290 0.9666666984558105
300 0.9333333969116211
310 0.9000000357627869
320 0.9333333969116211
330 0.9666666984558105
340 0.9000000357627869
350 1.0
360 0.9666666984558105
370 0.9333333969116211
380 0.9333333969116211
390 0.9666666984558105
400 0.9666666984558105
410 0.9666666984558105
420 1.0
430 0.9000000357627869
440 1.0
450 0.9000000357627869
460 0.9333333969116211
470 1.0
480 0.9000000357627869
490 0.9333333969116211
500 0.9000000357627869
510 0.9000000357627869
520 0.9333333969116211
530 0.9333333969116211
540 0.9666666984558105
550 0.9666666984558105
560 0.9666666984558105
570 0.9666666984558105
580 0.8333333730697632
590 0.9666666984558105
600 0.9333333969116211
610 0.9333333969116211
620 0.9333333969116211
630 1.0
640 0.9000000357627869
650 0.8666667342185974
660 0.9333333969116211
670 0.8666667342185974
680 0.9666666984558105
690 0.9333333969116211
700 1.0
710 0.9666666984558105
720 0.9666666984558105
730 0.9000000357627869
740 0.9333333969116211
750 0.9666666984558105
760 1.0
770 0.8666667342185974
780 0.9000000357627869
790 0.9333333969116211
800 0.9666666984558105
810 0.9000000357627869
820 0.9666666984558105
830 0.8000000715255737
avg acc: 0.9266587171337302
test acc: 0.8872902161068768
0 0.9333333969116211
10 1.0
20 1.0
30 0.9666666984558105
40 0.9666666984558105
50 1.0
60 0.9333333969116211
70 0.9666666984558105
80 0.8666667342185974
90 0.9666666984558105
100 0.9333333969116211
110 0.8666667342185974
120 0.9333333969116211
130 0.9000000357627869
140 0.8333333730697632
150 0.9666666984558105
160 0.9666666984558105
170 0.8666667342185974
180 0.9666666984558105
190 0.9666666984558105
200 0.9333333969116211
210 0.9333333969116211
220 0.9666666984558105
230 0.9666666984558105
240 0.9000000357627869
250 1.0
260 0.9333333969116211
270 0.9666666984558105
280 0.9333333969116211
290 0.9000000357627869
300 1.0
310 0.9333333969116211
320 0.9666666984558105
330 0.9666666984558105
340 0.9333333969116211
350 0.9333333969116211
360 0.9333333969116211
370 0.9333333969116211
380 1.0
390 1.0
400 0.9333333969116211
410 1.0
420 0.9333333969116211
430 0.9666666984558105
440 0.9333333969116211
450 0.9333333969116211
460 0.9666666984558105
470 0.8333333730697632
480 1.0
490 0.9333333969116211
500 0.9666666984558105
510 0.9000000357627869
520 0.9000000357627869
530 1.0
540 0.9333333969116211
550 0.9666666984558105
560 0.9000000357627869
570 0.9333333969116211
580 0.9333333969116211
590 0.9666666984558105
600 0.8333333730697632
610 0.9333333969116211
620 0.8666667342185974
630 0.9000000357627869
640 0.9333333969116211
650 0.9666666984558105
660 0.9666666984558105
670 0.9333333969116211
680 0.9333333969116211
690 0.9333333969116211
700 0.9666666984558105
710 0.9000000357627869
720 0.9333333969116211
730 1.0
740 0.9666666984558105
750 0.9333333969116211
760 0.9666666984558105
770 0.8333333730697632
780 0.9666666984558105
790 0.9000000357627869
800 0.9000000357627869
810 0.9000000357627869
820 0.9666666984558105
830 0.9666666984558105
avg acc: 0.9356515197445163
test acc: 0.890008042184569
0 1.0
10 1.0
20 0.9000000357627869
30 0.8666667342185974
40 0.9000000357627869
50 0.9333333969116211
60 0.9000000357627869
70 0.9666666984558105
80 0.8666667342185974
90 0.9000000357627869
100 0.9333333969116211
110 1.0
120 0.9666666984558105
130 0.9666666984558105
140 1.0
150 0.9333333969116211
160 0.9333333969116211
170 0.9333333969116211
180 1.0
190 0.9666666984558105
200 0.9333333969116211
210 1.0
220 0.9666666984558105
230 1.0
240 0.9333333969116211
250 0.8333333730697632
260 0.9666666984558105
270 0.9333333969116211
280 0.9000000357627869
290 1.0
300 0.9666666984558105
310 0.9333333969116211
320 0.9000000357627869
330 0.9000000357627869
340 1.0
350 0.9666666984558105
360 1.0
370 0.9666666984558105
380 0.9000000357627869
390 0.9666666984558105
400 0.9666666984558105
410 0.9333333969116211
420 0.9000000357627869
430 1.0
440 0.9333333969116211
450 0.9666666984558105
460 0.9666666984558105
470 1.0
480 1.0
490 0.9666666984558105
500 1.0
510 1.0
520 1.0
530 1.0
540 0.8666667342185974
550 1.0
560 0.9333333969116211
570 0.9333333969116211
580 0.9666666984558105
590 0.9666666984558105
600 0.9333333969116211
610 0.9000000357627869
620 0.9333333969116211
630 0.9666666984558105
640 0.9666666984558105
650 0.9333333969116211
660 0.9333333969116211
670 0.9000000357627869
680 0.9333333969116211
690 0.9000000357627869
700 0.9333333969116211
710 0.9666666984558105
720 0.9666666984558105
730 0.9333333969116211
740 0.9333333969116211
750 1.0
760 0.9666666984558105
770 0.9333333969116211
780 0.9333333969116211
790 0.9000000357627869
800 1.0
810 0.9000000357627869
820 1.0
830 0.9000000357627869
avg acc: 0.9450040338136595
test acc: 0.8848521674422624
0 1.0
10 1.0
20 0.9666666984558105
30 0.9666666984558105
40 1.0
50 1.0
60 0.9666666984558105
70 1.0
80 0.9666666984558105
100 0.9666666984558105
110 0.9666666984558105
120 0.9333333969116211
130 0.9666666984558105
140 0.9666666984558105
150 1.0
160 0.9666666984558105
170 1.0
180 1.0
190 0.9666666984558105
200 0.8666667342185974
210 1.0
220 0.8666667342185974
230 0.9666666984558105
240 0.9333333969116211
250 0.8333333730697632
260 0.9666666984558105
270 0.9666666984558105
280 0.9000000357627869
290 0.9666666984558105
300 0.9666666984558105
310 0.9333333969116211
320 1.0
330 0.9666666984558105
340 0.9666666984558105
350 0.9333333969116211
360 0.9000000357627869
370 0.8666667342185974
380 0.9333333969116211
390 0.8333333730697632
400 0.9666666984558105
410 1.0
420 0.9666666984558105
430 0.9666666984558105
440 1.0
450 0.9666666984558105
460 0.9333333969116211
470 1.0
480 0.9666666984558105
490 1.0
500 0.9666666984558105
510 0.9333333969116211
520 0.8666667342185974
530 0.9666666984558105
540 1.0
550 1.0
560 0.9333333969116211
570 0.9333333969116211
580 1.0
590 0.9666666984558105
600 0.9666666984558105
610 0.9666666984558105
620 0.9666666984558105
630 0.9666666984558105
640 0.9333333969116211
650 0.9000000357627869
660 0.9333333969116211
670 1.0
680 0.9333333969116211
690 0.9666666984558105
700 0.9333333969116211
710 1.0
720 0.9333333969116211
730 1.0
740 0.9666666984558105
750 0.9666666984558105
760 0.8666667342185974
770 0.9000000357627869
780 0.8000000715255737
790 0.9666666984558105
800 0.9666666984558105
810 0.8666667342185974
820 1.0
830 0.9666666984558105
avg acc: 0.9509592677334802
test acc: 0.8718625588668621
0 1.0
10 1.0
20 0.9666666984558105
30 0.9333333969116211
40 1.0
50 0.9666666984558105
60 0.9666666984558105
70 0.9666666984558105
80 1.0
90 0.9333333969116211
100 1.0
110 0.9666666984558105
120 0.9666666984558105
130 0.9666666984558105
140 0.9666666984558105
150 0.9666666984558105
160 0.9666666984558105
170 1.0
180 0.9666666984558105
190 0.9000000357627869
200 1.0
210 1.0
220 0.9333333969116211
230 1.0
240 0.9666666984558105
250 1.0
260 0.9666666984558105
270 0.9666666984558105
280 0.9333333969116211
290 0.9333333969116211
300 0.9666666984558105
310 0.9666666984558105
320 0.9666666984558105
330 0.9333333969116211
340 1.0
350 0.9333333969116211
360 0.9666666984558105
370 0.9333333969116211
380 0.9666666984558105
390 0.9333333969116211
400 0.9666666984558105
410 0.9666666984558105
420 0.9666666984558105
430 0.9333333969116211
440 0.9333333969116211
450 0.9666666984558105
460 1.0
470 1.0
480 0.9666666984558105
490 0.9333333969116211
500 0.9666666984558105
510 0.9333333969116211
520 0.9666666984558105
530 0.9666666984558105
540 1.0
550 0.9666666984558105
560 0.9333333969116211
570 1.0
580 0.9666666984558105
590 0.9666666984558105
600 1.0
610 0.9000000357627869
620 0.9333333969116211
630 0.9333333969116211
640 0.9333333969116211
650 0.9666666984558105
660 0.9000000357627869
670 0.9000000357627869
680 1.0
690 0.9333333969116211
700 0.9666666984558105
710 0.8000000715255737
720 0.9333333969116211
730 0.8666667342185974
740 0.9333333969116211
750 0.9666666984558105
760 1.0
770 0.9333333969116211
780 0.9000000357627869
790 0.9666666984558105
800 0.9333333969116211
810 0.8666667342185974
820 0.9000000357627869
830 0.9666666984558105
avg acc: 0.9605116213111283
test acc: 0.8822142779827118
0 0.9666666984558105
10 0.9666666984558105
20 1.0
30 0.9666666984558105
40 1.0
50 0.9666666984558105
60 1.0
70 0.9000000357627869
80 1.0
90 0.9666666984558105
100 0.9333333969116211
110 1.0
120 1.0
130 0.9666666984558105
140 0.9666666984558105
150 1.0
160 0.9666666984558105
170 0.9333333969116211
180 0.9666666984558105
190 0.9333333969116211
200 0.9666666984558105
210 1.0
220 0.9666666984558105
230 1.0
240 0.9666666984558105
250 1.0
260 0.9333333969116211
270 0.9666666984558105
280 0.9000000357627869
290 1.0
300 0.9333333969116211
310 0.9666666984558105
320 0.9666666984558105
330 0.9333333969116211
340 1.0
350 0.9333333969116211
360 0.9666666984558105
370 1.0
380 1.0
390 0.9000000357627869
400 1.0
410 1.0
420 1.0
430 1.0
440 1.0
450 0.9666666984558105
460 0.9000000357627869
470 1.0
480 1.0
490 0.8666667342185974
500 1.0
510 1.0
520 1.0
530 0.9666666984558105
540 0.9000000357627869
550 1.0
560 0.9333333969116211
570 0.9666666984558105
580 1.0
590 0.9666666984558105
600 0.9333333969116211
610 0.9666666984558105
620 0.9666666984558105
630 1.0
640 0.9000000357627869
650 0.9666666984558105
660 1.0
670 0.9000000357627869
680 0.9333333969116211
690 1.0
700 1.0
710 1.0
720 0.9666666984558105
730 1.0
740 1.0
750 1.0
760 1.0
770 0.8666667342185974
780 0.9666666984558105
790 0.9333333969116211
800 0.9666666984558105
810 1.0
820 1.0
830 0.9666666984558105
avg acc: 0.9653077817363419
test acc: 0.8769784666222634
0 1.0
10 0.9666666984558105
20 1.0
30 0.9333333969116211
40 1.0
50 1.0
60 1.0
70 1.0
80 0.9666666984558105
90 0.9333333969116211
100 0.9666666984558105
110 0.9666666984558105
120 1.0
130 0.9666666984558105
140 1.0
150 1.0
160 0.9666666984558105
170 1.0
180 0.9333333969116211
190 1.0
200 0.9666666984558105
210 1.0
220 0.8333333730697632
230 1.0
240 1.0
250 0.9666666984558105
260 0.9666666984558105
270 0.9000000357627869
280 0.9666666984558105
290 0.9333333969116211
300 0.9666666984558105
310 0.9666666984558105
320 0.9333333969116211
330 1.0
340 1.0
350 0.9333333969116211
360 0.9666666984558105
370 0.9666666984558105
380 0.9666666984558105
390 1.0
400 0.9333333969116211
410 0.9333333969116211
420 1.0
430 0.9666666984558105
440 0.9666666984558105
450 0.9333333969116211
460 1.0
470 0.9666666984558105
480 1.0
490 1.0
500 0.9333333969116211
510 0.9666666984558105
520 1.0
530 0.9333333969116211
540 0.9666666984558105
550 0.9333333969116211
560 0.9333333969116211
570 0.9333333969116211
580 1.0
590 1.0
600 0.9333333969116211
610 0.9666666984558105
620 1.0
630 1.0
640 1.0
650 0.9666666984558105
660 1.0
670 1.0
680 1.0
690 0.9333333969116211
700 1.0
710 0.9333333969116211
720 1.0
730 1.0
740 0.9666666984558105
750 0.9000000357627869
760 0.9000000357627869
770 0.9333333969116211
780 0.9666666984558105
790 1.0
800 1.0
810 0.9666666984558105
820 0.9666666984558105
830 1.0
avg acc: 0.9697442299885144
test acc: 0.8815348212667506
0 0.9666666984558105
10 0.9666666984558105
20 0.9666666984558105
30 0.9666666984558105
40 0.9666666984558105
50 0.8666667342185974
60 1.0
70 1.0
80 1.0
90 1.0
100 1.0
110 0.9666666984558105
120 1.0
130 1.0
140 1.0
150 0.9666666984558105
160 1.0
170 0.9333333969116211
180 1.0
190 0.9000000357627869
200 1.0
210 0.8666667342185974
220 1.0
230 1.0
240 1.0
250 0.9000000357627869
260 1.0
270 1.0
280 0.9666666984558105
290 0.9666666984558105
300 0.9666666984558105
310 0.9666666984558105
320 0.9666666984558105
330 1.0
340 1.0
350 0.9333333969116211
360 0.9666666984558105
370 1.0
380 0.9666666984558105
390 1.0
400 0.9666666984558105
410 1.0
420 1.0
430 1.0
440 1.0
450 1.0
460 1.0
470 1.0
480 0.9666666984558105
490 1.0
500 1.0
510 1.0
520 0.9666666984558105
530 0.9666666984558105
540 0.9666666984558105
550 0.9000000357627869
560 0.9000000357627869
570 0.9666666984558105
580 1.0
590 0.9666666984558105
600 1.0
610 0.9666666984558105
620 1.0
630 0.9666666984558105
640 0.9666666984558105
650 1.0
660 0.9666666984558105
670 1.0
680 0.9666666984558105
690 0.9666666984558105
700 0.9666666984558105
710 1.0
720 0.8666667342185974
730 1.0
740 0.9666666984558105
750 0.9333333969116211
760 1.0
770 0.9666666984558105
780 1.0
790 0.9666666984558105
800 1.0
810 1.0
820 1.0
830 0.9333333969116211
avg acc: 0.9726618941453435
test acc: 0.8754996503714463

6. 预测

  • 输出的预测:是('pos':1, 'neg':0)字符串的编号
for batch in test_iterator:
    # batch_size个预测
    preds = rnn(batch.text).squeeze(1)
    preds = predice_test(preds)
    # print(preds)

    i = 0
    for text in batch.text:
        # 遍历一句话里的每个单词
        for word in text:
            print(TEXT.vocab.itos[word], end=' ')
    
        print('')
        # 输出3句话
        if i == 3:
            break
        i = i + 1

    i = 0
    for pred in preds:
        idx = int(pred.item())
        print(idx, LABEL.vocab.itos[idx])
        # 输出3个结果(标签)
        if i == 3:
            break
        i = i + 1
    break
Anyone  Great A If  Without The Brilliant This  This If This Ten Absolutely For A This One Add a Just This I More What Brilliant Read  
who Classic story great you hires a  . movie it is you is minutes fantastic pure touching is of this mesmerizing love is hope suspenseful a and the  
gives Waters , film 've a doubt mixed  is with the like quite of !  movie a the little film the a this , script moving book  
this ! great in ever psychopath , with along terrible all greatest  possibly people Whatever vampire . good funniest gem that interplay great group more , performances , interpretation 
1 pos
1 pos
1 pos
1 pos

你可能感兴趣的:(深度学习与Pytorch入门实战(十六)情感分类实战(基于IMDB数据集))