CW2算法源码解析

论文链接:https://arxiv.org/abs/1608.04644
源码出处:https://github.com/Harry24k/adversarial-attacks-pytorch/tree/master


解析

CW这个名字来源于论文的两个作者的名字Carlini和Wagner的首字母,CW2是论文中的L2范式的攻击方法,这是一种基于优化的攻击方法,目标函数为: m i n i m i z e ∥ 1 2 ( t a n h ( w ) + 1 ) − x ∥ 2 2 + c ⋅ f ( 1 2 ( t a n h ( w ) + 1 ) ) minimize \Vert\frac{1}{2}(tanh(w)+1)-x\Vert^2_2+c\cdot f(\frac{1}{2}(tanh(w)+1)) minimize21(tanh(w)+1)x22+cf(21(tanh(w)+1))其中 t a n h tanh tanh表示双曲正切函数,本方法使用 1 2 ( t a n h ( w ) + 1 ) \frac{1}{2}(tanh(w)+1) 21(tanh(w)+1)来表示增加扰动后的图像,这样做可以保证扰动后的图像的范围在 [ 0 , 1 ] [0,1] [0,1]中,同时其处处可导数,有利于优化。在代码中,该方法先将图像输入 t a n h tanh tanh的反函数 a t a n h atanh atanh得到初始的 w w w,也就是优化的起点。文中也提到可以选择多个与初始起点相近随机起点来避免陷入局部最小值。公式中的 c c c表示 f f f的值在目标函数中的权重,文中使用二分法来确定在成功产生对抗样本的情况下的最小的 c c c的值。该公式左半部分含义是尽量降低对抗样本与原始图像的L2距离,右半部分是为了达到误分类的目的。
f f f的定义为: f ( x ′ ) = m a x ( m a x { Z ( x ′ ) i : i ≠ t } − Z ( x ′ ) t , − κ ) f(x')=max(max\{Z(x')_i:i\neq t\} -Z(x')_t, - \kappa) f(x)=max(max{Z(x)i:i=t}Z(x)t,κ)
只有当 f ( x ′ ) ≤ 0 f(x')\le0 f(x)0时,才说明 x ′ x' x被误分类为了目标标签 t t t Z ( x ′ ) Z(x') Z(x)表示 x ′ x' x在模型中的输出值logits。通过改变 κ \kappa κ的值,可以控制对抗样本 x ′ x' x对于目标标签的置信度, κ \kappa κ越大,则最终的置信度越高。

源码

相关代码的解析我写在了注释中。


import torch
import torch.nn as nn
import torch.optim as optim

from ..attack import Attack


class CW(Attack):
    r"""
    CW in the paper 'Towards Evaluating the Robustness of Neural Networks'
    [https://arxiv.org/abs/1608.04644]

    Distance Measure : L2

    Arguments:
        model (nn.Module): model to attack.
        c (float): c in the paper. parameter for box-constraint. (Default: 1)    
            :math:`minimize \Vert\frac{1}{2}(tanh(w)+1)-x\Vert^2_2+c\cdot f(\frac{1}{2}(tanh(w)+1))`
        kappa (float): kappa (also written as 'confidence') in the paper. (Default: 0)
            :math:`f(x')=max(max\{Z(x')_i:i\neq t\} -Z(x')_t, - \kappa)`
        steps (int): number of steps. (Default: 50)
        lr (float): learning rate of the Adam optimizer. (Default: 0.01)

    .. warning:: With default c, you can't easily get adversarial images. Set higher c like 1.

    Shape:
        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].
        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
        - output: :math:`(N, C, H, W)`.

    Examples::
        >>> attack = torchattacks.CW(model, c=1, kappa=0, steps=50, lr=0.01)
        >>> adv_images = attack(images, labels)

    .. note:: Binary search for c is NOT IMPLEMENTED methods in the paper due to time consuming.

    """
    def __init__(self, model, c=1, kappa=0, steps=50, lr=0.01):
        super().__init__("CW", model)
        self.c = c  # 由于二分搜索c的值时间开销太大,所以该代码直接定义了c的值
        self.kappa = kappa  # 定义公式中的kappa的值
        self.steps = steps  # 迭代次数
        self.lr = lr  # 学习率,用于优化器优化w的值
        # default为无目标攻击,targeted为有目标攻击,论文中的L2攻击是有目标攻击
        # 该代码也可以进行无目标攻击
        self.supported_mode = ['default', 'targeted']

    def forward(self, images, labels):
        r"""
        Overridden.
        """
        self._check_inputs(images)

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self.targeted:
        	# 得到目标标签
            target_labels = self.get_target_label(images, labels)

        # w = torch.zeros_like(images).detach() # Requires 2x times
        w = self.inverse_tanh_space(images).detach()  # 通过atanh得到初始w
        w.requires_grad = True
		
		# 初始化最佳对抗样本以及最佳L2距离
        best_adv_images = images.clone().detach()
        best_L2 = 1e10*torch.ones((len(images))).to(self.device)
        prev_cost = 1e10
        dim = len(images.shape)
		
		# 该损失函数用来计算公式中的第二范式距离的平方
        MSELoss = nn.MSELoss(reduction='none')
        Flatten = nn.Flatten()
		
		# 使用Adam优化器
        optimizer = optim.Adam([w], lr=self.lr)

        for step in range(self.steps):
            # Get adversarial images
            adv_images = self.tanh_space(w)

            # 计算L2距离,也就是公式左半部分
            current_L2 = MSELoss(Flatten(adv_images),
                                 Flatten(images)).sum(dim=1)
            L2_loss = current_L2.sum()

            outputs = self.get_logits(adv_images)
            # 计算f的值,也就是公式右半部分
            if self.targeted:
                f_loss = self.f(outputs, target_labels).sum()
            else:
                f_loss = self.f(outputs, labels).sum()
			
			# cost即为公式所得值
            cost = L2_loss + self.c*f_loss
			
			# 使用Adam优化器进行优化
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()

            # Update adversarial images
            pre = torch.argmax(outputs.detach(), 1)
            if self.targeted:
                # 找到成功pre == target_labels的样本
                condition = (pre == target_labels).float()
            else:
                # 如果是无目标攻击,则找到成功误分类的样本
                condition = (pre != labels).float()

            # 找到损失下降并且condition为1的样本, 
            # 也就是说只有同时成功误分类并且损失比以前最好的损失还低的图片才会被留下
            mask = condition*(best_L2 > current_L2.detach())
            best_L2 = mask*current_L2.detach() + (1-mask)*best_L2

            mask = mask.view([-1]+[1]*(dim-1))
            best_adv_images = mask*adv_images.detach() + (1-mask)*best_adv_images

            # 如果损失不再下降,就提前停止
            # max(.,1)为了防止除数为0
            if step % max(self.steps//10,1) == 0:
                if cost.item() > prev_cost:
                    return best_adv_images
                prev_cost = cost.item()

        return best_adv_images
	
	# 计算tanh
    def tanh_space(self, x):
        return 1/2*(torch.tanh(x) + 1)
	
	# 计算atanh
    def inverse_tanh_space(self, x):
        # torch.atanh is only for torch >= 1.7.0
        # atanh is defined in the range -1 to 1
        return self.atanh(torch.clamp(x*2-1, min=-1, max=1))

    def atanh(self, x):
        return 0.5*torch.log((1+x)/(1-x))

    # f函数
    def f(self, outputs, labels):
        one_hot_labels = torch.eye(outputs.shape[1]).to(self.device)[labels]

        other = torch.max((1-one_hot_labels)*outputs, dim=1)[0] # 得到除目标标签外的最高的logit
        real = torch.max(one_hot_labels*outputs, dim=1)[0]      # 得到目标标签的logit

        if self.targeted:
            return torch.clamp((other-real), min=-self.kappa)
        else:
        	# 如果是无目标攻击,那么应该增加到真实标签的距离,所以与上式相反
            return torch.clamp((real-other), min=-self.kappa)

你可能感兴趣的:(AI安全,算法,神经网络,python,ai)