与2022.10.25日更新,关于torch.nn.Parameter看看下面这篇博客,同时你可以关注一下nn.Linear与nn.Parameter的区别,可以用nn.Parameter实现nn.Linear
[Pytorch系列-30]:神经网络基础 - torch.nn库五大基本功能:nn.Parameter、nn.Linear、nn.functioinal、nn.Module、nn.Sequentia_文火冰糖的硅基工坊的博客-CSDN博客
这个是fedformer里面提的频率增强模块,其实就是通过傅立叶变换拿到频率特征然后乘上一个可学习参数,为该操作赋能,把提取到的频域信息整合到模型中。
代码
# ########## fourier layer #############
class FourierBlock(nn.Module):
def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'):
super(FourierBlock, self).__init__()
print('fourier enhanced block used!')
"""
1D Fourier block. It performs representation learning on frequency domain,
it does FFT, linear transform, and Inverse FFT.
"""
# get modes on frequency domain
self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)#得到随机打乱选取的基,后续进行DFT操作
print('modes={}, index={}'.format(modes, self.index))
self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(
self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.cfloat))
# Complex multiplication 复数乘法
def compl_mul1d(self, input, weights):
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
return torch.einsum("bhi,hio->bho", input, weights)#高维张量的计算
#搞懂这个torch.einsum操作!!!
def forward(self, q, k, v, mask):
# size = [B, L, H, E]
B, L, H, E = q.shape
x = q.permute(0, 2, 3, 1)
# Compute Fourier coefficients
x_ft = torch.fft.rfft(x, dim=-1)
# Perform Fourier neural operations
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
for wi, i in enumerate(self.index):
out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi])#这里就是点对点运算,每个频率点的幅值分别乘以一个可学习参数
# Return to time domain
x = torch.fft.irfft(out_ft, n=x.size(-1))
return (x, None)
torch.nn.Parameter理解_Stoneplay26的博客-CSDN博客_torch.nn.parameter
PyTorch里面的torch.nn.Parameter()_明泽.的博客-CSDN博客_torch.nn.parameter
参考资料
[Pytorch系列-30]:神经网络基础 - torch.nn库五大基本功能:nn.Parameter、nn.Linear、nn.functioinal、nn.Module、nn.Sequentia_文火冰糖的硅基工坊的博客-CSDN博客
torch.nn.Parameter()_chenzy_hust的博客-CSDN博客_nn.parameter()
PyTorch里面的torch.nn.Parameter() - 简书