pytorch 0.4版本加载0.4.1 1.0更高版本的model

def load_network(self, load_path, network, strict=True):
    if isinstance(network, nn.DataParallel):
        network = network.module

    model_dict = torch.load(load_path)
    filtered = {k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k}
    network.load_state_dict(filtered, strict=strict)

    
    # network.load_state_dict(torch.load(load_path), strict=strict)

你可能感兴趣的:(CNN学习笔记)