在文本分类任务中,有时会碰到多标签分类问题,即某个文章属于多个标签,如下表:
新闻标题 | 标签 |
---|---|
湖人VS凯尔特人比赛中,最后时刻塔图姆打手詹姆斯裁判未响哨,引得外界不满。 | 体育、竞技 |
为什么说狂飙是开年最强爆款 | 娱乐资讯 |
2022年汽车出口超300万辆 | 汽车、财经 |
# 加载预训练模型
bert = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
with_pool=True,
return_keras_model=False,
# hierarchical_position=True
)
output = Dropout(rate=0.1)(bert.model.output)
output = Dense(
units=len(class_id), activation='sigmoid', kernel_initializer=bert.initializer, name="zhx"
)(output)
model = keras.models.Model(bert.model.input, output)
# model.summary()
model.compile(
loss='binary_crossentropy',
optimizer=Adam(2e-5), # 用足够小的学习率
metrics=['accuracy']
)
最后一层使用sigmoid,预测标签为one-hot(所属标签对应值为1)。
def load_data(data):
"""加载数据
单条格式:(文本1, 文本2, 标签id)
"""
D1, D2 = [], []
for item in data:
y = [0] * len(class_id)
for label in item['label']:
y[int(class_id[label])] = 1
if item['docu_name'] in []:
D1.append((item['text'], y))
else:
D2.append((item['text'], y))
return D1, D2
train_self, valid_self = load_data(get_data_from_dir(r'E:\Dataset_self'))
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for is_end, (text, label) in self.sample(random):
# if len(text) > 200:
# text = text[:200] + text[-50:]
token_ids, segment_ids = tokenizer.encode(
text, maxlen=maxlen
)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append(label)
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_labels = sequence_padding(batch_labels)
yield [batch_token_ids, batch_segment_ids], batch_labels
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
def evaluate(data):
total, right = 0., 0.
for x_true, y_true in data:
y_pred = model.predict(x_true)
total += len(y_true)
for i in range(y_true.shape[0]):
y = y_true[i]
p = y_pred[i]
# 排名
label_num = y[y == 1].sum()
p_index_px = np.argsort(p)[::-1]
# 按照排名
# sign = 1
# for j in range(label_num):
# if y[p_index_px[j]] != 1:
# sign = 0
# break
# if sign:
# right += 1
# 严丝合缝
for j in range(p.shape[0]):
if p[j] > 0.2:
p[j] = 1
else:
p[j] = 0
if (y == p).all():
right += 1
return right / total
class Evaluator(keras.callbacks.Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_acc = 0.
def on_epoch_end(self, epoch, logs=None):
val_acc = evaluate(valid_generator)
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
model.save_weights('data/best_model.weights')
test_acc = evaluate(test_generator)
print(
u'val_acc: %.5f, best_val_acc: %.5f, test_acc: %.5f\n' %
(val_acc, self.best_val_acc, test_acc)
)