【ChatBot开发笔记】GPT2模型的导入、使用和分析,模型训练

Transformers是一个先进的NLP框架,适用于pytorch和tensorflow2.0,这边使用的GPT2框架是一个OpenAI的先进技术,在上下文连贯性和情感表达上有相当好的表现,实际使用时可以直接从transformer库导入:

from transformers.models.gpt2.modeling_gpt2 import GPT2Config, GPT2LMHeadModel

其实OpenAI还推出了GPT3——1750亿的参数和千万级别的算力需求就不是笔记本显卡带的动的了。GPT2有15亿参数,面对8M的语料就需要4h左右,还是比较可行的方案


定义参数,统计step,采用梯度累积和warmup策略,导入tensorboardX

def train(model, device, train_list, multi_gpu, args):
    # model:被训练的模型
    # device:使用的GPU
    # train_list:经过划分后的训练集
    # multi_gpu:多GPU设置参数
    # args:代码参数管理器

    train_dataset = MyDataset(train_list)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                  collate_fn=collate_fn)
    model.train()
    # 计算所有epoch进行参数优化的总步数total_steps
    # 在一定条件下,branch_size效果越大越好,梯度累积就以通过积累多个branch的梯度,在有限内存的前提下达到大branch的效果
    total_steps = int(train_dataset.__len__() * args.epochs / args.batch_size / args.gradient_accumulation)
    logger.info('total training steps = {}'.format(total_steps))

    # 设置优化器,并且在初始训练时,使用warmup策略
    # 普遍观点认为在初始阶段以小于基础学习率的学习率进行warmup有助于避免局部过拟合,帮助模型熟悉数据,节约时间
    optimizer = transformers.AdamW(model.parameters(), lr=args.lr, correct_bias=True)
    scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)


    logger.info('starting training')
    # 用于统计每次梯度累计的loss
    running_loss = 0
    # 统计一共训练了多少个step
    overall_step = 0
    # tensorboardX是Pytorch关于tensorflow的可视化工具的一个解决方案
    tb_writer = SummaryWriter(log_dir=args.writer_dir)
    # 记录 out of memory的次数
    oom_time = 0
    

default:

  • epoch = 10
  • batch_count = 4000
  • batch_size = 2
  • gradient_accumulation = 1
  • steps = 40000
# 开始训练
    for epoch in range(args.epochs):
        epoch_start_time = datetime.now()
        for batch_idx, input_ids in enumerate(train_dataloader):
            # 注意:GPT2模型的forward()函数,是对于给定的context,生成一个token,而不是生成一串token
            # GPT2Model的输入为n个token_id时,输出也是n个hidden_state,使用第n个hidden_state预测第n+1个token
            input_ids = input_ids.to(device)
            # 解决在运行过程中,由于显存不足产生的cuda out of memory的问题
            try:
                outputs = model.forward(input_ids=input_ids)  # https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel
                loss, accuracy = calculate_loss_and_accuracy(outputs, labels=input_ids, device=device)

                if multi_gpu:
                    loss = loss.mean()
                    accuracy = accuracy.mean()
                if args.gradient_accumulation > 1:
                    loss = loss / args.gradient_accumulation
                    accuracy = accuracy / args.gradient_accumulation
                loss.backward()
                # 梯度裁剪解决的是梯度消失或爆炸的问题,即设定阈值
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                # 进行一定step的梯度累计之后,更新参数
                if (batch_idx + 1) % args.gradient_accumulation == 0:
                    running_loss += loss.item()
                    # 更新参数
                    optimizer.step()
                    # 清空梯度信息
                    optimizer.zero_grad()
                    # 进行warm up
                    scheduler.step()
                    overall_step += 1
                    # 更新日志与tnesorboardX信息
                    if (overall_step + 1) % args.log_step == 0:
                        logger.info(
                            "batch {} of epoch {}, loss {}, accuracy {}".format(batch_idx + 1, epoch + 1, loss,
                                                                                accuracy))
                        tb_writer.add_scalar('loss', loss.item(), overall_step)
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    oom_time += 1
                    logger.info("WARNING: ran out of memory,times: {}".format(oom_time))
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    logger.info(str(exception))
                    raise exception
        logger.info('saving model for epoch {}'.format(epoch + 1))
        if args.train_mmi:  # 当前训练MMI模型
            model_path = join(args.mmi_model_output_path, 'model_epoch{}'.format(epoch + 1))
        else:  # 当前训练对话模型
            model_path = join(args.dialogue_model_output_path, 'model_epoch{}'.format(epoch + 1))
        if not os.path.exists(model_path):
            os.mkdir(model_path)
        model_to_save = model.module if hasattr(model, 'module') else model
        model_to_save.save_pretrained(model_path)
        logger.info('epoch {} finished'.format(epoch + 1))
        epoch_finish_time = datetime.now()
        logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))
    logger.info('training finished')

算法分析

前馈传播

outputs = model.forward(input_ids=input_ids)

参数input_ids是torch的LongTensor对象

outputs作为return返回,由两个tensor组成,分别是:

  • last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model 本项目中为 [2,135,13317]
  • past_key_values (tuple(tuple(torch.FloatTensor)) 本项目中为10层的[2,2,12,135,64]

反向更新

loss.backward()

你可能感兴趣的:(人工智能,python,深度学习)