AidLux之AI对抗防御算法

目前主流对抗防御的总体分支与逻辑:

AidLux之AI对抗防御算法_第1张图片

        其中对抗训练是指在训练过程中加入对抗样本,通过不断的学习对抗样本的特征,从而提升模型的鲁棒性。监测识别对抗样本顾名思义,在项目关键节点处,设置一些能够识别对抗样本的特种模型,从而提前预警对抗攻击风险。 模型鲁棒结构设计是指在模型中设计特定的滤波结构能够一定程度上增强模型鲁棒性,抵御对抗噪 声。

        对抗扰动结构破坏主要在数据流处理的时候使用,通过一些滤波算法,噪声结构破坏算法,噪声覆盖 算法等策略,减弱对抗噪声的影响,使其不能对模型造成攻击。 梯度掩膜则是在白盒对抗防御中非常高效的一种算法,因为其能掩盖真实梯度,从而能够使得白盒攻击算法失效。本章中在后面小节使用到的对抗防御方法主要是基于梯度掩膜的GCM模块。
AidLux之AI对抗防御算法_第2张图片

 

GCM模块通过改造输入,使模型在前向测试时并不影响原来的结果,但是在进行梯度反向传播时,能隐藏原来的梯度,使得白盒攻击无所适从。下图是加入GCM模块后的模型梯度反向传播图示:
AidLux之AI对抗防御算法_第3张图片

 

同时我们设置很大的w参数,设置很小的epsilon参数,使得攻击算法在获取梯度时会得到一个很大的数,其与算法模型的有效梯度流无关。
源码链接: https://pan.baidu.com/s/1RIduv6ngpCM3jx63D3PQGA  提取码:aid6
import os
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
from advertorch.utils import predict_from_logits
from advertorch.utils import NormalizeByChannelMeanStd
from robust_layer import GradientConcealment, ResizedPaddingLayer

from advertorch.attacks import LinfPGDAttack
from advertorch_examples.utils import ImageNetClassNameLookup
from advertorch_examples.utils import bhwc2bchw
from advertorch_examples.utils import bchw2bhwc

device = "cuda" if torch.cuda.is_available() else "cpu"


### 读取图片
def get_image():
    img_path = os.path.join("./images", "school_bus.png")

    def _load_image():
        from skimage.io import imread
        return imread(img_path) / 255.

    if os.path.exists(img_path):
        return _load_image()



def tensor2npimg(tensor):
    return bchw2bhwc(tensor[0].cpu().numpy())


### 展示攻击结果
def show_images(model, img, advimg, enhance=127):
    np_advimg = tensor2npimg(advimg)
    np_perturb = tensor2npimg(advimg - img)

    pred = imagenet_label2classname(predict_from_logits(model(img)))
    advpred = imagenet_label2classname(predict_from_logits(model(advimg)))

    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(np_img)

    plt.axis("off")
    plt.title("original image\n prediction: {}".format(pred))
    plt.subplot(1, 3, 2)
    plt.imshow(np_perturb * enhance + 0.5)

    plt.axis("off")
    plt.title("the perturbation,\n enhanced {} times".format(enhance))
    plt.subplot(1, 3, 3)
    plt.imshow(np_advimg)
    plt.axis("off")
    plt.title("perturbed image\n prediction: {}".format(advpred))
    plt.show()


normalize = NormalizeByChannelMeanStd(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

### GCM模块
robust_mode = GradientConcealment()

### 常规模型+GCM模块
class Model(nn.Module):
    def __init__(self, l=290):
        super(Model, self).__init__()

        self.l = l
        self.gcm = GradientConcealment()
        # model = resnet18(pretrained=True)
        model = mobilenet_v2(pretrained=False)

        pth_path = "../model/mobilenet_v2-b0353104.pth"
        # print(f'Loading pth from {pth_path}')
        state_dict = torch.load(pth_path, map_location='cpu')
        model.load_state_dict(state_dict, strict=True)
        # is_strict = False
        # if 'model' in state_dict.keys():
        #    model.load_state_dict(state_dict['model'], strict=is_strict)
        # else:
        #    model.load_state_dict(state_dict, strict=is_strict)

        normalize = NormalizeByChannelMeanStd(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.model = nn.Sequential(normalize, model)

    def load_params(self):
        pass

    def forward(self, x):
        x = self.gcm(x)
        # x = ResizedPaddingLayer(self.l)(x)
        out = self.model(x)
        return out

### 常规模型+GCM模块 加载
model_defense = Model().eval().to(device)


### 数据预处理
np_img = get_image()
img = torch.tensor(bhwc2bchw(np_img))[None, :, :, :].float().to(device)
imagenet_label2classname = ImageNetClassNameLookup()


### 测试模型输出结果
pred_defense = imagenet_label2classname(predict_from_logits(model_defense(img)))
print("test output:", pred_defense)

pre_label = predict_from_logits(model_defense(img))


### 对抗攻击:PGD攻击算法
adversary = LinfPGDAttack(
    model_defense, eps=8 / 255, eps_iter=2 / 255, nb_iter=80,
    rand_init=True, targeted=False)


### 完成攻击,输出对抗样本
advimg = adversary.perturb(img, pre_label)


### 展示源图片,对抗扰动,对抗样本以及模型的输出结果
show_images(model_defense, img, advimg)

输出结果:被攻击的原图,仍能正确识别为校车

AidLux之AI对抗防御算法_第4张图片

 

你可能感兴趣的:(算法,人工智能)