torch.nn.functional.pad()

torch.nn.functional.pad()

用于对张量进行填充。

输入

  • input(Tensor):需要padding的tensor
  • pad(tuple): m维元组,指定填充的维度和大小,输入的 ⌊ l e n ( p a d ) 2 ⌋ \lfloor \frac{len(pad)}{2} \rfloor 2len(pad) 维会被填充,因此 m 2 \frac{m}{2} 2m 应小于输入维度且m应为偶数

    示例:
    高宽H,W大多在最后两维,因此:

    • 填充最后一维宽W,
      pad为(padding_left, padding_right)

    • 填充最后两维高H,宽W,
      pad为(padding_left, padding_right,padding_top, padding_bottom)

    • 填充最后三维通道C,高H,宽W,
      pad为 (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)

  • mode:填充模式,默认为’constant’
    • ‘constant’:使用给定的值填充边界。
    • ‘reflect’:使用输入张量的镜像填充边界。
    • ‘replicate’:使用输入张量的最后一个元素(边界点的值)填充边界。

    以上模式的输出维度计算与’constant’类似,只是填充值有区别

  • value:'constant’模式下指定的填充值,默认为0

padding_mode

具体的关于填充模式的介绍参见

ConstantPad2d — PyTorch 2.1 documentation

ReflectionPad2d — PyTorch 2.1 documentation

ReplicationPad2d — PyTorch 2.1 documentation

constant 填充可用于任意维

Circular, replicate and reflection填充可用于以下情况:

对于4D或5D张量,填充最后3维

对于3D或4D张量,填充最后2维

对于2D或3D张量,填充最后1维

示例

import torch
import torch.nn.functional as F

test = torch.randn(1, 3, 4, 4) #(B, C, H, W) -> (1, 3, 4, 4)

p1d = (1, 1)                   #(pad_left, pad_right) -> (1, 1)
p2d = (1, 1, 2, 2)             #(pad_left, pad_right, pad_top, pad_bottom) -> (1, 1, 2, 2)  
p3d = (0, 1, 2, 2, 3, 3)       #(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back) -> (0, 1, 2, 2, 3, 3)

out = F.pad(test, p1d, "constant", 0) 
print(out.size())              # torch.Size([1, 3, 4, 6]) W = W + pad_left +pad_right -> 6= 4+ 1+ 1

out = F.pad(test, p2d, "constant", 0)
print(out.size())              # torch.Size([1, 3, 8, 6]) W = W + pad_left +pad_right -> 6= 4+ 1+ 1
                               #                          H = H + pad_top +pad_bottom -> 8= 4+ 2+ 2 

out = F.pad(test, p3d, "constant", 0)
print(out.size())              # torch.Size([1, 9, 8, 5]) W = W + pad_left +pad_right -> 5= 4+ 1+ 0
                               #                          H = H + pad_top +pad_bottom -> 8= 4+ 2+ 2 
                               #                          C = C + pad_front +pad_back -> 9= 3+ 3+ 3

Related Links

torch.nn.functional.pad — PyTorch 2.1 documentation

你可能感兴趣的:(pytorch)