pytorch tensorflow 卷积权重互导

做个记录

测试版本:
pytorch 1.15.1
tensorflow 1.15.3

通道格式
tensorflow NHWC
权重格式为
[filter_height, filter_width, in_channels, out_channels]

pytorch NCHW
权重格式为
[out_channels, in_channels, filter_height, filter_width]

tensorflow 权重转 pytorch

pt_weight = np.transpose(tf_weight, [3, 2, 0, 1])

pytorch 权重转 tensorflow

tf_weight = np.transpose(pt_weight, [2, 3, 1, 0])

额外注意,在偶数大小的卷积核上,tensorflow 和 pytorch 填充方式也要特殊处理
https://blog.csdn.net/weixin_44554475/article/details/106239967

你可能感兴趣的:(神经网络,深度学习的经验,tensorflow,深度学习,pytorch,卷积权重)