最近在学习NLP的相关知识,找了资料比较全的黑马程序员中讲解NLP的课程,可是其中有一部分实战 新闻主题分类实战项目中,我发现黑马程序员代码有大两的错误,多处代码逻辑错误:
!!!需要注意的是 torchtext 的版本是0.4 ,可能是版本更新后,这个模块被移走了,如果不是0.4 可能会出现from torchtext.datasets.text_classification 这句话错误!!!
from torchtext.datasets.text_classification import *
from torchtext.datasets.text_classification import _csv_iterator, _create_data_from_iterator
import os
import time
from torch import optim
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer
if not os.path.isdir('./data'):
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
# 定义创建数据集
def _setup_data_set(dataset_tar='./data/ag_news_csv.tar.gz',
n_grams=N_GRAMS, vocab=None,
extracted_files = extract_archive(dataset_tar)
train_csv_path = ''
test_csv_path = ''
for file_name in extracted_files:
if file_name.endswith('train.csv'):
train_csv_path = file_name
if file_name.endswith('test.csv'):
if vocab is None:
print("Building Vocab based on %s" % train_csv_path)
# 创建词典
vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams=n_grams))
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
print('Vocab has %d entries' % len(vocab))
print('Creating training data')
train_data, train_labels = _create_data_from_iterator(
vocab, _csv_iterator(test_csv_path, n_grams, yield_cls=True), include_unk)
print('Creating testing data')
test_data, test_labels = _create_data_from_iterator(
vocab, _csv_iterator(test_csv_path, n_grams, yield_cls=True), include_unk)
if len(train_labels ^ test_labels) > 0:
raise ValueError("Training and test labels on't match")
# 返回数据集实例
return (TextClassificationDataset(vocab, train_data, train_labels),
TextClassificationDataset(vocab, test_data, test_labels))
train_data_set, test_data_set = _setup_data_set()
# 定义模型
class TextSentiment(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
def init_weights(self):
init_range = 0.5
self.embedding.weight.data.uniform_(-init_range, init_range)
self.fc.weight.data.uniform_(-init_range, init_range)
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
VOCAB_SIZE = len(train_data_set.get_vocab())
NUM_CLASS = len(train_data_set.get_labels())
# 实列化
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)
min_valid_loss = float('inf')
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=4.0)
scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
train_len = int(len(train_data_set) * 0.95)
sub_train_, sub_valid_ = random_split(train_data_set, [train_len, len(train_data_set) - train_len])
def generate_batch(batch):
label = torch.tensor([entry[0] for entry in batch])
text = [entry[1] for entry in batch]
offsets =[0] + [len(entry) for entry in text]
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text = torch.cat(text)
return text, offsets, label
def train_function(sub_train_):
loss_ = 0
acc_ = 0
data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
for i, (text, offsets, cls) in enumerate(data):
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
output = model(text, offsets)
loss_ = criterion(output, cls)
loss_ += loss_.item()
acc_ += (output.argmax(1) == cls).sum().item()
# 调整学习率
return loss_ / len(sub_train_), acc_ / len(sub_train_)
def test(data_):
loss = 0
acc = 0
data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
for text, offsets, cls in data:
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
with torch.no_grad():
output = model(text, offsets)
loss = criterion(output, cls)
loss += loss.item()
acc += (output.argmax(1) == cls).sum().item()
return loss / len(data_), acc / len(data_)
for epoch in range(N_EPOCH):
start_time = time.time()
train_loss, train_acc = train_function(sub_train_)
valid_loss, valid_acc = test(sub_valid_)
secs = int(time.time() - start_time)
mins = secs / 60
secs = secs % 60
print('Epoch:%d' % (epoch + 1), "| time in %d minutes, %d seconds" % (mins, secs))
print(f"\tLoss:{train_loss:.4f}(train)\t|\tAcc:{train_acc * 100:.1f}%(train)")
print(f"\tLoss:{valid_loss:.6f}(valid)\t|\tAcc:{valid_acc * 100:.6f}%(valid)")
# 测试模型
ag_news_label = {
1: "World",
2: "Sports",
3: "Business",
4: "Sci/Tec"
def predict(text, model, vocab, ngrams):
tokenizer = get_tokenizer("basic_english")
with torch.no_grad():
text = torch.tensor([vocab[token] for token in ngrams_iterator(tokenizer(text),ngrams)])
output = model(text, torch.tensor([0]))
return output.argmax(1).item() + 1
ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
enduring the season’s worst weather conditions on Sunday at The \
Open on his way to a closing 75 at Royal Portrush, which \
considering the wind and the rain was a respectable showing. \
Thursday’s first round at the WGC-FedEx St. Jude Invitational \
was another story. With temperatures in the mid-80s and hardly any \
wind, the Spaniard was 13 strokes better in a flawless round. \
Thanks to his best putting performance on the PGA Tour, Rahm \
finished with an 8-under 62 for a three-stroke lead, which \
was even more impressive considering he’d never played the \
front nine at TPC Southwind."
vocab = train_data_set.get_vocab()
model = model.to("cpu")
print("This is a %s news" % ag_news_label[predict(ex_text_str, model, vocab, 2)])
Building Vocab based on ./data/ag_news_csv/train.csv
120000lines [00:06, 17621.91lines/s]
Vocab has 1308844 entries
Creating training data
7600lines [00:00, 8405.65lines/s]
Creating testing data
7600lines [00:00, 9790.11lines/s]
Epoch:1 | time in 0 minutes, 0 seconds
Loss:0.0003(train) | Acc:40.9%(train)
Loss:0.0038(valid) | Acc:47.1%(valid)
Epoch:2 | time in 0 minutes, 0 seconds
Loss:0.0002(train) | Acc:70.4%(train)
Loss:0.0031(valid) | Acc:67.9%(valid)
Epoch:3 | time in 0 minutes, 0 seconds
Loss:0.0004(train) | Acc:82.4%(train)
Loss:0.0126(valid) | Acc:52.6%(valid)
Epoch:4 | time in 0 minutes, 0 seconds
Loss:0.0001(train) | Acc:88.3%(train)
Loss:0.0026(valid) | Acc:60.8%(valid)
Epoch:5 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:91.9%(train)
Loss:0.0002(valid) | Acc:79.7%(valid)
Epoch:6 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:94.8%(train)
Loss:0.0001(valid) | Acc:81.8%(valid)
Epoch:7 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:96.7%(train)
Loss:0.0001(valid) | Acc:83.4%(valid)
Epoch:8 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:98.5%(train)
Loss:0.0001(valid) | Acc:83.4%(valid)
Epoch:9 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:99.3%(train)
Loss:0.0001(valid) | Acc:81.1%(valid)
Epoch:10 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:99.6%(train)
Loss:0.0002(valid) | Acc:82.1%(valid)
Epoch:11 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:99.8%(train)
Loss:0.0001(valid) | Acc:84.7%(valid)
Epoch:12 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:99.9%(train)
Loss:0.0001(valid) | Acc:83.2%(valid)
Epoch:13 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:100.0%(train)
Loss:0.0001(valid) | Acc:83.7%(valid)
Epoch:14 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:100.0%(train)
Loss:0.0001(valid) | Acc:83.2%(valid)
Epoch:15 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:100.0%(train)
Loss:0.0001(valid) | Acc:84.2%(valid)
Epoch:16 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:100.0%(train)
Loss:0.0001(valid) | Acc:85.0%(valid)
Epoch:17 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:100.0%(train)
Loss:0.0001(valid) | Acc:85.0%(valid)
Epoch:18 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:100.0%(train)
Loss:0.0000(valid) | Acc:85.5%(valid)
Epoch:19 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:100.0%(train)
Loss:0.0001(valid) | Acc:84.2%(valid)
Epoch:20 | time in 0 minutes, 0 seconds
Loss:0.0000(train) | Acc:100.0%(train)
Loss:0.0001(valid) | Acc:84.5%(valid)
This is a Sports news