同学们,注意啦!layoutlmv2模型可能是NLP界的福音哦!
详细理论传送门
torch的水很深,跟tensorflow一样,一定要记得跟cudn版本对上
1、加载数据集
from datasets import load_dataset
datasets = load_dataset("data/datasets")
labels = datasets['train'].features['n_tags'].feature.names
print(labels)
2、封装数据集,将label与id对应
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}
结果如下:
3、数据集转换
# 之前有同学问我加载预处理模型是怎么加载,这里一并答复了
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
# 定义数据集的特征类型
features = Features({
'image': Array3D(dtype="int64", shape=(3, 224, 224)),
'input_ids': Sequence(feature=Value(dtype='int64')),
'attention_mask': Sequence(Value(dtype='int64')),
'token_type_ids': Sequence(Value(dtype='int64')),
'bbox': Array2D(dtype="int64", shape=(512, 4)),
'labels': Sequence(ClassLabel(names=labels)),
})
1、解码训练集
processor.tokenizer.decode(train_dataset['input_ids'][0])
2、做好喂模型的数据迭代器
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2)
3、我们在原生 PyTorch 训练模型。我们使用 AdamW 训练所用的优化器。
# 加载预训练模型
model = LayoutLMv2ForTokenClassification.from_pretrained('microsoft/layoutlmv2-base-uncased',
num_labels=len(labels))
# 装载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-5)
# 设置训练参数
global_step = 0
num_train_epochs = 10
# 纪录训练过程
t_total = len(train_dataloader) * num_train_epochs
# 模型训练
model.train()
for epoch in range(num_train_epochs):
print("Epoch:", epoch)
for batch in tqdm(train_dataloader):
# 将梯度参数清0
optimizer.zero_grad()
# 前向 + 后向 + 优化
outputs = model(**batch)
loss = outputs.loss
# 训练100步输出一次loss
if global_step % 100 == 0:
print(f"Loss in {global_step} steps: {loss.item()}")
loss.backward()
optimizer.step()
global_step += 1
大家可以自己定义自己所需要的评估结果,我这里就不给大家展示了,以免影响到大家的创新能力
{
‘ANSWER’:
{
‘precision’: 0.7677725118483413,
‘recall’: 0.8009888751545118,
‘f1’: 0.7840290381125227,
‘number’: 809
},
‘HEADER’:
{
‘precision’: 0.6095238095238096,
‘recall’: 0.5378151260504201,
‘f1’: 0.5714285714285715,
‘number’: 119
},
‘QUESTION’:
{
‘precision’: 0.8166058394160584,
‘recall’: 0.8403755868544601,
‘f1’: 0.8283202221193892,
‘number’: 1065
},
‘overall_precision’: 0.7858190709046454,
‘overall_recall’: 0.8063221274460612,
‘overall_f1’: 0.7959385834571568,
‘overall_accuracy’: 0.7996731263133318
}