本文为[365天深度学习训练营学习记录博客
参考文章:365天深度学习训练营
原作者:[K同学啊 | 接辅导、项目定制]\n 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)
import os
import sys
import PIL
from PIL import Image
import time
import copy
import random
import pathlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import AG_NEWS
import torchvision
from torchinfo import summary
import torchsummary
import matplotlib.pyplot as plt
import numpy as np
import warnings
''' 下载或读取AG News数据集中的训练集与测试集 '''
def getDataset(root, dataset):
if not os.path.exists(root) or not os.path.isdir(root):
os.makedirs(root)
if not os.path.exists(dataset) or not os.path.isdir(dataset):
print('Downloading dataset...\n')
# 下载AG News数据集 直接运行会报网络错误 无法下载
train_ds, test_ds = AG_NEWS(root=root, split=("train", "test"))
else:
print('Dataset already downloaded, reading...\n')
# 读取本地AG News数据集 手动下载了train.csv和test.csv后可从本地加载数据
train_ds, test_ds = AG_NEWS(root=dataset, split=("train", "test"))
#print("Train:", next(train_ds), len(list(train_ds))+1)
#print("Test :", next(test_ds), len(list(test_ds))+1)
return train_ds, test_ds
''' 设置GPU '''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device\n".format(device))
''' 加载数据 '''
root = './data/'
data_dir = os.path.join(root, 'AG_NEWS.data')
train_ds, test_ds = getDataset(root, data_dir)
运行结果:
Using cuda device
Dataset already downloaded, reading...
Train: (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.") 120000
Test : (3, "Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.") 7600
''' 构建词典 '''
def buildDict(train_ds):
tokenizer = get_tokenizer('basic_english') # 返回分词器函数
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_ds))
text_pipeline = lambda x: vocab.lookup_indices(tokenizer(x))
label_pipeline = lambda x: int(x)
#print(vocab.UNK, vocab._default_unk_index())# 打印默认索引,如果找不到单词,则会选择默认索引
#print(vocab.lookup_indices(['here', 'is', 'an', 'example']))
#print(text_pipeline('here is the an example'))
#print(label_pipeline('10'))
return vocab, text_pipeline, label_pipeline
# 构建词典
text_pipeline, label_pipeline = buildDict(train_ds)
运行结果:
120001lines [00:04, 27817.88lines/s]
0
[471, 22, 31, 5177]
[471, 22, 3, 31, 5177]
10
''' 加载数据,并设置batch_size '''
def loadData(train_ds, test_ds, batch_size=8, device='cpu'):
# 构建词典
vocab, text_pipeline, label_pipeline = buildDict(train_ds)
# 生成数据批次和迭代器
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
# 标签列表
label_list.append(label_pipeline(_label))
# 文本列表
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
# 偏移量,即语句的总词汇量
offsets.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
text_list = torch.cat(text_list)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和
return label_list.to(device), text_list.to(device), offsets.to(device)
# 从 train_ds 加载训练集
train_dl = torch.utils.data.DataLoader(train_ds,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_batch,
num_workers=0)
# 从 test_ds 加载测试集
test_dl = torch.utils.data.DataLoader(test_ds,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_batch,
num_workers=0)
# 取一个批次查看数据格式
#data = train_dl.__iter__()
#print(type(data), data, '\n')
return vocab, train_dl, test_dl
# 生成数据批次和迭代器
batch_size = 64
train_dl, test_dl = loadData(train_ds, test_ds, batch_size=batch_size, device=device)
运行结果:
120001lines [00:04, 27749.13lines/s]
class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super(TextClassificationModel, self).__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange) # 将tensor用从均匀分布中抽样得到的值填充
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
def forward(self, text, offsets):
embedded = self.embedding(text, offsets) # torch.Size([64, 64])
output = self.fc(embedded) # torch.Size([64, 4])
return output
''' 定义实例 '''
train_iter = AG_NEWS(root='./data/AG_NEWS.data', split=("train"))
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)
print('num_class', num_class)
print('vocab_size', vocab_size)
print(model)
def train(dataloader):
model.train() # 训练模式
total_acc, total_count = 0, 0
log_interval = 500
start_time = time.time()
for idx, (label, text, offsets) in enumerate(dataloader):
optimizer.zero_grad()
predited_label = model(text, offsets)
loss = criterion(predited_label, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 规定了最大不能超过的max_norm
optimizer.step()
total_acc += (predited_label.argmax(1) == label).sum().item()
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))
total_acc, total_count = 0, 0
start_time = time.time()
def evaluate(dataloader):
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
for idx, (label, text, offsets) in enumerate(dataloader):
predited_label = model(text, offsets)
# loss = criterion(predited_label, label)
total_acc += (predited_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return total_acc / total_count
if __name__ == '__main__':
# 超参数(Hyperparameters)
EPOCHS = 10 # epoch
LR = 5 # learning rate
BATCH_SIZE = 64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS(root=path)
train_dataset = list(train_iter)
test_dataset = list(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) # shuffle表示随机打乱
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
for epoch in range(1, EPOCHS + 1):
epoch_start_time = time.time()
train(train_dataloader)
accu_val = evaluate(valid_dataloader)
if total_accu is not None and total_accu > accu_val:
scheduler.step()
else:
total_accu = accu_val
print('-' * 59)
print('| end of epoch {:3d} | time: {:5.2f}s | '
'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))
print('-' * 59)
torch.save(model.state_dict(), 'output\\model_TextClassification.pth')
| epoch 1 | 500/ 1782 batches, accuracy 0.687
| epoch 1 | 1000/ 1782 batches, accuracy 0.856
| epoch 1 | 1500/ 1782 batches, accuracy 0.875
-----------------------------------------------------------
| end of epoch 1 | time: 23.15s | valid accuracy 0.881
-----------------------------------------------------------
| epoch 2 | 500/ 1782 batches, accuracy 0.898
| epoch 2 | 1000/ 1782 batches, accuracy 0.898
| epoch 2 | 1500/ 1782 batches, accuracy 0.903
-----------------------------------------------------------
| end of epoch 2 | time: 16.20s | valid accuracy 0.897
-----------------------------------------------------------
| epoch 3 | 500/ 1782 batches, accuracy 0.917
| epoch 3 | 1000/ 1782 batches, accuracy 0.915
| epoch 3 | 1500/ 1782 batches, accuracy 0.914
-----------------------------------------------------------
| end of epoch 3 | time: 15.98s | valid accuracy 0.902
-----------------------------------------------------------
| epoch 4 | 500/ 1782 batches, accuracy 0.924
| epoch 4 | 1000/ 1782 batches, accuracy 0.924
| epoch 4 | 1500/ 1782 batches, accuracy 0.922
-----------------------------------------------------------
| end of epoch 4 | time: 16.63s | valid accuracy 0.901
-----------------------------------------------------------
| epoch 5 | 500/ 1782 batches, accuracy 0.937
| epoch 5 | 1000/ 1782 batches, accuracy 0.937
| epoch 5 | 1500/ 1782 batches, accuracy 0.938
-----------------------------------------------------------
| end of epoch 5 | time: 16.37s | valid accuracy 0.912
-----------------------------------------------------------
| epoch 6 | 500/ 1782 batches, accuracy 0.938
| epoch 6 | 1000/ 1782 batches, accuracy 0.939
| epoch 6 | 1500/ 1782 batches, accuracy 0.940
-----------------------------------------------------------
| end of epoch 6 | time: 16.17s | valid accuracy 0.912
-----------------------------------------------------------
| epoch 7 | 500/ 1782 batches, accuracy 0.940
| epoch 7 | 1000/ 1782 batches, accuracy 0.938
| epoch 7 | 1500/ 1782 batches, accuracy 0.943
-----------------------------------------------------------
| end of epoch 7 | time: 16.20s | valid accuracy 0.911
-----------------------------------------------------------
| epoch 8 | 500/ 1782 batches, accuracy 0.941
| epoch 8 | 1000/ 1782 batches, accuracy 0.940
| epoch 8 | 1500/ 1782 batches, accuracy 0.942
-----------------------------------------------------------
| end of epoch 8 | time: 16.46s | valid accuracy 0.911
-----------------------------------------------------------
| epoch 9 | 500/ 1782 batches, accuracy 0.941
| epoch 9 | 1000/ 1782 batches, accuracy 0.941
| epoch 9 | 1500/ 1782 batches, accuracy 0.943
-----------------------------------------------------------
| end of epoch 9 | time: 17.50s | valid accuracy 0.912
-----------------------------------------------------------
| epoch 10 | 500/ 1782 batches, accuracy 0.940
| epoch 10 | 1000/ 1782 batches, accuracy 0.942
| epoch 10 | 1500/ 1782 batches, accuracy 0.942
-----------------------------------------------------------
| end of epoch 10 | time: 16.12s | valid accuracy 0.912
-----------------------------------------------------------
vocab
),用于将文本转换为数字表示。text_pipeline
和label_pipeline
)。EmbeddingBag
和Linear
层构建了一个简单的文本分类模型。CrossEntropyLoss
)和随机梯度下降优化器(SGD
)。train
)和评估(evaluate
)函数。model_TextClassification.pth
。