使用xlnet实现中文文本分类 超详细(附代码)

**

使用xlnet实现中文文本分类

**

1、下载xlnet代码
https://github.com/zihangdai/xlnet
2、下载xlnet中文预训练模型
https://github.com/ymcui/Chinese-PreTrained-XLNet
3、训练数据的处理
新建一个文件夹,文件夹名字随意。在此处放置三个单独的文件:train.tsv dev.tsv和test.tsv。在train.tsv,dev.tsv没有标题,如下所示:第1列:行的ID(可以是计数,或者如果你不希望跟踪每个人,则每行甚至可以是相同的数字或字母),第2列:该行的标签为int。这些是分类器旨在预测的分类标签。第3列:所有字母均相同的列,因此您需要包括一个一次性的列。第4栏:您要分类的文本示例。
train.tsv和dev.tsv的示例:
使用xlnet实现中文文本分类 超详细(附代码)_第1张图片
test.tsv格式略有不同。它具有第1列:每个示例的ID,类似于train和dev文件中的第1列,以及第2列:要分类的文本。另外,test.tsv应该有一个标题行(而train和dev没有)。这是test.tsv的示例:
使用xlnet实现中文文本分类 超详细(附代码)_第2张图片
4、打开xlnet-master中的run_classifier.py文件,进行两处修改
1)新建一个mytask类,此类与已有的类并列即可;标签可以根据需要自定义。

class MyTaskProcessor(DataProcessor):
    def __init__(self):
        self.train_file = "train.tsv"
        self.dev_file = "dev.tsv"
        self.test_file = "test.tsv"
        self.label_column = 1
        self.text_a_column = 3
        self.text_b_column = None
        self.contains_header = True
        self.test_text_a_column = None
        self.test_text_b_column = None
        self.test_contains_header = True

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, self.train_file)), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        if self.test_text_a_column is None:
            self.test_text_a_column = self.text_a_column
        if self.test_text_b_column is None:
            self.test_text_b_column = self.text_b_column

        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, self.test_file)), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1", "2"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0 and self.contains_header and set_type != "test":
                continue
            if i == 0 and self.test_contains_header and set_type == "test":
                continue
            guid = "%s-%s" % (set_type, i)

            a_column = (self.text_a_column if set_type != "test" else
                        self.test_text_a_column)
            b_column = (self.text_b_column if set_type != "test" else
                        self.test_text_b_column)

            # there are some incomplete lines in QNLI
            if len(line) <= a_column:
                tf.logging.warning('Incomplete line, ignored.')
                continue
            text_a = line[a_column]

            if b_column is not None:
                if len(line) <= b_column:
                    tf.logging.warning('Incomplete line, ignored.')
                    continue
                text_b = line[b_column]
            else:
                text_b = None

            if set_type == "test":
                label = self.get_labels()[0]
            else:
                if len(line) <= self.label_column:
                    tf.logging.warning('Incomplete line, ignored.')
                    continue
                label = line[self.label_column]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

2)继续在该文件中进行修改,找到main(_)函数,添加我们刚刚创建的类

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  #### Validate flags
  if FLAGS.save_steps is not None:
    FLAGS.iterations = min(FLAGS.iterations, FLAGS.save_steps)

  if FLAGS.do_predict:
    predict_dir = FLAGS.predict_dir
    if not tf.gfile.Exists(predict_dir):
      tf.gfile.MakeDirs(predict_dir)

  processors = {
      "mnli_matched": MnliMatchedProcessor,
      "mnli_mismatched": MnliMismatchedProcessor,
      'sts-b': StsbProcessor,
      'imdb': ImdbProcessor,
      "yelp5": Yelp5Processor,
      "mytask": MyTaskProcessor,
  }

5、新建一个out文件和一个modelout文件,这在后面运行脚本的时候需要用到
然后建一个run.sh文件,用来存储和运行脚本
运行脚本如下:
如果没有使用tpu可以修改脚本的相关参数,根据电脑性能更改最后几个参数,地址可以使用绝对地址(例如model_dir的地址可以是刚才新建的modelout的文件夹)

python run_classifier.py \
--use_tpu=True \
--tpu=${TPU_NAME} \
--do_train=True \
--do_eval=True \
--eval_all_ckpt=True \
--task_name=imdb \
--data_dir=${IMDB_DIR} \
--output_dir=${GS_ROOT}/proc_data/imdb \
--model_dir=${GS_ROOT}/exp/imdb \
--uncased=False \
--spiece_model_file=${LARGE_DIR}/spiece.model \
--model_config_path=${GS_ROOT}/${LARGE_DIR}/model_config.json \
--init_checkpoint=${GS_ROOT}/${LARGE_DIR}/xlnet_model.ckpt \
--max_seq_length=512 \
--train_batch_size=32 \
--eval_batch_size=8 \
--num_hosts=1 \
--num_core_per_host=8 \
--learning_rate=2e-5 \
--train_steps=4000 \
--warmup_steps=500 \
--save_steps=500 \
--iterations=500

6、最后就是把脚本的每一行变成总体的一行,中间用空格隔开。我是放到pycharm的terminal中运行的。
结果如下:忽略准确率,为了加快运行
在这里插入图片描述
感谢支持!
版权声明:本文为博主原创文章,转载请附上博文链接!

你可能感兴趣的:(使用xlnet实现中文文本分类 超详细(附代码))