关于迁移模型pt文件中某一层数据的代码

new_state = {}
state_dict = torch.load(weights_path,map_location=device)
for name, layer in state_dict.items():
  print(name)
  print(layer.shape)
i=1
for name,layer in state_dict.items():
  if i == 1:
   print(layer.shape)
   g = layer[:,1,:,:]
   g = torch.unsqueeze(g,1)
   b = layer[:,0,:,:]
   b = torch.unsqueeze(b,1)
   r  1ayer[:,2,:,:]
   r = torch.unsqueeze(r,1)
   cat_input = [g,b,r,g]
   gbrg= torch.cat(cat_input,1)
   print(gbrg.shape)
   i+=1
for key,value in state_dict.items():
    if (key =='module.features.0.weight'):
       new_state[key] = gbrg
    else :
       new_state[key] = value
print(new_state)
#for key, value in state_dict.items():
 #   new_state[key.replace( 'module.  '')] = value
torch.save(new_state, 'model.pth')

你可能感兴趣的:(python,开发语言)