运行代码
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
epsilons = [0, .05, .1, .15, .2, .25, .3]
pretrained_model = "lenet_mnist_model.pth"
use_cuda=True
# LeNet Model definition
class Net(nn.Module):#创建网络
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)#灰度图片,一开始是1
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)#10类,最后是10
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
test_loader = torch.utils.data.DataLoader(#导入数据
datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
])),
batch_size=1, shuffle=True)
# D选择使用cpu或者是gpu
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
# 初始化网络
model = Net().to(device)
model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))#导入数据
# 进入测试模式
model.eval()
def fgsm_attack(image,epsilon,data_grad):#此函数的功能是进行fgsm攻击,需要输入三个变量,干净的图片,扰动量和输入图片
sign_data_grad=data_grad.sign()
perturbed_image=image+epsilon*sign_data_grad#公式
perturbed_image=torch.clamp(perturbed_image,0,1)#为了保持图像的原始范围,将受干扰的图像裁剪到一定的范围【0,1】
return perturbed_image
def test(model,device,test_loader,epsilon):#测试函数
correct=0#存放正确的个数
adv_examples=[]#存放正确的例子
for data,target in test_loader:
data,target=data.to(device),target.to(device)
data.requires_grad=True
output=model(data)
init_pred=output.max(1,keepdim=True)[1]#选取最大的类别概率
if init_pred.item()!=target.item():#判断类别是否相等
continue
loss=F.nll_loss(output,target)
model.zero_grad()
loss.backward()
data_grad=data.grad.data
perturbed_data=fgsm_attack(data,epsilon,data_grad)
output=model(perturbed_data)
final_pred=output.max(1,keepdim=True)[1]
if final_pred.item()==target.item():#判断类别是否相等
correct+=1
if (epsilon == 0) and (len(adv_examples) < 6):#这里是在选取例子,可以输出
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
else:
# Save some adv examples for visualization later
if len(adv_examples) < 6:
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
# Calculate final accuracy for this epsilon
final_acc = correct / float(len(test_loader))#算正确率
print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))
# Return the accuracy and an adversarial example
return final_acc, adv_examples
accuracies = []
examples = []
# Run test for each epsilon
for eps in epsilons:
acc, ex = test(model, device, test_loader, eps)
accuracies.append(acc)
examples.append(ex)
plt.plot(epsilons,accuracies)
plt.show()
cnt = 0
plt.figure(figsize=(8,10))
for i in range(len(epsilons)):
for j in range(len(examples[i])):
cnt += 1
plt.subplot(len(epsilons),len(examples[0]),cnt)
plt.xticks([], [])
plt.yticks([], [])
if j == 0:
plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
orig,adv,ex = examples[i][j]
plt.title("{} -> {}".format(orig, adv))
plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()
运行结果: