RuntimeError: Error(s) in loading state_dict for BertClassifier 模型不匹配

RuntimeError: Error(s) in loading state_dict for BertClassifier:
	size mismatch for lstm.weight_ih_l0: copying a param with shape torch.Size([400, 768]) from checkpoint, the shape in current model is torch.Size([360, 768]).
	size mismatch for lstm.weight_hh_l0: copying a param with shape torch.Size([400, 100]) from checkpoint, the shape in current model is torch.Size([360, 90]).
	size mismatch for lstm.weight_ih_l0_reverse: copying a param with shape torch.Size([400, 768]) from checkpoint, the shape in current model is torch.Size([360, 768]).
	size mismatch for lstm.weight_hh_l0_reverse: copying a param with shape torch.Size([400, 100]) from checkpoint, the shape in current model is torch.Size([360, 90]).
	size mismatch for linear1.weight: copying a param with shape torch.Size([100, 401]) from checkpoint, the shape in current model is torch.Size([90, 361]).
	size mismatch for linear1.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([90]).
	size mismatch for linear2.weight: copying a param with shape torch.Size([7, 100]) from checkpoint, the shape in current model is torch.Size([7, 90]).
	size mismatch for linear1_ent.weight: copying a param with shape torch.Size([50, 200]) from checkpoint, the shape in current model is torch.Size([45, 180]).
	size mismatch for linear1_ent.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([45]).
	size mismatch for linear2_ent.weight: copying a param with shape torch.Size([2, 50]) from checkpoint, the shape in current model is torch.Size([2, 45]).

原因

由于导入的模型和当前模型的参数不一致。

首先找到当前的模型,
下面展示一些 内联代码片

model = BertClassifier(args)
class BertClassifier(nn.Module):
    'Neural Network Architecture'
    def __init__(self, args):
        
        super(BertClassifier, self).__init__()
        
        self.hid_size = args.hid
        self.batch_size = args.batch
        self.num_layers = args.num_layers
        self.num_classes = len(args.label_to_id)
        self.num_ent_classes = 2

        self.dropout = nn.Dropout(p=args.dropout)
        # lstm is shared for both relation and entity
        self.lstm = nn.LSTM(768, self.hid_size, self.num_layers, bias = False, bidirectional=True)

        # MLP classifier for relation
        self.linear1 = nn.Linear(self.hid_size*4+args.n_fts, self.hid_size)
        self.linear2 = nn.Linear(self.hid_size, self.num_classes)

        # MLP classifier for entity
        self.linear1_ent = nn.Linear(self.hid_size*2, int(self.hid_size / 2))
        self.linear2_ent = nn.Linear(int(self.hid_size / 2), self.num_ent_classes)

        self.act = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)
        self.softmax_ent = nn.Softmax(dim=2)

查看报错,找到不匹配的内容

比如

size mismatch for linear2_ent.weight: copying a param with shape torch.Size([2, 50]) from checkpoint, the shape in current model is torch.Size([2, 45]).

linear2_ent.weight,找到这个在BertClassifier(nn.Module):,查看其构成,

self.linear2_ent = nn.Linear(int(self.hid_size / 2), self.num_ent_classes)

对比两个模型的参数,修改使其一致

你可能感兴趣的:(自然语言处理)