pytorch--基于参数权重初始化模型

#coding=utf-8

"""
背景: 将CPM-large模型权重映射到transforer的GPT2LMHeadModel预训练模型上。
存在问题:
1.原始的CPM-large的部分层权重与transformers中的计算方式不一样。 例如线性计算有的用linear, 有的用Cov1D。
linear中的权重自带转置, 而Cov1D的需要权重提前转置
2. GPT2LMHeadModel 最后用的final_linear是带参的liner, 而原始CPM-large用的不带参的矩阵计算。

权重映射方法:
1.先查看模型1与模型2的模型结构与参数shape,理解模型结构, 梳理映射关系
2.基于映射关系更新模型1的权重
3.保存参数更新后的模型
"""
"""

import torch
import torch.nn.functional as F
import os
import argparse
from tqdm import trange
from transformers import GPT2LMHeadModel, GPT2Config, CpmTokenizer ,GPT2Model
import numpy as np

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def set_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='0,1', type=str, required=False, help='设置使用哪些显卡')
    parser.add_argument('--model_config', default='cpm-large.json', type=str, required=False,
                        help='需要从头训练一个模型时,模型参数的配置文件')
    args = parser.parse_args()
    return args

args = set_args()

# 权重待更新模型
def pytorch_model():
    model_config = GPT2Config.from_json_file(args.model_config)
    model = GPT2LMHeadModel(config=model_config)
    #model = GPT2Model(config=model_config)
    model.to(device)

    print("###################transformer gpt2 model ")
    print(model)
    for k, v in model.state_dict().items():
        print('------k', k, v.shape)
    print('-----pytorch w', len( model.state_dict().keys()))
    return model


#torch_model = pytorch_model()

# 已知模型参数
def paddle_model():
    import argparse
    import numpy as np
    from toch_model import GPT2Model
    import torch
    # finetune model file
    pretrained_model = './model/mp_rank_00_model_states.pt'
    # 初始化GPT-2模型
    model = GPT2Model(
        vocab_size=30000,
        layer_size=32,
        block_size=1024,
        embedding_dropout=0.0,
        embedding_size=2560,
        num_attention_heads=32,
        attention_dropout=0.0,
        residual_dropout=0.0)

    print('正在加载模型,耗时需要几分钟,请稍后...')

    # 读取CPM-LM模型参数(FP16)
    state_dict = torch.load(pretrained_model)
    model.load_state_dict(state_dict['module'])
    model.to(device)

    print("############自定义model:")
    print(model)
    # 查看模型参数
    for k, v in model.state_dict().items():
        print('------k', k, v.shape, v)
    print('-----paddle w', len(model.state_dict().keys()))
    return model

p_model = paddle_model()

"""对应模型进行映射"""
def weight_trans( torch_model, paddle_model):
    hidden_map_dict = {'transformer.h.#.ln_1.weight':'transformer.layers.#.input_layernorm.weight',
                  'transformer.h.#.ln_1.bias':'transformer.layers.#.input_layernorm.bias',
                  'transformer.h.#.attn.c_attn.weight':'transformer.layers.#.attention.query_key_value.weight',
                  'transformer.h.#.attn.c_attn.bias':'transformer.layers.#.attention.query_key_value.bias',
                  'transformer.h.#.attn.c_proj.weight':'transformer.layers.#.attention.dense.weight',
                  'transformer.h.#.attn.c_proj.bias':'transformer.layers.#.attention.dense.bias',
                  'transformer.h.#.ln_2.weight': 'transformer.layers.#.post_attention_layernorm.weight',
                  'transformer.h.#.ln_2.bias': 'transformer.layers.#.post_attention_layernorm.bias',
                  'transformer.h.#.mlp.c_fc.weight': 'transformer.layers.#.mlp.dense_h_to_4h.weight',
                  'transformer.h.#.mlp.c_fc.bias': 'transformer.layers.#.mlp.dense_h_to_4h.bias',
                  'transformer.h.#.mlp.c_proj.weight' : 'transformer.layers.#.mlp.dense_4h_to_h.weight',
                  'transformer.h.#.mlp.c_proj.bias': 'transformer.layers.#.mlp.dense_4h_to_h.bias',
                  }
    hidden_map_dict_s = {}
    for i in range(32):
        #print('-----i', str(i))
        for k, v in hidden_map_dict.items():
            k = k.replace('.#.','.'+str(i)+'.')
            v =  v.replace('.#.','.'+str(i)+'.')
            hidden_map_dict_s[k]=v

    state_dict = {}
    paddle_static = paddle_model.state_dict()
    for k, v in torch_model.state_dict().items():
        if k =='transformer.wte.weight':
            state_dict[k] = paddle_static['word_embeddings.weight'].clone()
        elif k == 'transformer.wpe.weight':
            state_dict[k] = paddle_static['position_embeddings.weight'].clone()
        elif k == 'transformer.ln_f.weight':
            state_dict[k] =paddle_static['transformer.final_layernorm.weight'].clone()
        elif k == 'transformer.ln_f.bias':
            state_dict[k] = paddle_static['transformer.final_layernorm.bias'].clone()
        elif k in hidden_map_dict_s.keys():
            if '.'.join(k.split('.')[-3:] )in ['attn.c_attn.weight','attn.c_proj.weight','mlp.c_fc.weight','mlp.c_proj.weight']:
                state_dict[k]= paddle_static[hidden_map_dict_s[k]].clone().transpose(1,0)
            else:
                state_dict[k] = paddle_static[hidden_map_dict_s[k]].clone()
        elif k == 'lm_head.weight':
            state_dict[k] = paddle_static['word_embeddings.weight'].clone()
        else:
            state_dict[k]=v
            #print('-----not find k',k )
        #print(k, state_dict[k].shape, state_dict[k])



    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model_config = GPT2Config.from_json_file(args.model_config)
    model = GPT2LMHeadModel(config=model_config)
    #model = GPT2Model(config=model_config)
    model.to(device)

    model_dict = torch_model.state_dict()
    # 更新修改之后的 model_dict
    model_dict.update(state_dict)

    # 加载我们真正需要的 state_dict
    print('############### after trans par model')
    model.load_state_dict(model_dict, strict=False)
    for k, v in model.state_dict().items():
        print('------k', k, v.shape, v)

    print('-----last para num', len(model_dict.keys()))
    model.save_pretrained('./model/CPM-large_bin/')

    # 查看转的权重是否ok
    model = GPT2LMHeadModel.from_pretrained('./model/CPM-large_bin/')
    model.to(device)

    print('############### after load par model')
    print(model)
    for k, v in model.state_dict().items():
        print('------k', k, v.shape, v)
    print('-----pytorch w', len(model.state_dict().keys()))

    return model

#weight_trans(torch_model, p_model)

你可能感兴趣的:(python学习,pytorch,python,深度学习)