Pytorch 中的 torch.nn.Identity( ) 的作用

1 torch.nn.Identity( ) 的作用

torch.nn.Identity( ) 相当于一个恒等函数

f(x) = x

  • 这个函数相当于输入什么就输出什么, 可以用在对已经设计好模型结构的修改, 比如模型的最后一层是 1000 分类, 我们可以将最后一层用 nn.Identity( ) 替换掉, 得到它之前学习的特征, 然后再自己设计最后一层的结构

  • 在迁移学习中经常使用

2 示例

import torch
from torch import nn
from torch.nn import NLLLoss
import timm


class MiniModel(nn.Module):
    
    def __init__(self, backbone, num_class, pretrained=False, backbone_ckpt=None):
        
        super().__init__()
        self.backbone = timm.crear_model(backbone, pretrained=pretrained, checkpoint_path=backbone_ckpt)
        self.head = nn.Linear(self.backbone.get_classifier().in_features, num_class)
        # 替代最后一层的全连接网络
        self.backbone.head.fc = nn.Identity()
        self.loss_fn = NLLLoss()
        
    def forward(self, image, label):
        embed = self.backbone(image)
        logit = self.head(embed)
        
        if label is not None:
            logit_logsoftmax = torch.log_softmax(logit, 1)
            loss = self.loss_fn(logit_logsoftmax, label)
            return {"prediction": logit, "loss": loss}
        return {"prediction": logit}

你可能感兴趣的:(Pytorch,中的各种函数,Pytorch)