论文链接:Hierarchy-Aware Global Model for Hierarchical Text Classification
github代码链接:HiAGM
详细内容见阅读笔记
-HiAGM
--config # json配置文件路径(下列文件都是示例文件)
--data # 数据及预处理文件示例
--data_modules # 数据处理脚本文件夹
--helper # 训练辅助脚本文件夹
--models # 模型主体架构文件夹
--train_modules # 训练过程脚本文件夹
--LICENSE
--README.md
--evaluate.py # 训练过程评价脚本(评价参数的计算)
--train.py # 模型训练执行脚本
json配置文件示例。
-HiAGM
--config
# gcn结构编码器、HiAGM-TP框架、rcv1数据集、cpu运行
---gcn-rcv1-v2-cpu.json
# gcn结构编码器、HiAGM-TP框架、rcv1数据集、gpu运行
---gcn-rcv1-v2.json
# gcn结构编码器、HiAGM-LA框架、rcv1数据集、gpu运行
---gcnla-rcv1-v2.json
# gcn结构编码器、Origin框架、rcv1数据集、gpu运行
---rcv1-v2.json
# TreeLSTM结构编码器、HiAGM-TP框架、rcv1数据集、gpu运行
---tree-rcv1-v2.json
# TreeLSTM结构编码器、HiAGM-LA框架、rcv1数据集、gpu运行
---treela-rcv1-v2.json
rcv1示例数据,及rcv1,ny,wos数据集的预处理文件。
-HiAGM
--data
# nyt数据预处理脚本
---preprocess_nyt.py
# wos数据预处理脚本
---preprocess_wos.py
# 纽约时报数据文件
---idnewnyt_test.json
---idnewnyt_train.json
---idnewnyt_val.json
---nyt.taxonomy
---nyt_label.vocab
# rcv1数据文件
---rcv1.taxonomy
---rcv1_overall_corpus_train_prob.json
---rcv1_prob.json
---rcv1_test.json
---rcv1_train.json
---rcv1_val.json
---sample_rcv1.taxonomy
数据文件包含以下5个:
# 训练集
dataset_train.json
# 测试集
dataset_test.json
# 验证集
dataset_val.json
# label层级结构
dataset.taxonomy
# label先验概率
dataset_prob.json
基于准备好的数据,做预处理,生成词表,模型数据加载模块等处理脚本。
-HiAGM
--data_modules
# 校对数据,用于data_loader.py中
---collator.py
# 数据加载文件,加载数据集数据,用于模型训练
---data_loader.py
# 读取数据,用于data_loader.py中
---dataset.py
# 数据预处理,包括去停用词,标点清洗等
---preprocess.py
# 构建vocab文件夹中的label.dict和word.dict文件
---vocab.py
辅助脚本文件,如日志管理、checkpoint的加载和保存等函数。
-HiAGM
--helper
# 加载json配置文件内容
---configure.py
# 用于计算data文件夹中的dataset.taxonomy和dataset_prob.json文件
---hierarchy_tree_statistic.py
# 日志管理
---logger.py
# checkpoint的加载和保存等函数
---utils.py
模型架构文件夹,包含结构编码器、嵌入层、多标签注意力等模型结构脚本。
-HiAGM
--models
# 结构编码器相关脚本文件夹(包括GCNN、结构编码器、tree结构生成、加权treeLSTM)
---structure_model
# 嵌入层
---embedding_layer.py
# 模型主体文件
---model.py
# 多标签注意力
---multi_label_attention.py
# 原始模型主体
---origin.py
# 文本编码器
---text_encoder.py
# 文本特征传播
---text_feature_propagation.py
-HiAGM
--train_modules
# loss函数
---criterions.py
# 评价指标计算
---evaluation_metrics.py
# 训练器
---trainer.py
预测值=1 | 预测值=0 | |
---|---|---|
真实值=1 | TP | FN |
真实值=0 | FP | TN |
p r e c i s i o n , P P V , p o s i t i v e p r e d i c t i v e v a l u e = T P / ( T P + F P ) precision, PPV, positive predictive value = TP / (TP + FP) precision,PPV,positivepredictivevalue=TP/(TP+FP)
r e c a l l , s e n s i t i v i t y , T P R , T r u e P o s i t i v e R a t e = T P / ( T P + F N ) recall, sensitivity, TPR, True Positive Rate = TP / (TP + FN) recall,sensitivity,TPR,TruePositiveRate=TP/(TP+FN)
s p e c i f i c i t y , T N R , T r u e N e g a t i v e R a t e = T N / ( T N + F P ) specificity, TNR, True Negative Rate = TN / (TN + FP) specificity,TNR,TrueNegativeRate=TN/(TN+FP)
F 1 _ s c o r e = 2 ∗ p r e c i s i o n ∗ r e c a l l / ( p r e c i s i o n + r e c a l l ) = 2 ∗ T P / ( 2 ∗ T P + F P + F N ) F1\_score = 2 * precision * recall / (precision + recall) =2*TP / (2*TP+FP+FN) F1_score=2∗precision∗recall/(precision+recall)=2∗TP/(2∗TP+FP+FN)
不需要区分知识点,直接使用总体样本的准确率和召回率计算F1-score
。即 m i c r o _ f 1 = 2 ∗ p r e c i s i o n ∗ r e c a l l / ( p r e c i s i o n + r e c a l l ) micro\_f1 = 2 * precision * recall / (precision + recall) micro_f1=2∗precision∗recall/(precision+recall)
precision_micro = float(right_total) / predict_total if predict_total > 0 else 0.0
recall_micro = float(right_total) / gold_total
micro_f1 = 2 * precision_micro * recall_micro / (precision_micro + recall_micro) if (precision_micro + recall_micro) > 0 else 0.0
macro-f1需要先计算出每一个类别的准确率和召回率及其F1-score
,然后通过求均值得到在整个样本上的F1-score
。
precision_macro = sum([v for _, v in precision_dict.items()]) / len(list(precision_dict.keys()))
recall_macro = sum([v for _, v in recall_dict.items()]) / len(list(precision_dict.keys()))
macro_f1 = sum([v for _, v in fscore_dict.items()]) / len(list(fscore_dict.keys()))
可跑通样本集
配置文件:gcn-rcv1-v2.json
训练信息:
2021/08/26 09:58:42 - INFO : Building Vocabulary....
2021/08/26 09:58:42 - INFO : Loading Vocabulary from Cached Dictionary...
2021/08/26 09:58:42 - INFO : Vocabulary of token 50002
2021/08/26 09:58:42 - INFO : Vocabulary of label 40
2021/08/26 09:58:43 - INFO : Loading 300-dimension token embedding from pretrained file: USERPATH/HiAGM/glove.6B/glove.6B.300d.txt
2021/08/26 09:59:00 - INFO : Total vocab size of token is 50002.
2021/08/26 09:59:00 - INFO : Pretrained vocab embedding has 49965 / 50002
结果:
2021/08/26 10:02:43 - INFO : Epoch 76 Time Cost 2.3963842391967773 secs.
2021/08/26 10:02:45 - INFO : TRAIN performance at epoch 77 --- Precision: 1.000000, Recall: 0.128205, Micro-F1: 0.227273, Macro-F1: 0.034722, Loss: 0.345776.
2021/08/26 10:02:46 - INFO : DEV performance at epoch 77 --- Precision: 1.000000, Recall: 0.037037, Micro-F1: 0.071429, Macro-F1: 0.016667, Loss: 0.399174.
2021/08/26 10:02:50 - INFO : TEST performance at epoch 27 --- Precision: 1.000000, Recall: 0.178571, Micro-F1: 0.303030, Macro-F1: 0.045000, Loss: 0.425028.
可跑通
配置文件:gcn-wos-v2-cpu.json
训练信息:
2021/08/23 13:26:43 - INFO : Building Vocabulary....
2021/08/23 13:26:43 - INFO : Loading Vocabulary from Cached Dictionary...
2021/08/23 13:26:43 - INFO : Vocabulary of token 50002
2021/08/23 13:26:43 - INFO : Vocabulary of label 141
2021/08/23 13:26:45 - INFO : Loading 300-dimension token embedding from pretrained file: USERPATH/HiAGM/glove.6B/glove.6B.300d.txt
2021/08/23 13:26:56 - INFO : Total vocab size of token is 50002.
2021/08/23 13:26:56 - INFO : Pretrained vocab embedding has 49965 / 50002
结果:
2021/08/28 07:31:06 - INFO : Epoch 198 Time Cost 2063.6894171237946 secs.
2021/08/28 08:05:01 - INFO : TRAIN performance at epoch 199 --- Precision: 0.997131, Recall: 0.999917, Micro-F1: 0.998522, Macro-F1: 0.989585, Loss: 0.001030.
2021/08/28 08:05:36 - INFO : DEV performance at epoch 199 --- Precision: 0.865440, Recall: 0.822559, Micro-F1: 0.843455, Macro-F1: 0.764416, Loss: 0.018950.
2021/08/28 08:06:24 - INFO : TEST performance at epoch 149 --- Precision: 0.856174, Recall: 0.819091, Micro-F1: 0.837222, Macro-F1: 0.760100, Loss: 0.018396.
没有找到现成的300维中文embedding,暂时先使用
sgns.merge.word
文件。
数据预处理脚本文件
preprocess_ty.py
。
文件列表:
data/ty_total.json
data/ty_train.json
data/ty_test.json
data/ty_val.json
data/ty.taxnomy
.json
文件内容示例:
{"label": ["M131", "M13", "MCAT"], "token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate", "Thursday", "boosting", "US", "dollar", "triggering", "interest", "rate", "cuts", "European"]}
.taxnomy
文件内容示例:
Root CCAT ECAT GCAT MCAT
CCAT C12 C13 C15 C18 C22 C24 C31 C33 C41
C15 C151 C152
...
data/ty_prob.json
数据处理脚本文件
get_prior_prob.py
。
用于获取先验概率。
文件内容示例:
{"Root": {"CCAT": 0.37500000000000006, "ECAT": 0.16666666666666669, "GCAT": 0.29166666666666674, "MCAT": 0.16666666666666669},
"CCAT": {"C12": 0.08333333333333333, "C13": 0.16666666666666666, "C15": 0.3333333333333333, "C18": 0.08333333333333333, "C22": 0.08333333333333333, "C24": 0.16666666666666666, "C31": 0.0, "C33": 0.08333333333333333, "C41": 0.0},
...
}
ty_vocab/word.dict
文件内容示例:
# 词\t词频
said 104
The 65
percent 50
company 30
pct 30
year 29
government 27
ty_vocab/label.dict
文件内容示例:
# 标签\t标签下级节点个数
E12 1
ECAT 5
G15 2
G154 1
GCAT 8
GPOL 3
M13 2
config/gcn-ty-v2.json
文件内容:
{
"data": {
"dataset": "ty",
"data_dir": "data",
"train_file": "ty_train.json",
"val_file": "ty_val.json",
"test_file": "ty_test.json",
"prob_json": "ty_prob.json",
"hierarchy": "ty.taxonomy"
},
"vocabulary": {
"dir": "ty_vocab",
"vocab_dict": "word.dict",
"max_token_vocab": 60000,
"label_dict": "label.dict"
},
"embedding": {
"token": {
"dimension": 300,
"type": "pretrain",
"pretrained_file": "USERPATH/HiAGM/sgns/sgns.merge.word",
"dropout": 0.5,
"init_type": "uniform"
},
"label": {
"dimension": 300,
"type": "random",
"dropout": 0.5,
"init_type": "kaiming_uniform"
}
},
"text_encoder": {
"max_length": 256,
"RNN": {
"bidirectional": true,
"num_layers": 1,
"type": "GRU",
"hidden_dimension": 64,
"dropout": 0.1
},
"CNN": {
"kernel_size": [2, 3, 4],
"num_kernel": 100
},
"topK_max_pooling": 1
},
"structure_encoder": {
"type": "GCN",
"node": {
"type": "text",
"dimension": 300,
"dropout": 0.05
}
},
"model": {
"type": "HiAGM-TP",
"linear_transformation": {
"text_dimension": 300,
"node_dimension": 300,
"dropout": 0.5
},
"classifier": {
"num_layer": 1,
"dropout": 0.5
}
},
"train": {
"optimizer": {
"type": "Adam",
"learning_rate": 0.0001,
"lr_decay": 1.0,
"lr_patience": 5,
"early_stopping": 50
},
"batch_size": 64,
"start_epoch": 0,
"end_epoch": 250,
"loss": {
"classification": "BCEWithLogitsLoss",
"recursive_regularization": {
"flag": true,
"penalty": 0.000001
}
},
"device_setting": {
"device": "cuda",
"visible_device_list": "0",
"num_workers": 10
},
"checkpoint": {
"dir": "ty_hiagm_tp_checkpoint",
"max_number": 10,
"save_best": ["Macro_F1", "Micro_F1"]
}
},
"eval": {
"batch_size": 512,
"threshold": 0.5
},
"test": {
"best_checkpoint": "best_micro_HiAGM-TP",
"batch_size": 512
},
"log": {
"level": "info",
"filename": "gcn-ty-v2.log"
}
}
由于模型本身不具备
predict
脚本,项目原因添加了predict
模块。
以快速实现功能为主,脚本/框架比较粗糙,后期再修改。
使用
data/xlsx2json.py
生成ty_predict.json
文件。label给空list就行。
文件内容示例:
{"token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate", "Thursday", "boosting", "US", "dollar", "triggering", "interest", "rate", "cuts", "European"], "label": []}
{"token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate"], "label": []}
...
项目利益相关,不提供本小节代码。
data_modules/dataset.py
文件中ClassificationDataset
类的__init__
函数中self.corpus_files
添加"PRED"。
data_modules/data_loader.py
文件中data_loaders
函数中添加pred_loader
部分。
config/gcn-ty-v2.json
文件中,data
部分添加pred_file
。
文件内容:
{
"data": {
"dataset": "ty",
"data_dir": "data",
"train_file": "ty_train.json",
"val_file": "ty_val.json",
"test_file": "ty_test.json",
"pred_file": "ty_pred.json",
"prob_json": "ty_prob.json",
"hierarchy": "ty.taxonomy"
}
}
项目利益相关,不提供本小节代码。
predict.py
这里脚本根据train.py
进行的改动,仅快速实现了预测功能,还需要修改优化。
train_modules/trainer.py
文件中Trainer
类的run
函数中添加"PRED"条件,并新增函数pred
。
train_modules/evaluation_metrics.py
文件中evaluate
函数中添加pred_label
,并返回。
由于没有进行任何优化和数据清洗处理,只是跑了流程,数据标签有4层,所以当前效果比较差。