目录
torch.nn子模块Vision Layers详解
nn.PixelShuffle
用法与用途
使用技巧
注意事项
参数
示例代码
nn.PixelUnshuffle
用法与用途
使用技巧
注意事项
参数
示例代码
nn.Upsample
用法与用途
使用技巧
注意事项
参数
示例代码
nn.UpsamplingNearest2d
用法与用途
使用技巧
注意事项
参数
形状(同上)
示例代码
nn.UpsamplingBilinear2d
用法与用途
使用技巧
注意事项
参数
示例代码
总结
torch.nn.PixelShuffle
是 PyTorch 深度学习框架中的一个子模块,主要用于图像超分辨率(Super Resolution)任务。这个模块通过重新排列输入张量(Tensor)的元素,从而将图像的分辨率提高。
PixelShuffle
接收一个输入张量,并按照指定的上采样因子(upscale factor)重新排列张量中的元素,以提高图像的分辨率。upscale_factor
(int): 用于提高空间分辨率的因子。import torch
import torch.nn as nn
# 初始化 PixelShuffle 模块
pixel_shuffle = nn.PixelShuffle(3)
# 创建一个随机张量作为输入
# 输入张量的形状为 (批大小, 通道数, 高, 宽)
# 通道数必须是上采样因子的平方倍数,这里为 3^2 = 9
input = torch.randn(1, 9, 4, 4)
# 应用 PixelShuffle
output = pixel_shuffle(input)
# 输出张量的形状
print(output.size())
这段代码首先创建了一个 PixelShuffle
模块,上采样因子设置为 3。然后,创建一个形状为 (1, 9, 4, 4) 的输入张量,并将其传递给 PixelShuffle
模块。输出的张量形状会变为 (1, 1, 12, 12),即分辨率提高了。
torch.nn.PixelUnshuffle
是 PyTorch 深度学习框架中的一个子模块,它执行 PixelShuffle
的逆操作。PixelUnshuffle
通过重新排列输入张量的元素,从而降低图像的分辨率。这个模块在一些特定的图像处理任务中非常有用,特别是当需要降采样图像时。
PixelUnshuffle
接收一个输入张量,并按照指定的下采样因子(downscale factor)重新排列张量中的元素,以降低图像的分辨率。downscale_factor
(int): 用于降低空间分辨率的因子。import torch
import torch.nn as nn
# 初始化 PixelUnshuffle 模块
pixel_unshuffle = nn.PixelUnshuffle(3)
# 创建一个随机张量作为输入
# 输入张量的形状为 (批大小, 通道数, 高, 宽)
input = torch.randn(1, 1, 12, 12)
# 应用 PixelUnshuffle
output = pixel_unshuffle(input)
# 输出张量的形状
print(output.size())
这段代码首先创建了一个 PixelUnshuffle
模块,下采样因子设置为 3。然后,创建一个形状为 (1, 1, 12, 12) 的输入张量,并将其传递给 PixelUnshuffle
模块。输出的张量形状会变为 (1, 9, 4, 4),即通道数增加,而空间分辨率降低了。
torch.nn.Upsample
是 PyTorch 中的一个子模块,用于对多通道的 1D(时间序列)、2D(空间)或 3D(体积)数据进行上采样(增加分辨率)。
Upsample
可以增加数据的尺寸,例如将一个低分辨率的图像转换成高分辨率的图像。它可以处理 3D、4D 或 5D 的张量,分别对应于 1D、2D 和 3D 数据。Upsample
常用于图像超分辨率、放大图像或视频帧等任务。nearest
, linear
, bilinear
, bicubic
或 trilinear
。align_corners
参数控制角点像素的对齐方式。在使用 linear
, bilinear
, bicubic
和 trilinear
模式时,它会影响插值的结果。nearest
通常用于类别标签,而 bilinear
更适用于图像。size
或 scale_factor
指定输出的尺寸,但不能同时指定两者,因为这会引起歧义。size
(int or Tuple[int]): 输出的空间尺寸。scale_factor
(float or Tuple[float]): 空间尺寸的乘数。mode
(str): 上采样算法,包括 'nearest', 'linear', 'bilinear', 'bicubic', 'trilinear'。align_corners
(bool): 控制角点像素的对齐方式。recompute_scale_factor
(bool): 重新计算用于插值计算的比例因子。import torch
import torch.nn as nn
# 创建一个 2x2 的输入张量
input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
# 初始化 Upsample 模块,上采样因子为 2,使用最近邻插值
m = nn.Upsample(scale_factor=2, mode='nearest')
output_nearest = m(input)
# 初始化 Upsample 模块,上采样因子为 2,使用双线性插值
m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
output_bilinear = m(input)
# 输出结果
print("Nearest neighbor upsampling:\n", output_nearest)
print("\nBilinear upsampling:\n", output_bilinear)
这段代码展示了如何使用 Upsample
来对一个小张量进行上采样,分别使用最近邻和双线性插值。这可以在图像放大等场景中被应用。
torch.nn.UpsamplingNearest2d
是 PyTorch 中的一个子模块,专门用于对 2D 数据(如图像)应用最近邻上采样。这种类型的上采样通过复制邻近的像素值来增加图像的尺寸,从而提高图像的分辨率。
size
)或上采样因子(scale_factor
)来使用此模块。UpsamplingNearest2d
已在较新版本的 PyTorch 中弃用,建议改用 torch.nn.functional.interpolate()
方法。size
(int or Tuple[int, int], optional): 输出的空间尺寸。scale_factor
(float or Tuple[float, float], optional): 空间尺寸的乘数。import torch
import torch.nn as nn
# 创建一个 2x2 的输入张量
input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
# 初始化 UpsamplingNearest2d 模块,上采样因子为 2
m = nn.UpsamplingNearest2d(scale_factor=2)
output = m(input)
# 输出结果
print("Nearest neighbor upsampling:\n", output)
这段代码展示了如何使用 UpsamplingNearest2d
对一个小张量进行最近邻上采样。这种上采样方法简单但可能导致像素化的视觉效果。
torch.nn.UpsamplingBilinear2d
是 PyTorch 深度学习框架中的一个子模块,用于将输入信号(由多个输入通道组成)应用 2D 双线性上采样。这个模块在图像处理中非常有用,特别是在需要放大图像并保持图像内容平滑时。
size
(输出图像的尺寸)或 scale_factor
(空间尺寸的乘数)来使用 UpsamplingBilinear2d
。size
或 scale_factor
。size
直接指定输出图像的高度和宽度,而 scale_factor
指定相对于原始尺寸的放大比例。UpsamplingBilinear2d
类在最新版本的 PyTorch 中已被废弃,推荐使用 torch.nn.functional.interpolate(..., mode='bilinear', align_corners=True)
方法进行上采样。size
(int or Tuple[int, int], optional): 输出空间尺寸。scale_factor
(float or Tuple[float, float], optional): 空间尺寸的乘数。import torch
import torch.nn as nn
# 创建一个 2x2 的输入张量
input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
# 初始化 UpsamplingBilinear2d 模块,上采样因子为 2
m = nn.UpsamplingBilinear2d(scale_factor=2)
output = m(input)
# 输出结果
print("Bilinear upsampling:\n", output)
这段代码展示了如何使用 UpsamplingBilinear2d
对一个小张量进行双线性上采样。这种上采样方法能够在放大图像时保持更好的图像质量,避免像素化的视觉效果。
这篇博客深入探讨了 PyTorch 深度学习框架中的几个关键的图像上采样和下采样子模块,包括 nn.PixelShuffle
, nn.PixelUnshuffle
, nn.Upsample
, nn.UpsamplingNearest2d
, 和 nn.UpsamplingBilinear2d
。每个模块的用法、用途、关键技巧和注意事项都进行了详细的说明。PixelShuffle
和 PixelUnshuffle
分别用于图像的超分辨率提升和降采样处理,而 Upsample
提供了多种上采样方法,包括最近邻和双线性插值等。UpsamplingNearest2d
和 UpsamplingBilinear2d
则专注于 2D 图像的最近邻和双线性上采样。