Pytorch中apply函数作用

pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_weights的操作。如下apply递归调用_init_vit_weights,初始化ViT模型的子模块。

		from torch import nn
        #Weight init,初始化pos_embed
        # trunc_normal_利用正态分布生成一个点,点在[a, b]区间之内
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        # Weight init,初始化cls_token
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        # 调用vit初始函数
        self.apply(_init_vit_weights)


		def _init_vit_weights(m):
		    """
		    ViT weight initialization
		    :param m: module
		    """
		    if isinstance(m, nn.Linear):
		        nn.init.trunc_normal_(m.weight, std=.01)
		        if m.bias is not None:
		            nn.init.zeros_(m.bias)
		    elif isinstance(m, nn.Conv2d):
		        nn.init.kaiming_normal_(m.weight, mode="fan_out")
		        if m.bias is not None:
		            nn.init.zeros_(m.bias)
		    elif isinstance(m, nn.LayerNorm):
		        nn.init.zeros_(m.bias)
		        nn.init.ones_(m.weight)

你可能感兴趣的:(深度学习,pytorch,python,深度学习)