技术干货|昇思MindSpore NLP模型迁移之Bert模型—文本匹配任务(二):训练和评估

前言:

我将会介绍如何使用MindSpore的Bert模型来做下游任务:lcqmc的文本匹配任务。

主机环境:

系统:ubuntu18

GPU:3090

MindSpore版本:1.3

数据集:lcqmc

lcqmc文本匹配任务的定义:

哈工大文本匹配数据集,LCQMC 是哈尔滨工业大学在自然语言处理国际顶会 COLING2018 构建的问题语义匹配数据集,其目标是判断两个问题的语义是否相同。

数据集中的字段分别如下:

text_a, text_b, label。

其中text_a和text_b为两个问题的文本。若两个问题的语义相同则label为1,否则为0。

权重迁移PyTorch->MindSpore

由于官网已经提供了微调好的权重信息,所以我们尝试直接转换权重进行预测。

我们先要知道模型权重名称以及形状等,需要PyTorch与MindSpore模型一一对应。

首先,我们将huggingface的bert-chinese-base的torch bin文件下载下来。

接下来使用下面的函数将Torch权重参数文件转化为MindSpore权重参数文件

def torch_to_ms(model, torch_model,save_path):
    """
    Updates mobilenetv2 model mindspore param's data from torch param's data.
    Args:
        model: mindspore model
        torch_model: torch model
    """
    print("start load")
    # load torch parameter and mindspore parameter
    torch_param_dict = torch_model
    ms_param_dict = model.parameters_dict()
    count = 0
    for ms_key in ms_param_dict.keys():
        ms_key_tmp = ms_key.split('.')
        if ms_key_tmp[0] == 'bert_embedding_lookup':
            count+=1
            update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.word_embeddings.weight', ms_key)
        elif ms_key_tmp[0] == 'bert_embedding_postprocessor':
            if ms_key_tmp[1] == "token_type_embedding":
                count+=1
                update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.token_type_embeddings.weight', ms_key)
            elif ms_key_tmp[1] == "full_position_embedding":
                count+=1
                update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.position_embeddings.weight',
                                   ms_key)
            elif ms_key_tmp[1] =="layernorm":
                if ms_key_tmp[2]=="gamma":
                    count+=1
                    update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.LayerNorm.weight',
                                       ms_key)
                else:
                    count+=1
                    update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.LayerNorm.bias',
                                       ms_key)
        elif ms_key_tmp[0] == "bert_encoder":
            if ms_key_tmp[3] == 'attention':
                    par = ms_key_tmp[4].split('_')[0]
                    count+=1
                    update_torch_to_ms(torch_param_dict, ms_param_dict, 'encoder.layer.'+ms_key_tmp[2]+'.'+ms_key_tmp[3]+'.'
                                       +'self.'+par+'.'+ms_key_tmp[5],
                                       ms_key)
            elif ms_key_tmp[3] == 'attention_output':
                if ms_key_tmp[4] == 'dense':
                    print(7)
                    count+=1
                    update_torch_to_ms(torch_param_dict, ms_param_dict,
                                   'encoder.layer.' + ms_key_tmp[2] + '.attention.output.'+ms_key_tmp[4]+'.'+ms_key_tmp[5],
                                   ms_key)

                elif ms_key_tmp[4]=='layernorm':
                    if ms_key_tmp[5]=='gamma':
                        print(8)
                        count+=1
                        update_torch_to_ms(torch_param_dict, ms_param_dict,
                                           'encoder.layer.' + ms_key_tmp[2] + '.attention.output.LayerNorm.weight',
                                           ms_key)
                    else:
                        count+=1
                        update_torch_to_ms(torch_param_dict, ms_param_dict,
                                           'encoder.layer.' + ms_key_tmp[2] + '.attention.output.LayerNorm.bias',
                                           ms_key)
            elif ms_key_tmp[3] == 'intermediate':
                count+=1
                update_torch_to_ms(torch_param_dict, ms_param_dict,
                                   'encoder.layer.' + ms_key_tmp[2] + '.intermediate.dense.'+ms_key_tmp[4],
                                   ms_key)
            elif ms_key_tmp[3] == 'output':
                if ms_key_tmp[4] == 'dense':
                    count+=1
                    update_torch_to_ms(torch_param_dict, ms_param_dict,
                                   'encoder.layer.' + ms_key_tmp[2] + '.output.dense.'+ms_key_tmp[5],
                                   ms_key)
                else:
                    if ms_key_tmp[5] == 'gamma':
                        count+=1
                        update_torch_to_ms(torch_param_dict, ms_param_dict,
                                       'encoder.layer.' + ms_key_tmp[2] + '.output.LayerNorm.weight',
                                       ms_key)

                    else:
                        count+=1
                        update_torch_to_ms(torch_param_dict, ms_param_dict,
                                       'encoder.layer.' + ms_key_tmp[2] + '.output.LayerNorm.bias',
                                       ms_key)

        if ms_key_tmp[0] == 'dense':
            if ms_key_tmp[1] == 'weight':
                count+=1
                update_torch_to_ms(torch_param_dict, ms_param_dict,
                                   'pooler.dense.weight',
                                   ms_key)
            else:
                count+=1
                update_torch_to_ms(torch_param_dict, ms_param_dict,
                                   'pooler.dense.bias',
                                   ms_key)

    save_checkpoint(model, save_path)
    print("finish load")

def update_bn(torch_param_dict, ms_param_dict, ms_key, ms_key_tmp):
    """Updates mindspore batchnorm param's data from torch batchnorm param's data."""

    str_join = '.'
    if ms_key_tmp[-1] == "moving_mean":
        ms_key_tmp[-1] = "running_mean"
        torch_key = str_join.join(ms_key_tmp)
        update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key)
    elif ms_key_tmp[-1] == "moving_variance":
        ms_key_tmp[-1] = "running_var"
        torch_key = str_join.join(ms_key_tmp)
        update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key)
    elif ms_key_tmp[-1] == "gamma":
        ms_key_tmp[-1] = "weight"
        torch_key = str_join.join(ms_key_tmp)
        update_torch_to_ms(torch_param_dict, ms_param_dict, 'transformer.' + torch_key, ms_key)
    elif ms_key_tmp[-1] == "beta":
        ms_key_tmp[-1] = "bias"
        torch_key = str_join.join(ms_key_tmp)
        update_torch_to_ms(torch_param_dict, ms_param_dict, 'transformer.' + torch_key, ms_key)

def update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key):
    """Updates mindspore param's data from torch param's data."""

    value = torch_param_dict[torch_key].cpu().numpy()
    value = Parameter(Tensor(value), name=ms_key)
    _update_param(ms_param_dict[ms_key], value)

def _update_param(param, new_param):
    """Updates param's data from new_param's data."""

    if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
        if param.data.dtype != new_param.data.dtype:
            print("Failed to combine the net and the parameters for param %s.", param.name)
            msg = ("Net parameters {} type({}) different from parameter_dict's({})"
                   .format(param.name, param.data.dtype, new_param.data.dtype))
            raise RuntimeError(msg)

        if param.data.shape != new_param.data.shape:
            if not _special_process_par(param, new_param):
                print("Failed to combine the net and the parameters for param %s.", param.name)
                msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
                       .format(param.name, param.data.shape, new_param.data.shape))
                raise RuntimeError(msg)
            return

        param.set_data(new_param.data)
        return

    if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
              if param.data.shape != (1,) and param.data.shape != ():
            print("Failed to combine the net and the parameters for param %s.", param.name)
            msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)."
                   .format(param.name, param.data.shape))
            raise RuntimeError(msg)
        param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))

    elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
        print("Failed to combine the net and the parameters for param %s.", param.name)
        msg = ("Net parameters {} type({}) different from parameter_dict's({})"
               .format(param.name, type(param.data), type(new_param.data)))
        raise RuntimeError(msg)

    else:
        param.set_data(type(param.data)(new_param.data))

def _special_process_par(par, new_par):
    """
    Processes the special condition.

    Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
    """
    par_shape_len = len(par.data.shape)
    new_par_shape_len = len(new_par.data.shape)
    delta_len = new_par_shape_len - par_shape_len
    delta_i = 0
    for delta_i in range(delta_len):
        if new_par.data.shape[par_shape_len + delta_i] != 1:
            break
    if delta_i == delta_len - 1:
        new_val = new_par.data.asnumpy()
        new_val = new_val.reshape(par.data.shape)
        par.set_data(Tensor(new_val, par.data.dtype))
        return True
    return False

实际应用案例如下:

import BertConfig
import BertModel as ms_bm
import BertModel as tc_bm
bert_config_file = "./model/test.yaml"
bert_config = BertConfig.from_yaml_file(bert_config_file)
model = ms_bm(bert_config, False)

torch_model = tc_bm.from_pretrained("/content/model/bert_cn")
torch_to_ms(model, torch_model.state_dict(),"./model/bert2.ckpt")

这里名称一定要一一对应。如果后期改动了模型,也需要在检查一下这个转换函数是否能对应。

下游任务:lcqmc文本匹配任务训练

1、封装Bert为bert_embeding

首先我们先将之前构建好的Bert再进行一步封装为bert_embeding

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert Embedding."""
import logging
from typing import Tuple
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import BertModel, BertConfig
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class BertEmbedding(nn.Cell):
    """
    This is a class that loads pre-trained weight files into the model.
    """
    def __init__(self, bert_config: BertConfig, is_training: bool = False):
        super(BertEmbedding, self).__init__()
        self.bert = BertModel(bert_config, is_training)

    def init_bertmodel(self, bert):
        """
        Manual initialization BertModel
        """
        self.bert = bert

    def from_pretrain(self, ckpt_file):
        """
        Load the model parameters from checkpoint
        """
        param_dict = load_checkpoint(ckpt_file)
        load_param_into_net(self.bert, param_dict)

    def construct(self, input_ids: Tensor, token_type_ids: Tensor, input_mask: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Returns the result of the model after loading the pre-training weights

        Args:
            input_ids (:class:`mindspore.tensor`):A vector containing the transformation of characters
                into corresponding ids.
            token_type_ids (:class:`mindspore.tensor`):A vector containing segemnt ids.
            input_mask (:class:`mindspore.tensor`):the mask for input_ids.
        Returns:
            sequence_output:the sequence output .
            pooled_output:the pooled output of first token:cls..
        """
        sequence_output, pooled_output, _ = self.bert(input_ids, token_type_ids, input_mask)
        return sequence_output, pooled_output

2、下游任务:BertforSequenceClassification

将Bert作为预训练模型,接着在Bert的基础上,取Bert的cls token的embeding作为输入,输入到全连接网络中,这就是BertforSequenceClassification

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert for Sequence Classification script."""

import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.common.dtype as mstype
from mindspore.common.initializer import TruncatedNormal
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
from mindspore.context import ParallelMode
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import Squeeze
from mindspore.communication.management import get_group_size
from mindspore import context, load_checkpoint, load_param_into_net
from mindspore.common.seed import _get_graph_seed
from bert_embedding import BertEmbedding

class BertforSequenceClassification(nn.Cell):
    """
    Train interface for classification finetuning task.

    Args:
        config (Class): Configuration for BertModel.
        is_training (bool): True for training mode. False for eval mode.
        num_labels (int): Number of label types.
        dropout_prob (float): The dropout probability for BertforSequenceClassification.
        multi_sample_dropout (int): Dropout times per step
        label_smooth (float): Label Smoothing Regularization
    """
    def __init__(self, config, is_training, num_labels, dropout_prob=0.0, multi_sample_dropout=1, label_smooth=1):
        super(BertforSequenceClassification, self).__init__()
        if not is_training:
            config.hidden_dropout_prob = 0.0
            config.hidden_probs_dropout_prob = 0.0
        self.bert = BertEmbedding(config, is_training)
        self.cast = P.Cast()
        self.weight_init = TruncatedNormal(config.initializer_range)
        self.softmax = nn.Softmax(axis=-1)
        self.dtype = config.dtype
        self.num_labels = num_labels
        self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
                                has_bias=True).to_float(mstype.float32)
        self.dropout_list=[]
        for count in range(0, multi_sample_dropout):
            seed0, seed1 = _get_graph_seed(1, "dropout")
            self.dropout_list.append(ops.Dropout(1-dropout_prob, seed0, seed1))
        self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction="mean")
        self.squeeze = Squeeze(1)
        self.num_labels = num_labels
        self.is_training = is_training
        self.one_hot = nn.OneHot(depth=num_labels, axis=-1)
        self.label_smooth = label_smooth

    def from_pretrain(self, ckpt_file):
        """
        Load the model parameters from checkpoint
        """
        param_dict = load_checkpoint(ckpt_file)
        load_param_into_net(self, param_dict)

    def init_embedding(self, embedding):
        """
        Manual initialization Embedding
        """
        self.bert = embedding

    def construct(self, input_ids, input_mask, token_type_id, label_ids=0):
        """
        Classification task
        """
        _, pooled_output = self.bert(input_ids, token_type_id, input_mask)
        loss = None
        if self.is_training:
            onehot_label = self.one_hot(self.squeeze(label_ids))
            smooth_label = self.label_smooth * onehot_label + (1-self.label_smooth)/(self.num_labels-1) * (1-onehot_label)
            for dropout in self.dropout_list:
                cls, _ = dropout(pooled_output)
                logits = self.dense_1(cls)
                temp_loss = self.loss(logits, smooth_label)
                if loss == None:
                    loss = temp_loss
                else:
                    loss += temp_loss
            loss = loss/len(self.dropout_list)
        else:
            loss = self.dense_1(pooled_output)
        return loss

class BertLearningRate(LearningRateSchedule):
    """
    Warmup-decay learning rate for Bert network.
    """

    def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
        super(BertLearningRate, self).__init__()
        self.warmup_flag = False
        if warmup_steps > 0:
            self.warmup_flag = True
            self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
        self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
        self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))

        self.greater = P.Greater()
        self.one = Tensor(np.array([1.0]).astype(np.float32))
        self.cast = P.Cast()

    def construct(self, global_step):
        decay_lr = self.decay_lr(global_step)
        if self.warmup_flag:
            is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
            warmup_lr = self.warmup_lr(global_step)
            lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
        else:
            lr = decay_lr
        return lr

class BertFinetuneCell(nn.Cell):
    """
    Especially defined for finetuning where only four inputs tensor are needed.

    Append an optimizer to the training network after that the construct
    function can be called to create the backward graph.

    Different from the builtin loss_scale wrapper cell, we apply grad_clip before the optimization.

    Args:
        network (Cell): The training network. Note that loss function should have been added.
        optimizer (Optimizer): Optimizer for updating the weights.
        scale_update_cell (Cell): Cell to do the loss scale. Default: None.
    """

    def __init__(self, network, optimizer, scale_update_cell=None):

        super(BertFinetuneCell, self).__init__(auto_prefix=False)
        self.network = network
        self.network.set_grad()
        self.weights = optimizer.parameters
        self.optimizer = optimizer
        self.grad = C.GradOperation(get_by_list=True,
                                    sens_param=True)
        self.reducer_flag = False
        self.allreduce = P.AllReduce()
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
            self.reducer_flag = True
        self.grad_reducer = None
        if self.reducer_flag:
            mean = context.get_auto_parallel_context("gradients_mean")
            degree = get_group_size()
            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
        self.cast = P.Cast()
        self.gpu_target = False
        if context.get_context("device_target") == "GPU":
            self.gpu_target = True
            self.float_status = P.FloatStatus()
            self.addn = P.AddN()
            self.reshape = P.Reshape()
        else:
            self.alloc_status = P.NPUAllocFloatStatus()
            self.get_status = P.NPUGetFloatStatus()
            self.clear_status = P.NPUClearFloatStatus()
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.base = Tensor(1, mstype.float32)
        self.less_equal = P.LessEqual()
        self.hyper_map = C.HyperMap()
        self.loss_scale = None
        self.loss_scaling_manager = scale_update_cell
        if scale_update_cell:
            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))

    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  label_ids,
                  sens=None):
        """Bert Finetune"""
        
        weights = self.weights
        init = False
        loss = self.network(input_ids,
                            input_mask,
                            token_type_id,
                            label_ids)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens

        if not self.gpu_target:
            init = self.alloc_status()
            init = F.depend(init, loss)
            clear_status = self.clear_status(init)
            scaling_sens = F.depend(scaling_sens, clear_status)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 label_ids,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))

        self.optimizer(grads)
        return loss

3、任务训练

from mindspore.train.callback import Callback
from mindspore.train.callback import TimeMonitor
from mindspore.train import Model
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore import save_checkpoint, context, load_checkpoint, load_param_into_net
from mindtext.modules.encoder.bert import BertConfig
from bert import BertforSequenceClassification, BertLearningRate, BertFinetuneCell
from bert_embedding import BertEmbedding
import LCQMCDataset
from mindspore.common.tensor import Tensor
import time

def get_ms_timestamp():
    t = time.time()
    return int(round(t * 1000))

class LossCallBack(Callback):
    """
    Monitor the loss in training.

    If the loss is NAN or INF terminating training.

    Note:
        If per_print_times is 0 do not print loss.

    Args:
        per_print_times (int): Print loss every times. Default: 1.
    """

    def __init__(self, per_print_times=1, rank_ids=0):
        super(LossCallBack, self).__init__()
        if not isinstance(per_print_times, int) or per_print_times < 0:
            raise ValueError("print_step must be int and >= 0.")
        self._per_print_times = per_print_times
        self.rank_id = rank_ids
        self.time_stamp_first = get_ms_timestamp()

    def step_end(self, run_context):
        """Monitor the loss in training."""
        time_stamp_current = get_ms_timestamp()
        cb_params = run_context.original_args()
        print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - self.time_stamp_first,
                                                                     cb_params.cur_epoch_num,
                                                                     cb_params.cur_step_num,
                                                                     str(cb_params.net_outputs)))
        with open("./loss_{}.log".format(self.rank_id), "a+") as f:
            f.write("time: {}, epoch: {}, step: {}, loss: {}".format(
                time_stamp_current - self.time_stamp_first,
                cb_params.cur_epoch_num,
                cb_params.cur_step_num,
                str(cb_params.net_outputs.asnumpy())))
            f.write('\n')

def train(train_data, bert, optimizer, save_path, epoch_num):
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
    netwithgrads = BertFinetuneCell(bert, optimizer=optimizer, scale_update_cell=update_cell)
    callbacks = [TimeMonitor(train_data.get_dataset_size()), LossCallBack(train_data.get_dataset_size())]
    model = Model(netwithgrads)
    model.train(epoch_num, train_data, callbacks=callbacks, dataset_sink_mode=False)
    save_checkpoint(model.train_network.network, save_path)

def main():
    #context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
    context.set_context(mode=0, device_target="GPU")
    #context.set_context(enable_graph_kernel=True)

    epoch_num = 6
    save_path = "./model/output/train_lcqmc2.ckpt"
   
    dataset = LCQMCDataset(paths='./dataset/lcqmc',
                      tokenizer="./model",
                      max_length=128,
                      truncation_strategy=True,
                      batch_size=32, columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
                      test_columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'])

    ds = dataset.from_cache(batch_size=128,
                      columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
                      test_columns_list=['input_ids', 'attention_mask', 'token_type_ids'])
    train_data = ds['train']
    bert_config_file = "./model/test.yaml"
    bert_config = BertConfig.from_yaml_file(bert_config_file)
    model_path = "./model/bert_cn.ckpt"
    bert = BertforSequenceClassification(bert_config, True, num_labels=2, dropout_prob=0.1, multi_sample_dropout=5, label_smooth=0.9)
    eb = BertEmbedding(bert_config, True)
    eb.from_pretrain(model_path)
    bert.init_embedding(eb)

    lr_schedule = BertLearningRate(learning_rate=2e-5,
                                   end_learning_rate=2e-5 * 0 ,
                                   warmup_steps=int(train_data.get_dataset_size() * epoch_num * 0.1),
                                   decay_steps=train_data.get_dataset_size() * epoch_num,
                                   power=1.0)
    params = bert.trainable_params()
    optimizer = AdamWeightDecay(params, lr_schedule, eps=1e-8)

    train(train_data, bert, optimizer, save_path, epoch_num)

if __name__ == "__main__":
    main()

关键参数:

bert_config = BertConfig.from_yaml_file(bert_config_file):读取Bert的配置参数

eb.from_pretrain(model_path) :加载Bert的MindSpore权重文件 bert.init_embedding(eb):初始化加载的权重

lr_schedule :学习率控制器

optimizer:梯度优化器

评估

使用lcaqmc的测试集来作为评估训练,输出模型在测试集中的精确度

from mindspore.nn import Accuracy
from tqdm import tqdm
from mindspore import context
import BertforSequenceClassification
import BertConfig
import mindspore
import LCQMCDataset

def eval(eval_data, model):
    metirc = Accuracy('classification')
    metirc.clear()
    squeeze = mindspore.ops.Squeeze(1)
    for batch in tqdm(eval_data.create_dict_iterator(num_epochs=1), total=eval_data.get_dataset_size()):
        input_ids = batch['input_ids']
        token_type_id = batch['token_type_ids']
        input_mask = batch['attention_mask']
        label_ids = batch['label']
        inputs = {"input_ids": input_ids,
                  "input_mask": input_mask,
                  "token_type_id": token_type_id
                  }
        output = model(**inputs)
        sm = mindspore.nn.Softmax(axis=-1)
        output = sm(output)
        #print(output)
        metirc.update(output, squeeze(label_ids))
    print(metirc.eval())

def main():
    context.set_context(mode=0, device_target="GPU")

    dataset = LCQMCDataset(paths='./dataset/lcqmc',
                      tokenizer="./model",
                      max_length=128,
                      truncation_strategy=True,
                      batch_size=128, columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
                      test_columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'])

    #ds = dataset()
    ds = dataset.from_cache(batch_size=128,
                      columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
                      test_columns_list=['input_ids', 'attention_mask', 'token_type_ids','label'])

    eval_data = ds['test']

    bert_config_file = "./model/test.yaml"
    bert_config = BertConfig.from_yaml_file(bert_config_file)

    bert = BertforSequenceClassification(bert_config, is_training=False, num_labels=2,  dropout_prob=0.0)
    model_path = "./model/output/train_lcqmc2.ckpt"
    bert.from_pretrain(model_path)

    eval(eval_data, bert)

if __name__ == "__main__":
    main()

结果

技术干货|昇思MindSpore NLP模型迁移之Bert模型—文本匹配任务(二):训练和评估_第1张图片

模型在对应的数据集的验证集和验证集精确度

你可能感兴趣的:(技术博客,bert,人工智能,深度学习)