利用Pytorch搭建自己的数据预测模型

前言

在Pytorch环境下搭建多层神经感知机,实现对数据的预测。本文提供的数据为两组RGB值,一组是纯色图像的RGB。另一组是在特定场景下拍摄的纯色图像的RGB数值。因为在特定的场景下,所以RGB值会被改变,现在要做的是如何利用网络,模拟“特定场景”。输入一组RGB值,让网络能够准确的预测同样场景下RGB值的改变。


一、多层神经感知机是什么?

多层感知机(MLP,Multilayer Perceptron)也叫人工神经网络(ANN,Artificial Neural Network),是由输入、输出和一个或多个隐藏层组成的简单神经网络。

二、网络搭建

1.数据准备

本文所用数据如下所示:

25 119 89 53 61 30
103 89 138 62 51 54
89 190 177 71 74 52
142 134 89 52 64 64
91 163 102 57 70 54
112 89 109 58 51 58
221 201 89 52 77 79
89 117 136 61 58 50
89 89 89 54 54 54
78 124 144 64 60 49
89 151 93 54 67 53
108 205 0 1 75 57
237 89 225 83 55 85
95 139 89 52 63 53
89 186 167 71 73 53
169 199 89 52 75 70
212 190 177 73 76 80
132 154 141 65 69 64
55 99 126 61 55 42
123 165 89 54 71 62
186 178 89 52 71 74
116 185 89 53 73 60
0 172 178 72 73 1
112 102 89 51 55 56
49 84 111 55 51 40

以下五组用于预测。

160 161 89 55 71 70
31 131 22 28 62 32
89 91 126 60 52 52
89 108 90 51 57 52
122 89 171 71 52 61

前三列为纯色图像的RGB值,后三列就是与之对应在特定场景下的RGB数值。直接将其复制到txt文件中即可。

2.数据处理

数据处理代码如下:

import torch
import numpy as np
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,data_path):
        self.data_path = data_path
        self.load_data(data_path)
    def load_data(self, data_path):
        # load data
        data = np.loadtxt(data_path)
        self.size = data.shape[0]
        self.X = data[:, :3]#获取前三列数据
        self.Y = data[:, 3:]#获取后三列数据
    def __len__(self):
        return self.size
    def __iter__(self):
        return self
    def __getitem__(self, index):
        # 获得原始输入
        return self.X[index, :], self.Y[index]

以上为处理原始数据的代码。

3.搭建MLP模型

MLP代码如下:

import torch.nn as nn
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(3, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(inplace=True),
            nn.Linear(16, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
        )
        self.decoder = nn.Sequential(
            nn.Linear(32, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(inplace=True),
            nn.Linear(16, 3),
        )
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

nn.Linear()是用于设置网络中的全连接层。其表达式为Y=AX+b。输入的数据为[25,3],经过(3,16)的线性变换之后输出为:[25,16]。 BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。Relu会使一部分神经元的输出为0,这样就造成了网络的稀疏性,并且减少了参数的相互依存关系,缓解了过拟合问题的发生。

3.搭建训练网络

代码如下:

from dataset import MyDataset
import torch
import torch.nn.functional as F
import torch.optim as optim
from main import MLP
N_EPOCHS = 1000  # 训练的 epoch 次数
BATCH_SIZE = 20  # 训练的 batch size
train_set = MyDataset('./RGB.txt')#将RGB数据加载到数据处理的包中
train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=BATCH_SIZE,
        shuffle=True)
model = MLP()
optimizer = optim.Adam(model.parameters(), lr=0.01)#Adam自适应学习率来加快收敛速度。
for epoch in range(N_EPOCHS):
    for b_index, (x, y) in enumerate(train_loader):
        x = x.view(x.size()[0], -1)
        decoded = model(x.float())#decoded为经过MLP模型之后的输出结果。
        mse_loss = F.smooth_l1_loss(decoded, y.float()) #定义损失函数
        optimizer.zero_grad()
        mse_loss.backward()
        optimizer.step()
    print("Epoch: [%3d], Loss: %.4f" %(epoch + 1, mse_loss.data))
    print('Saving state, iter:', str(epoch + 1)), torch.save(model.state_dict(), 'logs/Epoch%d.pth' % ((epoch + 1)))#将权重保存到logs文件中

3.数据预测

预测代码如下:

import numpy as np
import torch
import torch.nn as nn
from main import MLP
import cv2
from PIL import Image
import numpy as np

model_path = "logs/Epoch1000.pth"
model = MLP()
model.eval()
state_dict = torch.load(model_path)
model.load_state_dict(state_dict,strict=False)
model = nn.DataParallel(model)
rgb = input('输入rgb:')
rgb = np.fromstring(rgb, dtype=int, sep=' ')
rgb_tensor = torch.from_numpy(rgb).float()
out = model(rgb_tensor.unsqueeze(0))
print(out)

训练完成后,选取权重文件,输入数据进行预测。

输入rgb:221 201 89
tensor([[53.4057, 75.9493, 78.0163]]

RGB为:221,201,89在特定场景下的真实值52,77,79,网络预测值为53.4057, 75.9493, 78.0163。


总结

代码、数据、权重文件已经放在:https://github.com/0neDawn/RGB-prediction-based-on-MLP-,需要的可以自取。有什么不明白的地方,可以留言,私信。有什么说的不对或者有误的地方希望您指出,我也会及时改正!

你可能感兴趣的:(pytorch,深度学习,python)