PyTorch笔记:修改模型中的某些权重参数

要解决的问题

之前尝试复现MCNN,但由于这是16年的工作,现有的许多代码由于版本等各种各样的问题,所以我都跑不起来。在那些能跑起来的代码里又没有给权重,后来发现一个给了权重,但是确实.h5格式的,不能直接由Pytorch加载。而且里面参数名都有一个前缀DEM.,需要匹配前缀并且加载到模型中。

方法

其实很多模型权重文件就是一个字典,有key和value。所以要读取出key和value,然后找到现有模型中的对应的键。

def load_h5(fname,model):
    import h5py
    h5f = h5py.File(fname, mode='r')
    for k, v in model.state_dict().items():        
        param = torch.from_numpy(np.asarray(h5f['DME.'+k]))
        v.copy_(param)   

你可能感兴趣的:(图像处理与计算机视觉,pytorch,深度学习,人工智能)