发票关键信息抽取SER

SER(Semantic Entity Recognition):语义实体识别。语义实体识别指的是给定一段文本行,确定其类别(如姓名住址等类别)。本文采用基于VI-LayoutXLM的多模态语义实体识别方法。

1、增值税发票数据集 

    https://download.csdn.net/download/ronshi/88467149

2、配置文件 ser_vi_layoutxlm_xfund_zh_udml.yml

Global:
  use_gpu: true
  epoch_num: &epoch_num 200
  log_smooth_window: 10
  print_batch_step: 10
  save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh_udml
  save_epoch_step: 2000
  # evaluation is run every 10 iterations after the 0th iteration
  eval_batch_step: [ 0, 19 ]
  cal_metric_during_train: False
  save_inference_dir:
  use_visualdl: False
  seed: 2022
  infer_img: my/b201.jpg
  save_res_path: ./output/ser_layoutxlm_xfund_zh/res


Architecture:
  model_type: &model_type "kie"
  name: DistillationModel
  algorithm: Distillation
  Models:
    Teacher:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: &algorithm "LayoutXLM"
      Transform:
      Backbone:
        name: LayoutXLMForSer
        pretrained: True
        # one of base or vi
        mode: vi
        checkpoints:
        num_classes: &num_classes 5
    Student:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: *algorithm
      Transform:
      Backbone:
        name: LayoutXLMForSer
        pretrained: True
        # one of base or vi
        mode: vi
        checkpoints:
        num_classes: *num_classes


Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationVQASerTokenLayoutLMLoss:
      weight: 1.0
      model_name_list: ["Student", "Teacher"]
      key: backbone_out
      num_classes: *num_classes
  - DistillationSERDMLLoss:
      weight: 1.0
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: backbone_out
  - DistillationVQADistanceLoss:
      weight: 0.5
      mode: "l2"
      model_name_pairs:
        - ["Student", "Teacher"]
      key: hidden_states_5
      name: "loss_5"
  - DistillationVQADistanceLoss:
      weight: 0.5
      mode: "l2"
      model_name_pairs:
        - ["Student", "Teacher"]
      key: hidden_states_8
      name: "loss_8"
  
  

Optimizer:
  name: AdamW
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Linear
    learning_rate: 0.00005
    epochs: *epoch_num
    warmup_epoch: 10
  regularizer:
    name: L2
    factor: 0.00000
    
PostProcess:
  name: DistillationSerPostProcess
  model_name: ["Student", "Teacher"]
  key: backbone_out
  class_path: &class_path my/zzsfp/class_list.txt

Metric:
  name: DistillationMetric
  base_metric_name: VQASerTokenMetric
  main_indicator: hmean
  key: "Student"

Train:
  dataset:
    name: SimpleDataSet
    data_dir: my/zzsfp/imgs
    label_file_list: 
      - my/zzsfp/train.json
    ratio_list: [ 1.0 ]
    transforms:
      - DecodeImage: # load image
          img_mode: RGB
          channel_first: False
      - VQATokenLabelEncode: # Class handling label
          contains_re: False
          algorithm: *algorithm
          class_path: *class_path
          # one of [None, "tb-yx"]
          order_method: &order_method "tb-yx"
      - VQATokenPad:
          max_seq_len: &max_seq_len 512
          return_attention_mask: True
      - VQASerTokenChunk:
          max_seq_len: *max_seq_len
      - Resize:
          size: [224,224]
      - NormalizeImage:
          scale: 1
          mean: [ 123.675, 116.28, 103.53 ]
          std: [ 58.395, 57.12, 57.375 ]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 4
    num_workers: 4

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: my/zzsfp/imgs
    label_file_list:
      - my/zzsfp/val.json
    transforms:
      - DecodeImage: # load image
          img_mode: RGB
          channel_first: False
      - VQATokenLabelEncode: # Class handling label
          contains_re: False
          algorithm: *algorithm
          class_path: *class_path
          order_method: *order_method
      - VQATokenPad:
          max_seq_len: *max_seq_len
          return_attention_mask: True
      - VQASerTokenChunk:
          max_seq_len: *max_seq_len
      - Resize:
          size: [224,224]
      - NormalizeImage:
          scale: 1
          mean: [ 123.675, 116.28, 103.53 ]
          std: [ 58.395, 57.12, 57.375 ]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 4
    num_workers: 4

3、训练模型

python tools/train.py -c my/ser_vi_layoutxlm_xfund_zh_udml.yml 

4、模型评估

python tools/eval.py -c my/ser_vi_layoutxlm_xfund_zh_udml.yml -o Global.pretrained_model="./output/ser_vi_layoutxlm_xfund_zh_udml/latest/model_state.pdparams"

5、模型预测 

python tools/infer_kie_token_ser.py -c my/ser_vi_layoutxlm_xfund_zh_udml.yml -o Global.pretrained_model="./output/ser_vi_layoutxlm_xfund_zh_udml/latest/model_state.pdparams"

你可能感兴趣的:(AI,人工智能)