1、先贴上我的还原代码(Jupter)
本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,
并使用MNIST数据集(28*28手写数字图片集)进行训练和测试。针对过程中的每个步骤都尽可能的给出了详尽的解释。
MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它”下手”几乎成为一个 “典范”,可以说它就是计算机视觉里面的Hello World。所以我们这里也会使用MNIST来进行实战。
# 1 准备工作
## 导入包
```python
import torch
import torchvision
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np # 科学计算库,处理多维数组,进行数据分析
import pandas as pd
import matplotlib.pyplot as plt # 提供一个类似 Matlab 的绘图框架
import scipy.io as sio # 数据输入输出,用于读入.mat文件。scipy一个高级的科学计算库,它和Numpy联系很密切
from scipy.optimize import minimize # 优化函数
To_image = transforms.ToPILImage()
```
## 导入数据集
```python
train_dataset = datasets.MNIST(root = 'G:\ProgramData\TestData\MNIST', train = True,
transform = transforms.ToTensor(), download = True)
test_dataset = datasets.MNIST(root = 'G:\ProgramData\TestData\MNIST', train = False,
transform = transforms.ToTensor(), download = True)
```
```python
train_loader = DataLoader(dataset = train_dataset, batch_size =200, shuffle = False)
test_loader = DataLoader(dataset = test_dataset, batch_size= 200, shuffle = False)
```
```python
def weights_init(m):
if hasattr(m, "weight"):
m.weight.data.uniform_(-0.5, 0.5)
if hasattr(m, "bias"):
m.bias.data.uniform_(-0.5, 0.5)
``
```python
# 定义模型
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
act = nn.Sigmoid
self.body = nn.Sequential(
nn.Conv2d(1, 12, kernel_size=5, padding=5//2, stride=2),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
act(),
)
self.fc = nn.Sequential(
nn.Linear(588, 10) #这边输出588
)
def forward(self, x):
out = self.body(x)
out = out.view(out.size(0), -1)
# print(out.size())
out = self.fc(out)
return out
```
```python
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
net = ConvNet().to(device)
net.apply(weights_init)
### 损失函数
这里不调用,而是采用作者的损失函数
```python
#将标签转为热编码
def label_to_onehot(target, num_classes=10):
target = torch.unsqueeze(target, 1)
onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
onehot_target.scatter_(1, target, 1)
return onehot_target
#损失函数
def cross_entropy_for_onehot(pred, target):
return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))
criterion=cross_entropy_for_onehot
```
### 计算初始梯度
```python
#从test中复制一批原始的数据
test_loader_X=iter(test_loader).next()[0].clone().detach()
test_loader_y=iter(test_loader).next()[1].clone().detach()
#只拿出一条数据
test_loader_X=test_loader_X[0].view(1,1,28,28)
test_loader_y=test_loader_y[0].view(1)
print(test_loader_X.size())
print(test_loader_y)
```python
#计算原始梯度
#热编码
test_loader_y_onehot=label_to_onehot(test_loader_y)
pred = net(test_loader_X)
loss =criterion(pred, test_loader_y_onehot) #损失值
dy_dx = torch.autograd.grad(loss, net.parameters()) #获取对参数W的梯度
original_dy_dx = list((_.detach().clone() for _ in dy_dx)) #对原始梯度复制
#print(original_dy_dx)
print(loss)
print(pred)
### 构建虚拟数据
```python
# generate dummy data and label
dummy_data = torch.randn(test_loader_X.size()).to(device).requires_grad_(True)
dummy_label = torch.randn(test_loader_y_onehot.size()).to(device).requires_grad_(True)
```
### 还原图片
```python
#生成模型和优化器
#optimizer = optim.Adam(model_2.parameters())
optimizer = optim.LBFGS([dummy_data, dummy_label])
#optimizer = optim.Adam([dummy_data, dummy_label])
```
```python
history = []
for iters in range(300):
def closure():
optimizer.zero_grad()
dummy_pred = net(dummy_data)
dummy_onehot_label = F.softmax(dummy_label, dim=-1)
dummy_loss = criterion(dummy_pred, dummy_onehot_label)
dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)
grad_diff = 0
for gx, gy in zip(dummy_dy_dx, original_dy_dx):
grad_diff += ((gx - gy) ** 2).sum()
grad_diff.backward()
return grad_diff
optimizer.step(closure)
if iters % 10 == 0:
current_loss = closure()
print(iters, "%.4f" % current_loss.item())
history.append(To_image(dummy_data[0].cpu()))
```python
plt.figure(figsize=(12, 5))
for i in range(30):
plt.subplot(3, 10, i + 1)
plt.imshow(history[i])
plt.title("iter=%d" % (i * 10))
plt.axis('off')
plt.show()
```
2、注意点
- 这里选用的网络模型很重要,作者的网络模型能很好的还原图片,我自己的网络模型还原不出来,我还在找其中的问题。其中有个问题是要求能二次求导,激活函数很关键,原文用的是sigmoid,这个是可以二次求导的
- 这边的优化函数用LBFGS,用Amda不一定能还原图片
- 数据集我这边用的是一个,我用batchsize=200是还原不出图片的
- 虚拟的data也是一个因素,它会影响还原图片,如果生成的随机虚拟数据合适那么很快还图片,如果不合适,则Loss损失值无法下降还原