version:tensorflow==2.3.0, transformers==4.5.0
1.继承TFPretrainedModel
class TFMyBertModelPreTrained(TFBertPreTrainedModel):
def __init__(self,config):
super(TFMyBertModelPreTrained, self).__init__(config)
self.bert = TFBertModel(config)
self.classifier = tf.keras.layers.Dense(2, name="classifier")
def call(self,inputs):
pretrain_features = self.bert(inputs)
logits = self.classifier(pretrain_features.pooler_output)
return logits
2.继承tf.keras.models.Model
class TFMyBertModel(tf.keras.models.Model):
def __init__(self, model_path):
super(TFMyBertModel, self).__init__()
self.bert = TFBertModel.from_pretrained(model_path)
# 接一个Dense做2分类
self.classifier = tf.keras.layers.Dense(2, name="classifier")
def call(self, inputs,*args,**kwargs):
pretrain_features = self.bert(inputs,*args,**kwargs)
#这边只是为了学习,在训练时就没加dropout,防止过拟合
logits = self.classifier(pretrain_features.pooler_output)
return logits
使用区别主要在于模型初始化和调用方式。
1.需要根据给定的config文件,来构建模型。
2.直接在model_path下查找config文件来构建模型,并且加载参数。
如果使用1的方法构建,调用时,需要构建一个config对象:
bert_config = BertConfig.from_pretrained(model_path)
model = TFMyBertModelPreTrained(bert_config)
如果使用2的方法构建,需要给一个model_Path,
model = TFMyBertModel(model_path=model_path)
另外,继承TFBertPretrainedModel,可以继续被fine-tuning.
整体代码
from transformers import TFBertModel, TFBertPreTrainedModel
import tensorflow as tf
class TFMyBertModel(tf.keras.models.Model):
def __init__(self, model_path,**kwargs):
super(TFMyBertModel, self).__init__(model_path,**kwargs)
self.bert = TFBertModel.from_pretrained(model_path,output_hidden_states=True, output_attentions=True)
self.classifier = tf.keras.layers.Dense(2, name="classifier")
# debug code
@tf.autograph.experimental.do_not_convert
def call(self, inputs,**kwargs):
pretrain_features = self.bert(inputs,output_hidden_states=True, output_attentions=True)
logits = self.classifier(pretrain_features.pooler_output)
return logits
# 从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
class TFMyBertModelPreTrained(TFBertPreTrainedModel):
def __init__(self, config,**kwargs):
super(TFMyBertModelPreTrained, self).__init__(config,**kwargs)
self.bert = TFBertModel(config,**kwargs)
self.classifier = tf.keras.layers.Dense(2, name="classifier")
@tf.autograph.experimental.do_not_convert
def call(self, inputs,**kwargs):
pretrain_features = self.bert(inputs, output_hidden_states=True, output_attentions=True)
logits = self.classifier(pretrain_features.pooler_output)
return logits