Freeze BN in Pytorch

def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

use model.apply() to freeze bn

def train(model,data_loader,criterion,epoch):
    model.train() # switch to train mode
    model.apply(set_bn_eval) # this will freeze the bn in training process
    ###
    # training code
    ###

wrap up, commonly used

def main():
    # ...
    for epoch in epochs:
        train(model,train_loader,criterion,epoch)
        test(model,eval_loader,epoch)
    # ...

你可能感兴趣的:(Freeze BN in Pytorch)