class RGBNet(nn.Module):
def __init__(self, num_spectral=31, num_res=6, num_fm=64):
super(RGBNet, self).__init__()
self.CA = nn.Sequential(
nn.Conv2d(num_spectral, 1, kernel_size=1, stride=1),
nn.LeakyReLU(),
nn.Conv2d(1, num_spectral, kernel_size=1, stride=1),
nn.Sigmoid()
)
self.SA = nn.Sequential(
nn.Conv2d(1, 1, kernel_size=6, stride=1, padding='same'), # 输入通道由3改为1
nn.Sigmoid(),
)
self.rgb = nn.Sequential(
nn.Conv2d(3, 3, kernel_size=6, stride=4, padding=2),
nn.Conv2d(3, 3, kernel_size=6, stride=2, padding=2)
)
self.rs1 = nn.Sequential(
nn.Conv2d(num_spectral + 3, num_spectral * 2 * 2 * 2 * 2 * 2 * 2, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU()
)
# self.ps= PS(2)
self.ps = nn.PixelShuffle(8)
self.rs2 = nn.Sequential(
nn.Conv2d(num_spectral + 3, num_fm, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU()
)
self.res_blocks = nn.ModuleList()
for _ in range(num_res):
self.res_blocks.append(nn.Sequential(
nn.Conv2d(num_fm, num_fm, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(num_fm, num_fm, kernel_size=3, stride=1, padding=1),
))
self.final_conv = nn.Conv2d(num_fm, num_spectral, kernel_size=3, stride=1, padding=1)
def forward(self, rgb_hp, ms_hp): # ms_hp, rgb_hp
rgb_hp_up = torch.nn.functional.interpolate(ms_hp, scale_factor=8, mode='bicubic')
gap_ms_c = ms_hp.mean(dim=(2, 3), keepdim=True) # 每个通道平均
CA = self.CA(gap_ms_c)
gap_RGB_s = rgb_hp.mean(dim=1, keepdim=True)
SA = self.SA(gap_RGB_s)
rgb = self.rgb(rgb_hp)
# rs = self.rs(ms_hp)
temp1 = ms_hp[:, :15]
temp2 = ms_hp[:, 15:]
rgb_temp1 = rgb[:, 0].unsqueeze(1)
rgb_temp2 = rgb[:, 1].unsqueeze(1)
rgb_temp3 = rgb[:, 2].unsqueeze(1)
# rs = torch.cat((rgb_temp1, temp1, rgb_temp2, temp2,rgb_temp3,temp3,rgb_temp4), dim=1)
rs = torch.cat((rgb_temp1, temp1, rgb_temp2, temp2, rgb_temp3), dim=1)
rs = self.rs1(rs)
rs = self.ps(rs)
temp1 = rs[:, :15]
temp2 = rs[:, 15:]
rgb_temp1 = rgb_hp[:, 0].unsqueeze(1)
rgb_temp2 = rgb_hp[:, 1].unsqueeze(1)
rgb_temp3 = rgb_hp[:, 2].unsqueeze(1)
rs = torch.cat((rgb_temp1, temp1, rgb_temp2, temp2, rgb_temp3), dim=1)
rs = self.rs2(rs)
for res_block in self.res_blocks:
rs1 = res_block(rs)
rs = rs + rs1
rs = SA * rs
rs = self.final_conv(rs)
rs = CA * rs
out = rs + rgb_hp_up
return out
self.CA = nn.Sequential( nn.Conv2d(num_spectral, 1, kernel_size=1, stride=1), nn.LeakyReLU(), nn.Conv2d(1, num_spectral, kernel_size=1, stride=1), nn.Sigmoid() )这部分代码定义了 RGBNet 模块中的通道注意力(Channel Attention)部分。以下是对其结构和功能的解释:
self.CA
: 通道注意力模块,采用nn.Sequential
封装了一系列操作。
- 第一层 (
nn.Conv2d(num_spectral, 1, kernel_size=1, stride=1)
): 通过 1x1 的卷积操作将输入通道数num_spectral
减少到 1,这一步可能有助于提取通道之间的关系。- 激活函数 (
nn.LeakyReLU()
): 使用 LeakyReLU 激活函数进行非线性变换。- 第二层 (
nn.Conv2d(1, num_spectral, kernel_size=1, stride=1)
): 再通过 1x1 的卷积操作将通道数恢复到num_spectral
,这一步可能有助于强调重要的通道信息。- Sigmoid 激活 (
nn.Sigmoid()
): 最后使用 Sigmoid 激活函数,将输出的数值范围缩放到 (0, 1) 之间,得到通道注意力权重。通道注意力的思想是根据每个通道的重要性动态调整其权重,以增强对重要通道的关注,减弱对不重要通道的关注。这有助于网络更加关注对当前任务有意义的信息。
self.SA = nn.Sequential( nn.Conv2d(1, 1, kernel_size=6, stride=1, padding='same'), # 输入通道由3改为1 nn.Sigmoid(), )这部分代码定义了 RGBNet 模块中的空间注意力(Spatial Attention)部分。以下是对其结构和功能的解释:
self.SA
: 空间注意力模块,同样采用nn.Sequential
封装了一系列操作。
- 卷积层 (
nn.Conv2d(1, 1, kernel_size=6, stride=1, padding='same')
): 通过 kernel_size 为 6 的卷积操作进行空间特征的提取,stride 为 1 表示不跳步,padding 为 'same' 可能表示使用零填充以保持输出大小不变。- Sigmoid 激活 (
nn.Sigmoid()
): 最后使用 Sigmoid 激活函数,将输出的数值范围缩放到 (0, 1) 之间,得到空间注意力权重。空间注意力的目的是根据每个像素点的重要性动态调整其权重,以增强对重要空间位置的关注,减弱对不重要空间位置的关注。这有助于网络更加关注对当前任务有意义的空间信息。
self.rgb = nn.Sequential( nn.Conv2d(3, 3, kernel_size=6, stride=4, padding=2), nn.Conv2d(3, 3, kernel_size=6, stride=2, padding=2) )这部分代码定义了 RGBNet 模块中的 RGB 处理部分。以下是对其结构和功能的解释:
self.rgb
: RGB 处理模块,采用nn.Sequential
封装了两个卷积层。
- 第一层 (
nn.Conv2d(3, 3, kernel_size=6, stride=4, padding=2)
): 通过 kernel_size 为 6,stride 为 4 的卷积操作进行下采样,可能有助于捕捉粗糙的全局信息。- 第二层 (
nn.Conv2d(3, 3, kernel_size=6, stride=2, padding=2)
): 通过 kernel_size 为 6,stride 为 2 的卷积操作进行下采样,可能有助于进一步捕捉细节信息。这个 RGB 处理部分的目的可能是在网络中引入一些全局和局部的 RGB 信息,以辅助后续的特征学习和处理。这可以在某些任务中有助于提高网络性能,例如图像融合或目标检测等。
self.rs1 = nn.Sequential( nn.Conv2d(num_spectral + 3, num_spectral * 2 * 2 * 2 * 2 * 2 * 2, kernel_size=3, stride=1, padding=1), nn.LeakyReLU() )
这部分代码定义了 RGBNet 模块中的第一个残差连接(Residual Connection)部分。以下是对其结构和功能的解释:
self.rs1
: 第一个残差连接模块,采用nn.Sequential
封装了两个操作。
- 卷积层 (
nn.Conv2d(num_spectral + 3, num_spectral * 2 * 2 * 2 * 2 * 2 * 2, kernel_size=3, stride=1, padding=1)
): 通过 kernel_size 为 3 的卷积操作进行特征变换,输入通道数为num_spectral + 3
,输出通道数为num_spectral * 2 * 2 * 2 * 2 * 2 * 2
。- LeakyReLU 激活 (
nn.LeakyReLU()
): 使用 LeakyReLU 激活函数进行非线性变换。这个残差连接的目的可能是引入更多的非线性变换和特征映射,以增强网络的表达能力。这可以帮助网络更好地学习复杂的特征和关系。
python复制代码
self.ps = nn.PixelShuffle(8)
这部分代码定义了 RGBNet 模块中的像素洗牌(Pixel Shuffle)操作。以下是对其功能的解释:
self.ps
: 像素洗牌操作,采用 PyTorch 提供的nn.PixelShuffle
模块,其中参数 8 表示每个像素块的大小。Pixel Shuffle 是一种上采样操作,它可以将通道数降低,同时增加空间分辨率。这个像素洗牌的目的可能是在网络中进行高效的上采样,以适应后续处理或任务的需要。这在图像生成或超分辨率重建等任务中常常用于增加图像的细节和清晰度。
python复制代码
self.rs2 = nn.Sequential( nn.Conv2d(num_spectral + 3, num_fm, kernel_size=3, stride=1, padding=1), nn.LeakyReLU() )
这部分代码定义了 RGBNet 模块中的第二个残差连接(Residual Connection)部分。以下是对其结构和功能的解释:
self.rs2
: 第二个残差连接模块,采用nn.Sequential
封装了两个操作。
- 卷积层 (
nn.Conv2d(num_spectral + 3, num_fm, kernel_size=3, stride=1, padding=1)
): 通过 kernel_size 为 3 的卷积操作进行特征变换,输入通道数为num_spectral + 3
,输出通道数为num_fm
。- LeakyReLU 激活 (
nn.LeakyReLU()
): 使用 LeakyReLU 激活函数进行非线性变换。这个残差连接的目的可能与第一个残差连接相似,引入更多的非线性变换和特征映射,以进一步增强网络的表达能力。
python复制代码
self.res_blocks = nn.ModuleList() for _ in range(num_res): self.res_blocks.append(nn.Sequential( nn.Conv2d(num_fm, num_fm, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(), nn.Conv2d(num_fm, num_fm, kernel_size=3, stride=1, padding=1), ))
这部分代码定义了 RGBNet 模块中的一系列残差块(Residual Blocks)。具体来说,通过
nn.ModuleList
创建了一个包含多个残差块的列表,每个残差块采用nn.Sequential
封装了两个卷积层和一个 LeakyReLU 激活函数。这样的残差块通常用于在网络中引入残差连接,帮助网络更好地捕捉和学习特征。每个残差块都包含两个卷积层,可以通过堆叠这些残差块来增加网络的深度,提高对复杂特征的建模能力。
python复制代码
self.final_conv = nn.Conv2d(num_fm, num_spectral, kernel_size=3, stride=1, padding=1)
这部分代码定义了 RGBNet 模块中的最终卷积层
self.final_conv
: 最终的卷积层,采用nn.Conv2d
进行定义。这一层卷积的输入通道数为num_fm
,输出通道数为num_spectral
,使用了 kernel_size 为 3 的卷积核,stride 为 1,padding 为 1。这一层的作用可能是将网络学到的特征映射转化为最终的输出,输出通道数为
num_spectral
,以匹配任务的要求。这个卷积操作可能有助于整合特征信息,使得最终的输出更好地反映输入的光谱信息。
def forward(self, rgb_hp, ms_hp): # ms_hp, rgb_hp rgb_hp_up = torch.nn.functional.interpolate(ms_hp, scale_factor=8, mode='bicubic') gap_ms_c = ms_hp.mean(dim=(2, 3), keepdim=True) # 每个通道平均 CA = self.CA(gap_ms_c) gap_RGB_s = rgb_hp.mean(dim=1, keepdim=True) SA = self.SA(gap_RGB_s) rgb = self.rgb(rgb_hp) temp1 = ms_hp[:, :15] temp2 = ms_hp[:, 15:] rgb_temp1 = rgb[:, 0].unsqueeze(1) rgb_temp2 = rgb[:, 1].unsqueeze(1) rgb_temp3 = rgb[:, 2].unsqueeze(1) rs = torch.cat((rgb_temp1, temp1, rgb_temp2, temp2, rgb_temp3), dim=1) rs = self.rs1(rs) rs = self.ps(rs) temp1 = rs[:, :15] temp2 = rs[:, 15:] rgb_temp1 = rgb_hp[:, 0].unsqueeze(1) rgb_temp2 = rgb_hp[:, 1].unsqueeze(1) rgb_temp3 = rgb_hp[:, 2].unsqueeze(1) rs = torch.cat((rgb_temp1, temp1, rgb_temp2, temp2, rgb_temp3), dim=1) rs = self.rs2(rs) for res_block in self.res_blocks: rs1 = res_block(rs) rs = rs + rs1 rs = SA * rs rs = self.final_conv(rs) rs = CA * rs out = rs + rgb_hp_up return out
这个
forward
函数定义了 RGBNet 模块的前向传播过程。具体来说:
- 输入参数: 接收两个输入,
rgb_hp
和ms_hp
。- 上采样: 使用
torch.nn.functional.interpolate
对ms_hp
进行上采样,将其尺寸放大 8 倍,采用双三次插值的方式。- 通道注意力(CA): 通过通道注意力模块
self.CA
处理ms_hp
的通道平均值,得到通道注意力权重CA
。- 空间注意力(SA): 通过空间注意力模块
self.SA
处理rgb_hp
的通道平均值,得到空间注意力权重SA
。- RGB处理: 通过 RGB 处理模块
self.rgb
处理输入的rgb_hp
。- 通道拼接: 将处理后的
rgb
和部分ms_hp
进行通道拼接得到rs
。- 第一个残差连接(
self.rs1
): 通过残差连接self.rs1
处理rs
。- 像素洗牌(
self.ps
): 对处理后的rs
进行像素洗牌操作。- 第二个残差连接(
self.rs2
): 通过残差连接self.rs2
处理洗牌后的rs
。- 多个残差块: 通过多个残差块
self.res_blocks
处理rs
。- 空间注意力加权: 使用空间注意力权重
SA
对rs
进行加权。- 最终卷积层: 通过最终的卷积层
self.final_conv
处理加权后的rs
。- 通道注意力加权: 使用通道注意力权重
CA
对卷积结果进行加权。- 最终输出: 将加权后的结果与上采样后的
rgb_hp
相加得到最终输出out
。这个前向传播过程整合了通道注意力、空间注意力、RGB信息、残差连接等多种操作,旨在使网络能够更好地处理输入数据并生成相应的输出。