为了节省显存(内存),PyTorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数在使用后应及时删除(remove),以避免每次都运行钩子增加运行负载。
这里总结一下并给出实际用法和注意点。hook方法有4种:
1、Tensor.register_hook()
2、torch.nn.Module.register_forward_hook()
3、torch.nn.Module.register_backward_hook()
4、torch.nn.Module.register_forward_pre_hook()
1.Tensor.register_hook(hook):对于单个张量,可以使用register_hook()方法注册一个hook。该方法将一个函数(即hook)注册到张量上,在张量被计算时调用该函数。这个函数可以用来获取张量的梯度或值,或者对张量进行其他操作。例如,以下代码演示了如何使用register_hook()方法获取张量的梯度:
import torch
x = torch.randn(2, 2, requires_grad=True)
def print_grad(grad):
print(grad)
hook_handle = x.register_hook(print_grad)
y = x.sum()
y.backward()
hook_handle.remove()
在这个例子中,我们创建了一个包含梯度的张量x,并使用register_hook()方法注册了一个打印梯度的hook函数print_grad()。在计算y的梯度时,hook函数被调用并打印梯度。最后,我们使用hook_handle.remove()方法从张量中删除hook函数。
2.torch.nn.Module.register_forward_hook(hook):对于模型中的每个层,可以使用register_forward_hook()方法注册一个hook函数。这个函数将在模型的前向传递中被调用,并可以用来获取层的输出。例如,以下代码演示了如何使用register_forward_hook()方法获取模型的某一层的输出:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.avgpool(x)
return x
def hook(module, input, output):
print(output.shape)
model = MyModel()
handle = model.conv2.register_forward_hook(hook)
x = torch.randn(1, 3, 224, 224)
output = model(x)
handle.remove()
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
pass
在这个例子中,我们定义了一个包含两个卷积层和一个全局平均池化层的模型,并使用register_forward_hook()方法在第二个卷积层上注册了一个hook函数。当模型进行前向传递时,hook函数将被调用并打印输出张量的形状。最后,我们使用handle.remove()方法从模型中删除hook函数。
torch.nn.Module.register_backward_hook(hook):对于模型中的每个层,可以使用register_backward_hook()方法注册一个hook函数。这个函数将在模型的反向传递中被调用,并可以用来获取梯度或其他信息。例如,以下代码演示了如何使用register_backward_hook()方法获取某一层的梯度
:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.avgpool(x)
return x
def hook(module, grad_input, grad_output):
print(grad_input[0].shape, grad_output[0].shape)
model = MyModel()
handle = model.conv2.register_backward_hook(hook)
x = torch.randn(1, 3, 224,
output = model(x)
output.sum().backward()
handle.remove()
在这个例子中,我们定义了一个包含两个卷积层和一个全局平均池化层的模型,并使用register_backward_hook()方法在第二个卷积层上注册了一个hook函数。当模型进行反向传递时,hook函数将被调用并打印输入梯度和输出梯度张量的形状。最后,我们使用handle.remove()方法从模型中删除hook函数。
- torch.nn.Module.register_forward_pre_hook(hook):对于模型中的每个层,可以使用register_forward_pre_hook()方法注册一个hook函数。这个函数将在模型的前向传递之前被调用,并可以用来获取输入张量或其他信息。例如,以下代码演示了如何使用register_forward_pre_hook()方法获取模型的输入张量:
import torch.nn as nn
class MyModel(nn.Module):
def init(self):
super(MyModel, self).init()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.avgpool(x)
[return x](poe://www.poe.com/_api/key_phrase?phrase=return%20x&prompt=Tell%20me%20more%20about%20return%20x.)
def hook(module, input):
print(input[0].shape)
model = MyModel()
handle = model.conv1.register_forward_pre_hook(hook)
x = torch.randn(1, 3, 224, 224)
output = model(x)
handle.remove()
在这个例子中,我们定义了一个包含两个卷积层和一个全局平均池化层的模型,并使用register_forward_pre_hook()方法在第一个卷积层上注册了一个hook函数。当模型进行前向传递时,hook函数将被调用并打印输入张量的形状。最后,我们使用handle.remove()方法从模型中删除hook函数。
需要注意的是,hook函数应该尽可能快地执行,以避免对模型的计算时间造成过多的影响。此外,如果注册了太多的hook函数,会导致额外的内存占用和计算负担。因此,应该仔细考虑何时需要注册hook函数,并在使用后及时删除它们。