pytorch如何freeze模型参数

pytorch如何freeze模型参数

在做迁移学习或者自监督学习时,一般先预训练一个模型,再将该模型参数作为目标任务模型的初始化参数,或者直接freeze预训练模型,不再更新其参数。

今天记录下如何pytorch freeze模型参数

我是参考知乎一个文章,总结的很完整,我直接拿过来用了,原文出处为

https: // www.zhihu.com / question / 311095447 / answer / 589307812
from collections.abc import Iterable


def set_freeze_by_names(model, layer_names, freeze=True):
    if not isinstance(layer_names, Iterable):
        layer_names = [layer_names]
    for name, child in model.named_children():
        if name not in layer_names:
            continue
        for param in child.parameters():
            #print(param.name)
            param.requires_grad = not freeze


def freeze_by_names(model, layer_names):
    set_freeze_by_names(model, layer_names, True)


def unfreeze_by_names(model, layer_names):
    set_freeze_by_names(model, layer_names, False)


def set_freeze_by_idxs(model, idxs, freeze=True):
    if not isinstance(idxs, Iterable):
        idxs = [idxs]
    num_child = len(list(model.children()))
    idxs = tuple(map(lambda idx: num_child + idx if idx < 0 else idx, idxs))
    for idx, child in enumerate(model.children()):
        if idx not in idxs:
            continue
        for param in child.parameters():
            param.requires_grad = not freeze


def freeze_by_idxs(model, idxs):
    set_freeze_by_idxs(model, idxs, True)


def unfreeze_by_idxs(model, idxs):
    set_freeze_by_idxs(model, idxs, False)

 

你可能感兴趣的:(Torch)