stn在mnist上的实现

stn在mnist上的实现

个人博客 - https://cxy-sky.github.io/

代码参考来源:PyTorch框架实战系列(3)——空间变换器网络STN_Daniel Yuz的博客-CSDN博客

理论:Pytorch中的仿射变换(affine_grid)_liangbaqiang的博客-CSDN博客

详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn

​ 图片显示用的是matplotlib,自己没下opencv.

CNN

import torch
from torch import nn, optim


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),
        )
        self.linear = nn.Sequential(
            nn.Dropout2d(0.5),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size()[0], -1)
        # print(x.size())
        x = self.linear(x)
        return x


if __name__ == '__main__':
    model = CNN()
    x = torch.rand(1, 1, 28, 28)
    print(model)
    y = model(x)
    print(y)

STN

import torch
from torch import nn


class STN(nn.Module):
    def __init__(self):
        super(STN, self).__init__()
        self.location_cov = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
        )

        self.localization_linear = nn.Sequential(
            nn.Linear(in_features=10 * 3 * 3, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=2 * 3)
        )

        self.localization_linear[2].weight.data.zero_()
        self.localization_linear[2].bias.data.copy_(torch.tensor([1, 0, 0,
                                                                  0, 1, 0], dtype=torch.float))

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),
        )
        self.linear = nn.Sequential(
            nn.Dropout2d(0.5),
            nn.Linear(512, 10)
        )

    def stn(self, x):
        x2 = self.location_cov(x)
        x2 = x2.view(x2.size()[0], -1)
        x2 = self.localization_linear(x2)
        theta = x2.view(x2.size()[0], 2, 3)
        grid = nn.functional.affine_grid(theta, x.size(), align_corners=True)
        x = nn.functional.grid_sample(x, grid, align_corners=True)
        return x

    def forward(self, x):
        x = self.stn(x)
        x = self.cnn(x)
        x = x.view(x.size()[0], -1)
        x = self.linear(x)
        return x


if __name__ == '__main__':
    x = torch.rand(1, 1, 28, 28)
    model = STN()
    print(model)
    print(model(x))

train

import numpy as np
import torch
from torchvision import transforms
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from PIL import Image
from torch import nn, optim

from stn.CNN import CNN
from stn.STN import STN

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 数据处理
transform = transforms.Compose([
    transforms.RandomRotation(45),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
]
)

train_data = torchvision.datasets.MNIST('../data/mnist',
                                        download=True,
                                        train=True,
                                        transform=transform
                                        )

test_data = torchvision.datasets.MNIST('../data/mnist',
                                       download=True,
                                       train=False,
                                       transform=transform, )

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=64,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=64,
                                          shuffle=True)

data_iter = iter(train_loader)
imgs = torchvision.utils.make_grid(next(data_iter)[0], 8)
imgs = imgs.numpy().transpose(1, 2, 0)
imgs = imgs * 0.5 + 0.5
plt.imshow(imgs)
plt.show()


# model = CNN()
model = STN()
model = model.to(device)
loss_fun = nn.CrossEntropyLoss().to(device)
opt_fun = optim.Adam(params=model.parameters(), lr=0.001)

loss = 0
train_acc_count = []
test_acc_count = []
train_loss = []
test_loss = []


def train(epoch):

    for i in range(epoch):
        for index, data in enumerate(train_loader):
            imgs = data[0].to(device)
            labels = data[1].to(device)
            outputs = model(imgs).to(device)
            loss = loss_fun(outputs, labels)
            loss.backward()
            opt_fun.step()
            opt_fun.zero_grad()
            if index % 100 == 0:
                print("第{}轮,第{}次,loss为:{}".format(i + 1, index, loss.item()))
                train_loss.append(loss.item())


def test():
    test_count = 0.
    for imgs, labels in test_loader:
        with torch.no_grad():
            outputs = model(imgs.to(device)).to(device)
            test_acc_count = (torch.max(outputs, dim=1)[1] == labels.to(device)).sum().item()
            test_count = labels.size()[0]
    print("测试集准确率{}".format(test_acc_count / test_count))


if __name__ == '__main__':
    # 设置随机数种子
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    # 保证每次结果一样
    torch.backends.cudnn.deterministic = True
    train(10)
    test()
    sava_path = '../model/mnistStn.pth'
    torch.save(model.state_dict(), sava_path)
    plt.plot(train_loss)
    plt.show()


showImage

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import torch
import matplotlib.pyplot as plt

from stn.STN import STN

transform = transforms.Compose([
    transforms.RandomRotation(45),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
]
)

train_data = torchvision.datasets.MNIST('../data/mnist',
                                        download=True,
                                        train=True,
                                        transform=transform
                                        )

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=64,
                                           shuffle=True)

data_iter = iter(train_loader)
imgs, labels = next(data_iter)
pre = torchvision.utils.make_grid(imgs, 8)
pre = pre.numpy().transpose(1, 2, 0)
pre = pre * 0.5 + 0.5
plt.subplot(2, 1, 1)
plt.imshow(pre)
plt.title('pre')

model = STN()
model.load_state_dict(torch.load('../model/mnistStn.pth'))
now = model.stn(imgs).detach()
now = torchvision.utils.make_grid(now, 8)
now = now.numpy().transpose(1, 2, 0)
now = now * 0.5 + 0.5
plt.subplot(2, 1, 2)
plt.imshow(now)
plt.title('now')

plt.show()

train,epoch=10

stn在mnist上的实现_第1张图片

​ 展示transom后的图片,还是感觉很神奇

stn在mnist上的实现_第2张图片

你可能感兴趣的:(python,深度学习,计算机视觉,pytorch,STN,CNN)