首先我们需要下载中文模型文件,直接给出链接下载即可https://pan.baidu.com/s/1-c068UOgfhrMyIIhR5fHXg,提取码是: 2z2r,解压完成后会出现五个文件,其中一个词汇表文件vocab.txt,还有三个Bert tensorflow的模型文件,这里就不一一列举了,还有一个参数设置文件bert_config.json。接下来再去github上down下来模型就可以开始搞了!
开搞!
首先在main()下面的processors里面创建一个自己的项目,例如我自己的创建成叫做my_bert,在后面接着定义一个类的名称。
def main(_):
tf.logging.set_verbosity(tf.logging.INFO) #设计日志级别
'''在这里创建一个项目'''
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"my_bert":my_bertProcessor
}
接下来我们可以按照processors里的其它项目进行改动,比如说MrpcProcessor(),我们可以把class MrpcProcessor(DataProcessor):整个都复制过来,然后在下面重新粘贴一下即可,然后我们把下面的代码段进行略微的改动即可进行分类,具体怎么改看下面介绍
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
我们需要将这个类的名字改成你自己定义的那个,然后把def get_labels(self):里面的分类类别改动一下,然后在def _create_examples(self, lines, set_type):进行细微改动即可,具体改动我把我的改动给大家,看一下可以对比一下,就很明显知道为什么要这么改了
class my_bertProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35']
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
在这里解释一下为什么要这么改动,首先我们要做多分类,所以必须要将label改成我们所需要的,其次我们在进行文本处理的时候,和我们输入数据的格式有关系,这里的line的格式类似[label , context],所以我们在提取的时候需要设置一下文本和对应提取的内容,再就是我们进行文本分类的时候,我们只有一句话,所以不需要text_b,我们在下面直接设置一下None就可以了
下面就介绍一下参数的给定,我没有按照github上面那样给定,而是直接在文件里面进行设置了一下,下面代码给出的是我的设置
flags.DEFINE_string(
"data_dir", "data", #这里面需要添加你自己的分类数据文件夹的名字
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string(
"bert_config_file", "BERT/bert_config.json", #这里面加入参数文件
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
flags.DEFINE_string("task_name", "my_bert", "The name of the task to train.") #这个task里添加的是你自己的项目名字
flags.DEFINE_string("vocab_file", "BERT/vocab.txt", #解压完后的五个文件之一
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string(
"output_dir", "model/", #这个里将会存储你接下来训练的模型文件和验证的结果
"The output directory where the model checkpoints will be written.")
## Other parameters
flags.DEFINE_string(
"init_checkpoint", 'BERT/bert_model.ckpt', #这里面添加上你下载的bert训练好的模型文件
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_bool("do_train", True, "Whether to run training.") #训练的时候训练成True
flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.")
flags.DEFINE_bool(
"do_predict", True, #想要验证就tr
"Whether to run the model in inference mode on the test set.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
#这个batchsize越大模型效果越好,当然这取决你的机器内存多大
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
flags.DEFINE_float("num_train_epochs", 3.0,
"Total number of training epochs to perform.")
flags.DEFINE_float(
"warmup_proportion", 0.1,
"Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10% of training.")
flags.DEFINE_integer("save_checkpoints_steps", 100,
"How often to save the model checkpoint.")
flags.DEFINE_integer("iterations_per_loop", 50,
"How many steps to make in each estimator call.")
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
tf.flags.DEFINE_string(
"tpu_name", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", None,
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
"gcp_project", None,
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_integer(
"num_tpu_cores", 8,
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
接下来说一下输入数据的格式,bert模型默认读取.tsv文件,所以我们直接把txt文档的后缀改一下就可以,文件里的内容可以是这个样子,比如说:label + \t + sentence,然后运行就可以了,文件会自动打印,把你的训练数据文件改成如下图所示:
给出一部分我自己的训练的截图
至此就这些,有需要补充的想起来再说