register_forward_hook多输入问题

class UpBlock_attention(nn.Module):


	def forward(self, fuse, out, fuse_a, out_a): 

假如模块有四个输出,我需要勾取UpBlock_attention模块的输入,其中在模型中定义了self.trans3=UpBlock_attention,如果直接注册:

input_list = []
output_list = []

def forward_hook(model, input_data, output_data):
    input_list.append(input_data)
    output_list.append(output_data)
model.trans3.register_forward_hook(forward_hook)

在取其中的张量时候会报错:ValueError: only one element tensors can be converted to Python scalars。

for i in range(len(input_list)):
        input_list_tensor = torch.tensor(input_list[i])
        tensor_threewei = input_list_tensor.squeeze(0)

正确做法:在注册钩子函数时就直接定义只有一个输入:

input_list.append(input_data[1])

这样在遍历时不会报错。

你可能感兴趣的:(pytorch踩坑,pytorch函数,pytorch)