基于GPT2实现考公申论文章生成

近几年来,考公的人数越来越多,而申论作为考公非常重要的一部分,也是另很多人头痛的一部分。很多人在考试之前都会背一些优秀范文或句段,以便在考试时派上用场。这里我用GPT2预训练很多篇申论范文,使之能在某个话题的提示下自动申成一片范文或句段。话不多说,直接上代码。

数据预处理

这里我找了500篇申论范文,不是很多,当然也可以多找点,最好是各类话题都有,越多越好。

造字典

将所有文章中的字,符号提取出来,去重后存入一个txt文档中
代码实现

import os
DIR_PATH = r"novels"
VOCAB_FILE = r"Vocab.txt"
words = set()
x=0
for i, filename in enumerate(os.listdir(DIR_PATH)):
    x=x+1
    f_path = os.path.join(DIR_PATH, filename)
    print(f_path)
    with open(f_path, "r+", encoding="utf-8") as f:
        w = f.read(1)
        while w:

            if w == '\n' or w == '\r' or w == ' ':
                # words.add('[SEP]')
                pass
            else:
                words.add(w)
            w = f.read(1)
print(x)
with open(VOCAB_FILE, "w+", encoding="utf-8") as f:
    f.write("[START] [SEQ] [UNK] [PAD] [END] ")
    f.write(" ".join(words))
    f.flush()

对文章进行编码

利用字典对文章进行编码,如字典中第12个字是“我”,则在原文中的“我”就用数字11代替,然后保存每篇文章的编码。
代码实现:

import os
SRC_DIR = r"novels"
DST_DIR = r"encoded_novels"
VOCAB_FILE = "Vocab.txt"
if not os.path.exists(DST_DIR):
    os.makedirs(DST_DIR)
with open(VOCAB_FILE, "r+", encoding="utf-8") as f:
    tokens = f.read().split()
count = 0
for i, filename in enumerate(os.listdir(SRC_DIR)):
    f_path = os.path.join(SRC_DIR, filename)
    print(f_path)
    with open(f_path, "r+", encoding="utf-8") as f:
        dst = ["0"]
        w = f.read(1)
        while w:
            if w == '\n' or w == '\r' or w == '\t' or ord(w) == 12288:
                dst.append("1")
            elif w == ' ':
                dst.append("3")
            else:
                try:
                    dst.append(str(tokens.index(w)))
                except:
                    dst.append("2")
            w = f.read(1)
    count+=1
    with open(os.path.join(DST_DIR, "{}.txt".format(count)), "w+", encoding="utf-8") as df:
        df.write(" ".join(dst))
print(count)

网络模型

我搭建的是带多头注意力的GPT模型,由于电脑GPU显存不大,所以头数设的12,模块数设的6,字的维数为768,最多可生成500字

# config文件
block_num = 6
head_num = 12
embed_dim = 768
vocab_num = 3012
pos_num =500
multi=4
stride=1
device = "cuda:0"
import torch
from torch import nn
import config as cfg
class Attention(nn.Module):
    def __init__(self, isMask=True):
        super().__init__()
        self.dk = (cfg.embed_dim // cfg.head_num) ** 0.5
        self.isMask = isMask
        self.c_attn = nn.Linear(cfg.embed_dim, cfg.embed_dim * 3)
        self.attn_drop = nn.Dropout(0.1)
        self.resi_drop = nn.Dropout(0.1)
        self.c_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
        if self.isMask:
            # self.register_buffer("mask", torch.tril(torch.ones(cfg.pos_num, cfg.pos_num)))
                self.mask = torch.tril(torch.ones(cfg.pos_num, cfg.pos_num)).cuda()
    def forward(self, x):
        x = self.c_attn(x) # x形状(N,S,V),N代表多少个句子,S代表多少个词,V代表每个词的维度
        x = x.reshape(*x.shape[:-1], cfg.head_num, -1)  # (N,S,V)——>(N,S,12,768/12*3)
        x = x.transpose(-2, -3)  # (N,S,12,768/12*3)——>(N,12,,S,768/12*3)
        q, k, v = x.chunk(3, dim=-1)
        w = (q @ k.transpose(-1, -2)) / self.dk  # (N,12,S,64)@(N,12,64,S)=(N,12,S,S)
        # if self.isMask:
        # mask=(self.mask if self.isMask else 1)
        mask=torch.tril(torch.ones(w.size(-2), w.size(-1))).cuda()
        w = w * mask - (1 - mask) * 1e5
        w = torch.softmax(w, dim=-1)
        w = self.attn_drop(w)
        a = w @ v  # (N,12,S,S)@(N,12,S,64)-->(N,12,S,64)
        a = a.transpose(-2, -3)  # (N,12,S,64)-->(N,S,12,64)
        a = a.reshape(*a.shape[:-2], cfg.embed_dim)  # (N,S,12,64)-->(N,S,768)
        h = self.c_proj(a)
        h = self.resi_drop(h)
        return h
class Block(nn.Module):
    def __init__(self, isMask=True):
        super().__init__()
        self.layer_normal_1 = nn.LayerNorm(cfg.embed_dim)
        self.attention = Attention(isMask)
        self.layer_normal_2 = nn.LayerNorm(cfg.embed_dim)
        self.proj = nn.Sequential(
            nn.Linear(cfg.embed_dim, cfg.multi * cfg.embed_dim),
            nn.LeakyReLU(),
            nn.Linear(cfg.multi * cfg.embed_dim, cfg.embed_dim),
        )
        self.dropout = nn.Dropout(0.1)
    def forward(self, x):
        h = self.layer_normal_1(x)
        a = self.attention(h)
        a = a + x  # 加一个残差
        a = self.layer_normal_2(a)
        h = self.proj(a)
        h = self.dropout(h)
        y = h + a  # 加一个残差
        return y
class GPT2(nn.Module):
    def __init__(self):
        super().__init__()
        self.vocab_embed = nn.Embedding(cfg.vocab_num, cfg.embed_dim) # 定义一个字典
        self.pos_embed = nn.Embedding(cfg.pos_num, cfg.embed_dim)   # 定义一个位置编码
        # self.type_embed = nn.Embedding(cfg.type_num, cfg.embed_dim)   # 定义一个类型编码
        self.blocks = []
        for _ in range(cfg.block_num):
            self.blocks.append(Block())
        self.drop = nn.Dropout(0.1)
        self.sequential = nn.Sequential(*self.blocks)
        self.output_layer = nn.Linear(cfg.embed_dim, cfg.vocab_num, bias=False)
    def forward(self, x, p):
        e = self.vocab_embed(x)  # 对输入进行词向量编码
        p = self.pos_embed(p)    # 对输入进行位置编码
        # t = self.type_embed(t)   # 对输入进行类型编码
        h = self.drop(e + p)
        h = self.sequential(h)
        return self.output_layer(h)

网络训练

生成训练数据

import torch, os
from torch.utils.data import Dataset
import config as cfg
class MyDataset(Dataset):
    def __init__(self, dir):
        self.dataset = []
        for filename in os.listdir(dir):
            with open(os.path.join(dir, filename), "r+") as f:
                ws = [int(x) for x in f.readline().split()]
                ws_len = len(ws)
                start = 0
                while ws_len - start > cfg.pos_num + 1:
                    self.dataset.append(ws[start:start + cfg.pos_num + 1])
                    start += cfg.stride
                else:
                    if ws_len > cfg.pos_num + 1:
                        self.dataset.append(ws[ws_len - cfg.pos_num - 1:])
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, index):
        data = torch.tensor(self.dataset[index])
        return data[0:-1], data[1:]

训练


from module import *
from dataset import *
import torch, os
from torch import  optim
from torch.utils.data import DataLoader
from torch.nn import  functional as F
# def weight_init(m):
#     if isinstance(m, nn.Linear):
#         nn.init.xavier_normal_(m.weight)
#         if m.bias is not None:
#             nn.init.constant_(m.bias, 0)
save_path=r"网络参数"
class Trainer:
    def __init__(self):
        self.net = GPT2()
        self.weight_file = os.path.join(save_path, "gpt2_k.pt")
        if os.path.exists(self.weight_file):
            self.net.load_state_dict(torch.load(self.weight_file))
        # else:
        #     self.net.apply(weight_init)

        self.net.to(torch.device(cfg.device))

        self.opt = optim.Adam(self.net.parameters(), lr=0.0001)
    def train(self):
        myDataset = MyDataset(r"encoded_novels")
        print(len(myDataset))
        dataloader = DataLoader(myDataset, batch_size=4, shuffle=True)
        epoch=0
        while True:
            epoch=epoch+1
            sum_loss = 0
            for i, (x, y) in enumerate(dataloader):
                x, y = x.to(torch.device(cfg.device)), y.to(torch.device(cfg.device))
                p = torch.arange(0, x.shape[1])[None, :].repeat(x.shape[0], 1).to(torch.device(cfg.device))
                # print(p)
                _y = self.net(x, p).reshape(-1, cfg.vocab_num)
                y = y.reshape(-1)
                loss = F.cross_entropy(_y, y)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                print(loss.cpu().detach().item())
                sum_loss += loss.cpu().detach().item()
                if i % 1000 == 0 and i > 0:
                    torch.save(self.net.state_dict(), self.weight_file)
            print("第{0}轮训练完毕".format(epoch))
            print("轮的平均损失为{0}".format(sum_loss / len(dataloader)))
            torch.save(self.net.state_dict(), self.weight_file)
            print("参数保存成功")

测试

from module import *
def transer(x):              # 索引到字的换算
    VOCAB_FILE = "Vocab.txt"
    with open(VOCAB_FILE, "r+", encoding="utf-8") as f:
        tokens = f.read().split()
    y=x[0]
    for i in y:
        print(tokens[i], end=" ")
def Transfer(str):          # 字到索引的换算
    VOCAB_FILE = "Vocab.txt"
    with open(VOCAB_FILE, "r+", encoding="utf-8") as f:
        tokens = f.read().split()
    idx=tokens.index(str)
    return idx
if __name__ == '__main__':
    gpt = GPT2()
    gpt.to(torch.device(cfg.device))
    gpt.eval()
    gpt.load_state_dict(torch.load(r"网络参数\gpt2_k.pt"))

    os = []
    x = torch.tensor([[Transfer("依"),Transfer("法"),Transfer("治"),Transfer("国")]]).cuda()  # 给定一个开始词
    p = torch.tensor([[0,1,2,3]]).cuda()  # 给定一个起始位置
    l=x.size()[1]
    for i in range(400):
        y = gpt(x, p)
        y = y[:, -1:]
        v, y = torch.topk(y, 8, dim=-1)

        v, y = v.reshape(-1, 8), y.reshape(-1, 8)
        v = torch.multinomial(torch.softmax(v, dim=-1), 1)
        y = torch.gather(y, -1, v)

        x = torch.cat([x, y], dim=1)
        p = torch.tensor([range(i + l + 1)]).cuda()
    print(transer(x))

比如,输入“人工智能”,则会生成如下片段:

人 工 智 能 , 网 上 购 物 , 物 联 网 , 各 种 新 兴 技 术 层 出 不 穷 , 各 种 创 新 思 想 不 断 迸 发 , 国 家 政 策 环 境 需 求 都 为 创 新 提 供 了 丰 富 的 土 壤 , 这 也 是 最 坏 的 时 代 , 自 主 品 牌 创 新 能 力 薄 弱 , 山 寨 产 品 盛 行 , 核 心 技 术 被 外 方 意 志 很 大 程 度 上 削 减 了 我 国 的 竞 争 力 , 究 其 原 因 , 一 方 面 是 企 业 缺 乏 竞 争 意 识 , 创 新 意 识 目 光 短 浅 所 致 , 而 另 一 方 面 在 于 人 才 的 流 失 , 由 于 学 术 界 浮 躁 的 氛 围 , 以 及 体 制 的 不 完 善 等 , 许 多 科 研 人 员 面 临 工 资 低 , 没 有 项 目 的 窘 境 , 为 了 改 善 环 境 , 降 低 生 存 压 力 , 转 而 流 向 其 他 的 领 域 , 因 此 想 要 中 国 品 牌 走 出 国 门 , 提 升 竞 争 力 , 创 新 是 关 键 。 打 造 中 国 品 牌 提 升 国 家 竞 争 力 , 融 入 民 族 精 神 是 重 点 。 中 国 品 牌 之 所 以 被 称 为 中 国 品 牌 , 关 键 在 于 其 拥 有 独 特 的 魅 力 , 不 同 于 其 他 国 家 , 必 须 有 中 国 的 特 色 , 必 须 有 中 国 的 文 化 , 与 文 化 紧 密 结 合 , 故 宫 博 物 院 的 文 创 产 品 , 就 是 将 这 一 融 合 发 挥 到 极 致 的 典 范 , 将 文 物 蕴 含 的 文 化 内 容 融 入 到 产 品 设 计 当 中 , 设 计 出 具 有 中 国 特 色 的 独 一 无 二 的 文 创 产 品 , 不 仅 能 够 吸 引 大 量 的 游 客 , 更 传 承 了 中 国 文 化 之 道 , 不 仅 打 造 了 品 牌 , 更 将 这 一 品 牌 销 往 国 外 , 可 见 , 打 造 中 国 品 牌 , 还 必 须 要 将 中 国 文 化 结 合 其 中 , 方 能 够 让 中 国 品 牌 脱 颖 而 出 , 与 众 不 同 , 方 能 体 现 中 国 竞 争 力 。

你可能感兴趣的:(基于GPT2实现考公申论文章生成)