tensorflow与pytorch的移植转换函数对比表

     相信有了这份表格对比,tensorflow与pytorch的基本移植转换,应该是手到擒来。

名称 tensorflow pytorch
二维卷积 tf.nn.conv2d(input_x, w, strides=[1, 1, 1, 1], padding='SAME') torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
relu激活函数 tf.nn.relu(input_x) torch.nn.ReLU()
填充函数 tf.pad(input_x, [(a0,b0), (a1,b1), (a2,b2), (a3,b3)])

pad_num = (a3,b3, a2, b2, a1, b1, a0, b0)

torch.nn.functional.pad(input_x, pad_num, mode='constant')

元素个数 tf.size(input_x) torch.numel(input_x)
展平 tf.reshape(input_x, (tf.size(input_x), -1)) input_x.view(torch.numel(input_x), -1)
softmax tf.nn.softmax(input_x, axis=1) torch.nn.functional.softmax(input_x, dim=1)
调整类型 tf.cast(input_x, tf.int32) input_x.type(torch.LongTensor)
除去维度为1 tf.squeeze(input_x, squeeze_dims=1) torch.squeeze(input_x)
合并 tf.concat((input_x1, input_x2), axis=3) torch.cat((input_x1, input_x2), dim=3)
划分成相同维度的块 tf.split(input_x, axis=3, num_or_size_splits=2) torch.chunk(input_x, dim=3, chunks=2)
产生1的矩阵 tf.ones((a,b)) torch.ones(a,b)
重复 tf.tile() input_x.repeat()
交换维度 tf.transpose(input_x, [0,1,2,3]) input_x.permute((0,1,2,3))  注意这种用法:p01.permute((0, 2, 3, 1)).contiguous().view(int(np.prod(shape01)), -1)

你可能感兴趣的:(深度学习基本组件,深度学习基本组件)