Bert-BiLSTM-CRF pytorch 代码解析-1:def _forward_alg(self, feats, mask=None)

理解 github上代码:Bert-BiLSTM-CRF-pytorch
Github 相关链接: link.

neg_log_likelihood_loss = forward_score - gold_score
这部分应该是为了计算所有路径的分数(forward_score )

    def _forward_alg(self, feats, mask=None):
        Do the forward algorithm to compute the partition function (batched).

            feats: size=(batch_size, seq_len, self.target_size+2)
            mask: size=(batch_size, seq_len)

        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(-1)

        # 1. mask 转置 后 shape 为: (seq_len, batch), 
        #    feats 原先 shape=(batch_size, seq_len, tag_size) 
        #          先转置:    (seq_len, batch_size, tag_size)
        #          view:  (seq_len*batch_size, 1, tag_size)
        #          然后在 -2 维度复制: (seq_len*batch_size, [tag_size], tag_size)
        mask = mask.transpose(1, 0).contiguous()
        ins_num = batch_size * seq_len
        feats = feats.transpose(1, 0).contiguous().view(
            ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)

        # 2. scores: LSTM所有时间步的输出 feats 先加上 转移分数
        scores = feats + self.transitions.view(
            1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)
        seq_iter = enumerate(scores) 
        # seq_iter: t=0 开始的LSTM所有时间步迭代输出
        # inivalues: t=1 开始的LSTM所有时间步迭代输出
            _, inivalues = seq_iter.__next__()
            _, inivalues =

        # 2. 计算 a 在 t=0 时刻的初始值
        partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
        # 3. 迭代计算 a (即partition ) 在 t=1,2,。。。更新的值
        for idx, cur_values in seq_iter: # fro idx = 1,2,3..., cur_values是LSTM输出+转移分数的值
            cur_values = cur_values + partition.contiguous().view(
                batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            cur_partition = log_sum_exp(cur_values, tag_size)
            mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
            masked_cur_partition = cur_partition.masked_select(mask_idx.byte())
            if masked_cur_partition.dim() != 0:
                # 将mask_idx中值为1元素对应的masked_cur_partition中位置的元素复制到本partition中。
                # mask应该有和partition相同数目的元素。
                # 即 mask 部分的 partition值不再更新
                mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
                partition.masked_scatter_(mask_idx.byte(), masked_cur_partition)
        cur_values = self.transitions.view(1, tag_size, tag_size).expand(
                batch_size, tag_size, tag_size) + partition.contiguous().view(
                batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
        cur_partition = log_sum_exp(cur_values, tag_size)
        final_partition = cur_partition[:, self.END_TAG_IDX]
        return final_partition.sum(), scores

def log_sum_exp(vec, m_size):
    直接计算可能会出现 exp(999)=INF 上溢问题
    所以 考虑 torch.max(vec, 1)这部分, 以避免 上溢问题

        vec: size=(batch_ size, vanishing_dim, hidden_dim)
        m_size: hidden_dim
        size=(batch_size, hidden_dim)

    _, idx = torch.max(vec, 1)  # B * 1 * M ,为了防止 log(过大值max),所有值减去每列最大值
    max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size)  # B * M
    return max_score.view(-1, m_size) + torch.log(torch.sum(
        torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)


class CRF(nn.Module) 中 _forward_alg 过程

0. 输入的设置


feats = torch.FloatTensor([[[ 0.1938, -0.0033, -0.0786,  0.1115],
                            [-0.0450, -0.1575,  0.0550, -0.1546],
                            [-0.0271, -0.0669, -0.0533, -0.1674]],

                            [[-0.0269, -0.1714, -0.0775, -0.0791],
                            [-0.0745, -0.2008, -0.1868,  0.2168],
                            [ 0.0703,  0.0196,  0.0457,  0.0400]]])
mask=torch.FloatTensor([[1, 1, 0],
                        [1, 1, 1]])
tags=torch.FloatTensor([[5, 4, 0],
                        [5, 1, 6]])

transitions = torch.Tensor([[    7,     3,  -1000,     2],
                            [    2,     1,  -1000,     5],
                            [    1,     3,  -1000,     2],
                            [-1000, -1000, -1000, -1000]])
# transitions.shape: torch.Size([4,4]), 4中包含2个起止符 

1. mask 和 feats 的转置及扩展维度等处理

batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(-1)
# 1. mask 转置 后 shape 为: (seq_len, batch), 
#    feats 原先 shape=(batch_size, seq_len, tag_size) 
#          先转置:    (seq_len, batch_size, tag_size)
#          view:  (seq_len*batch_size, 1, tag_size)
#          然后在 -2 维度复制: (seq_len*batch_size, [tag_size], tag_size)
mask = mask.transpose(1, 0).contiguous()
ins_num = batch_size * seq_len
feats = feats.transpose(1, 0).contiguous().view(
    ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
print ('batch_size,seq_len,tag_size:',batch_size,seq_len,tag_size)
print ('\nmask:',mask.shape,'\n',mask)
print ('\nins_num=',ins_num)
print ('\nfeats:',feats.shape,'\n',feats)
batch_size,seq_len,tag_size: 2 3 4

mask: torch.Size([3, 2]) 
      tensor([[1, 1],
              [1, 1],
              [0, 1]], dtype=torch.uint8)

ins_num= 6

feats: torch.Size([6, 4, 4]) 
 tensor([[[ 0.1938, -0.0033, -0.0786,  0.1115],
         [ 0.1938, -0.0033, -0.0786,  0.1115],
         [ 0.1938, -0.0033, -0.0786,  0.1115],
         [ 0.1938, -0.0033, -0.0786,  0.1115]],

        [[-0.0269, -0.1714, -0.0775, -0.0791],
         [-0.0269, -0.1714, -0.0775, -0.0791],
         [-0.0269, -0.1714, -0.0775, -0.0791],
         [-0.0269, -0.1714, -0.0775, -0.0791]],

        [[-0.0450, -0.1575,  0.0550, -0.1546],
         [-0.0450, -0.1575,  0.0550, -0.1546],
         [-0.0450, -0.1575,  0.0550, -0.1546],
         [-0.0450, -0.1575,  0.0550, -0.1546]],

        [[-0.0745, -0.2008, -0.1868,  0.2168],
         [-0.0745, -0.2008, -0.1868,  0.2168],
         [-0.0745, -0.2008, -0.1868,  0.2168],
         [-0.0745, -0.2008, -0.1868,  0.2168]],

        [[-0.0271, -0.0669, -0.0533, -0.1674],
         [-0.0271, -0.0669, -0.0533, -0.1674],
         [-0.0271, -0.0669, -0.0533, -0.1674],
         [-0.0271, -0.0669, -0.0533, -0.1674]],

        [[ 0.0703,  0.0196,  0.0457,  0.0400],
         [ 0.0703,  0.0196,  0.0457,  0.0400],
         [ 0.0703,  0.0196,  0.0457,  0.0400],
         [ 0.0703,  0.0196,  0.0457,  0.0400]]])

2. 计算 scores: LSTM所有时间步的输出 feats 先加上 转移分数

# transitions 复制 seqlen * batch 份,得到 shape= [6, 4, 4],然后和 feats 逐点相加
tr_ = transitions.view(
    1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
print ('tr_',tr_.shape,'\n',tr_)

scores = feats + tr_
print ('scores',scores.shape,'\n',scores)
tr_:  torch.Size([6, 4, 4]) 
 tensor([[[    7.,     3., -1000.,     2.],
         [    2.,     1., -1000.,     5.],
         [    1.,     3., -1000.,     2.],
         [-1000., -1000., -1000., -1000.]],
        [[    7.,     3., -1000.,     2.],
         [    2.,     1., -1000.,     5.],
         [    1.,     3., -1000.,     2.],
         [-1000., -1000., -1000., -1000.]]])

scores: torch.Size([6, 4, 4]) 
 tensor([[[ 7.1938,  2.9967, -1000.1,  2.1115],
         [ 2.1938,   0.9967, -1000.1,  5.1115],
         [ 1.1938,   2.9967, -1000.1,  2.1115],
         [-999.81,    -1000, -1000.1, -999.89]],

        [[ 6.9731e+00,  2.8286e+00, -1.0001e+03,  1.9209e+00],
         [ 1.9731e+00,  8.2860e-01, -1.0001e+03,  4.9209e+00],
         [ 9.7310e-01,  2.8286e+00, -1.0001e+03,  1.9209e+00],
         [-1.0000e+03, -1.0002e+03, -1.0001e+03, -1.0001e+03]],

        [[ 6.9550e+00,  2.8425e+00, -9.9995e+02,  1.8454e+00],
         [ 1.9550e+00,  8.4250e-01, -9.9995e+02,  4.8454e+00],
         [ 9.5500e-01,  2.8425e+00, -9.9995e+02,  1.8454e+00],
         [-1.0000e+03, -1.0002e+03, -9.9995e+02, -1.0002e+03]],

        [[ 6.9255e+00,  2.7992e+00, -1.0002e+03,  2.2168e+00],
         [ 1.9255e+00,  7.9920e-01, -1.0002e+03,  5.2168e+00],
         [ 9.2550e-01,  2.7992e+00, -1.0002e+03,  2.2168e+00],
         [-1.0001e+03, -1.0002e+03, -1.0002e+03, -9.9978e+02]],

        [[ 6.9729e+00,  2.9331e+00, -1.0001e+03,  1.8326e+00],
         [ 1.9729e+00,  9.3310e-01, -1.0001e+03,  4.8326e+00],
         [ 9.7290e-01,  2.9331e+00, -1.0001e+03,  1.8326e+00],
         [-1.0000e+03, -1.0001e+03, -1.0001e+03, -1.0002e+03]],

        [[ 7.0703e+00,  3.0196e+00, -9.9995e+02,  2.0400e+00],
         [ 2.0703e+00,  1.0196e+00, -9.9995e+02,  5.0400e+00],
         [ 1.0703e+00,  3.0196e+00, -9.9995e+02,  2.0400e+00],
         [-9.9993e+02, -9.9998e+02, -9.9995e+02, -9.9996e+02]]])

scores = scores.view(seq_len, batch_size, tag_size, tag_size)
print ('scores reshape:',scores.shape,'\n',scores)

scores reshape: torch.Size([3, 2, 4, 4]) 
 tensor([[[[ 7.1938e+00,  2.9967e+00, -1.0001e+03,  2.1115e+00],
          [ 2.1938e+00,  9.9670e-01, -1.0001e+03,  5.1115e+00],
          [ 1.1938e+00,  2.9967e+00, -1.0001e+03,  2.1115e+00],
          [-9.9981e+02, -1.0000e+03, -1.0001e+03, -9.9989e+02]],

         [[ 6.9731e+00,  2.8286e+00, -1.0001e+03,  1.9209e+00],
          [ 1.9731e+00,  8.2860e-01, -1.0001e+03,  4.9209e+00],
          [ 9.7310e-01,  2.8286e+00, -1.0001e+03,  1.9209e+00],
          [-1.0000e+03, -1.0002e+03, -1.0001e+03, -1.0001e+03]]],

        [[[ 6.9550e+00,  2.8425e+00, -9.9995e+02,  1.8454e+00],
          [ 1.9550e+00,  8.4250e-01, -9.9995e+02,  4.8454e+00],
          [ 9.5500e-01,  2.8425e+00, -9.9995e+02,  1.8454e+00],
          [-1.0000e+03, -1.0002e+03, -9.9995e+02, -1.0002e+03]],

         [[ 6.9255e+00,  2.7992e+00, -1.0002e+03,  2.2168e+00],
          [ 1.9255e+00,  7.9920e-01, -1.0002e+03,  5.2168e+00],
          [ 9.2550e-01,  2.7992e+00, -1.0002e+03,  2.2168e+00],
          [-1.0001e+03, -1.0002e+03, -1.0002e+03, -9.9978e+02]]],

        [[[ 6.9729e+00,  2.9331e+00, -1.0001e+03,  1.8326e+00],
          [ 1.9729e+00,  9.3310e-01, -1.0001e+03,  4.8326e+00],
          [ 9.7290e-01,  2.9331e+00, -1.0001e+03,  1.8326e+00],
          [-1.0000e+03, -1.0001e+03, -1.0001e+03, -1.0002e+03]],

         [[ 7.0703e+00,  3.0196e+00, -9.9995e+02,  2.0400e+00],
          [ 2.0703e+00,  1.0196e+00, -9.9995e+02,  5.0400e+00],
          [ 1.0703e+00,  3.0196e+00, -9.9995e+02,  2.0400e+00],
          [-9.9993e+02, -9.9998e+02, -9.9995e+02, -9.9996e+02]]]])

3. 设置时间步迭代器 seq_iter

4. inivalues:

是batch=2个样本,经过LSTM后第一个时间步的输出 (2, 4),复制至(2, 4, 4)之后和转移矩阵相加的结果

seq_iter = enumerate(scores) 
# seq_iter: t=0 开始的LSTM所有时间步迭代输出
# 每一次迭代输出 (2, 4, 4) 的score

    _, inivalues = seq_iter.__next__()
    _, inivalues =
#inivalues = scores[0],即seq_iter 第一次迭代给出的值
inivalues torch.Size([2, 4, 4]) 
 tensor([[[ 7.1938e+00,  2.9967e+00, -1.0001e+03,  2.1115e+00],
         [ 2.1938e+00,  9.9670e-01, -1.0001e+03,  5.1115e+00],
         [ 1.1938e+00,  2.9967e+00, -1.0001e+03,  2.1115e+00],
         [-9.9981e+02, -1.0000e+03, -1.0001e+03, -9.9989e+02]],

        [[ 6.9731e+00,  2.8286e+00, -1.0001e+03,  1.9209e+00],
         [ 1.9731e+00,  8.2860e-01, -1.0001e+03,  4.9209e+00],
         [ 9.7310e-01,  2.8286e+00, -1.0001e+03,  1.9209e+00],
         [-1.0000e+03, -1.0002e+03, -1.0001e+03, -1.0001e+03]]])

5. 计算 a 在 t=0 时刻的初始值

#t=0 时 partition(分数)为START_TAG_IDX行对应的值 ??
partition = inivalues[:, START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
print ('partition',partition.shape,'\n',partition)
partition torch.Size([2, 4, 1]) 
 tensor([[[ 1.1938e+00],
         [ 2.9967e+00],
         [ 2.1115e+00]],

        [[ 9.7310e-01],
         [ 2.8286e+00],
         [ 1.9209e+00]]])

3. 迭代计算 a (即partition ) 在 t=1,2,。。。更新 partition 的 值

for idx, cur_values in seq_iter: # fro idx = 1,2,3..., cur_values是LSTM输出+转移分数的值
    print('\n\n',idx,cur_values.shape,'\n', cur_values)
    pa_ = partition.contiguous().view(
        batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
    print('\npa_', pa_.shape,'\n',pa_)
    cur_values = cur_values + pa_
    print('相加后的 cur_values:',cur_values.shape,'\n', cur_values)

    cur_partition = log_sum_exp(cur_values, tag_size)
    print ('cur_partition 是 log sum exp 之后的结果:\n',cur_partition)

    mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)

    masked_cur_partition = cur_partition.masked_select(mask_idx.byte())
    print ('masked_cur_partition:',masked_cur_partition)

    if masked_cur_partition.dim() != 0:
        # 将mask_idx中值为1元素对应的masked_cur_partition中位置的元素复制到本partition中。
        # mask应该有和partition相同数目的元素。
        # 即 mask 部分的 partition值不再更新
        mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
        print ('mask_idx:',mask_idx)
        partition.masked_scatter_(mask_idx.byte(), masked_cur_partition)
        print ('partition',partition)
idx = 1 torch.Size([2, 4, 4]) 
 tensor([[[ 6.9550e+00,  2.8425e+00, -9.9995e+02,  1.8454e+00],
         [ 1.9550e+00,  8.4250e-01, -9.9995e+02,  4.8454e+00],
         [ 9.5500e-01,  2.8425e+00, -9.9995e+02,  1.8454e+00],
         [-1.0000e+03, -1.0002e+03, -9.9995e+02, -1.0002e+03]],

        [[ 6.9255e+00,  2.7992e+00, -1.0002e+03,  2.2168e+00],
         [ 1.9255e+00,  7.9920e-01, -1.0002e+03,  5.2168e+00],
         [ 9.2550e-01,  2.7992e+00, -1.0002e+03,  2.2168e+00],
         [-1.0001e+03, -1.0002e+03, -1.0002e+03, -9.9978e+02]]])

pa_ torch.Size([2, 4, 4]) 
 tensor([[[ 1.1938e+00,  1.1938e+00,  1.1938e+00,  1.1938e+00],
         [ 2.9967e+00,  2.9967e+00,  2.9967e+00,  2.9967e+00],
         [-1.0001e+03, -1.0001e+03, -1.0001e+03, -1.0001e+03],
         [ 2.1115e+00,  2.1115e+00,  2.1115e+00,  2.1115e+00]],

        [[ 9.7310e-01,  9.7310e-01,  9.7310e-01,  9.7310e-01],
         [ 2.8286e+00,  2.8286e+00,  2.8286e+00,  2.8286e+00],
         [-1.0001e+03, -1.0001e+03, -1.0001e+03, -1.0001e+03],
         [ 1.9209e+00,  1.9209e+00,  1.9209e+00,  1.9209e+00]]])
相加后的 cur_values: torch.Size([2, 4, 4]) 
 tensor([[[    8.1488,     4.0363,  -998.7512,     3.0392],
         [    4.9517,     3.8392,  -996.9483,     7.8421],
         [ -999.1236,  -997.2361, -2000.0237,  -998.2332],
         [ -997.9335,  -998.0460,  -997.8335,  -998.0431]],

        [[    7.8986,     3.7723,  -999.2137,     3.1899],
         [    4.7541,     3.6278,  -997.3582,     8.0454],
         [ -999.1520,  -997.2783, -2000.2644,  -997.8607],
         [ -998.1536,  -998.2799,  -998.2659,  -997.8623]]])

cur_partition 是 log sum exp 之后的结果:
 tensor([[   8.1889,    4.6357, -996.4925,    7.8503],
        [   7.9408,    4.3958, -996.9136,    8.0532]])

mask_idx tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]], dtype=torch.uint8)

tensor([   8.1889,    4.6357, -996.4925,    7.8503,    7.9408,    4.3958,
        -996.9136,    8.0532])

mask_idx: tensor([[[1],

         [1]]], dtype=torch.uint8)

partition tensor([[[   8.1889],
         [   4.6357],
         [   7.8503]],

        [[   7.9408],
         [   4.3958],
         [   8.0532]]])


 idx = 2, torch.Size([2, 4, 4]) 
 tensor([[[ 6.9729e+00,  2.9331e+00, -1.0001e+03,  1.8326e+00],
         [ 1.9729e+00,  9.3310e-01, -1.0001e+03,  4.8326e+00],
         [ 9.7290e-01,  2.9331e+00, -1.0001e+03,  1.8326e+00],
         [-1.0000e+03, -1.0001e+03, -1.0001e+03, -1.0002e+03]],

        [[ 7.0703e+00,  3.0196e+00, -9.9995e+02,  2.0400e+00],
         [ 2.0703e+00,  1.0196e+00, -9.9995e+02,  5.0400e+00],
         [ 1.0703e+00,  3.0196e+00, -9.9995e+02,  2.0400e+00],
         [-9.9993e+02, -9.9998e+02, -9.9995e+02, -9.9996e+02]]])

pa_ torch.Size([2, 4, 4]) 
 tensor([[[   8.1889,    8.1889,    8.1889,    8.1889],
         [   4.6357,    4.6357,    4.6357,    4.6357],
         [-996.4925, -996.4925, -996.4925, -996.4925],
         [   7.8503,    7.8503,    7.8503,    7.8503]],

        [[   7.9408,    7.9408,    7.9408,    7.9408],
         [   4.3958,    4.3958,    4.3958,    4.3958],
         [-996.9136, -996.9136, -996.9136, -996.9136],
         [   8.0532,    8.0532,    8.0532,    8.0532]]])
相加后的 cur_values: torch.Size([2, 4, 4]) 
 tensor([[[   15.1618,    11.1220,  -991.8644,    10.0215],
         [    6.6086,     5.5688,  -995.4175,     9.4683],
         [ -995.5196,  -993.5594, -1996.5458,  -994.6599],
         [ -992.1768,  -992.2166,  -992.2030,  -992.3171]],

        [[   15.0111,    10.9604,  -992.0135,     9.9808],
         [    6.4661,     5.4154,  -995.5585,     9.4358],
         [ -995.8433,  -993.8940, -1996.8679,  -994.8737],
         [ -991.8765,  -991.9272,  -991.9011,  -991.9069]]])

cur_partition 是 log sum exp 之后的结果:
 tensor([[  15.1620,   11.1258, -991.3098,   10.4758],
        [  15.0113,   10.9643, -991.2490,   10.4381]])

mask_idx tensor([[0, 0, 0, 0],
        [1, 1, 1, 1]], dtype=torch.uint8)
masked_cur_partition: tensor([  15.0113,   10.9643, -991.2490,   10.4381])
mask_idx: tensor([[[0],

         [1]]], dtype=torch.uint8)
partition tensor([[[   8.1889],
         [   4.6357],
         [   7.8503]],

        [[  15.0113],
         [  10.9643],
         [  10.4381]]])

4. 计算 final_partition

v1 = transitions.view(1, tag_size, tag_size).expand(
        batch_size, tag_size, tag_size)
print ('\nv1:',v1)
v2 = partition.contiguous().view(
        batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
print ('\nv2:',v2)
cur_values = v1 + v2

cur_partition = log_sum_exp(cur_values, tag_size)
print ('\ncur_partition:',cur_partition)

final_partition = cur_partition[:, END_TAG_IDX]
print ('\nfinal_partition:',final_partition)

return1 = final_partition.sum()
return2 = scores
print (return1,'\n',return2)

v1: tensor([[[    7.,     3., -1000.,     2.],
         [    2.,     1., -1000.,     5.],
         [    1.,     3., -1000.,     2.],
         [-1000., -1000., -1000., -1000.]],

        [[    7.,     3., -1000.,     2.],
         [    2.,     1., -1000.,     5.],
         [    1.,     3., -1000.,     2.],
         [-1000., -1000., -1000., -1000.]]])

v2: tensor([[[   8.1889,    8.1889,    8.1889,    8.1889],
         [   4.6357,    4.6357,    4.6357,    4.6357],
         [-996.4925, -996.4925, -996.4925, -996.4925],
         [   7.8503,    7.8503,    7.8503,    7.8503]],

        [[  15.0113,   15.0113,   15.0113,   15.0113],
         [  10.9643,   10.9643,   10.9643,   10.9643],
         [-991.2490, -991.2490, -991.2490, -991.2490],
         [  10.4381,   10.4381,   10.4381,   10.4381]]])

cur_values: tensor([[[   15.1889,    11.1889,  -991.8112,    10.1889],
         [    6.6357,     5.6357,  -995.3643,     9.6357],
         [ -995.4925,  -993.4925, -1996.4924,  -994.4925],
         [ -992.1497,  -992.1497,  -992.1497,  -992.1497]],

        [[   22.0113,    18.0113,  -984.9887,    17.0113],
         [   12.9643,    11.9643,  -989.0357,    15.9643],
         [ -990.2490,  -988.2490, -1991.2490,  -989.2490],
         [ -989.5619,  -989.5619,  -989.5619,  -989.5619]]])

cur_partition: tensor([[  15.1891,   11.1927, -991.2565,   10.6432],
        [  22.0114,   18.0136, -984.9613,   17.3121]])

final_partition: tensor([10.6432, 17.3121])

return1: tensor(27.9553) 

return2: torch.Size([3, 2, 4, 4])
 tensor([[[[ 7.1938e+00,  2.9967e+00, -1.0001e+03,  2.1115e+00],
          [ 2.1938e+00,  9.9670e-01, -1.0001e+03,  5.1115e+00],
          [ 1.1938e+00,  2.9967e+00, -1.0001e+03,  2.1115e+00],
          [-9.9981e+02, -1.0000e+03, -1.0001e+03, -9.9989e+02]],

         [[ 6.9731e+00,  2.8286e+00, -1.0001e+03,  1.9209e+00],
          [ 1.9731e+00,  8.2860e-01, -1.0001e+03,  4.9209e+00],
          [ 9.7310e-01,  2.8286e+00, -1.0001e+03,  1.9209e+00],
          [-1.0000e+03, -1.0002e+03, -1.0001e+03, -1.0001e+03]]],

        [[[ 6.9550e+00,  2.8425e+00, -9.9995e+02,  1.8454e+00],
          [ 1.9550e+00,  8.4250e-01, -9.9995e+02,  4.8454e+00],
          [ 9.5500e-01,  2.8425e+00, -9.9995e+02,  1.8454e+00],
          [-1.0000e+03, -1.0002e+03, -9.9995e+02, -1.0002e+03]],

         [[ 6.9255e+00,  2.7992e+00, -1.0002e+03,  2.2168e+00],
          [ 1.9255e+00,  7.9920e-01, -1.0002e+03,  5.2168e+00],
          [ 9.2550e-01,  2.7992e+00, -1.0002e+03,  2.2168e+00],
          [-1.0001e+03, -1.0002e+03, -1.0002e+03, -9.9978e+02]]],

        [[[ 6.9729e+00,  2.9331e+00, -1.0001e+03,  1.8326e+00],
          [ 1.9729e+00,  9.3310e-01, -1.0001e+03,  4.8326e+00],
          [ 9.7290e-01,  2.9331e+00, -1.0001e+03,  1.8326e+00],
          [-1.0000e+03, -1.0001e+03, -1.0001e+03, -1.0002e+03]],

         [[ 7.0703e+00,  3.0196e+00, -9.9995e+02,  2.0400e+00],
          [ 2.0703e+00,  1.0196e+00, -9.9995e+02,  5.0400e+00],
          [ 1.0703e+00,  3.0196e+00, -9.9995e+02,  2.0400e+00],
          [-9.9993e+02, -9.9998e+02, -9.9995e+02, -9.9996e+02]]]])
