主机环境:
系统:ubuntu18
GPU:3090
MindSpore版本:1.3
数据集:lcqmc
lcqmc文本匹配任务的定义:
哈工大文本匹配数据集,LCQMC 是哈尔滨工业大学在自然语言处理国际顶会 COLING2018 构建的问题语义匹配数据集,其目标是判断两个问题的语义是否相同。
数据集中的字段分别如下:
text_a, text_b, label。
其中text_a和text_b为两个问题的文本。若两个问题的语义相同则label为1,否则为0。
权重迁移PyTorch->MindSpore
由于官网已经提供了微调好的权重信息,所以我们尝试直接转换权重进行预测。
我们先要知道模型权重名称以及形状等,需要PyTorch与MindSpore模型一一对应。
首先,我们将huggingface的bert-chinese-base的torch bin文件下载下来。
接下来使用下面的函数将Torch权重参数文件转化为MindSpore权重参数文件
def torch_to_ms(model, torch_model,save_path):
"""
Updates mobilenetv2 model mindspore param's data from torch param's data.
Args:
model: mindspore model
torch_model: torch model
"""
print("start load")
# load torch parameter and mindspore parameter
torch_param_dict = torch_model
ms_param_dict = model.parameters_dict()
count = 0
for ms_key in ms_param_dict.keys():
ms_key_tmp = ms_key.split('.')
if ms_key_tmp[0] == 'bert_embedding_lookup':
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.word_embeddings.weight', ms_key)
elif ms_key_tmp[0] == 'bert_embedding_postprocessor':
if ms_key_tmp[1] == "token_type_embedding":
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.token_type_embeddings.weight', ms_key)
elif ms_key_tmp[1] == "full_position_embedding":
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.position_embeddings.weight',
ms_key)
elif ms_key_tmp[1] =="layernorm":
if ms_key_tmp[2]=="gamma":
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.LayerNorm.weight',
ms_key)
else:
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.LayerNorm.bias',
ms_key)
elif ms_key_tmp[0] == "bert_encoder":
if ms_key_tmp[3] == 'attention':
par = ms_key_tmp[4].split('_')[0]
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict, 'encoder.layer.'+ms_key_tmp[2]+'.'+ms_key_tmp[3]+'.'
+'self.'+par+'.'+ms_key_tmp[5],
ms_key)
elif ms_key_tmp[3] == 'attention_output':
if ms_key_tmp[4] == 'dense':
print(7)
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'encoder.layer.' + ms_key_tmp[2] + '.attention.output.'+ms_key_tmp[4]+'.'+ms_key_tmp[5],
ms_key)
elif ms_key_tmp[4]=='layernorm':
if ms_key_tmp[5]=='gamma':
print(8)
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'encoder.layer.' + ms_key_tmp[2] + '.attention.output.LayerNorm.weight',
ms_key)
else:
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'encoder.layer.' + ms_key_tmp[2] + '.attention.output.LayerNorm.bias',
ms_key)
elif ms_key_tmp[3] == 'intermediate':
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'encoder.layer.' + ms_key_tmp[2] + '.intermediate.dense.'+ms_key_tmp[4],
ms_key)
elif ms_key_tmp[3] == 'output':
if ms_key_tmp[4] == 'dense':
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'encoder.layer.' + ms_key_tmp[2] + '.output.dense.'+ms_key_tmp[5],
ms_key)
else:
if ms_key_tmp[5] == 'gamma':
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'encoder.layer.' + ms_key_tmp[2] + '.output.LayerNorm.weight',
ms_key)
else:
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'encoder.layer.' + ms_key_tmp[2] + '.output.LayerNorm.bias',
ms_key)
if ms_key_tmp[0] == 'dense':
if ms_key_tmp[1] == 'weight':
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'pooler.dense.weight',
ms_key)
else:
count+=1
update_torch_to_ms(torch_param_dict, ms_param_dict,
'pooler.dense.bias',
ms_key)
save_checkpoint(model, save_path)
print("finish load")
def update_bn(torch_param_dict, ms_param_dict, ms_key, ms_key_tmp):
"""Updates mindspore batchnorm param's data from torch batchnorm param's data."""
str_join = '.'
if ms_key_tmp[-1] == "moving_mean":
ms_key_tmp[-1] = "running_mean"
torch_key = str_join.join(ms_key_tmp)
update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key)
elif ms_key_tmp[-1] == "moving_variance":
ms_key_tmp[-1] = "running_var"
torch_key = str_join.join(ms_key_tmp)
update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key)
elif ms_key_tmp[-1] == "gamma":
ms_key_tmp[-1] = "weight"
torch_key = str_join.join(ms_key_tmp)
update_torch_to_ms(torch_param_dict, ms_param_dict, 'transformer.' + torch_key, ms_key)
elif ms_key_tmp[-1] == "beta":
ms_key_tmp[-1] = "bias"
torch_key = str_join.join(ms_key_tmp)
update_torch_to_ms(torch_param_dict, ms_param_dict, 'transformer.' + torch_key, ms_key)
def update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key):
"""Updates mindspore param's data from torch param's data."""
value = torch_param_dict[torch_key].cpu().numpy()
value = Parameter(Tensor(value), name=ms_key)
_update_param(ms_param_dict[ms_key], value)
def _update_param(param, new_param):
"""Updates param's data from new_param's data."""
if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
if param.data.dtype != new_param.data.dtype:
print("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} type({}) different from parameter_dict's({})"
.format(param.name, param.data.dtype, new_param.data.dtype))
raise RuntimeError(msg)
if param.data.shape != new_param.data.shape:
if not _special_process_par(param, new_param):
print("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
.format(param.name, param.data.shape, new_param.data.shape))
raise RuntimeError(msg)
return
param.set_data(new_param.data)
return
if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
if param.data.shape != (1,) and param.data.shape != ():
print("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)."
.format(param.name, param.data.shape))
raise RuntimeError(msg)
param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))
elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
print("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} type({}) different from parameter_dict's({})"
.format(param.name, type(param.data), type(new_param.data)))
raise RuntimeError(msg)
else:
param.set_data(type(param.data)(new_param.data))
def _special_process_par(par, new_par):
"""
Processes the special condition.
Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
"""
par_shape_len = len(par.data.shape)
new_par_shape_len = len(new_par.data.shape)
delta_len = new_par_shape_len - par_shape_len
delta_i = 0
for delta_i in range(delta_len):
if new_par.data.shape[par_shape_len + delta_i] != 1:
break
if delta_i == delta_len - 1:
new_val = new_par.data.asnumpy()
new_val = new_val.reshape(par.data.shape)
par.set_data(Tensor(new_val, par.data.dtype))
return True
return False
实际应用案例如下:
import BertConfig
import BertModel as ms_bm
import BertModel as tc_bm
bert_config_file = "./model/test.yaml"
bert_config = BertConfig.from_yaml_file(bert_config_file)
model = ms_bm(bert_config, False)
torch_model = tc_bm.from_pretrained("/content/model/bert_cn")
torch_to_ms(model, torch_model.state_dict(),"./model/bert2.ckpt")
这里名称一定要一一对应。如果后期改动了模型,也需要在检查一下这个转换函数是否能对应。
下游任务:lcqmc文本匹配任务训练
首先我们先将之前构建好的Bert再进行一步封装为bert_embeding
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert Embedding."""
import logging
from typing import Tuple
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import BertModel, BertConfig
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class BertEmbedding(nn.Cell):
"""
This is a class that loads pre-trained weight files into the model.
"""
def __init__(self, bert_config: BertConfig, is_training: bool = False):
super(BertEmbedding, self).__init__()
self.bert = BertModel(bert_config, is_training)
def init_bertmodel(self, bert):
"""
Manual initialization BertModel
"""
self.bert = bert
def from_pretrain(self, ckpt_file):
"""
Load the model parameters from checkpoint
"""
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(self.bert, param_dict)
def construct(self, input_ids: Tensor, token_type_ids: Tensor, input_mask: Tensor) -> Tuple[Tensor, Tensor]:
"""
Returns the result of the model after loading the pre-training weights
Args:
input_ids (:class:`mindspore.tensor`):A vector containing the transformation of characters
into corresponding ids.
token_type_ids (:class:`mindspore.tensor`):A vector containing segemnt ids.
input_mask (:class:`mindspore.tensor`):the mask for input_ids.
Returns:
sequence_output:the sequence output .
pooled_output:the pooled output of first token:cls..
"""
sequence_output, pooled_output, _ = self.bert(input_ids, token_type_ids, input_mask)
return sequence_output, pooled_output
2、下游任务:BertforSequenceClassification
将Bert作为预训练模型,接着在Bert的基础上,取Bert的cls token的embeding作为输入,输入到全连接网络中,这就是BertforSequenceClassification。
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert for Sequence Classification script."""
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.common.dtype as mstype
from mindspore.common.initializer import TruncatedNormal
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
from mindspore.context import ParallelMode
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import Squeeze
from mindspore.communication.management import get_group_size
from mindspore import context, load_checkpoint, load_param_into_net
from mindspore.common.seed import _get_graph_seed
from bert_embedding import BertEmbedding
class BertforSequenceClassification(nn.Cell):
"""
Train interface for classification finetuning task.
Args:
config (Class): Configuration for BertModel.
is_training (bool): True for training mode. False for eval mode.
num_labels (int): Number of label types.
dropout_prob (float): The dropout probability for BertforSequenceClassification.
multi_sample_dropout (int): Dropout times per step
label_smooth (float): Label Smoothing Regularization
"""
def __init__(self, config, is_training, num_labels, dropout_prob=0.0, multi_sample_dropout=1, label_smooth=1):
super(BertforSequenceClassification, self).__init__()
if not is_training:
config.hidden_dropout_prob = 0.0
config.hidden_probs_dropout_prob = 0.0
self.bert = BertEmbedding(config, is_training)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.softmax = nn.Softmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(mstype.float32)
self.dropout_list=[]
for count in range(0, multi_sample_dropout):
seed0, seed1 = _get_graph_seed(1, "dropout")
self.dropout_list.append(ops.Dropout(1-dropout_prob, seed0, seed1))
self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction="mean")
self.squeeze = Squeeze(1)
self.num_labels = num_labels
self.is_training = is_training
self.one_hot = nn.OneHot(depth=num_labels, axis=-1)
self.label_smooth = label_smooth
def from_pretrain(self, ckpt_file):
"""
Load the model parameters from checkpoint
"""
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(self, param_dict)
def init_embedding(self, embedding):
"""
Manual initialization Embedding
"""
self.bert = embedding
def construct(self, input_ids, input_mask, token_type_id, label_ids=0):
"""
Classification task
"""
_, pooled_output = self.bert(input_ids, token_type_id, input_mask)
loss = None
if self.is_training:
onehot_label = self.one_hot(self.squeeze(label_ids))
smooth_label = self.label_smooth * onehot_label + (1-self.label_smooth)/(self.num_labels-1) * (1-onehot_label)
for dropout in self.dropout_list:
cls, _ = dropout(pooled_output)
logits = self.dense_1(cls)
temp_loss = self.loss(logits, smooth_label)
if loss == None:
loss = temp_loss
else:
loss += temp_loss
loss = loss/len(self.dropout_list)
else:
loss = self.dense_1(pooled_output)
return loss
class BertLearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for Bert network.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
decay_lr = self.decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr
class BertFinetuneCell(nn.Cell):
"""
Especially defined for finetuning where only four inputs tensor are needed.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Different from the builtin loss_scale wrapper cell, we apply grad_clip before the optimization.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertFinetuneCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.gpu_target = False
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
def construct(self,
input_ids,
input_mask,
token_type_id,
label_ids,
sens=None):
"""Bert Finetune"""
weights = self.weights
init = False
loss = self.network(input_ids,
input_mask,
token_type_id,
label_ids)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
if not self.gpu_target:
init = self.alloc_status()
init = F.depend(init, loss)
clear_status = self.clear_status(init)
scaling_sens = F.depend(scaling_sens, clear_status)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
label_ids,
self.cast(scaling_sens,
mstype.float32))
self.optimizer(grads)
return loss
3、任务训练
from mindspore.train.callback import Callback
from mindspore.train.callback import TimeMonitor
from mindspore.train import Model
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore import save_checkpoint, context, load_checkpoint, load_param_into_net
from mindtext.modules.encoder.bert import BertConfig
from bert import BertforSequenceClassification, BertLearningRate, BertFinetuneCell
from bert_embedding import BertEmbedding
import LCQMCDataset
from mindspore.common.tensor import Tensor
import time
def get_ms_timestamp():
t = time.time()
return int(round(t * 1000))
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_ids=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.rank_id = rank_ids
self.time_stamp_first = get_ms_timestamp()
def step_end(self, run_context):
"""Monitor the loss in training."""
time_stamp_current = get_ms_timestamp()
cb_params = run_context.original_args()
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - self.time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs)))
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
f.write("time: {}, epoch: {}, step: {}, loss: {}".format(
time_stamp_current - self.time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs.asnumpy())))
f.write('\n')
def train(train_data, bert, optimizer, save_path, epoch_num):
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
netwithgrads = BertFinetuneCell(bert, optimizer=optimizer, scale_update_cell=update_cell)
callbacks = [TimeMonitor(train_data.get_dataset_size()), LossCallBack(train_data.get_dataset_size())]
model = Model(netwithgrads)
model.train(epoch_num, train_data, callbacks=callbacks, dataset_sink_mode=False)
save_checkpoint(model.train_network.network, save_path)
def main():
#context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
context.set_context(mode=0, device_target="GPU")
#context.set_context(enable_graph_kernel=True)
epoch_num = 6
save_path = "./model/output/train_lcqmc2.ckpt"
dataset = LCQMCDataset(paths='./dataset/lcqmc',
tokenizer="./model",
max_length=128,
truncation_strategy=True,
batch_size=32, columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
test_columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'])
ds = dataset.from_cache(batch_size=128,
columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
test_columns_list=['input_ids', 'attention_mask', 'token_type_ids'])
train_data = ds['train']
bert_config_file = "./model/test.yaml"
bert_config = BertConfig.from_yaml_file(bert_config_file)
model_path = "./model/bert_cn.ckpt"
bert = BertforSequenceClassification(bert_config, True, num_labels=2, dropout_prob=0.1, multi_sample_dropout=5, label_smooth=0.9)
eb = BertEmbedding(bert_config, True)
eb.from_pretrain(model_path)
bert.init_embedding(eb)
lr_schedule = BertLearningRate(learning_rate=2e-5,
end_learning_rate=2e-5 * 0 ,
warmup_steps=int(train_data.get_dataset_size() * epoch_num * 0.1),
decay_steps=train_data.get_dataset_size() * epoch_num,
power=1.0)
params = bert.trainable_params()
optimizer = AdamWeightDecay(params, lr_schedule, eps=1e-8)
train(train_data, bert, optimizer, save_path, epoch_num)
if __name__ == "__main__":
main()
关键参数:
bert_config = BertConfig.from_yaml_file(bert_config_file):读取Bert的配置参数
eb.from_pretrain(model_path) :加载Bert的MindSpore权重文件 bert.init_embedding(eb):初始化加载的权重
lr_schedule :学习率控制器
optimizer:梯度优化器
评估
from mindspore.nn import Accuracy
from tqdm import tqdm
from mindspore import context
import BertforSequenceClassification
import BertConfig
import mindspore
import LCQMCDataset
def eval(eval_data, model):
metirc = Accuracy('classification')
metirc.clear()
squeeze = mindspore.ops.Squeeze(1)
for batch in tqdm(eval_data.create_dict_iterator(num_epochs=1), total=eval_data.get_dataset_size()):
input_ids = batch['input_ids']
token_type_id = batch['token_type_ids']
input_mask = batch['attention_mask']
label_ids = batch['label']
inputs = {"input_ids": input_ids,
"input_mask": input_mask,
"token_type_id": token_type_id
}
output = model(**inputs)
sm = mindspore.nn.Softmax(axis=-1)
output = sm(output)
#print(output)
metirc.update(output, squeeze(label_ids))
print(metirc.eval())
def main():
context.set_context(mode=0, device_target="GPU")
dataset = LCQMCDataset(paths='./dataset/lcqmc',
tokenizer="./model",
max_length=128,
truncation_strategy=True,
batch_size=128, columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
test_columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'])
#ds = dataset()
ds = dataset.from_cache(batch_size=128,
columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
test_columns_list=['input_ids', 'attention_mask', 'token_type_ids','label'])
eval_data = ds['test']
bert_config_file = "./model/test.yaml"
bert_config = BertConfig.from_yaml_file(bert_config_file)
bert = BertforSequenceClassification(bert_config, is_training=False, num_labels=2, dropout_prob=0.0)
model_path = "./model/output/train_lcqmc2.ckpt"
bert.from_pretrain(model_path)
eval(eval_data, bert)
if __name__ == "__main__":
main()
结果
模型在对应的数据集的验证集和验证集精确度