中文NLP笔记:11. 基于 LSTM 生成古诗


基于 LSTM 生成古诗

1. 语料准备

  一共四万多首古诗,一行一首诗

2. 预处理

  将汉字表示为 One-Hot 的形式

  在每行末尾加上 ] 符号是为了标识这首诗已经结束,说明 ] 符号之前的语句和之后的语句是没有关联关系的,后面会舍弃掉包含 ] 符号的训练数据。

      puncs = [']', '[', '(', ')', '{', '}', ':', '《', '》']

    def preprocess_file(Config):

        # 语料文本内容

        files_content = ''

        with open(Config.poetry_file, 'r', encoding='utf-8') as f:

            for line in f:

                # 每行的末尾加上"]"符号代表一首诗结束

                for char in puncs:

                    line = line.replace(char, "")

                files_content += line.strip() + "]"


        words = sorted(list(files_content))

        words.remove(']')

        counted_words = {}

        for word in words:

            if word in counted_words:

                counted_words[word] += 1

            else:

                counted_words[word] = 1


        # 去掉低频的字

        erase = []

        for key in counted_words:

            if counted_words[key] <= 2:

                erase.append(key)

        for key in erase:

            del counted_words[key]

        del counted_words[']']

        wordPairs = sorted(counted_words.items(), key=lambda x: -x[1])


        words, _ = zip(*wordPairs)

        # word到id的映射

        word2num = dict((c, i + 1) for i, c in enumerate(words))

        num2word = dict((i, c) for i, c in enumerate(words))

        word2numF = lambda x: word2num.get(x, 0)

        return word2numF, num2word, words, files_content

3. 模型参数配置

  class Config(object):

    poetry_file = 'poetry.txt'

    weight_file = 'poetry_model.h5'

    # 根据前六个字预测第七个字

    max_len = 6

    batch_size = 512

    learning_rate = 0.001

4. 构建模型

  通过 PoetryModel 类实现

      class PoetryModel(object):

        def __init__(self, config):

            pass


        def build_model(self):

            pass


        def sample(self, preds, temperature=1.0):

            pass


        def generate_sample_result(self, epoch, logs):

            pass


        def predict(self, text):

            pass


        def data_generator(self):

            pass

        def train(self):

            pass

  (1)init 函数

  加载 Config 配置信息,进行语料预处理和模型加载

      def __init__(self, config):

            self.model = None

            self.do_train = True

            self.loaded_model = False

            self.config = config


            # 文件预处理

            self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)

            if os.path.exists(self.config.weight_file):

                self.model = load_model(self.config.weight_file)

                self.model.summary()

            else:

                self.train()

            self.do_train = False

            self.loaded_model = True

  (2)build_model 函数

  GRU 模型建立

      def build_model(self):

            '''建立模型'''

            input_tensor = Input(shape=(self.config.max_len,))

            embedd = Embedding(len(self.num2word)+1, 300, input_length=self.config.max_len)(input_tensor)

            lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)

            dropout = Dropout(0.6)(lstm)

            lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)

            dropout = Dropout(0.6)(lstm)

            flatten = Flatten()(lstm)

            dense = Dense(len(self.words), activation='softmax')(flatten)

            self.model = Model(inputs=input_tensor, outputs=dense)

            optimizer = Adam(lr=self.config.learning_rate)

            self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

  (3)sample 函数

      def sample(self, preds, temperature=1.0):


        preds = np.asarray(preds).astype('float64')

        preds = np.log(preds) / temperature

        exp_preds = np.exp(preds)

        preds = exp_preds / np.sum(exp_preds)

        probas = np.random.multinomial(1, preds, 1)

        return np.argmax(probas)

  (4)训练模型

      def generate_sample_result(self, epoch, logs): 

            print("\n==================Epoch {}=====================".format(epoch))

            for diversity in [0.5, 1.0, 1.5]:

                print("------------Diversity {}--------------".format(diversity))

                start_index = random.randint(0, len(self.files_content) - self.config.max_len - 1)

                generated = ''

                sentence = self.files_content[start_index: start_index + self.config.max_len]

                generated += sentence

                for i in range(20):

                    x_pred = np.zeros((1, self.config.max_len))

                    for t, char in enumerate(sentence[-6:]):

                        x_pred[0, t] = self.word2numF(char)


                    preds = self.model.predict(x_pred, verbose=0)[0]

                    next_index = self.sample(preds, diversity)

                    next_char = self.num2word[next_index]

                    generated += next_char

                    sentence = sentence + next_char

                print(sentence)

  (5)predict 函数

  根据给出的文字,生成诗句

      def predict(self, text):

            if not self.loaded_model:

                return

            with open(self.config.poetry_file, 'r', encoding='utf-8') as f:

                file_list = f.readlines()

            random_line = random.choice(file_list)

            # 如果给的text不到四个字,则随机补全

            if not text or len(text) != 4:

                for _ in range(4 - len(text)):

                    random_str_index = random.randrange(0, len(self.words))

                    text += self.num2word.get(random_str_index) if self.num2word.get(random_str_index) not in [',', '。',

                                                                                                              ','] else self.num2word.get(

                        random_str_index + 1)

            seed = random_line[-(self.config.max_len):-1]

            res = ''

            seed = 'c' + seed

            for c in text:

                seed = seed[1:] + c

                for j in range(5):

                    x_pred = np.zeros((1, self.config.max_len))

                    for t, char in enumerate(seed):

                        x_pred[0, t] = self.word2numF(char)

                    preds = self.model.predict(x_pred, verbose=0)[0]

                    next_index = self.sample(preds, 1.0)

                    next_char = self.num2word[next_index]

                    seed = seed[1:] + next_char

                res += seed

            return res

  (6) data_generator 函数

  生成数据,提供给模型训练时使用

        def data_generator(self):

            i = 0

            while 1:

                x = self.files_content[i: i + self.config.max_len]

                y = self.files_content[i + self.config.max_len]

                puncs = [']', '[', '(', ')', '{', '}', ':', '《', '》', ':']

                if len([i for i in puncs if i in x]) != 0:

                    i += 1

                    continue

                if len([i for i in puncs if i in y]) != 0:

                    i += 1

                    continue

                y_vec = np.zeros(

                    shape=(1, len(self.words)),

                    dtype=np.bool

                )

                y_vec[0, self.word2numF(y)] = 1.0

                x_vec = np.zeros(

                    shape=(1, self.config.max_len),

                    dtype=np.int32

                )

                for t, char in enumerate(x):

                    x_vec[0, t] = self.word2numF(char)

                yield x_vec, y_vec

                i += 1

  (7)train 函数

      def train(self):

            #number_of_epoch = len(self.files_content) // self.config.batch_size

            number_of_epoch = 10

            if not self.model:

                self.build_model()

            self.model.summary()

            self.model.fit_generator(

                generator=self.data_generator(),

                verbose=True,

                steps_per_epoch=self.config.batch_size,

                epochs=number_of_epoch,

                callbacks=[

                    keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),

                    LambdaCallback(on_epoch_end=self.generate_sample_result)

                ]

            )

5. 进行模型训练

  model = PoetryModel(Config)

6. 作诗

      text = input("text:")

    sentence = model.predict(text)

    print(sentence)



学习资料:

《中文自然语言处理入门实战》

你可能感兴趣的:(中文NLP笔记:11. 基于 LSTM 生成古诗)