图像频域特征提取

图像频域特征提取

   代码来源于Multimodal Fusion with Co-Attention Networks for Fake News Detection — 虚假新闻检测,多模态融合,将频域特征提取方式提取出来

图像频域特征提取_第1张图片

定义函数及网络结构

import numpy as np
from scipy.fftpack import fft,dct
import torch
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
from torchvision import transforms

# 定义函数及网络
transform_dct = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor()
            ])

def process_dct_img(img):
    img = img.numpy() #size = [1, 224, 224]
    height = img.shape[1]
    width = img.shape[2]
    #print('height:{}'.format(height))
    N = 8 
    step = int(height/N) #28

    dct_img = np.zeros((1, N*N, step*step, 1), dtype=np.float32) #[1,64,784,1]
    fft_img = np.zeros((1, N*N, step*step, 1))
    #print('dct_img:{}'.format(dct_img.shape))
    
    i = 0
    for row in np.arange(0, height, step):
        for col in np.arange(0, width, step):
            block = np.array(img[:, row:(row+step), col:(col+step)], dtype=np.float32)
            #print('block:{}'.format(block.shape))
            block1 = block.reshape(-1, step*step, 1) #[batch_size,784,1]
            dct_img[:, i,:,:] = dct(block1) #[batch_size, 64, 784, 1]

            i += 1

    #for i in range(64):
    fft_img[:,:,:,:] = fft(dct_img[:,:,:,:]).real #[batch_size,64, 784,1]
    
    fft_img = torch.from_numpy(fft_img).float() #[batch_size, 64, 784, 1]
    new_img = F.interpolate(fft_img, size=[250,1]) #[batch_size, 64, 250, 1]
    new_img = new_img.squeeze(0).squeeze(-1) #torch.size = [64, 250]
    
    return new_img   

class DctStem(nn.Module):
    def __init__(self, kernel_sizes, num_channels):
        super(DctStem, self).__init__()
        self.convs = nn.Sequential(
            ConvBNRelu2d(in_channels=1,
                         out_channels=num_channels[0],
                         kernel_size=kernel_sizes[0]),
            ConvBNRelu2d(
                in_channels=num_channels[0],
                out_channels=num_channels[1],
                kernel_size=kernel_sizes[1],
            ),
            ConvBNRelu2d(
                in_channels=num_channels[1],
                out_channels=num_channels[2],
                kernel_size=kernel_sizes[2],
            ),
            nn.MaxPool2d((1, 2)),
        )

    def forward(self, dct_img):
        x = dct_img.unsqueeze(1)
        x = x.unsqueeze(1)
        print(x.shape)
        img = self.convs(x)
        img = img.permute(0, 2, 1, 3)

        return img

class DctInceptionBlock(nn.Module):
    def __init__(
        self,
        in_channel=128,
        branch1_channels=[64],
        branch2_channels=[48, 64],
        branch3_channels=[64, 96, 96],
        branch4_channels=[32],
    ):
        super(DctInceptionBlock, self).__init__()

        self.branch1 = ConvBNRelu2d(in_channels=in_channel,
                                    out_channels=branch1_channels[0],
                                    kernel_size=1)

        self.branch2 = nn.Sequential(
            ConvBNRelu2d(in_channels=in_channel,
                         out_channels=branch2_channels[0],
                         kernel_size=1),
            ConvBNRelu2d(
                in_channels=branch2_channels[0],
                out_channels=branch2_channels[1],
                kernel_size=3,
                padding=(0, 1),
            ),
        )

        self.branch3 = nn.Sequential(
            ConvBNRelu2d(in_channels=in_channel,
                         out_channels=branch3_channels[0],
                         kernel_size=1),
            ConvBNRelu2d(
                in_channels=branch3_channels[0],
                out_channels=branch3_channels[1],
                kernel_size=3,
                padding=(0, 1),
            ),
            ConvBNRelu2d(
                in_channels=branch3_channels[1],
                out_channels=branch3_channels[2],
                kernel_size=3,
                padding=(0, 1),
            ),
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=(1, 3), stride=1, padding=(0, 1)),
            ConvBNRelu2d(in_channels=in_channel,
                         out_channels=branch4_channels[0],
                         kernel_size=1),
        )

    def forward(self, x):

        x = x.permute(0, 2, 1, 3)
        # y = x
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        out4 = self.branch4(x)
        out = torch.cat([out1, out2, out3, out4], dim=1)
        out = out.permute(0, 2, 1, 3)

        return out
def ConvBNRelu2d(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(1, kernel_size),
            stride=stride,
            padding=padding,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )



class DctCNN(nn.Module):
    def __init__(self,
                 model_dim,
                 dropout,
                 kernel_sizes,
                 num_channels,
                 in_channel=128,
                 branch1_channels=[64],
                 branch2_channels=[48, 64],
                 branch3_channels=[64, 96, 96],
                 branch4_channels=[32],
                 out_channels=64):

        super(DctCNN, self).__init__()

        self.stem = DctStem(kernel_sizes, num_channels)

        self.InceptionBlock = DctInceptionBlock(
            in_channel,
            branch1_channels,
            branch2_channels,
            branch3_channels,
            branch4_channels,
        )

        self.maxPool = nn.MaxPool2d((1, 122))

        self.dropout = nn.Dropout(dropout)

        self.conv = ConvBNRelu2d(branch1_channels[-1] + branch2_channels[-1] +
                               branch3_channels[-1] + branch4_channels[-1],
                               out_channels,
                               kernel_size=1)

    def forward(self, dct_img):
        dct_f = self.stem(dct_img)
        x = self.InceptionBlock(dct_f)
        x = self.maxPool(x)
        x = x.permute(0, 2, 1, 3)
        x = self.conv(x)
        x = x.permute(0, 2, 1, 3)
        x = x.squeeze(-1)
        
        x = x.reshape(-1,4096)

        return x
    
class NetShareFusion(nn.Module):
    def __init__(self,
                 kernel_sizes,
                 num_channels,
                 model_dim,
                 drop_and_BN,
                 num_labels=2,
                 dropout=0.5):

        super(NetShareFusion, self).__init__()

        self.model_dim = model_dim
        self.drop_and_BN = drop_and_BN
        self.dropout = nn.Dropout(dropout)

        #dct_image
        self.dct_img = DctCNN(model_dim,
                              dropout,
                              kernel_sizes,
                              num_channels,
                              in_channel=128,
                              branch1_channels=[64],
                              branch2_channels=[48, 64],
                              branch3_channels=[64, 96, 96],
                              branch4_channels=[32],
                              out_channels=64)
        self.linear_dct = nn.Linear(4096, model_dim)
        self.bn_dct = nn.BatchNorm1d(model_dim)

       
        #classifier
        self.linear1 = nn.Linear(model_dim, 35)
        self.bn_1 = nn.BatchNorm1d(35)
        self.linear2 = nn.Linear(35, num_labels)
        self.softmax = nn.Softmax(dim=1)
    
    def drop_BN_layer(self, x, part='dct'):
        if part == 'dct':
            bn = self.bn_dct

        if self.drop_and_BN == 'drop-BN':
            x = self.dropout(x)
            x = bn(x)
        
        return x

    def forward(self,dct_img):

        #dct_feature
        dct_out = self.dct_img(dct_img)
        dct_out = F.relu(self.linear_dct(dct_out))
        dct_out = self.drop_BN_layer(dct_out, part='dct')
        print(dct_out.shape)
       
        output = F.relu(self.linear1(dct_out))
        output = self.dropout(output)
        #output = self.bn_1(output)
        output = self.linear2(output)
        #print('output_size:{}'.format(output.shape))
        y_pred_prob = self.softmax(output)

        return output, y_pred_prob

模型调用

model = NetShareFusion(
                kernel_sizes= [3, 3, 3],
                num_channels=[32, 64, 128],
                model_dim=256,
                dropout=0.5,
                drop_and_BN=['drop-BN'])

image = Image.open(r"D:\mywork\tupian\people\6LmB22tIBO.jpg")
dct_img = transform_dct(image.convert('L'))
dct_img = process_dct_img(dct_img)
print(dct_img.shape)
output, y_pred_prob = model(dct_img)
print(output)
print(y_pred_prob)

#torch.Size([64, 250])
#torch.Size([64, 1, 1, 250])
#torch.Size([1, 256])
#tensor([[0.3687, 0.1634]], grad_fn=)
#tensor([[0.5512, 0.4488]], grad_fn=)

说明:图像频域特征提取是没有分类。

你可能感兴趣的:(python,开发语言)