new_state = {}
state_dict = torch.load(weights_path,map_location=device)
for name, layer in state_dict.items():
print(name)
print(layer.shape)
i=1
for name,layer in state_dict.items():
if i == 1:
print(layer.shape)
g = layer[:,1,:,:]
g = torch.unsqueeze(g,1)
b = layer[:,0,:,:]
b = torch.unsqueeze(b,1)
r 1ayer[:,2,:,:]
r = torch.unsqueeze(r,1)
cat_input = [g,b,r,g]
gbrg= torch.cat(cat_input,1)
print(gbrg.shape)
i+=1
for key,value in state_dict.items():
if (key =='module.features.0.weight'):
new_state[key] = gbrg
else :
new_state[key] = value
print(new_state)
#for key, value in state_dict.items():
# new_state[key.replace( 'module. '')] = value
torch.save(new_state, 'model.pth')