代码来源于Multimodal Fusion with Co-Attention Networks for Fake News Detection — 虚假新闻检测,多模态融合,将频域特征提取方式提取出来
import numpy as np
from scipy.fftpack import fft,dct
import torch
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
from torchvision import transforms
# 定义函数及网络
transform_dct = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor()
])
def process_dct_img(img):
img = img.numpy() #size = [1, 224, 224]
height = img.shape[1]
width = img.shape[2]
#print('height:{}'.format(height))
N = 8
step = int(height/N) #28
dct_img = np.zeros((1, N*N, step*step, 1), dtype=np.float32) #[1,64,784,1]
fft_img = np.zeros((1, N*N, step*step, 1))
#print('dct_img:{}'.format(dct_img.shape))
i = 0
for row in np.arange(0, height, step):
for col in np.arange(0, width, step):
block = np.array(img[:, row:(row+step), col:(col+step)], dtype=np.float32)
#print('block:{}'.format(block.shape))
block1 = block.reshape(-1, step*step, 1) #[batch_size,784,1]
dct_img[:, i,:,:] = dct(block1) #[batch_size, 64, 784, 1]
i += 1
#for i in range(64):
fft_img[:,:,:,:] = fft(dct_img[:,:,:,:]).real #[batch_size,64, 784,1]
fft_img = torch.from_numpy(fft_img).float() #[batch_size, 64, 784, 1]
new_img = F.interpolate(fft_img, size=[250,1]) #[batch_size, 64, 250, 1]
new_img = new_img.squeeze(0).squeeze(-1) #torch.size = [64, 250]
return new_img
class DctStem(nn.Module):
def __init__(self, kernel_sizes, num_channels):
super(DctStem, self).__init__()
self.convs = nn.Sequential(
ConvBNRelu2d(in_channels=1,
out_channels=num_channels[0],
kernel_size=kernel_sizes[0]),
ConvBNRelu2d(
in_channels=num_channels[0],
out_channels=num_channels[1],
kernel_size=kernel_sizes[1],
),
ConvBNRelu2d(
in_channels=num_channels[1],
out_channels=num_channels[2],
kernel_size=kernel_sizes[2],
),
nn.MaxPool2d((1, 2)),
)
def forward(self, dct_img):
x = dct_img.unsqueeze(1)
x = x.unsqueeze(1)
print(x.shape)
img = self.convs(x)
img = img.permute(0, 2, 1, 3)
return img
class DctInceptionBlock(nn.Module):
def __init__(
self,
in_channel=128,
branch1_channels=[64],
branch2_channels=[48, 64],
branch3_channels=[64, 96, 96],
branch4_channels=[32],
):
super(DctInceptionBlock, self).__init__()
self.branch1 = ConvBNRelu2d(in_channels=in_channel,
out_channels=branch1_channels[0],
kernel_size=1)
self.branch2 = nn.Sequential(
ConvBNRelu2d(in_channels=in_channel,
out_channels=branch2_channels[0],
kernel_size=1),
ConvBNRelu2d(
in_channels=branch2_channels[0],
out_channels=branch2_channels[1],
kernel_size=3,
padding=(0, 1),
),
)
self.branch3 = nn.Sequential(
ConvBNRelu2d(in_channels=in_channel,
out_channels=branch3_channels[0],
kernel_size=1),
ConvBNRelu2d(
in_channels=branch3_channels[0],
out_channels=branch3_channels[1],
kernel_size=3,
padding=(0, 1),
),
ConvBNRelu2d(
in_channels=branch3_channels[1],
out_channels=branch3_channels[2],
kernel_size=3,
padding=(0, 1),
),
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=(1, 3), stride=1, padding=(0, 1)),
ConvBNRelu2d(in_channels=in_channel,
out_channels=branch4_channels[0],
kernel_size=1),
)
def forward(self, x):
x = x.permute(0, 2, 1, 3)
# y = x
out1 = self.branch1(x)
out2 = self.branch2(x)
out3 = self.branch3(x)
out4 = self.branch4(x)
out = torch.cat([out1, out2, out3, out4], dim=1)
out = out.permute(0, 2, 1, 3)
return out
def ConvBNRelu2d(in_channels, out_channels, kernel_size, stride=1, padding=0):
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(1, kernel_size),
stride=stride,
padding=padding,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
class DctCNN(nn.Module):
def __init__(self,
model_dim,
dropout,
kernel_sizes,
num_channels,
in_channel=128,
branch1_channels=[64],
branch2_channels=[48, 64],
branch3_channels=[64, 96, 96],
branch4_channels=[32],
out_channels=64):
super(DctCNN, self).__init__()
self.stem = DctStem(kernel_sizes, num_channels)
self.InceptionBlock = DctInceptionBlock(
in_channel,
branch1_channels,
branch2_channels,
branch3_channels,
branch4_channels,
)
self.maxPool = nn.MaxPool2d((1, 122))
self.dropout = nn.Dropout(dropout)
self.conv = ConvBNRelu2d(branch1_channels[-1] + branch2_channels[-1] +
branch3_channels[-1] + branch4_channels[-1],
out_channels,
kernel_size=1)
def forward(self, dct_img):
dct_f = self.stem(dct_img)
x = self.InceptionBlock(dct_f)
x = self.maxPool(x)
x = x.permute(0, 2, 1, 3)
x = self.conv(x)
x = x.permute(0, 2, 1, 3)
x = x.squeeze(-1)
x = x.reshape(-1,4096)
return x
class NetShareFusion(nn.Module):
def __init__(self,
kernel_sizes,
num_channels,
model_dim,
drop_and_BN,
num_labels=2,
dropout=0.5):
super(NetShareFusion, self).__init__()
self.model_dim = model_dim
self.drop_and_BN = drop_and_BN
self.dropout = nn.Dropout(dropout)
#dct_image
self.dct_img = DctCNN(model_dim,
dropout,
kernel_sizes,
num_channels,
in_channel=128,
branch1_channels=[64],
branch2_channels=[48, 64],
branch3_channels=[64, 96, 96],
branch4_channels=[32],
out_channels=64)
self.linear_dct = nn.Linear(4096, model_dim)
self.bn_dct = nn.BatchNorm1d(model_dim)
#classifier
self.linear1 = nn.Linear(model_dim, 35)
self.bn_1 = nn.BatchNorm1d(35)
self.linear2 = nn.Linear(35, num_labels)
self.softmax = nn.Softmax(dim=1)
def drop_BN_layer(self, x, part='dct'):
if part == 'dct':
bn = self.bn_dct
if self.drop_and_BN == 'drop-BN':
x = self.dropout(x)
x = bn(x)
return x
def forward(self,dct_img):
#dct_feature
dct_out = self.dct_img(dct_img)
dct_out = F.relu(self.linear_dct(dct_out))
dct_out = self.drop_BN_layer(dct_out, part='dct')
print(dct_out.shape)
output = F.relu(self.linear1(dct_out))
output = self.dropout(output)
#output = self.bn_1(output)
output = self.linear2(output)
#print('output_size:{}'.format(output.shape))
y_pred_prob = self.softmax(output)
return output, y_pred_prob
model = NetShareFusion(
kernel_sizes= [3, 3, 3],
num_channels=[32, 64, 128],
model_dim=256,
dropout=0.5,
drop_and_BN=['drop-BN'])
image = Image.open(r"D:\mywork\tupian\people\6LmB22tIBO.jpg")
dct_img = transform_dct(image.convert('L'))
dct_img = process_dct_img(dct_img)
print(dct_img.shape)
output, y_pred_prob = model(dct_img)
print(output)
print(y_pred_prob)
#torch.Size([64, 250])
#torch.Size([64, 1, 1, 250])
#torch.Size([1, 256])
#tensor([[0.3687, 0.1634]], grad_fn=)
#tensor([[0.5512, 0.4488]], grad_fn=)