LMTC-emnlp论文+代码剖析(BERT-LWAN)

LMTC-emnlp论文

来源

论文Meta-LMTC:Meta-Learning for Large-Scale Multi-Label Text Classification(2021emnlp)提到的使用meta-lmtc方法能增强BERTlike模型即 BERT-LWAN(Ilias Chalkidis…)。原文实验用的是蒸馏的Bert即DistillBert-LWAN。

评估标准

P r e c i s i o n @ K : P @ K = T P @ k T P @ k + F P @ k R e c a l l @ K : R @ K = T P @ k T P @ k + F N @ k n D C G @ K = D C G @ k I D C G @ k   D C G @ k = ∑ i = 1 k r e l i l o g 2 ( i + 1 )   I D C G @ k = ∑ i = 1 ∣ R E L ∣ r e l i l o g 2 ( i + 1 ) Precision@K:P@K=\frac{TP@k}{TP@k+FP@k}\\ Recall@K:R@K=\frac{TP@k}{TP@k+FN@k}\\ nDCG@K = \frac{DCG@k}{IDCG@k}\space DCG@k=\sum_{i=1}^k \frac{rel_i}{log_2(i+1)}\space IDCG@k=\sum_{i=1}^{|REL|} \frac{rel_i}{log_2(i+1)} Precision@K:P@K=TP@k+FP@kTP@kRecall@K:R@K=TP@k+FN@kTP@knDCG@K=IDCG@kDCG@k DCG@k=i=1klog2(i+1)reli IDCG@k=i=1RELlog2(i+1)reli

数据集

EURLEX57K:有关欧盟法律的数据集(发表于2019ACL:Large-Scale Multi-Label Text Classification on EU Legislation)

标签数量:总共4654个。Frequent(出现频次>50):739个 Few(出现频次<50):3369 个 zero:163个

文本数量:

datasets结构:分验证集、测试集、训练集以及一个标签解释器文件。每一个集合里都是json文件

LMTC-emnlp论文+代码剖析(BERT-LWAN)_第1张图片

内容展示:训练时的文本只截取了header、recitals、main body、attachments

{
"celex_id": "32015R0597", 
"uri": "http://publications.europa.eu/resource/cellar/e96dd688-e400-11e4-b1d3-01aa75ed71a1", 
"type": "Regulation", 
"concepts": ["1118", "1605", "2173", "2635", "3191", "693"], 
"title": "Commission Implementing Regulation (EU) 2015/597 of 15 April 2015 establishing the standard import values for determining the entry price of certain fruit and vegetables\n", 
"header": "16.4.2015 EN Official Journal of the European Union L 99/23\nCOMMISSION IMPLEMENTING REGULATION (EU) 2015/597\nof 15 April 2015\nestablishing the standard import values for determining the entry price of certain fruit and vegetables\nTHE EUROPEAN COMMISSION", 
"recitals": ",\nHaving regard to the Treaty on the Functioning of the European Union,\nHaving regard to Regulation (EU) No\u00a01308/2013 of the European Parliament and of the Council of 17\u00a0December 2013 establishing a common organisation of the markets in agricultural products and repealing Council Regulations (EEC) No\u00a0922/72, (EEC) No\u00a0234/79, (EC) No\u00a01037/2001 and (EC) No\u00a01234/2007\u00a0(1),\nHaving regard to Commission Implementing Regulation (EU) No 543/2011 of 7 June 2011 laying down detailed rules for the application of Council Regulation (EC) No 1234/2007 in respect of the fruit and vegetables and processed fruit and vegetables sectors\u00a0(2), and in particular Article 136(1) thereof,\nWhereas:\n(1) Implementing Regulation (EU) No 543/2011 lays down, pursuant to the outcome of the Uruguay Round multilateral trade negotiations, the criteria whereby the Commission fixes the standard values for imports from third countries, in respect of the products and periods stipulated in Annex XVI, Part A thereto.\n(2) The standard import value is calculated each working day, in accordance with Article 136(1) of Implementing Regulation (EU) No 543/2011, taking into account variable daily data. Therefore this Regulation should enter into force on the day of its publication in the Official Journal of the European Union,", 
"main_body": ["The standard import values referred to in Article 136 of Implementing Regulation (EU) No 543/2011 are fixed in the Annex to this Regulation.", "This Regulation shall enter into force on the day of its publication in the Official Journal of the European Union.\nThis Regulation shall be binding in its entirety and directly applicable in all Member States."], 
"attachments": "Done at Brussels, 15 April 2015.\nFor the Commission,\nOn behalf of the President,\nJerzy PLEWA\nDirector-General for Agriculture and Rural Development\n(1)\u00a0\u00a0OJ L\u00a0347, 20.12.2013, p.\u00a0671.\n(2)\u00a0\u00a0OJ L\u00a0157, 15.6.2011, p.\u00a01.\nANNEX\nStandard import values for determining the entry price of certain fruit and vegetables\n(EUR/100 kg)\nCN code Third country code\u00a0(1) Standard import value\n0702\u00a000\u00a000 MA 103,8\nSN 185,4\nTR 120,5\nZZ 136,6\n0707\u00a000\u00a005 MA 176,1\nTR 139,5\nZZ 157,8\n0709\u00a093\u00a010 MA 92,0\nTR 164,4\nZZ 128,2\n0805\u00a010\u00a020 EG 48,6\nIL 72,1\nMA 52,4\nTN 55,3\nTR 67,4\nZZ 59,2\n0805\u00a050\u00a010 MA 57,3\nTR 45,7\nZZ 51,5\n0808\u00a010\u00a080 BR 97,3\nCL 113,9\nCN 100,9\nMK 29,8\nNZ 121,0\nUS 209,2\nZA 122,2\nZZ 113,5\n0808\u00a030\u00a090 AR 107,9\nCL 151,3\nZA 132,7\nZZ 130,6\n(1)\u00a0\u00a0Nomenclature of countries laid down by Commission Regulation (EU) No\u00a01106/2012 of 27\u00a0November 2012 implementing Regulation (EC) No\u00a0471/2009 of the European Parliament and of the Council on Community statistics relating to external trade with non-member countries, as regards the update of the nomenclature of countries and territories (OJ L 328, 28.11.2012, p. 7). Code \u2018ZZ\u2019 stands for \u2018of other origin\u2019."
}

EURLEX57K.json解释器展示

{
"3474": {"concept_id": "3474", "label": "international affairs", "alt_labels": ["international politics"], "parents": []}, 
"1597": {"concept_id": "1597", "label": "school legislation", "alt_labels": [], "parents": ["2467"]}, "3363": {"concept_id": "3363", "label": "union representative", "alt_labels": ["trade union representative"], "parents": ["3374"]}, 
"4488": {"concept_id": "4488", "label": "data processing", "alt_labels": ["automatic data processing", "electronic data processing"], "parents": []}, 
"2316": {"concept_id": "2316", "label": "barge", "alt_labels": ["canal boat"], "parents": ["1036"]}, 
"5709": {"concept_id": "5709", "label": "Lithuania", "alt_labels": ["Republic of Lithuania"], "parents": ["122", "2200", "5283", "5774", "5781"]},.........

论文

2020 emnlp论文:An Empirical Study on Large-Scale Multi-Label Text Classification Including Few and Zero-Shot Labels
LMTC-emnlp论文+代码剖析(BERT-LWAN)_第2张图片

我们实证评估了一系列的LMTC方法,从普通的LWAN到层次分类方法和迁移学习,在来自不同领域的三个数据集()上进行Frequent、Few、Zero-shot学习 。

工作内容(贡献)

  • 基于概率标签树(Probabilistic Label Trees)的层级方法比LWAN(CNN-LWAN)好。
  • 提出了一个新的基于迁移学习的sota模型BERT-LWAN,在总的效果上最好。
  • 通过利用标签的层级关系来增强few and zero-shot learning,提出了新的模型:。

Methods

LWAN:label-wise attention network。(Mullenbachetal.,2018ACL:CNN-LWAN):给予每个标签一个不同的注意力分数。(1)还没从标签的层级关系中利用到结构化信息(2)可能利用到了层级关系只是还在研究中(3)没有结合预训练模型。

CNN-LWAN:

LMTC-emnlp论文+代码剖析(BERT-LWAN)_第3张图片

N为文本长度,x为词向量,de为词向量维度
X = m a t r i x [ x 1 , x 2 , . . . x N ]   X ∈ R d e × N X = matrix[x_1,x_2,...x_N]\space X\in R^{de×N} X=matrix[x1,x2,...xN] XRde×N
嵌入层使用convolutional filter Wc ,dc为输出维度,k为宽度(通道数)
W c ∈ R k × d e × d c W_c \in R^{k×de×dc} WcRk×de×dc
对于每一步n 计算hn(上下文) 最后构成矩阵H
h n = g ( W c ∗ x n : n + k − 1 + b )   b ∈ R d c   H ∈ R d c × N h_n=g(W_c*x_{n:n+k-1} + b)\space b \in R^{dc} \space H \in R^{dc×N} hn=g(Wcxn:n+k1+b) bRdc HRdc×N
通常来说,卷积过后会通过一个池化层减少成为一个向量。但考虑到一个文本中不同部分可能相关。因此,我们给每一个标签运用注意力机制。这样还能给对应的标签从文本中挑出最相关的k-gram。
α l = S o f t M a x ( H T u l ) \alpha_l = SoftMax(H^Tu_l) αl=SoftMax(HTul)
注意力分数与H相乘求和
v l = ∑ n = 1 N α l . . n h n v_l=\sum_{n=1}^N \alpha_{l..n}h_n vl=n=1Nαl..nhn
PLT:Probabilistic Label Trees。考虑到LWAN计算的复杂性,LWAN不能运用在更极端规模的数据集(millions labels)上。Jasinskaetal.,2016;Prabhuetal.,2018;Khandagaleetal.,2019等人把PLT用在了Extreme Multi-label Text Classification(XMTC)。

Flat neural methods

BIGRU-LWAN
a l t = e x p ( h t T u l ) ∑ e x p ( h t T u l ) d l = 1 T ∑ t = 1 T a l t h t a_{lt}=\frac{exp(h_t^T u_l)}{\sum exp(h_t^T u_l)}\\ d_l = \frac{1}{T}\sum_{t=1}^T a_{lt}h_t alt=exp(htTul)exp(htTul)dl=T1t=1Taltht

Transfer learning based LMTC

BERT-LWAN

BERT,ROBERTA

BIGRU-LWAN-ELMO

Hierarchical PLT-based methods

PARABEL,BONSAI

ATTENTION-XML

Zero-shot LMTC

C-BIGRU-LWAN

GC-BIGRU-LWAN

DC-BIGRU-LWAN

DN-BIGRU-LWAN

DNC-BIGRU-LWAN

GNC-BIGRU-LWAN

结果

LMTC-emnlp论文+代码剖析(BERT-LWAN)_第4张图片
LMTC-emnlp论文+代码剖析(BERT-LWAN)_第5张图片
LMTC-emnlp论文+代码剖析(BERT-LWAN)_第6张图片
LMTC-emnlp论文+代码剖析(BERT-LWAN)_第7张图片
LMTC-emnlp论文+代码剖析(BERT-LWAN)_第8张图片

代码

支持LWAN-BIGRU, ZERO-LWAN-BIGRU, GRAPH-ZERO-LWAN-BIGRU, BERT-BASE, ROBERTA-BASE, BERT-LWAN

展示部分代码

标签数据的处理:

train_counts#Counter   for concept in data['concepts']:train_counts[concept] += 1

train_concepts = set(list(train_counts))#存训练集中所有标签id

frequent, few = [], []#分别存训练集中frequent,few标签id

rest_cepts = set()#存dev test中所有的标签id

with open(os.path.join(DATA_SET_DIR, Configuration['task']['dataset'],
                   '{}.json'.format(Configuration['task']['dataset']))) as file:
		data = json.load(file)#导入解释器文件EURLEX57K.json
        none = set(data.keys())#存解释器里所有的标签id 包含了dev test train里所有的标签 以及还有500多个未出现的标签

#存储没有出现在dev test train集合标签
none = none.difference(train_concepts.union((rest_concepts)))
    
parents = []#获得父母标签
#再与所有的父母标签作交集 得到在父母标签中但是没有出现在dev train test里的
none = none.intersection(set(parents))

zero = list(rest_concepts.difference(train_concepts))#出现在test和dev里但是没出现在训练集中的 163个
true_zero = deepcopy(zero)#浅拷贝一份
zero = zero + list(none)

#存储标签列表 [['international', 'affairs'],['school','legislation'],....] 按顺序frequent few zero
#label_terms.append([token for token in word_tokenize(data[label]['label']) if re.search('[A-Za-z]', token)]) 存储的是解释器里所有的标签
label_terms = []

label_terms_ids = vectorizer.tokenize(label_terms)#转词向量

LOGGER.info('#Labels:         {}'.format(len(label_terms)))
LOGGER.info('Frequent labels: {}'.format(len(frequent)))
LOGGER.info('Few labels:      {}'.format(len(few)))
LOGGER.info('Zero labels:     {}'.format(len(true_zero)))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ai3Gt4Hp-1668334791991)(C:\Users\wcx\AppData\Roaming\Typora\typora-user-images\image-20221113150927364.png)]

构建模型

model = LWAN(self.label_terms_id, self.true_labels_cutoff)
model = model.build_compile(n_hidden_layers=Configuration['model']['n_hidden_layers'],
                                        hidden_units_size=Configuration['model']['hidden_units_size'],
                                        dropout_rate=Configuration['model']['dropout_rate'],
                                        word_dropout_rate=Configuration['model']['word_dropout_rate'])

LWAN

class LWAN:
    def __init__(self, label_terms_ids, true_labels_cutoff):
        super().__init__()
        self.label_encoder = Configuration['model']['label_encoder']
        self.token_encoder = Configuration['model']['token_encoder']
        self.word_embedding_path = Configuration['model']['embeddings']
        self.label_terms_ids = label_terms_ids
        self.true_labels_cutoff = true_labels_cutoff
        self.bert_version = Configuration['model']['bert']
    def build_compile(self, n_hidden_layers, hidden_units_size, dropout_rate, word_dropout_rate):
        return self._compile_label_wise_attention(n_hidden_layers=n_hidden_layers,
                                                      hidden_units_size=hidden_units_size,
                                                      dropout_rate=dropout_rate,
                                                      word_dropout_rate=word_dropout_rate)
    def _compile_label_wise_attention(self, n_hidden_layers, hidden_units_size, dropout_rate):
        # Document Encoding 
        inputs = Input(shape=(None,), name='inputs')
        self.pretrained_embeddings = self.PretrainedEmbedding()
        embeddings = self.pretrained_embeddings(inputs)
        token_encodings = self.TokenEncoder(inputs=embeddings, encoder=self.token_encoder,
                                            dropout_rate=dropout_rate, word_dropout_rate=word_dropout_rate,
                                            hidden_layers=n_hidden_layers)

        # Label-wise Attention Mechanism matching documents with labels
        document_label_encodings = LabelWiseAttention(n_classes=len(self.label_terms_ids))(token_encodings)

        model = Model(inputs=[inputs] if not self.elmo else [inputs, inputs_2],
                      outputs=[document_label_encodings])

        return model

LabelwiseAttention

class LabelWiseAttention(Layer):

    def __init__(self, n_classes=4271):
        self.supports_masking = True
        self.n_classes = n_classes
        super(LabelWiseAttention, self).__init__()
    #可训练的参数矩阵
    def build(self, input_shape):
        assert len(input_shape) == 3

        self.Wa = self.add_weight(shape=(self.n_classes, input_shape[-1]),
                                  trainable=True, name='Wa')

        self.Wo = self.add_weight(shape=(self.n_classes, input_shape[-1]),
                                  trainable=True, name='Wo')

        self.bo = self.add_weight(shape=(self.n_classes,),
                                  initializer='zeros',
                                  trainable=True, name='bo')
    #注意力机制
    def call(self, x, mask=None):
    	a = dot_product(x, self.Wa)

    	def label_wise_attention(values):
        	doc_repi, ai = values
        	ai = tf.nn.softmax(tf.transpose(ai))#得注意力分数
        	label_aware_doc_rep = dot_product(ai, tf.transpose(doc_repi))
        	return [label_aware_doc_rep, label_aware_doc_rep]

    	label_aware_doc_reprs, attention_scores = K.map_fn(label_wise_attention, [x, a])

    	# Compute label-scores
    	label_aware_doc_reprs = tf.reduce_sum(label_aware_doc_reprs * self.Wo, axis=-1) + self.bo
    	label_aware_doc_reprs = tf.sigmoid(label_aware_doc_reprs)

    	return label_aware_doc_reprs

训练

fit_history = model.fit_generator(train_generator,
                                  validation_data=val_generator,
                                  epochs=Configuration['model']['epochs'],
                                  callbacks=[early_stopping, model_checkpoint])

评估

在验证集和测试集上分别评估

评估的时候frequenct、few、zero分开评估。由于前面为frequent、few、zero做了一个字典排序。

具体做法:

targets = np.zeros((len(sequences), len(self.label_ids))

label_id = dict()

label_id = [(‘frequent labels’,0),…(‘frequent labels’,738),(‘few labels’,739),…(‘few labels’,4107),(’zero labels’,4108),…]

后面再通过标签号找到对应的区间即可。

总之最后比较的是:预测出来的标签中是frequent的标签和真实的标签中是frequent的标签,预测出来的标签中是few的标签和真实的标签中是few的标签,预测出来的标签中是zero的标签和真实的标签中是zero的标签。

start, end = labels_range
p = precision_score(true_targets[:, start:end], pred_targets[:, start:end], average=average_type)
r = recall_score(true_targets[:, start:end], pred_targets[:, start:end], average=average_type)
f1 = f1_score(true_targets[:, start:end], pred_targets[:, start:end], average=average_type)

``python
start, end = labels_range
p = precision_score(true_targets[:, start:end], pred_targets[:, start:end], average=average_type)
r = recall_score(true_targets[:, start:end], pred_targets[:, start:end], average=average_type)
f1 = f1_score(true_targets[:, start:end], pred_targets[:, start:end], average=average_type)


你可能感兴趣的:(bert,算法,nlp,深度学习)