实验数据集采用已分词与标注的影评文本,文本标签分为两类:0表示正面评价、1表示负面评价。数据集概况如下:
wiki_word2vec_50.bin
如果未分词,第一步应该对影评文本进行分词
双向LSTM
可以理解为同时训练两个LSTM
,两个LSTM
的方向、参数都不同。当前时刻的 h t h_t ht 就是将两个方向不同的LSTM
得到的两个 h t h_t ht 向量拼接到一起。我们使用双向LSTM
捕捉到当前时刻 t t t的过去和未来的特征,通过反向传播来训练双向LSTM
网络。
模型搭建核心点:
由于该任务是情感分类任务,因此,只需要对整个句子的信息进行分类,所以,这里拼接的是整个句子的信息-正向LSTM与负向LSTM的最深的隐藏层的结果。
单向LSTM与双向LSTM的输出结果差别:
- 由于双向LSTM当前时刻的 h t h_t ht 就是将两个方向不同的
LSTM
得到的两个 h t h_t ht 向量拼接到一起。因此,在维度方面,正向LSTM的最深的隐藏层 h t h_t ht的维度为[2,batch,hidden_size]
,负向LSTM的最深的隐藏层 h 0 h_0 h0的维度为[2,batch,hidden_size]
,两者再拼接的话,维度就是[4,batch,hidden_size]
模型搭建代码为:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMModel(nn.Module):
def __init__(
self,
input_size,
hidden_size,
num_layers,
dropout,
bidirectional,
batch_first,
classes,
pretrained_weight,
update_w2v
):
"""
:param input_size: 输入x的特征数,即embedding的size
:param hidden_size:隐藏层的大小
:param num_layers:LSTM的层数,可形成多层的堆叠LSTM
:param dropout: 如果非0,则在除最后一层外的每个LSTM层的输出上引入Dropout层,Dropout概率等于dropout
:param classes:类别数
:param batch_first:控制输入与输出的形状,如果为True,则输入和输出张量被提供为(batch, seq, feature)
:param bidirectional:如果为True,则为双向LSTM
:param pretrained_weight:预训练的词向量
:param update_w2v:控制是否更新词向量
:return:
"""
super(LSTMModel, self).__init__()
# embedding:向量层,将单词索引转为单词向量
self.embedding = nn.Embedding.from_pretrained(pretrained_weight)
self.embedding.weight.requires_grad = True
# encoder层
self.encoder = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional
)
# decoder层
if bidirectional:
self.decoder1 = nn.Linear(hidden_size * 4, hidden_size)
self.decoder2 = nn.Linear(hidden_size, classes)
else:
self.decoder1 = nn.Linear(hidden_size * 2, hidden_size)
self.decoder2 = nn.Linear(hidden_size, classes)
def forward(self, x):
"""
前向传播
:param x:输入
:return:
"""
# embedding层
# x.shape=(batch,seq_len);embedding.shape=(num_embeddings, embedding_dim) => emb.shape=(batch,seq_len,embedding_dim)
emb = self.embedding(x)
# encoder层
state, hidden = self.encoder(emb)
# states: (batch,seq_len, D*hidden_size), D=2 if bidirectional = True else 1, =>[64,75,256]
# hidden: (h_n, c_n) => h_n / c_n shape:(D∗num_layers, batch, hidden_size) =>[4,64,128]
# 这里看似拼接输出层结果,实则拼接正向与负向LSTM的隐藏层结果
encoding = torch.cat([state[:, 0, :], state[:, -1, :]], dim=1)
# decoder层
# encoding shape: (batch, 2*D*hidden_size): [64,512]
outputs = self.decoder1(encoding)
outputs = self.decoder2(outputs) # outputs shape:(batch, n_class) => [64,2]
return outputs
如果是静态Attention
,其网络结构如下:
h t h_t ht是每一个词的hidden state
,而 h s ‾ \overline{h_s} hs 向量,开始是随机生成的,后面经过反向传播可以得到 ∂ L o s s ∂ h s ‾ \frac{\partial{Loss}}{\partial{\overline{h_s}}} ∂hs∂Loss,通过梯度不断迭代更新。
该分类任务中,注意力得分计算公式为:
s c o r e ( h t , h s ‾ ) = v a T t a n h ( W a [ h t ; h s ‾ ] ) score(h_t,\overline{h_s})=v_{a}^{T}tanh(W_a[h_t;\overline{h_s}]) score(ht,hs)=vaTtanh(Wa[ht;hs])
score
是标量。每句话进行拼接,然后做softmax
得到概率,然后对hidden state
进行加权平均,得到总向量,然后经过一个分类层,经softmax
得到每一个类别的得分。
这里的注意力机制,就是通过训练给予重要的词一个大的权重,给予不重要的词一个小的权重。
模型搭建代码为:
class LSTM_attention(nn.Module):
def __init__(self,
input_size,
hidden_size,
num_layers,
dropout,
bidirectional,
batch_first,
classes,
pretrained_weight,
update_w2v,
):
"""
:param input_size: 输入x的特征数,即embedding的size
:param hidden_size:隐藏层的大小
:param num_layers:LSTM的层数,可形成多层的堆叠LSTM
:param dropout: 如果非0,则在除最后一层外的每个LSTM层的输出上引入Dropout层,Dropout概率等于dropout
:param classes:类别数
:param batch_first:控制输入与输出的形状,如果为True,则输入和输出张量被提供为(batch, seq, feature)
:param bidirectional:如果为True,则为双向LSTM
:param pretrained_weight:预训练的词向量
:param update_w2v:控制是否更新词向量
:return:
"""
super(LSTM_attention, self).__init__()
# embedding:向量层,将单词索引转为单词向量
self.embedding = nn.Embedding.from_pretrained(pretrained_weight)
self.embedding.weight.requires_grad = True
# encoder层
self.encoder = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional
)
# nn.Parameter:使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。
self.weight_W = nn.Parameter(torch.Tensor(2 * hidden_size, 2 * hidden_size))
self.weight_proj = nn.Parameter(torch.Tensor(2 * hidden_size, 1))
# 向量初始化
nn.init.uniform_(self.weight_W, -0.1, 0.1)
nn.init.uniform_(self.weight_proj, -0.1, 0.1)
# decoder层
if bidirectional:
self.decoder1 = nn.Linear(hidden_size * 2, hidden_size)
self.decoder2 = nn.Linear(hidden_size, classes)
else:
self.decoder1 = nn.Linear(hidden_size, hidden_size)
self.decoder2 = nn.Linear(hidden_size, classes)
def forward(self, x):
"""
前向传播
:param x:输入
:return:
"""
# embedding层
# x.shape=(batch,seq_len);embedding.shape=(num_embeddings, embedding_dim) => emb.shape=(batch,seq_len,embedding_dim)
emb = self.embedding(x)
# encoder层
state, hidden = self.encoder(emb)
# states: (batch,seq_len, D*hidden_size), D=2 if bidirectional = True else 1, =>[64,75,256]
# hidden: (h_n, c_n) => h_n / c_n shape:(D∗num_layers, batch, hidden_size) =>[4,64,128]
# attention:self.weight_proj * tanh(self.weight_W * state)
# (batch,seq_len, 2*hidden_size) => (batch,seq_len, 2*hidden_size)
u = torch.tanh(torch.matmul(state, self.weight_W))
# (batch,seq_len, 2*hidden_size) => (batch,seq_len,1)
att = torch.matmul(u, self.weight_proj)
att_score = F.softmax(att, dim=1)
scored_x = state * att_score
encoding = torch.sum(scored_x, dim=1)
# decoder层
# encoding shape: (batch, D*hidden_size): [64,256]
outputs = self.decoder1(encoding)
outputs = self.decoder2(outputs) # outputs shape:(batch, n_class) => [64,2]
return outputs
论文中的模型结构为:
图中的卷积核提取的是相邻两个单词向量(Two-gram),我们可以提取不同的窗口大小的特征,即利用不同的卷积核。如下图,卷积核分别提取了2-gram
、3-gram
、4-gram
的信息。
TextCNN
模型的核心在于以不同尺寸的卷积核来提取词向量分别得到输出,将不同的输出结果分别经池化层后进行拼接,得到总的输出,再经全连接层进行分类。
其模型搭建代码为:
class TextCNNModel(nn.Module):
def __init__(self,
num_filters,
kernel_sizes,
embedding_dim,
dropout,
classes,
pretrained_weight,
update_w2v):
"""
搭建TextCNN模型
:param num_filters: 输出通道数
:param kernel_sizes: 多个卷积核的高[2,3,4]
:param embedding_dim: 卷积核的宽
:param dropout: 遗失率
:param classes: 类别数
:param pretrained_weight: 权重
:param update_w2v: 是否更新w2v
"""
super(TextCNNModel, self).__init__()
# embedding层:加载预训练词向量
self.embedding = nn.Embedding.from_pretrained(pretrained_weight)
self.embedding.weight.data.requires_grad = update_w2v
# 多个卷积层,2-gram;3-gram;4-gram...
self.convs = nn.ModuleList([nn.Conv2d(1, num_filters, (K, embedding_dim)) for K in kernel_sizes]) ## 卷积层
# drouopt层
self.dropout = nn.Dropout(dropout)
# 全连接层
self.fc = nn.Linear(len(kernel_sizes) * num_filters, classes) ##全连接层
def forward(self, x):
"""
前向传播
:param x: 输入
:return:
"""
# # (batch,seq_len) => (batch,seq_len,emb_size)
x = self.embedding(x)
# (batch,seq_len,emb_size) => (batch,1,seq_len,emb_size)
x = x.unsqueeze(1)
# (batch,1,seq_len,emb_size) => (batch,num_filters,seq_len - kernel_size + 1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
# (batch,num_filters,seq_len - kernel_size + 1) => (batch,num_filters)
x = [F.max_pool1d(line, line.size(2)).squeeze(2) for line in x]
# [(batch,num_filters)*len(kernel_sizes)] => (batch,len(kernel_sizes) * num_filters)
x = torch.cat(x, 1)
x = self.dropout(x)
# (batch,len(kernel_sizes) * num_filters) => (batch,classes)
logit = self.fc(x)
return logit
一个深度学习任务的实现,一般需要如下几个模块:
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: liujie
@file: Config.py
@time: 2022/08/29
@desc:参数统一配置
"""
class MyConfig:
num_filters = 6 # CNN的输出通道数
kernel_sizes = [2, 3, 4]
update_w2v = True # 是否在训练中更新w2v
n_class = 2 # 分类数:分别为pos和neg
max_sen_len = 75 # 句子最大长度
embedding_dim = 50 # 词向量维度
batch_size = 64 # 批处理尺寸
hidden_dim = 128 # 隐藏层节点数
n_epoch = 50 # 训练迭代周期,即遍历整个训练样本的次数
lr = 0.0001 # 学习率;若opt=‘adadelta',则不需要定义学习率
drop_keep_prob = 0.2 # dropout层,参数keep的比例
num_layers = 2 # LSTM层数
seed = 2022
batch_first = True
bidirectional = True # 是否使用双向LSTM
model_dir = "./model"
stopword_path = "./data/stopword.txt"
train_path = "./data/train.txt"
val_path = "./data/validation.txt"
test_path = "./data/test.txt"
pre_path = "./data/pre.txt"
word2id_path = "./word2vec/word2id.txt"
pre_word2vec_path = "./word2vec/wiki_word2vec_50.bin"
corpus_word2vec_path = "./word2vec/word_vec.txt"
model_state_dict_path = "./model/sen_model.pkl"
best_model_path = "./model/sen_model_best.pkl"
数据预处理流程如下:
其代码dataProcess.py
为:
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: liujie
@file: dataProcess.py
@time: 2022/08/29
@desc:
数据预处理流程:
1.加载训练、验证、测试数据集与停用词表
2.建立word2index与index2word映射字典
3.利用预训练word2vec向量来构建字典集对应的word2vec向量,向量的行数代表单词的索引
4.文本转为索引数字模式-将原始文本(包括标签和文本)里的每个词转为word2id对应的索引数字,并以数组返回
"""
import re
import codecs
import gensim
import numpy as np
from Config import MyConfig
class Dataprocess:
def __init__(self):
self.stopWords = self.stopWordList_Load(MyConfig.stopword_path)
self.word2id = self.bulid_word2index(MyConfig.word2id_path) # 建立word2id
self.id2word = self.bulid_index2word(self.word2id) # 建立id2word
self.w2vec = self.bulid_word2vec(MyConfig.pre_word2vec_path, self.word2id,
MyConfig.corpus_word2vec_path) # 建立word2vec
# 构造训练集、验证集、测试集数组
self.result = self.prepare_data(self.word2id,
train_path=MyConfig.train_path,
val_path=MyConfig.val_path,
test_path=MyConfig.test_path,
seq_lenth=MyConfig.max_sen_len)
def org_data_load(self, file_path):
"""
加载原数据集中的lable与text
:param file_path: 文件路径
:return: lable列表与text列表
"""
lable = []
text = []
with codecs.open(file_path, "r", encoding="utf-8") as f:
for line in f.readlines():
# 切割
str = line.strip().split("\t")
lable.append(str[0])
text.append(str[1])
return lable, text
def stopWordList_Load(self, filepath):
"""
加载停用词表
:param filepath: 文件路径
:return: 返回停用词
"""
stopWordList = []
with codecs.open(filepath, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip()
stopWordList.append(line)
return stopWordList
def bulid_word2index(self, file_path):
"""
构造word2index字典文件
:return:
"""
# 读取文件路径
path = [MyConfig.train_path, MyConfig.val_path]
word2id = {"_PAD_": 0}
for _path in path:
with codecs.open(_path, 'r', encoding="utf-8") as f:
for line in f.readlines():
output = []
words = line.strip().split("\t")[1].split(" ")
for word in words:
if word not in self.stopWords:
# 找出长度大于1的汉字字符串
rt = re.findall("[\u4E00-\u9FA5]+", word)
if len(rt) == 0:
continue
else:
output.append(rt[0])
for word in output:
if word not in word2id.keys():
word2id[word] = len(word2id)
# 将word2id写入文件
with codecs.open(file_path, 'w', encoding="utf-8") as f:
for word, index in word2id.items():
f.write(word + "\t" + str(index) + '\n')
return word2id
def bulid_index2word(self, word2id):
"""
构建id2word字典
:param word2id:
:return:
"""
id2word = {}
for word, index in word2id.items():
id2word[index] = word
return id2word
def bulid_word2vec(self, fname, word2id, save_to_path=None):
"""
利用预训练word2vec向量来构建字典集对应的word2vec向量,向量的行数代表单词的索引
:param fname: 预训练模型名称
:param word2id: 字典
:param save_to_path: 存储语料的词向量文件
:return:
"""
n_words = max(word2id.values()) + 1 # 总词数
# 加载预训练的word2vec模型
model = gensim.models.KeyedVectors.load_word2vec_format(fname, binary=True)
# 初始化word2vec向量
words_vec = np.array(np.random.uniform(-1, 1, [n_words, model.vector_size]))
for word in word2id.keys():
# 避免因未登录词造成的错误
try:
words_vec[word2id[word]] = model[word]
except KeyError:
pass
if save_to_path:
with codecs.open(save_to_path, 'w', encoding="utf-8") as f:
for vec in words_vec:
vec = [str(w) for w in vec]
f.write(",".join(vec))
f.write("\n")
return words_vec
def text_of_array(self, word2id, seq_lenth, path):
"""
文本转为索引数字模式-将原始文本(包括标签和文本)里的每个词转为word2id对应的索引数字,并以数组返回
:param word2id: dict, 语料文本中包含的词汇集
:param seq_lenth: int, 序列的限定长度
:param path: str, 待处理的原始文本数据集
:return: 返回原始文本转化索引数字数组后的数据集(array), 标签集(list)
"""
labels = []
i = 0
sens = []
# 获取句子个数
with codecs.open(path, encoding="utf-8") as f:
for line in f.readlines():
words = line.strip().split("\t")[1].split(" ")
new_sen = [word2id.get(word, 0) for word in words if word not in self.stopWords]
new_sen_vec = np.array(new_sen).reshape(1, -1)
sens.append(new_sen_vec)
# 将原始数据集中的文本转为单词索引,并将单词索引格式的文件写入到文件中
with codecs.open(path, encoding="utf-8") as f:
sentences_array = np.zeros(shape=(len(sens), seq_lenth))
for line in f.readlines():
words = line.strip().split("\t")[1].split(" ")
new_sen = [word2id.get(word, 0) for word in words if word not in self.stopWords]
new_sen_vec = np.array(new_sen).reshape(1, -1)
# 如果句子长度小于seq_lenth,则进行填充处理;反之,进行截断处理
if np.size(new_sen_vec, axis=1) < seq_lenth:
sentences_array[i, seq_lenth - np.size(new_sen_vec, axis=1):] = new_sen_vec[0, :]
else:
sentences_array[i, :] = new_sen_vec[0, 0:seq_lenth]
i += 1
label = line.strip().split("\t")[0]
labels.append(int(label))
return np.array(sentences_array), labels
def text_of_array_nolable(self, word2id, seq_lenth, path):
"""
文本转为索引数字模式-将原始文本(仅包括文本)里的每个词转为word2id对应的索引数字,并以数组返回.
:param word2id: 语料文本中包含的词汇集
:param seq_lenth: 序列的限定长度
:param path: 待处理的原始文本数据集
:return: 原始文本转化索引数字数组后的数据集(array)
"""
i = 0
sens = []
# 获取句子个数
with codecs.open(path, encoding="utf-8") as f:
for line in f.readlines():
words = line.strip().split("\t")[1].split(" ")
new_sen = [word2id.get(word, 0) for word in words if word not in self.stopWords]
new_sen_vec = np.array(new_sen).reshape(1, -1)
sens.append(new_sen_vec)
# 将原始数据集中的文本转为单词索引,并将单词索引格式的文件写入到文件中
with codecs.open(path, encoding="utf-8") as f:
sentences_array = np.zeros(shape=(len(sens), seq_lenth))
for line in f.readlines():
words = line.strip().split("\t")[1].split(" ")
new_sen = [word2id.get(word, 0) for word in words if word not in self.stopWords]
new_sen_vec = np.array(new_sen).reshape(1, -1)
# 如果句子长度小于seq_lenth,则进行填充处理;反之,进行截断处理
if np.size(new_sen_vec, axis=1) < seq_lenth:
sentences_array[i, seq_lenth - np.size(new_sen_vec, axis=1):] = new_sen_vec[0, :]
else:
sentences_array[i, :] = new_sen_vec[0, 0:seq_lenth]
i += 1
return np.array(sentences_array)
def to_categorical(self, y, num_classes=None):
"""
将类别转化为one-hot编码
:param y: 类别特征列表
:param num_classes: 类别个数
:return: 返回one-hot编码数组,shape:(len(y), num_classes)
"""
y = np.array(y, dtype="int")
input_shape = y.shape
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
input_shape = tuple(input_shape[:-1])
# ravel方法:将多维数组变成一维数组
y = y.ravel()
if not num_classes:
num_classes = np.max(y) + 1
n = y.shape[0]
categorical = np.zeros((n, num_classes))
categorical[np.arange(n), y] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
return categorical
def prepare_data(self, word2id, train_path, val_path, test_path, seq_lenth):
"""
得到数字索引表示的句子和标签
:param word2id: 语料文本中包含的词汇集.
:param train_path: 文件路径
:param val_path: 文件路径
:param test_path: 文件路径
:param seq_lenth: 序列固定长度
:return: 返回训练集、验证集、测试集数组
"""
train_array, train_label = self.text_of_array(word2id, seq_lenth, train_path)
val_array, val_label = self.text_of_array(word2id, seq_lenth, val_path)
test_array, test_label = self.text_of_array(word2id, seq_lenth, test_path)
# train_label = self.to_categorical(train_label, num_classes=2)
# val_label = self.to_categorical(val_label, num_classes=2)
# test_label = self.to_categorical(test_label, num_classes=2)
train_label = np.array([train_label]).T
val_label = np.array([val_label]).T
test_label = np.array([test_label]).T
return train_array, train_label, val_array, val_label, test_array, test_label
if __name__ == '__main__':
dataprocess = Dataprocess()
train_array, train_label, val_array, val_label, test_array, test_label = dataprocess.result
数据读入,主要涉及自定义数据集。
其代码dataSet.py
如下:
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: liujie
@file: dataSet.py
@time: 2022/08/29
@desc:
"""
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def __init__(self, data, label):
"""
用于向类中传入外部参数,同时定义样本集
"""
self.data = data
if label is not None:
self.label = label
def __len__(self):
"""
用于返回数据集的样本数
:return:
"""
return len(self.data)
def __getitem__(self, index):
"""
用于逐个读取样本集合中的元素,用于逐个读取样本集合中的元素
:param item:
:return:
"""
if self.label is not None:
data = torch.from_numpy(self.data[index])
label = torch.from_numpy(self.label[index])
return data, label
else:
data = torch.from_numpy(self.data[index])
return data
这一部分利用模型块可迅速搭建。
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: liujie
@file: models.py
@time: 2022/08/29
@desc: 分别搭建模型LSTM与LSTM+Attention
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
from Config import MyConfig
class LSTMModel(nn.Module):
def __init__(
self,
input_size,
hidden_size,
num_layers,
dropout,
bidirectional,
batch_first,
classes,
pretrained_weight,
update_w2v
):
"""
:param input_size: 输入x的特征数,即embedding的size
:param hidden_size:隐藏层的大小
:param num_layers:LSTM的层数,可形成多层的堆叠LSTM
:param dropout: 如果非0,则在除最后一层外的每个LSTM层的输出上引入Dropout层,Dropout概率等于dropout
:param classes:类别数
:param batch_first:控制输入与输出的形状,如果为True,则输入和输出张量被提供为(batch, seq, feature)
:param bidirectional:如果为True,则为双向LSTM
:param pretrained_weight:预训练的词向量
:param update_w2v:控制是否更新词向量
:return:
"""
super(LSTMModel, self).__init__()
# embedding:向量层,将单词索引转为单词向量
self.embedding = nn.Embedding.from_pretrained(pretrained_weight)
self.embedding.weight.requires_grad = True
# encoder层
self.encoder = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional
)
# decoder层
if bidirectional:
self.decoder1 = nn.Linear(hidden_size * 4, hidden_size)
self.decoder2 = nn.Linear(hidden_size, classes)
else:
self.decoder1 = nn.Linear(hidden_size * 2, hidden_size)
self.decoder2 = nn.Linear(hidden_size, classes)
def forward(self, x):
"""
前向传播
:param x:输入
:return:
"""
# embedding层
# x.shape=(batch,seq_len);embedding.shape=(num_embeddings, embedding_dim) => emb.shape=(batch,seq_len,embedding_dim)
emb = self.embedding(x)
# encoder层
state, hidden = self.encoder(emb)
# states: (batch,seq_len, D*hidden_size), D=2 if bidirectional = True else 1, =>[64,75,256]
# hidden: (h_n, c_n) => h_n / c_n shape:(D∗num_layers, batch, hidden_size) =>[4,64,128]
# 这里拼接输出层结果
encoding = torch.cat([state[:, 0, :], state[:, -1, :]], dim=1)
# decoder层
# encoding shape: (batch, 2*D*hidden_size): [64,512]
outputs = self.decoder1(encoding)
outputs = self.decoder2(outputs) # outputs shape:(batch, n_class) => [64,2]
return outputs
class LSTM_attention(nn.Module):
def __init__(self,
input_size,
hidden_size,
num_layers,
dropout,
bidirectional,
batch_first,
classes,
pretrained_weight,
update_w2v,
):
"""
:param input_size: 输入x的特征数,即embedding的size
:param hidden_size:隐藏层的大小
:param num_layers:LSTM的层数,可形成多层的堆叠LSTM
:param dropout: 如果非0,则在除最后一层外的每个LSTM层的输出上引入Dropout层,Dropout概率等于dropout
:param classes:类别数
:param batch_first:控制输入与输出的形状,如果为True,则输入和输出张量被提供为(batch, seq, feature)
:param bidirectional:如果为True,则为双向LSTM
:param pretrained_weight:预训练的词向量
:param update_w2v:控制是否更新词向量
:return:
"""
super(LSTM_attention, self).__init__()
# embedding:向量层,将单词索引转为单词向量
self.embedding = nn.Embedding.from_pretrained(pretrained_weight)
self.embedding.weight.requires_grad = True
# encoder层
self.encoder = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional
)
# nn.Parameter:使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。
self.weight_W = nn.Parameter(torch.Tensor(2 * hidden_size, 2 * hidden_size))
self.weight_proj = nn.Parameter(torch.Tensor(2 * hidden_size, 1))
# 向量初始化
nn.init.uniform_(self.weight_W, -0.1, 0.1)
nn.init.uniform_(self.weight_proj, -0.1, 0.1)
# decoder层
if bidirectional:
self.decoder1 = nn.Linear(hidden_size * 2, hidden_size)
self.decoder2 = nn.Linear(hidden_size, classes)
else:
self.decoder1 = nn.Linear(hidden_size, hidden_size)
self.decoder2 = nn.Linear(hidden_size, classes)
def forward(self, x):
"""
前向传播
:param x:输入
:return:
"""
# embedding层
# x.shape=(batch,seq_len);embedding.shape=(num_embeddings, embedding_dim) => emb.shape=(batch,seq_len,embedding_dim)
emb = self.embedding(x)
# encoder层
state, hidden = self.encoder(emb)
# states: (batch,seq_len, D*hidden_size), D=2 if bidirectional = True else 1, =>[64,75,256]
# hidden: (h_n, c_n) => h_n / c_n shape:(D∗num_layers, batch, hidden_size) =>[4,64,128]
# attention:self.weight_proj * tanh(self.weight_W * state)
# (batch,seq_len, 2*hidden_size) => (batch,seq_len, 2*hidden_size)
u = torch.tanh(torch.matmul(state, self.weight_W))
# (batch,seq_len, 2*hidden_size) => (batch,seq_len,1)
att = torch.matmul(u, self.weight_proj)
att_score = F.softmax(att, dim=1)
scored_x = state * att_score
encoding = torch.sum(scored_x, dim=1)
# decoder层
# encoding shape: (batch, D*hidden_size): [64,256]
outputs = self.decoder1(encoding)
outputs = self.decoder2(outputs) # outputs shape:(batch, n_class) => [64,2]
return outputs
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: liujie
@file: main.py
@time: 2022/08/30
@desc: 训练与预测
"""
import os
import tqdm
import torch
import torch.nn as nn
from dataProcess import Dataprocess
from dataSet import MyDataSet
from Config import MyConfig
from torch.utils.data import DataLoader
from models import LSTMModel, LSTM_attention
from torch import optim
from torchinfo import summary
from sklearn.metrics import f1_score, recall_score, confusion_matrix
def train_val(train_dataloader, val_dataloader, model, device, epoches, lr):
optimizer = optim.Adam(model.parameters(), lr=lr) # 优化器
criterion = nn.CrossEntropyLoss() # 损失函数
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2) # 学习率调整
best_acc = 0.8
for epoch in range(epoches):
train_loss = 0.0
correct = 0
total = 0
# 显示训练进度
train_dataloader = tqdm.tqdm(train_dataloader)
train_dataloader.set_description(
'[%s%04d/%04d]' % ('Epoch:', epoch + 1, epoches))
# 训练
model.train() # 训练模式更新参数
model.to(device)
for i, data_ in enumerate(train_dataloader):
data, label = data_[0].type(torch.LongTensor).to(device), data_[1].type(torch.LongTensor).to(device)
# 开始当前批次训练时,优化器的梯度置零,否则,梯度会累加
optimizer.zero_grad()
# 模型输出:output, shape:[num_samples, 2]
output = model(data)
# 实际目标label:label, shape:[num_samples, 1]=>[num_samples]
label = label.squeeze(1)
# 利用预先定义的criterion计算损失函数
loss = criterion(output, label)
# 反向传播
loss.backward()
# 利用优化器更新参数
optimizer.step()
# 损失
train_loss += loss.item()
# get predicted label: Returns ``(values, indices)``
_, predicted = torch.max(output, 1)
total += label.size(0)
correct += (label == predicted).sum().item()
F1 = f1_score(label.cpu(), predicted.cpu(), average="weighted")
Recall = recall_score(label.cpu(), predicted.cpu(), average="micro")
# 设置日志
postfic = {
"train_loss: {:.5f},train_acc:{:.3f}%,F1: {:.3f}%,Recall:{:.3f}%"
.format(
train_loss / (i + 1), 100 * correct / total, 100 * F1, 100 * Recall
)
}
train_dataloader.set_postfix(log=postfic)
# 验证
model.eval()
model.to(device)
val_dataloader = tqdm.tqdm(val_dataloader)
with torch.no_grad():
correct = 0 # 预测的和实际的label相同的样本个数
total = 0 # 累计validation样本个数
val_loss = 0.0
for i, val_data_ in enumerate(val_dataloader):
val_data, val_label = val_data_[0].type(torch.LongTensor).to(device), val_data_[1].type(
torch.LongTensor).to(device)
output = model(val_data)
# 实际目标label:label, shape:[num_samples, 1]=>[num_samples]
val_label = val_label.squeeze(1)
loss = criterion(output, val_label)
# 损失
val_loss += loss.item()
# get predicted label: Returns ``(values, indices)``
_, predicted = torch.max(output, 1)
total += val_label.size(0)
correct += (val_label == predicted).sum().item()
F1 = f1_score(val_label.cpu(), predicted.cpu(), average="weighted")
Recall = recall_score(val_label.cpu(), predicted.cpu(), average="micro")
CM = confusion_matrix(val_label.cpu(), predicted.cpu())
# 设置日志
postfic = {
"val_loss: {:.5f},val_acc:{:.3f}%,F1: {:.3f}%,Recall:{:.3f}%,CM:{}"
.format(
val_loss / (i + 1), 100 * correct / total, 100 * F1, 100 * Recall, CM
)
}
val_dataloader.set_postfix(log=postfic)
acc = correct / total
if acc > best_acc:
best_acc = acc
if os.path.exists(MyConfig.model_dir) == False:
os.mkdir(MyConfig.model_dir)
torch.save(model, MyConfig.best_model_path)
torch.save(model.state_dict(), MyConfig.model_state_dict_path)
def test(test_dataloader, model, device):
model.eval()
model.to(device)
criterion = nn.CrossEntropyLoss() # 损失函数
test_dataloader = tqdm.tqdm(test_dataloader)
with torch.no_grad():
correct = 0 # 预测的和实际的label相同的样本个数
total = 0 # 总测试样本个数
for i, test_data_ in enumerate(test_dataloader):
test_data, test_label = test_data_[0].type(torch.LongTensor).to(device), test_data_[1].type(
torch.LongTensor).to(device)
output = model(test_data)
# 实际目标label:label, shape:[num_samples, 1]=>[num_samples]
test_label = test_label.squeeze(1)
loss = criterion(output, test_label)
_, predicted = torch.max(output, 1)
total += test_label.size(0)
correct += (test_label == predicted).sum().item()
F1 = f1_score(test_label.cpu(), predicted.cpu(), average="weighted")
Recall = recall_score(test_label.cpu(), predicted.cpu(), average="micro")
CM = confusion_matrix(test_label.cpu(), predicted.cpu())
# 设置日志
postfic = {
"test_acc:{:.3f}%,F1: {:.3f}%,Recall:{:.3f}%,CM:{}"
.format(
100 * correct / total, 100 * F1, 100 * Recall, CM
)
}
test_dataloader.set_postfix(log=postfic)
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 得到句子表示与标签
(
train_array,
train_label,
val_array,
val_label,
test_array,
test_label,
) = Dataprocess().result
# 得到word2vec词向量表示
w2vec = Dataprocess().w2vec
w2vec = torch.from_numpy(w2vec)
w2vec = w2vec.float() # CUDA接受float32,不接受float64
# 数据载入
train_loader = MyDataSet(train_array, train_label)
train_dataloader = DataLoader(
train_loader, batch_size=MyConfig.batch_size, shuffle=True
)
val_loader = MyDataSet(val_array, val_label)
val_dataloader = DataLoader(
val_loader, batch_size=MyConfig.batch_size, shuffle=True
)
test_loader = MyDataSet(test_array, test_label)
test_dataloader = DataLoader(
test_loader, batch_size=MyConfig.batch_size, shuffle=True
)
# 模型的搭建
# model2 = LSTMModel(
# MyConfig.embedding_dim,
# MyConfig.hidden_dim,
# MyConfig.num_layers,
# MyConfig.drop_keep_prob,
# MyConfig.bidirectional,
# MyConfig.batch_first,
# MyConfig.n_class,
# w2vec,
# MyConfig.update_w2v,
# )
# model2 = LSTM_attention(
# MyConfig.embedding_dim,
# MyConfig.hidden_dim,
# MyConfig.num_layers,
# MyConfig.drop_keep_prob,
# MyConfig.bidirectional,
# MyConfig.batch_first,
# MyConfig.n_class,
# w2vec,
# MyConfig.update_w2v,
# )
# TextCNN
model2 = TextCNNModel(MyConfig.num_filters,
MyConfig.kernel_sizes,
MyConfig.embedding_dim,
MyConfig.drop_keep_prob,
MyConfig.n_class,
w2vec,
MyConfig.update_w2v,
)
# 训练与验证
train_val(train_dataloader,
val_dataloader,
model2,
device,
MyConfig.n_epoch,
MyConfig.lr)
# 测试
test(test_dataloader,
model2,
device)
LSTM_attention
运模型行结果为:
D:\softwares\anaconda3\envs\tfpt368\python.exe D:/PycharmProjects/sxlj/PyTorch_demo/text_classification_base_of_lstm/main.py
[Epoch:0001/0050]: 100%|██████████| 313/313 [00:13<00:00, 23.29it/s, log={'train_loss: 0.56064,train_acc:70.102%,F1: 86.667%,Recall:86.667%'}]
100%|██████████| 88/88 [00:00<00:00, 95.69it/s, log={'val_loss: 0.48481,val_acc:76.941%,F1: 80.307%,Recall:80.328%,CM:[[24 8]\n [ 4 25]]'}]
[Epoch:0002/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.68it/s, log={'train_loss: 0.46692,train_acc:77.558%,F1: 66.518%,Recall:66.667%'}]
100%|██████████| 88/88 [00:00<00:00, 91.15it/s, log={'val_loss: 0.47865,val_acc:77.545%,F1: 88.430%,Recall:88.525%,CM:[[32 2]\n [ 5 22]]'}]
[Epoch:0003/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.41it/s, log={'train_loss: 0.43907,train_acc:79.458%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 92.39it/s, log={'val_loss: 0.46185,val_acc:78.327%,F1: 88.518%,Recall:88.525%,CM:[[28 3]\n [ 4 26]]'}]
[Epoch:0004/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.38it/s, log={'train_loss: 0.41398,train_acc:81.133%,F1: 96.639%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 94.37it/s, log={'val_loss: 0.43865,val_acc:80.227%,F1: 73.970%,Recall:73.770%,CM:[[26 10]\n [ 6 19]]'}]
[Epoch:0005/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.46it/s, log={'train_loss: 0.39387,train_acc:82.488%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.18it/s, log={'val_loss: 0.42749,val_acc:80.814%,F1: 81.909%,Recall:81.967%,CM:[[23 7]\n [ 4 27]]'}]
[Epoch:0006/0050]: 100%|██████████| 313/313 [00:10<00:00, 28.88it/s, log={'train_loss: 0.37042,train_acc:83.698%,F1: 82.922%,Recall:83.333%'}]
100%|██████████| 88/88 [00:00<00:00, 93.87it/s, log={'val_loss: 0.42287,val_acc:81.453%,F1: 78.474%,Recall:78.689%,CM:[[30 5]\n [ 8 18]]'}]
[Epoch:0007/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.63it/s, log={'train_loss: 0.34795,train_acc:84.998%,F1: 96.694%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 93.07it/s, log={'val_loss: 0.41776,val_acc:82.146%,F1: 83.732%,Recall:83.607%,CM:[[29 7]\n [ 3 22]]'}]
[Epoch:0008/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.49it/s, log={'train_loss: 0.32638,train_acc:86.114%,F1: 86.481%,Recall:86.667%'}]
100%|██████████| 88/88 [00:00<00:00, 94.27it/s, log={'val_loss: 0.42515,val_acc:81.915%,F1: 72.236%,Recall:72.131%,CM:[[23 12]\n [ 5 21]]'}]
[Epoch:0009/0050]: 100%|██████████| 313/313 [00:10<00:00, 28.99it/s, log={'train_loss: 0.30633,train_acc:87.354%,F1: 93.167%,Recall:93.333%'}]
100%|██████████| 88/88 [00:01<00:00, 87.88it/s, log={'val_loss: 0.42712,val_acc:81.542%,F1: 91.790%,Recall:91.803%,CM:[[31 2]\n [ 3 25]]'}]
[Epoch:0010/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.48it/s, log={'train_loss: 0.28550,train_acc:88.349%,F1: 86.726%,Recall:86.667%'}]
100%|██████████| 88/88 [00:01<00:00, 83.00it/s, log={'val_loss: 0.41023,val_acc:83.088%,F1: 91.821%,Recall:91.803%,CM:[[24 2]\n [ 3 32]]'}]
[Epoch:0011/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.16it/s, log={'train_loss: 0.26399,train_acc:89.524%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 93.27it/s, log={'val_loss: 0.42133,val_acc:81.649%,F1: 78.735%,Recall:78.689%,CM:[[28 7]\n [ 6 20]]'}]
[Epoch:0012/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.35it/s, log={'train_loss: 0.24763,train_acc:90.169%,F1: 86.787%,Recall:86.667%'}]
100%|██████████| 88/88 [00:00<00:00, 93.87it/s, log={'val_loss: 0.42143,val_acc:82.910%,F1: 86.885%,Recall:86.885%,CM:[[28 4]\n [ 4 25]]'}]
[Epoch:0013/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.85it/s, log={'train_loss: 0.22936,train_acc:91.289%,F1: 89.753%,Recall:90.000%'}]
100%|██████████| 88/88 [00:00<00:00, 91.91it/s, log={'val_loss: 0.44425,val_acc:81.862%,F1: 88.512%,Recall:88.525%,CM:[[29 3]\n [ 4 25]]'}]
[Epoch:0014/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.95it/s, log={'train_loss: 0.21512,train_acc:91.609%,F1: 96.678%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 93.87it/s, log={'val_loss: 0.44040,val_acc:83.336%,F1: 83.615%,Recall:83.607%,CM:[[25 4]\n [ 6 26]]'}]
[Epoch:0015/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.17it/s, log={'train_loss: 0.19601,train_acc:92.679%,F1: 93.122%,Recall:93.333%'}]
100%|██████████| 88/88 [00:00<00:00, 96.12it/s, log={'val_loss: 0.49242,val_acc:82.466%,F1: 76.874%,Recall:77.049%,CM:[[27 5]\n [ 9 20]]'}]
[Epoch:0016/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.89it/s, log={'train_loss: 0.18043,train_acc:93.339%,F1: 93.304%,Recall:93.333%'}]
100%|██████████| 88/88 [00:00<00:00, 93.37it/s, log={'val_loss: 0.45973,val_acc:83.088%,F1: 83.740%,Recall:83.607%,CM:[[23 2]\n [ 8 28]]'}]
[Epoch:0017/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.90it/s, log={'train_loss: 0.16902,train_acc:93.899%,F1: 96.678%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 91.72it/s, log={'val_loss: 0.49138,val_acc:82.341%,F1: 83.688%,Recall:83.607%,CM:[[30 6]\n [ 4 21]]'}]
[Epoch:0018/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.25it/s, log={'train_loss: 0.15129,train_acc:94.869%,F1: 93.333%,Recall:93.333%'}]
100%|██████████| 88/88 [00:00<00:00, 97.28it/s, log={'val_loss: 0.51235,val_acc:82.608%,F1: 83.651%,Recall:83.607%,CM:[[23 4]\n [ 6 28]]'}]
[Epoch:0019/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.05it/s, log={'train_loss: 0.14316,train_acc:95.010%,F1: 96.678%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 96.43it/s, log={'val_loss: 0.51579,val_acc:82.874%,F1: 80.286%,Recall:80.328%,CM:[[24 10]\n [ 2 25]]'}]
[Epoch:0020/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.93it/s, log={'train_loss: 0.12620,train_acc:95.650%,F1: 96.663%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 94.47it/s, log={'val_loss: 0.54916,val_acc:82.537%,F1: 83.615%,Recall:83.607%,CM:[[26 6]\n [ 4 25]]'}]
[Epoch:0021/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.31it/s, log={'train_loss: 0.11475,train_acc:96.200%,F1: 96.678%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 96.43it/s, log={'val_loss: 0.60967,val_acc:82.093%,F1: 77.049%,Recall:77.049%,CM:[[24 7]\n [ 7 23]]'}]
[Epoch:0022/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.00it/s, log={'train_loss: 0.10402,train_acc:96.605%,F1: 93.426%,Recall:93.333%'}]
100%|██████████| 88/88 [00:00<00:00, 93.17it/s, log={'val_loss: 0.58835,val_acc:81.791%,F1: 78.677%,Recall:78.689%,CM:[[23 7]\n [ 6 25]]'}]
[Epoch:0023/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.64it/s, log={'train_loss: 0.09436,train_acc:97.030%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 92.20it/s, log={'val_loss: 0.61472,val_acc:81.968%,F1: 85.196%,Recall:85.246%,CM:[[20 5]\n [ 4 32]]'}]
[Epoch:0024/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.67it/s, log={'train_loss: 0.08209,train_acc:97.620%,F1: 90.011%,Recall:90.000%'}]
100%|██████████| 88/88 [00:00<00:00, 93.27it/s, log={'val_loss: 0.64169,val_acc:82.110%,F1: 86.892%,Recall:86.885%,CM:[[27 5]\n [ 3 26]]'}]
[Epoch:0025/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.55it/s, log={'train_loss: 0.07454,train_acc:97.760%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.08it/s, log={'val_loss: 0.63647,val_acc:81.897%,F1: 81.691%,Recall:81.967%,CM:[[21 9]\n [ 2 29]]'}]
[Epoch:0026/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.02it/s, log={'train_loss: 0.07037,train_acc:97.985%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.49it/s, log={'val_loss: 0.72361,val_acc:82.217%,F1: 91.817%,Recall:91.803%,CM:[[29 4]\n [ 1 27]]'}]
[Epoch:0027/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.99it/s, log={'train_loss: 0.06135,train_acc:98.285%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 93.08it/s, log={'val_loss: 0.67263,val_acc:81.453%,F1: 77.086%,Recall:77.049%,CM:[[25 8]\n [ 6 22]]'}]
[Epoch:0028/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.26it/s, log={'train_loss: 0.05325,train_acc:98.530%,F1: 96.678%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 97.18it/s, log={'val_loss: 0.80234,val_acc:81.791%,F1: 83.444%,Recall:83.607%,CM:[[17 6]\n [ 4 34]]'}]
[Epoch:0029/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.17it/s, log={'train_loss: 0.04880,train_acc:98.635%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 93.37it/s, log={'val_loss: 0.84946,val_acc:81.631%,F1: 83.607%,Recall:83.607%,CM:[[26 5]\n [ 5 25]]'}]
[Epoch:0030/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.87it/s, log={'train_loss: 0.04463,train_acc:98.860%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 92.68it/s, log={'val_loss: 0.77501,val_acc:81.435%,F1: 77.049%,Recall:77.049%,CM:[[24 7]\n [ 7 23]]'}]
[Epoch:0031/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.99it/s, log={'train_loss: 0.04244,train_acc:98.855%,F1: 93.333%,Recall:93.333%'}]
100%|██████████| 88/88 [00:00<00:00, 94.57it/s, log={'val_loss: 0.89640,val_acc:80.991%,F1: 78.549%,Recall:78.689%,CM:[[28 5]\n [ 8 20]]'}]
[Epoch:0032/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.14it/s, log={'train_loss: 0.03746,train_acc:99.020%,F1: 96.670%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 94.47it/s, log={'val_loss: 0.79363,val_acc:81.737%,F1: 85.278%,Recall:85.246%,CM:[[30 5]\n [ 4 22]]'}]
[Epoch:0033/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.22it/s, log={'train_loss: 0.03398,train_acc:99.170%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.59it/s, log={'val_loss: 0.89114,val_acc:81.560%,F1: 83.607%,Recall:83.607%,CM:[[21 5]\n [ 5 30]]'}]
[Epoch:0034/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.06it/s, log={'train_loss: 0.02994,train_acc:99.255%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.60it/s, log={'val_loss: 0.89480,val_acc:81.578%,F1: 83.669%,Recall:83.607%,CM:[[29 6]\n [ 4 22]]'}]
[Epoch:0035/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.86it/s, log={'train_loss: 0.02764,train_acc:99.300%,F1: 96.648%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 93.67it/s, log={'val_loss: 1.04152,val_acc:81.791%,F1: 76.901%,Recall:77.049%,CM:[[22 11]\n [ 3 25]]'}]
[Epoch:0036/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.83it/s, log={'train_loss: 0.02564,train_acc:99.315%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.91it/s, log={'val_loss: 0.93769,val_acc:81.737%,F1: 80.328%,Recall:80.328%,CM:[[27 6]\n [ 6 22]]'}]
[Epoch:0037/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.96it/s, log={'train_loss: 0.02017,train_acc:99.490%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 96.75it/s, log={'val_loss: 1.04526,val_acc:81.649%,F1: 83.686%,Recall:83.607%,CM:[[27 8]\n [ 2 24]]'}]
[Epoch:0038/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.07it/s, log={'train_loss: 0.01771,train_acc:99.580%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.08it/s, log={'val_loss: 0.98604,val_acc:81.649%,F1: 93.446%,Recall:93.443%,CM:[[29 3]\n [ 1 28]]'}]
[Epoch:0039/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.03it/s, log={'train_loss: 0.01636,train_acc:99.600%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.18it/s, log={'val_loss: 1.12957,val_acc:81.578%,F1: 81.907%,Recall:81.967%,CM:[[19 6]\n [ 5 31]]'}]
[Epoch:0040/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.38it/s, log={'train_loss: 0.01256,train_acc:99.755%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 93.47it/s, log={'val_loss: 1.15598,val_acc:81.435%,F1: 78.758%,Recall:78.689%,CM:[[22 5]\n [ 8 26]]'}]
[Epoch:0041/0050]: 100%|██████████| 313/313 [00:10<00:00, 30.09it/s, log={'train_loss: 0.01042,train_acc:99.785%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.39it/s, log={'val_loss: 1.20844,val_acc:81.471%,F1: 85.278%,Recall:85.246%,CM:[[22 4]\n [ 5 30]]'}]
[Epoch:0042/0050]: 100%|██████████| 313/313 [00:10<00:00, 28.73it/s, log={'train_loss: 0.02031,train_acc:99.435%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:01<00:00, 87.80it/s, log={'val_loss: 1.00797,val_acc:81.364%,F1: 82.055%,Recall:81.967%,CM:[[28 7]\n [ 4 22]]'}]
[Epoch:0043/0050]: 100%|██████████| 313/313 [00:11<00:00, 28.36it/s, log={'train_loss: 0.00935,train_acc:99.800%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:01<00:00, 87.45it/s, log={'val_loss: 1.16982,val_acc:81.116%,F1: 83.598%,Recall:83.607%,CM:[[26 4]\n [ 6 25]]'}]
[Epoch:0044/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.35it/s, log={'train_loss: 0.00776,train_acc:99.855%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:01<00:00, 87.80it/s, log={'val_loss: 1.23823,val_acc:81.311%,F1: 80.208%,Recall:80.328%,CM:[[30 5]\n [ 7 19]]'}]
[Epoch:0045/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.44it/s, log={'train_loss: 0.01205,train_acc:99.640%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 89.85it/s, log={'val_loss: 1.32987,val_acc:81.240%,F1: 83.544%,Recall:83.607%,CM:[[22 6]\n [ 4 29]]'}]
[Epoch:0046/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.29it/s, log={'train_loss: 0.00906,train_acc:99.780%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 89.22it/s, log={'val_loss: 1.29088,val_acc:81.080%,F1: 70.508%,Recall:70.492%,CM:[[21 8]\n [10 22]]'}]
[Epoch:0047/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.22it/s, log={'train_loss: 0.02002,train_acc:99.520%,F1: 96.639%,Recall:96.667%'}]
100%|██████████| 88/88 [00:01<00:00, 86.17it/s, log={'val_loss: 0.71967,val_acc:80.352%,F1: 68.802%,Recall:68.852%,CM:[[18 10]\n [ 9 24]]'}]
[Epoch:0048/0050]: 100%|██████████| 313/313 [00:10<00:00, 28.71it/s, log={'train_loss: 0.01397,train_acc:99.595%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.80it/s, log={'val_loss: 1.25143,val_acc:80.796%,F1: 86.814%,Recall:86.885%,CM:[[29 2]\n [ 6 24]]'}]
[Epoch:0049/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.84it/s, log={'train_loss: 0.00444,train_acc:99.930%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 95.60it/s, log={'val_loss: 1.30688,val_acc:80.814%,F1: 81.907%,Recall:81.967%,CM:[[31 5]\n [ 6 19]]'}]
[Epoch:0050/0050]: 100%|██████████| 313/313 [00:10<00:00, 29.76it/s, log={'train_loss: 0.00473,train_acc:99.880%,F1: 100.000%,Recall:100.000%'}]
100%|██████████| 88/88 [00:00<00:00, 94.37it/s, log={'val_loss: 1.33346,val_acc:81.489%,F1: 78.746%,Recall:78.689%,CM:[[23 4]\n [ 9 25]]'}]
100%|██████████| 6/6 [00:00<00:00, 98.62it/s, log={'test_acc:82.656%,F1: 77.038%,Recall:77.551%,CM:[[22 1]\n [10 16]]'}]
Process finished with exit code 0
可以发现:经过50轮迭代训练,在测试集上的准确率达到了82.656%
。
TextCNN
模型运行结果为:
D:\softwares\anaconda3\envs\tfpt368\python.exe D:/PycharmProjects/sxlj/PyTorch_demo/text_classification_base_of_lstm_textcnn/main.py
[Epoch:0001/0050]: 100%|██████████| 313/313 [00:06<00:00, 47.68it/s, log={'train_loss: 0.69621,train_acc:52.150%,F1: 53.333%,Recall:53.333%'}]
100%|██████████| 88/88 [00:00<00:00, 264.41it/s, log={'val_loss: 0.67885,val_acc:63.137%,F1: 52.280%,Recall:52.459%,CM:[[17 11]\n [18 15]]'}]
[Epoch:0002/0050]: 100%|██████████| 313/313 [00:01<00:00, 175.49it/s, log={'train_loss: 0.67432,train_acc:59.216%,F1: 53.333%,Recall:53.333%'}]
100%|██████████| 88/88 [00:00<00:00, 269.04it/s, log={'val_loss: 0.65875,val_acc:69.302%,F1: 60.656%,Recall:60.656%,CM:[[21 12]\n [12 16]]'}]
[Epoch:0003/0050]: 100%|██████████| 313/313 [00:01<00:00, 173.83it/s, log={'train_loss: 0.65302,train_acc:64.061%,F1: 76.745%,Recall:76.667%'}]
100%|██████████| 88/88 [00:00<00:00, 269.86it/s, log={'val_loss: 0.63558,val_acc:71.594%,F1: 78.617%,Recall:78.689%,CM:[[18 7]\n [ 6 30]]'}]
[Epoch:0004/0050]: 100%|██████████| 313/313 [00:01<00:00, 170.94it/s, log={'train_loss: 0.62896,train_acc:67.792%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 269.83it/s, log={'val_loss: 0.61021,val_acc:73.139%,F1: 71.826%,Recall:72.131%,CM:[[18 11]\n [ 6 26]]'}]
[Epoch:0005/0050]: 100%|██████████| 313/313 [00:01<00:00, 175.40it/s, log={'train_loss: 0.60655,train_acc:69.452%,F1: 70.247%,Recall:70.000%'}]
100%|██████████| 88/88 [00:00<00:00, 272.33it/s, log={'val_loss: 0.58741,val_acc:73.814%,F1: 78.903%,Recall:78.689%,CM:[[18 5]\n [ 8 30]]'}]
[Epoch:0006/0050]: 100%|██████████| 313/313 [00:01<00:00, 171.92it/s, log={'train_loss: 0.58609,train_acc:70.827%,F1: 86.546%,Recall:86.667%'}]
100%|██████████| 88/88 [00:00<00:00, 272.33it/s, log={'val_loss: 0.56825,val_acc:74.489%,F1: 76.987%,Recall:77.049%,CM:[[21 8]\n [ 6 26]]'}]
[Epoch:0007/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.77it/s, log={'train_loss: 0.57175,train_acc:71.492%,F1: 69.829%,Recall:70.000%'}]
100%|██████████| 88/88 [00:00<00:00, 277.48it/s, log={'val_loss: 0.55366,val_acc:74.720%,F1: 83.598%,Recall:83.607%,CM:[[25 6]\n [ 4 26]]'}]
[Epoch:0008/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.42it/s, log={'train_loss: 0.55642,train_acc:72.567%,F1: 79.911%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 274.05it/s, log={'val_loss: 0.54188,val_acc:74.773%,F1: 83.669%,Recall:83.607%,CM:[[29 6]\n [ 4 22]]'}]
[Epoch:0009/0050]: 100%|██████████| 313/313 [00:01<00:00, 171.51it/s, log={'train_loss: 0.54889,train_acc:72.782%,F1: 70.171%,Recall:70.000%'}]
100%|██████████| 88/88 [00:00<00:00, 284.63it/s, log={'val_loss: 0.53333,val_acc:74.987%,F1: 78.700%,Recall:78.689%,CM:[[25 7]\n [ 6 23]]'}]
[Epoch:0010/0050]: 100%|██████████| 313/313 [00:01<00:00, 175.80it/s, log={'train_loss: 0.53995,train_acc:73.382%,F1: 60.000%,Recall:60.000%'}]
100%|██████████| 88/88 [00:00<00:00, 273.17it/s, log={'val_loss: 0.52711,val_acc:75.004%,F1: 80.253%,Recall:80.328%,CM:[[28 5]\n [ 7 21]]'}]
[Epoch:0011/0050]: 100%|██████████| 313/313 [00:01<00:00, 170.06it/s, log={'train_loss: 0.53371,train_acc:73.477%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 273.17it/s, log={'val_loss: 0.52056,val_acc:75.253%,F1: 72.041%,Recall:72.131%,CM:[[24 7]\n [10 20]]'}]
[Epoch:0012/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.01it/s, log={'train_loss: 0.52858,train_acc:73.937%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 269.80it/s, log={'val_loss: 0.51623,val_acc:75.431%,F1: 78.723%,Recall:78.689%,CM:[[27 7]\n [ 6 21]]'}]
[Epoch:0013/0050]: 100%|██████████| 313/313 [00:01<00:00, 173.58it/s, log={'train_loss: 0.52194,train_acc:74.527%,F1: 76.800%,Recall:76.667%'}]
100%|██████████| 88/88 [00:00<00:00, 279.19it/s, log={'val_loss: 0.51252,val_acc:75.520%,F1: 83.695%,Recall:83.607%,CM:[[23 3]\n [ 7 28]]'}]
[Epoch:0014/0050]: 100%|██████████| 313/313 [00:01<00:00, 172.54it/s, log={'train_loss: 0.51817,train_acc:74.587%,F1: 89.943%,Recall:90.000%'}]
100%|██████████| 88/88 [00:00<00:00, 274.03it/s, log={'val_loss: 0.50848,val_acc:75.324%,F1: 77.277%,Recall:77.049%,CM:[[15 6]\n [ 8 32]]'}]
[Epoch:0015/0050]: 100%|██████████| 313/313 [00:01<00:00, 173.94it/s, log={'train_loss: 0.51235,train_acc:75.298%,F1: 73.333%,Recall:73.333%'}]
100%|██████████| 88/88 [00:00<00:00, 270.66it/s, log={'val_loss: 0.50518,val_acc:75.395%,F1: 78.747%,Recall:78.689%,CM:[[19 6]\n [ 7 29]]'}]
[Epoch:0016/0050]: 100%|██████████| 313/313 [00:01<00:00, 172.38it/s, log={'train_loss: 0.50972,train_acc:75.133%,F1: 66.518%,Recall:66.667%'}]
100%|██████████| 88/88 [00:00<00:00, 265.77it/s, log={'val_loss: 0.50244,val_acc:75.786%,F1: 80.328%,Recall:80.328%,CM:[[28 6]\n [ 6 21]]'}]
[Epoch:0017/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.45it/s, log={'train_loss: 0.50889,train_acc:75.338%,F1: 61.667%,Recall:60.000%'}]
100%|██████████| 88/88 [00:00<00:00, 270.66it/s, log={'val_loss: 0.49999,val_acc:75.893%,F1: 80.689%,Recall:80.328%,CM:[[30 9]\n [ 3 19]]'}]
[Epoch:0018/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.64it/s, log={'train_loss: 0.50319,train_acc:75.548%,F1: 76.796%,Recall:76.667%'}]
100%|██████████| 88/88 [00:00<00:00, 273.20it/s, log={'val_loss: 0.49774,val_acc:75.875%,F1: 75.530%,Recall:75.410%,CM:[[26 9]\n [ 6 20]]'}]
[Epoch:0019/0050]: 100%|██████████| 313/313 [00:01<00:00, 169.32it/s, log={'train_loss: 0.50157,train_acc:75.578%,F1: 69.900%,Recall:70.000%'}]
100%|██████████| 88/88 [00:00<00:00, 272.36it/s, log={'val_loss: 0.49553,val_acc:76.088%,F1: 73.770%,Recall:73.770%,CM:[[17 8]\n [ 8 28]]'}]
[Epoch:0020/0050]: 100%|██████████| 313/313 [00:01<00:00, 175.62it/s, log={'train_loss: 0.49466,train_acc:76.033%,F1: 83.389%,Recall:83.333%'}]
100%|██████████| 88/88 [00:00<00:00, 273.17it/s, log={'val_loss: 0.49245,val_acc:76.532%,F1: 85.286%,Recall:85.246%,CM:[[27 7]\n [ 2 25]]'}]
[Epoch:0021/0050]: 100%|██████████| 313/313 [00:01<00:00, 171.93it/s, log={'train_loss: 0.49366,train_acc:76.458%,F1: 63.457%,Recall:63.333%'}]
100%|██████████| 88/88 [00:00<00:00, 267.35it/s, log={'val_loss: 0.49038,val_acc:76.568%,F1: 68.769%,Recall:68.852%,CM:[[20 12]\n [ 7 22]]'}]
[Epoch:0022/0050]: 100%|██████████| 313/313 [00:01<00:00, 175.70it/s, log={'train_loss: 0.49277,train_acc:76.383%,F1: 86.411%,Recall:86.667%'}]
100%|██████████| 88/88 [00:00<00:00, 274.02it/s, log={'val_loss: 0.48877,val_acc:76.763%,F1: 73.770%,Recall:73.770%,CM:[[23 8]\n [ 8 22]]'}]
[Epoch:0023/0050]: 100%|██████████| 313/313 [00:01<00:00, 175.35it/s, log={'train_loss: 0.48951,train_acc:76.488%,F1: 66.815%,Recall:66.667%'}]
100%|██████████| 88/88 [00:00<00:00, 274.87it/s, log={'val_loss: 0.48666,val_acc:76.710%,F1: 77.160%,Recall:77.049%,CM:[[22 4]\n [10 25]]'}]
[Epoch:0024/0050]: 100%|██████████| 313/313 [00:01<00:00, 172.19it/s, log={'train_loss: 0.48742,train_acc:76.603%,F1: 69.967%,Recall:70.000%'}]
100%|██████████| 88/88 [00:00<00:00, 273.20it/s, log={'val_loss: 0.48497,val_acc:76.781%,F1: 75.543%,Recall:75.410%,CM:[[21 5]\n [10 25]]'}]
[Epoch:0025/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.55it/s, log={'train_loss: 0.48509,train_acc:77.008%,F1: 76.589%,Recall:76.667%'}]
100%|██████████| 88/88 [00:00<00:00, 272.33it/s, log={'val_loss: 0.48300,val_acc:76.905%,F1: 83.598%,Recall:83.607%,CM:[[25 6]\n [ 4 26]]'}]
[Epoch:0026/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.00it/s, log={'train_loss: 0.48135,train_acc:77.168%,F1: 53.333%,Recall:53.333%'}]
100%|██████████| 88/88 [00:00<00:00, 264.94it/s, log={'val_loss: 0.48137,val_acc:76.923%,F1: 77.049%,Recall:77.049%,CM:[[25 7]\n [ 7 22]]'}]
[Epoch:0027/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.83it/s, log={'train_loss: 0.47785,train_acc:77.313%,F1: 74.299%,Recall:73.333%'}]
100%|██████████| 88/88 [00:00<00:00, 283.75it/s, log={'val_loss: 0.47971,val_acc:77.101%,F1: 76.987%,Recall:77.049%,CM:[[21 8]\n [ 6 26]]'}]
[Epoch:0028/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.86it/s, log={'train_loss: 0.47841,train_acc:77.198%,F1: 69.967%,Recall:70.000%'}]
100%|██████████| 88/88 [00:00<00:00, 275.72it/s, log={'val_loss: 0.47854,val_acc:77.136%,F1: 65.555%,Recall:65.574%,CM:[[21 10]\n [11 19]]'}]
[Epoch:0029/0050]: 100%|██████████| 313/313 [00:01<00:00, 168.87it/s, log={'train_loss: 0.47548,train_acc:77.233%,F1: 86.922%,Recall:86.667%'}]
100%|██████████| 88/88 [00:00<00:00, 279.22it/s, log={'val_loss: 0.47715,val_acc:77.243%,F1: 73.770%,Recall:73.770%,CM:[[22 8]\n [ 8 23]]'}]
[Epoch:0030/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.38it/s, log={'train_loss: 0.47491,train_acc:77.953%,F1: 79.911%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 274.02it/s, log={'val_loss: 0.47600,val_acc:77.190%,F1: 73.742%,Recall:73.770%,CM:[[23 6]\n [10 22]]'}]
[Epoch:0031/0050]: 100%|██████████| 313/313 [00:01<00:00, 172.56it/s, log={'train_loss: 0.47108,train_acc:77.668%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 250.69it/s, log={'val_loss: 0.47458,val_acc:77.172%,F1: 77.111%,Recall:77.049%,CM:[[21 6]\n [ 8 26]]'}]
[Epoch:0032/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.38it/s, log={'train_loss: 0.46813,train_acc:77.813%,F1: 83.315%,Recall:83.333%'}]
100%|██████████| 88/88 [00:00<00:00, 277.47it/s, log={'val_loss: 0.47341,val_acc:77.651%,F1: 76.962%,Recall:77.049%,CM:[[20 8]\n [ 6 27]]'}]
[Epoch:0033/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.12it/s, log={'train_loss: 0.46566,train_acc:77.968%,F1: 83.429%,Recall:83.333%'}]
100%|██████████| 88/88 [00:00<00:00, 273.99it/s, log={'val_loss: 0.47243,val_acc:77.314%,F1: 80.028%,Recall:80.328%,CM:[[21 10]\n [ 2 28]]'}]
[Epoch:0034/0050]: 100%|██████████| 313/313 [00:01<00:00, 172.30it/s, log={'train_loss: 0.46597,train_acc:77.933%,F1: 76.859%,Recall:76.667%'}]
100%|██████████| 88/88 [00:00<00:00, 272.30it/s, log={'val_loss: 0.47175,val_acc:78.025%,F1: 74.030%,Recall:73.770%,CM:[[27 10]\n [ 6 18]]'}]
[Epoch:0035/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.15it/s, log={'train_loss: 0.46090,train_acc:78.378%,F1: 76.091%,Recall:76.667%'}]
100%|██████████| 88/88 [00:00<00:00, 269.01it/s, log={'val_loss: 0.47008,val_acc:77.616%,F1: 87.074%,Recall:86.885%,CM:[[38 5]\n [ 3 15]]'}]
[Epoch:0036/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.45it/s, log={'train_loss: 0.46238,train_acc:78.243%,F1: 72.823%,Recall:73.333%'}]
100%|██████████| 88/88 [00:00<00:00, 247.87it/s, log={'val_loss: 0.46991,val_acc:77.527%,F1: 77.037%,Recall:77.049%,CM:[[23 8]\n [ 6 24]]'}]
[Epoch:0037/0050]: 100%|██████████| 313/313 [00:01<00:00, 172.85it/s, log={'train_loss: 0.46166,train_acc:78.318%,F1: 89.899%,Recall:90.000%'}]
100%|██████████| 88/88 [00:00<00:00, 274.02it/s, log={'val_loss: 0.46848,val_acc:78.042%,F1: 83.119%,Recall:83.607%,CM:[[17 8]\n [ 2 34]]'}]
[Epoch:0038/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.03it/s, log={'train_loss: 0.46136,train_acc:78.353%,F1: 73.704%,Recall:73.333%'}]
100%|██████████| 88/88 [00:00<00:00, 274.87it/s, log={'val_loss: 0.46792,val_acc:77.882%,F1: 84.836%,Recall:85.246%,CM:[[19 8]\n [ 1 33]]'}]
[Epoch:0039/0050]: 100%|██████████| 313/313 [00:01<00:00, 169.26it/s, log={'train_loss: 0.45702,train_acc:78.488%,F1: 73.092%,Recall:73.333%'}]
100%|██████████| 88/88 [00:00<00:00, 282.81it/s, log={'val_loss: 0.46709,val_acc:78.167%,F1: 88.578%,Recall:88.525%,CM:[[35 4]\n [ 3 19]]'}]
[Epoch:0040/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.28it/s, log={'train_loss: 0.45724,train_acc:78.413%,F1: 79.819%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 274.02it/s, log={'val_loss: 0.46706,val_acc:77.740%,F1: 73.870%,Recall:73.770%,CM:[[26 9]\n [ 7 19]]'}]
[Epoch:0041/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.09it/s, log={'train_loss: 0.45531,train_acc:78.678%,F1: 66.518%,Recall:66.667%'}]
100%|██████████| 88/88 [00:00<00:00, 258.78it/s, log={'val_loss: 0.46644,val_acc:77.740%,F1: 76.002%,Recall:75.410%,CM:[[27 13]\n [ 2 19]]'}]
[Epoch:0042/0050]: 100%|██████████| 313/313 [00:01<00:00, 172.67it/s, log={'train_loss: 0.45388,train_acc:78.968%,F1: 86.787%,Recall:86.667%'}]
100%|██████████| 88/88 [00:00<00:00, 286.48it/s, log={'val_loss: 0.46547,val_acc:78.220%,F1: 77.049%,Recall:77.049%,CM:[[24 7]\n [ 7 23]]'}]
[Epoch:0043/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.42it/s, log={'train_loss: 0.45191,train_acc:78.673%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 274.90it/s, log={'val_loss: 0.46534,val_acc:77.794%,F1: 71.582%,Recall:72.131%,CM:[[26 4]\n [13 18]]'}]
[Epoch:0044/0050]: 100%|██████████| 313/313 [00:01<00:00, 170.63it/s, log={'train_loss: 0.45087,train_acc:78.978%,F1: 76.745%,Recall:76.667%'}]
100%|██████████| 88/88 [00:00<00:00, 280.11it/s, log={'val_loss: 0.46433,val_acc:78.273%,F1: 77.037%,Recall:77.049%,CM:[[23 8]\n [ 6 24]]'}]
[Epoch:0045/0050]: 100%|██████████| 313/313 [00:01<00:00, 176.83it/s, log={'train_loss: 0.45019,train_acc:79.073%,F1: 76.588%,Recall:76.667%'}]
100%|██████████| 88/88 [00:00<00:00, 279.20it/s, log={'val_loss: 0.46350,val_acc:78.042%,F1: 81.967%,Recall:81.967%,CM:[[25 7]\n [ 4 25]]'}]
[Epoch:0046/0050]: 100%|██████████| 313/313 [00:01<00:00, 173.57it/s, log={'train_loss: 0.44801,train_acc:79.008%,F1: 83.389%,Recall:83.333%'}]
100%|██████████| 88/88 [00:00<00:00, 263.39it/s, log={'val_loss: 0.46312,val_acc:78.344%,F1: 77.049%,Recall:77.049%,CM:[[18 7]\n [ 7 29]]'}]
[Epoch:0047/0050]: 100%|██████████| 313/313 [00:01<00:00, 175.49it/s, log={'train_loss: 0.44767,train_acc:79.178%,F1: 96.648%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 270.66it/s, log={'val_loss: 0.46221,val_acc:78.202%,F1: 82.114%,Recall:81.967%,CM:[[28 8]\n [ 3 22]]'}]
[Epoch:0048/0050]: 100%|██████████| 313/313 [00:01<00:00, 173.91it/s, log={'train_loss: 0.44852,train_acc:78.873%,F1: 96.678%,Recall:96.667%'}]
100%|██████████| 88/88 [00:00<00:00, 275.74it/s, log={'val_loss: 0.46215,val_acc:78.149%,F1: 70.428%,Recall:70.492%,CM:[[22 5]\n [13 21]]'}]
[Epoch:0049/0050]: 100%|██████████| 313/313 [00:01<00:00, 172.21it/s, log={'train_loss: 0.44544,train_acc:79.273%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 275.73it/s, log={'val_loss: 0.46195,val_acc:78.362%,F1: 83.607%,Recall:83.607%,CM:[[32 5]\n [ 5 19]]'}]
[Epoch:0050/0050]: 100%|██████████| 313/313 [00:01<00:00, 174.59it/s, log={'train_loss: 0.44742,train_acc:79.058%,F1: 80.000%,Recall:80.000%'}]
100%|██████████| 88/88 [00:00<00:00, 274.03it/s, log={'val_loss: 0.46105,val_acc:78.380%,F1: 81.919%,Recall:81.967%,CM:[[24 8]\n [ 3 26]]'}]
100%|██████████| 6/6 [00:00<00:00, 286.50it/s, log={'test_acc:79.133%,F1: 80.067%,Recall:79.592%,CM:[[15 2]\n [ 8 24]]'}]
Process finished with exit code 0
可以发现:经过50轮迭代训练,在测试集上的准确率达到了79.592%
。
参考文献: