1.class BaseModel(nn.Module)类中的_apply方法:
def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model[-1] # Detect()
if isinstance(m, (Detect, Segment)):
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self
此处是yolo v8中的代码:个人认为它的含义是:将 to()、cpu()、cuda()、half() 应用到不是参数或已注册缓冲区的模型张量上。
_apply(fn)
_apply(fn)方法会递归的应用于模块的每一个子模块及其自身。_apply() 是专门针对 parameter 和 buffer 而实现的一个“仅供内部使用”的接口,apply 函数是“公有”接口 (Python 对类的“公有”和“私有”区别并不是很严格,一般通过单前导下划线来区分)。apply 实际上可以通过修改 fn 来实现 _apply 能实现的功能
用法一:将模块转移到 CPU/ GPU上时,会调用_apply()方法,比如在执行net.cuda()时,会调用.
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
r"""Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on GPU while being optimized.
.. note::
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
"""
return self._apply(lambda t: t.cuda(device))
在_apply(fn)内部会执行3步:先对self.children() 进行递归的调用;使用fn对 self._parameters 中的参数及其 gradient 进行处理;使用fn对 self._buffers 中的 buffer 进行处理。
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return not torch.__future__.get_overwrite_module_params_on_conversion()
else:
return False
for key, param in self._parameters.items():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
param.data = param_applied
else:
assert isinstance(param, Parameter)
assert param.is_leaf
self._parameters[key] = Parameter(param_applied, param.requires_grad)
if param.grad is not None:
with torch.no_grad():
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
if should_use_set_data:
param.grad.data = grad_applied
else:
assert param.grad.is_leaf
self._parameters[key].gradgrad_applied.requires_grad_
(param.grad.requires_grad)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
这个地方我查了好多资料,都说的是apply()方法,自己还是有一些疑惑,知道的大佬欢迎解答。