具体步骤参考基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%这篇文章,各种步骤我写的很详细,只需要将MNIST数据集换成垃圾分类的数据集,再调整一下参数就好了。
权重和偏置
我们需要提取每一个具有学习的参数的训练层的权重和偏置,我使用了两层卷积和两层全连接,就要提取两个卷积层的权重和偏置,两个全连接层的权重和偏置。
# Extraction_Parameter.py
#引入库
#引用需要用到的库
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
#model
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
# Convolution layer 1
self.conv1 = nn.Conv2d(in_channels = 3 , out_channels = 8, kernel_size = 3, stride = 1, padding = 0 )
self.relu1 = nn.ReLU()
self.maxpool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
# Convolution layer 2
self.conv2 = nn.Conv2d(in_channels =8 , out_channels = 16, kernel_size = 3, stride = 1, padding = 0 )
self.relu2 = nn.ReLU()
self.maxpool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
# Fully-Connected layer 1
self.fc1 = nn.Linear(400,40)
# Fully-Connected layer 2
self.fc2 = nn.Linear(40,4)
def forward(self, x):
# conv layer 1 的前向计算,3行代码
out = self.conv1(x)
#print(out.shape)
out = self.relu1(out)
out = self.maxpool1(out)
#print(out.shape)
# conv layer 2 的前向计算,3行代码
out = self.conv2(out)
#print(out.shape)
out = self.relu2(out)
out = self.maxpool2(out)
#print(out.shape)
#Flatten拉平操作
out = out.view(out.size(0),-1)
#print(out.shape)
#FC layer的前向计算(2行代码)
out = self.fc1(out)
out = self.fc2(out)
return F.log_softmax(out,dim = 1)
#实例化模型
network = CNNModel()
#加载模型
model_path = "model1.pth"
network.load_state_dict(torch.load(model_path, map_location = torch.device('cpu')))
#network.eval()
parm = {}
for name,parameters in network.state_dict().items():
parm[name] = parameters.detach().numpy()
print(name, parameters)
w1 = parm['conv1.weight']
b1 = parm['conv1.bias']
w2 = parm['conv2.weight']
b2 = parm['conv2.bias']
fc1_w = parm['fc1.weight']
fc1_b = parm['fc1.bias']
fc2_w = parm['fc2.weight']
fc2_b = parm['fc2.bias']
#print(type(w1))
#print(len(w1[0]))
#print(len(w1[0][0]))
#print(len(w1[0][0][0]))
#conv1_wb
with open("parameters1_wb.h","a") as f:
print(type(w1))
#new_str1 = str(np.transpose(w1).tolist())
new_str1 = str(w1.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float conv1_weight[8][3][9] = {" + new_str3 + "};\n\n")
print("第一层卷积的权重保存成功")
f.close()
with open("parameters1_wb.h","a") as f:
print(type(b1))
#new_str1 = str(np.transpose(b1).tolist())
new_str1 = str(b1.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float conv1_bias[8] = {" + new_str3 + "};\n\n")
print("第一层卷积的偏置保存成功")
f.close()
#conv2_wb
with open("parameters1_wb.h","a") as f:
print(type(w2))
#new_str1 = str(np.transpose(w2).tolist())
new_str1 = str(w2.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float conv2_weight[16][8][9] = {" + new_str3 + "};\n\n")
print("第二层卷积的权重保存成功")
f.close()
with open("parameters1_wb.h","a") as f:
print(type(b2))
#new_str1 = str(np.transpose(b2).tolist())
new_str1 = str(b2.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float conv2_bias[16] = {" + new_str3 + "};\n\n")
print("第二层卷积的偏置保存成功")
f.close()
#fc1_wb
with open("parameters1_wb.h","a") as f:
print(type(fc1_w))
new_str1 = str(np.transpose(fc1_w).tolist())
#new_str1 = str(fc1_w.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float fc1_weight[" + str(400*40) + "] = {" + new_str3 + "};\n\n")
print("第一层全连接的权重保存成功")
f.close()
with open("parameters1_wb.h","a") as f:
print(type(fc1_b))
#new_str1 = str(np.transpose(fc1_b).tolist())
new_str1 = str(fc1_b.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float fc1_bias[40] = {" + new_str3 + "};\n\n")
print("第一层全连接的偏置保存成功")
f.close()
#fc2_wb
with open("parameters1_wb.h","a") as f:
print(type(fc2_w))
new_str1 = str(np.transpose(fc2_w).tolist())
#new_str1 = str(fc2_w.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float fc2_weight[" + str(40*4) + "] = {" + new_str3 + "};\n\n")
print("第二层全连接的权重保存成功")
f.close()
with open("parameters1_wb.h","a") as f:
print(type(fc2_b))
#new_str1 = str(np.transpose(fc2_b).tolist())
new_str1 = str(fc2_b.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float fc2_bias[4] = {" + new_str3 + "};\n\n")
print("第二层全连接的偏置保存成功")
f.close()
提取成功后会得到一个parameters1_wb.h文件,如图所示
将测试的图片同样提取为.h文件
# Extract_Image.py
from torchvision import transforms
import torch
import numpy as np
from PIL import Image
from itertools import chain
# 在训练模型时对图片进行怎样的预处理
# 在提取图片参数时需要先进行同样的处理再提取,不然维度数据对不上
data_transform = transforms.Compose(
[transforms.ToTensor()
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = Image.open("./test/Others1.jpg")#预测图片
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
img = img.numpy()
#img = img * (1.0/255.0)
#img = img.tolist()
print(img)
#参数提取
with open("pic1.h","a") as f:
#new_str1 = str(np.transpose(img).tolist())
new_str1 = str(img.tolist())
new_str2 = new_str1.replace('[','')
new_str3 = new_str2.replace(']','')
f.write("float Others2" + "[3][28][28] = {" + new_str3 + "};\n\n")
print("图片Others1参数读取完成")
f.close()
#include
#include "parameters1_wb.h"
#include "pic1.h"
#define CONV_KERNEL_SIZE 3
#define POLL_KERNEL_SIZE 2
#define POLL_STRIDE 2
#define IMG_SIZE 28
#define CONV1_IN_KERNEL 3
#define CONV1_OUT_SIZZE 26
#define CONV1_OUT_KERNEL 8
#define POLL1_OUT_SIZE 13
#define CONV2_OUT_KERNEL 16
#define CONV2_OUT_SIZE 11
#define POLL2_OUT_SIZE 5
#define FC_X 400 //16*5*5
#define FC1_OUT 40
#define FC1_B 40
#define FC2_OUT 4
#define FC2_B 4
int cnn_predict(float img[CONV1_IN_KERNEL][IMG_SIZE][IMG_SIZE],
float conv1_w[CONV1_OUT_KERNEL][CONV1_IN_KERNEL][CONV_KERNEL_SIZE * CONV_KERNEL_SIZE],
float conv1_b[CONV1_OUT_KERNEL],
float conv2_w[CONV2_OUT_KERNEL][CONV1_OUT_KERNEL][CONV_KERNEL_SIZE * CONV_KERNEL_SIZE],
float conv2_b[CONV2_OUT_KERNEL],
float fc1_w[FC_X * FC1_OUT],
float fc1_b[FC1_B],
float fc2_w[FC1_OUT * FC2_OUT],
float fc2_b[FC2_B])
{
//---------------------------第一层卷积---------------------------//
//in img size : 3*28*28
//out img size : 8*26*26
printf("\n------------------------------------Conv1_out------------------------------------\n");
int conv1_row, conv1_col, conv1_out_kernel, conv1_in_kernel, conv1_i, conv1_j;
float temp;
float conv1_out[CONV1_OUT_KERNEL][CONV1_OUT_SIZZE][CONV1_OUT_SIZZE] = {0.0};
for(conv1_out_kernel = 0; conv1_out_kernel < CONV1_OUT_KERNEL; conv1_out_kernel++)
{
//行卷积
for(conv1_row = 0; conv1_row < IMG_SIZE - CONV_KERNEL_SIZE + 1; conv1_row++)
{
//列卷积
for(conv1_col = 0; conv1_col < IMG_SIZE - CONV_KERNEL_SIZE + 1; conv1_col++)
{
temp = 0.0;
//多通道
for(conv1_in_kernel = 0; conv1_in_kernel < CONV1_IN_KERNEL; conv1_in_kernel++)
{
//单点卷积计算
//temp = 0,.0;
for(conv1_i = 0; conv1_i < CONV_KERNEL_SIZE; conv1_i++)
{
for(conv1_j = 0; conv1_j < CONV_KERNEL_SIZE; conv1_j++)
{
float a = img[conv1_in_kernel][conv1_i + conv1_row][conv1_j + conv1_col];
float b = conv1_w[conv1_out_kernel][conv1_in_kernel][conv1_i * CONV_KERNEL_SIZE + conv1_j];
temp += a * b;
}
}
}
temp += conv1_b[conv1_out_kernel];//加偏置
conv1_out[conv1_out_kernel][conv1_row][conv1_col] = temp > 0 ? temp : 0;//加激活
printf("%f ",conv1_out[conv1_out_kernel][conv1_row][conv1_col]);
if(conv1_col % 6 == 0)
{
printf("\n");
}
}
}
}
//---------------------------第一层池化---------------------------//
//in img size : 8*26*26
//out img size : 8*13*13
printf("\n------------------------------------Poll1_out------------------------------------\n");
int poll1_kernel, poll1_row, poll1_col, poll1_i, poll1_j;
float poll1_out[CONV1_OUT_KERNEL][POLL1_OUT_SIZE][POLL1_OUT_SIZE] = {0};
for(poll1_kernel = 0; poll1_kernel < CONV1_OUT_KERNEL; poll1_kernel++)
{
//行池化
for(poll1_row = 0; poll1_row < (CONV1_OUT_SIZZE - POLL_KERNEL_SIZE)/POLL_STRIDE + 1; poll1_row++)
{
//列池化
for(poll1_col = 0; poll1_col < (CONV1_OUT_SIZZE - POLL_KERNEL_SIZE)/POLL_STRIDE + 1; poll1_col++)
{
temp = 0.0;
//单点池化计算
for(poll1_i = 0; poll1_i < POLL_KERNEL_SIZE; poll1_i++)
{
for(poll1_j = 0; poll1_j < POLL_KERNEL_SIZE; poll1_j++)
{
temp = (conv1_out[poll1_kernel][poll1_i + poll1_row * POLL_STRIDE][poll1_j + poll1_col * POLL_STRIDE] > temp) ?
conv1_out[poll1_kernel][poll1_i + poll1_row * POLL_STRIDE][poll1_j + poll1_col * POLL_STRIDE] : temp;
}
}
poll1_out[poll1_kernel][poll1_row][poll1_col] = temp;
printf("%f ",poll1_out[poll1_kernel][poll1_row][poll1_col]);
if(poll1_col % 6 == 0)
{
printf("\n");
}
}
}
}
//---------------------------第二层卷积---------------------------//
//in img size : 8*13*13
//out img size : 16*11*11
printf("\n------------------------------------Conv2_out------------------------------------\n");
int conv2_row, conv2_col, conv2_out_kernel, conv2_in_kernel, conv2_i, conv2_j;
float conv2_out[CONV2_OUT_KERNEL][CONV2_OUT_SIZE][CONV2_OUT_SIZE] = {0.0};
for(conv2_out_kernel = 0; conv2_out_kernel < CONV2_OUT_KERNEL; conv2_out_kernel++)
{
//行卷积
for(conv2_row = 0; conv2_row < POLL1_OUT_SIZE - CONV_KERNEL_SIZE + 1; conv2_row++)
{
//列卷积
for(conv2_col = 0; conv2_col < POLL1_OUT_SIZE - CONV_KERNEL_SIZE + 1; conv2_col++)
{
temp = 0.0;
//多通道
for(conv2_in_kernel = 0; conv2_in_kernel < CONV1_OUT_KERNEL; conv2_in_kernel++)
{
//单点卷积计算
//temp = 0,.0;
for(conv2_i = 0; conv2_i < CONV_KERNEL_SIZE; conv2_i++)
{
for(conv2_j = 0; conv2_j < CONV_KERNEL_SIZE; conv2_j++)
{
float a = poll1_out[conv2_in_kernel][conv2_i + conv2_row][conv2_j + conv2_col];
float b = conv2_w[conv2_out_kernel][conv2_in_kernel][conv2_i * CONV_KERNEL_SIZE + conv2_j];
temp += a * b;
}
}
}
temp += conv2_b[conv2_out_kernel];//加偏置
conv2_out[conv2_out_kernel][conv2_row][conv2_col] = temp > 0 ? temp : 0;//加激活
printf("%f ",conv2_out[conv2_out_kernel][conv2_row][conv2_col]);
if(conv2_col % 6 == 0)
{
printf("\n");
}
}
}
}
//---------------------------第二层池化---------------------------//
//in img size : 16*11*11
//out img size : 16*5*5
printf("\n------------------------------------Poll2_out------------------------------------\n");
int poll2_kernel, poll2_row, poll2_col, poll2_i, poll2_j;
float poll2_out[CONV2_OUT_KERNEL][POLL2_OUT_SIZE][POLL2_OUT_SIZE] = {0};
for(poll2_kernel = 0; poll2_kernel < CONV2_OUT_KERNEL; poll2_kernel++)
{
//行池化
for(poll2_row = 0; poll2_row < (CONV2_OUT_SIZE - POLL_KERNEL_SIZE)/POLL_STRIDE + 1; poll2_row++)
{
//列池化
for(poll2_col = 0; poll2_col < (CONV2_OUT_SIZE - POLL_KERNEL_SIZE)/POLL_STRIDE + 1; poll2_col++)
{
temp = 0.0;
//单点池化计算
for(poll2_i = 0; poll2_i < POLL_KERNEL_SIZE; poll2_i++)
{
for(poll2_j = 0; poll2_j < POLL_KERNEL_SIZE; poll2_j++)
{
temp = (conv2_out[poll2_kernel][poll2_i + poll2_row * POLL_STRIDE][poll2_j + poll2_col * POLL_STRIDE] > temp) ?
conv2_out[poll2_kernel][poll2_i + poll2_row * POLL_STRIDE][poll2_j + poll2_col * POLL_STRIDE] : temp;
}
}
poll2_out[poll2_kernel][poll2_row][poll2_col] = temp;
printf("%f ",poll2_out[poll2_kernel][poll2_row][poll2_col]);
if(poll2_col % 6 == 0)
{
printf("\n");
}
}
}
}
//---------------------------多维数组转一维---------------------------//
//in img size : 16*5*5 3维
//out img size : 400 1维
printf("\n------------------------------------N to one------------------------------------\n");
float out[FC_X] = {0.0};
int i, j, k;
for(k = 0; k < CONV2_OUT_KERNEL; k++)
{
for(i = 0; i < POLL2_OUT_SIZE; i++)
{
for(j = 0; j < POLL2_OUT_SIZE; j++)
{
//这个公式很重要,有时候由于硬件问题使用多维数组进行
//运算会造成内存溢出,程序无法运行,这时就需要将所有的
//数据都转换成一维数组进行运算,就需要用到这个公式
//(通道数 - 1) * 行 * 列 + (行 - 1) * 行 + 列
//16*5*5 = 400 = 15*5*5 + 4*5 + 5
out[k * POLL2_OUT_SIZE * POLL2_OUT_SIZE + i * POLL2_OUT_SIZE + j] = poll2_out[k][i][j];
printf("%f ",out[k * POLL2_OUT_SIZE * POLL2_OUT_SIZE + i * POLL2_OUT_SIZE + j]);
}
}
}
//---------------------------第一层全连接---------------------------//
//in img size : 400
//out img size : 40
printf("\n------------------------------------ FC1_OUT ------------------------------------\n");
int fc1_i, fc1_j;
float fc1_out[FC1_OUT] = {0.0};
for(fc1_i = 0; fc1_i < FC1_OUT; fc1_i++)
{
temp = 0.0;
for(fc1_j = 0; fc1_j < FC_X; fc1_j++)
{
temp += fc1_w[fc1_j * FC1_OUT + fc1_i] * out[fc1_j];
}
//加偏置
temp += fc1_b[fc1_i];
fc1_out[fc1_i] = temp;
printf(" %f ",fc1_out[fc1_i]);
if(fc1_i % 8 == 0)
{
printf("\n");
}
}
//---------------------------第二层全连接---------------------------//
//in img size : 40
//out img size : 4
printf("\n------------------------------------ FC2_OUT ------------------------------------\n");
int fc2_i, fc2_j;
float fc2_out[FC2_OUT] = {0.0};
for(fc2_i = 0; fc2_i < FC2_OUT; fc2_i++)
{
temp = 0.0;
for(fc2_j = 0; fc2_j < FC1_OUT; fc2_j++)
{
temp += fc2_w[fc2_j * FC2_OUT + fc2_i] * fc1_out[fc2_j];
}
//加偏置
temp += fc2_b[fc2_i];
fc2_out[fc2_i] = temp;
printf(" %f ",fc2_out[fc2_i]);
if(fc2_i % 8 == 0)
{
printf("\n");
}
}
//---------------------------找出概率最大值的索引---------------------------//
temp = 0.0;
int ret;
for(i = 0; i < FC2_OUT; i++ )
{
if(fc2_out[i] > temp)
{
temp = fc2_out[i];
ret = i;
}
}
//0: Hazardous;
//1: Kitchen;
//2: Others;
//3: Recycled;
return ret;
}
int cnn_test()
{
int ret= 0;
ret = cnn_predict(Others1,conv1_weight,conv1_bias,conv2_weight,conv2_bias,
fc1_weight,fc1_bias,fc2_weight,fc2_bias);
char class[][10] = {"Hazardous","Kitchen","Others","Recycled"};
printf("\n input Others predict is: %s\n",class[ret]);
return 0;
}
int main()
{
cnn_test();
return 0;
}
由于FPGA内存限制,全部使用一维数据进行运算。
#include "HLS/hls.h"
#include "HLS/stdio.h"
#include "parameters_wb.h"
#include "pic.h"
#define CONV_KERNEL_SIZE 3
#define POLL_KERNEL_SIZE 2
#define POLL_STRIDE 2
#define IMG_SIZE 28
#define CONV1_IN_KERNEL 3
#define CONV1_OUT_SIZE 26
#define CONV1_OUT_KERNEL 8
#define POLL1_OUT_SIZE 13
#define CONV2_OUT_KERNEL 16
#define CONV2_OUT_SIZE 11
#define POLL2_OUT_SIZE 5
#define FC_X 400 //16*5*5
#define FC1_OUT 40
#define FC1_B 40
#define FC2_OUT 4
#define FC2_B 4
hls_avalon_slave_component
component int one_dim_rubbish(
hls_avalon_slave_memory_argument(3*28*28*sizeof(float)) float *in_img,
hls_avalon_slave_memory_argument(8*3*3*3*sizeof(float)) float *conv1_w,
hls_avalon_slave_memory_argument(8*sizeof(float)) float *conv1_b,
hls_avalon_slave_memory_argument(16*8*3*3*sizeof(float)) float *conv2_w,
hls_avalon_slave_memory_argument(16*sizeof(float)) float *conv2_b,
hls_avalon_slave_memory_argument(16*5*5*40*sizeof(float)) float *fc1_w,
hls_avalon_slave_memory_argument(40*sizeof(float)) float *fc1_b,
hls_avalon_slave_memory_argument(40*4*sizeof(float)) float *fc2_w,
hls_avalon_slave_memory_argument(4*sizeof(float)) float *fc2_b
)
{
float out1[CONV1_IN_KERNEL * IMG_SIZE * IMG_SIZE];
float out2[CONV1_OUT_KERNEL * CONV1_OUT_SIZE * CONV1_OUT_SIZE];
//---------------------------第一层卷积---------------------------//
//in img size : 3*28*28
//out img size : 8*26*26
//printf("\n------------------------------------Conv1_out------------------------------------\n");
int conv1_row, conv1_col, conv1_out_kernel, conv1_in_kernel, conv1_i, conv1_j;
float temp;
for(conv1_out_kernel = 0; conv1_out_kernel < CONV1_OUT_KERNEL; conv1_out_kernel++)
{
//行卷积
for(conv1_row = 0; conv1_row < IMG_SIZE - CONV_KERNEL_SIZE + 1; conv1_row++)
{
//列卷积
for(conv1_col = 0; conv1_col < IMG_SIZE - CONV_KERNEL_SIZE + 1; conv1_col++)
{
temp = 0.0;
//多通道
for(conv1_in_kernel = 0; conv1_in_kernel < CONV1_IN_KERNEL; conv1_in_kernel++)
{
//单点卷积计算
//temp = 0,.0;
for(conv1_i = 0; conv1_i < CONV_KERNEL_SIZE; conv1_i++)
{
for(conv1_j = 0; conv1_j < CONV_KERNEL_SIZE; conv1_j++)
{
//通道数 * 行 * 列 + (行 - 1) * 行 + 列
float a = in_img[conv1_in_kernel * IMG_SIZE * IMG_SIZE +
(conv1_i + conv1_row) * IMG_SIZE +
conv1_j + conv1_col];
float b = conv1_w[conv1_out_kernel * CONV1_IN_KERNEL * CONV_KERNEL_SIZE * CONV_KERNEL_SIZE +
conv1_in_kernel * CONV_KERNEL_SIZE * CONV_KERNEL_SIZE +
conv1_i * CONV_KERNEL_SIZE +
conv1_j];
temp += a * b;
}
}
}
temp += conv1_b[conv1_out_kernel];//加偏置
out2[conv1_out_kernel * CONV1_OUT_SIZE * CONV1_OUT_SIZE +
conv1_row * CONV1_OUT_SIZE +
conv1_col] = temp > 0 ? temp : 0;
}
}
}
//---------------------------第一层池化---------------------------//
//in img size : 8*26*26
//out img size : 8*13*13
//printf("\n------------------------------------Poll1_out------------------------------------\n");
int poll1_kernel, poll1_row, poll1_col, poll1_i, poll1_j;
for(poll1_kernel = 0; poll1_kernel < CONV1_OUT_KERNEL; poll1_kernel++)
{
//行池化
for(poll1_row = 0; poll1_row < (CONV1_OUT_SIZE - POLL_KERNEL_SIZE)/POLL_STRIDE + 1; poll1_row++)
{
//列池化
for(poll1_col = 0; poll1_col < (CONV1_OUT_SIZE - POLL_KERNEL_SIZE)/POLL_STRIDE + 1; poll1_col++)
{
temp = 0.0;
//单点池化计算
for(poll1_i = 0; poll1_i < POLL_KERNEL_SIZE; poll1_i++)
{
for(poll1_j = 0; poll1_j < POLL_KERNEL_SIZE; poll1_j++)
{
temp = (out2[poll1_kernel * CONV1_OUT_SIZE * CONV1_OUT_SIZE +
(poll1_i + poll1_row * POLL_STRIDE) * CONV1_OUT_SIZE +
poll1_j + poll1_col * POLL_STRIDE] > temp) ?
out2[poll1_kernel * CONV1_OUT_SIZE * CONV1_OUT_SIZE +
(poll1_i + poll1_row * POLL_STRIDE) * CONV1_OUT_SIZE +
poll1_j + poll1_col * POLL_STRIDE] : temp;
}
}
out1[poll1_kernel * POLL1_OUT_SIZE * POLL1_OUT_SIZE +
poll1_row * POLL1_OUT_SIZE +
poll1_col] = temp;
}
}
}
//---------------------------第二层卷积---------------------------//
//in img size : 8*13*13
//out img size : 16*11*11
int i;
for(i = 0; i < CONV1_OUT_KERNEL * CONV1_OUT_SIZE * CONV1_OUT_SIZE; i++)
{
out2[i] = 0;
}
//printf("\n------------------------------------Conv2_out------------------------------------\n");
int conv2_row, conv2_col, conv2_out_kernel, conv2_in_kernel, conv2_i, conv2_j;
for(conv2_out_kernel = 0; conv2_out_kernel < CONV2_OUT_KERNEL; conv2_out_kernel++)
{
//行卷积
for(conv2_row = 0; conv2_row < POLL1_OUT_SIZE - CONV_KERNEL_SIZE + 1; conv2_row++)
{
//列卷积
for(conv2_col = 0; conv2_col < POLL1_OUT_SIZE - CONV_KERNEL_SIZE + 1; conv2_col++)
{
temp = 0.0;
//多通道
for(conv2_in_kernel = 0; conv2_in_kernel < CONV1_OUT_KERNEL; conv2_in_kernel++)
{
//单点卷积计算
//temp = 0,.0;
for(conv2_i = 0; conv2_i < CONV_KERNEL_SIZE; conv2_i++)
{
for(conv2_j = 0; conv2_j < CONV_KERNEL_SIZE; conv2_j++)
{
float a = out1[conv2_in_kernel * POLL1_OUT_SIZE * POLL1_OUT_SIZE +
(conv2_i + conv2_row) * POLL1_OUT_SIZE +
conv2_j + conv2_col];
float b = conv2_w[conv2_out_kernel * CONV1_OUT_KERNEL * CONV_KERNEL_SIZE * CONV_KERNEL_SIZE +
conv2_in_kernel * CONV_KERNEL_SIZE * CONV_KERNEL_SIZE +
conv2_i * CONV_KERNEL_SIZE +
conv2_j];
temp += a * b;
}
}
}
temp += conv2_b[conv2_out_kernel];//加偏置
out2[conv2_out_kernel * CONV2_OUT_SIZE * CONV2_OUT_SIZE +
conv2_row * CONV2_OUT_SIZE +
conv2_col] = temp > 0 ? temp : 0;
}
}
}
//---------------------------第二层池化---------------------------//
//in img size : 16*11*11
//out img size : 16*5*5
for(i = 0; i < CONV1_IN_KERNEL * IMG_SIZE * IMG_SIZE; i++)
{
out1[i] = 0;
}
//printf("\n------------------------------------Poll2_out------------------------------------\n");
int poll2_kernel, poll2_row, poll2_col, poll2_i, poll2_j;
for(poll2_kernel = 0; poll2_kernel < CONV2_OUT_KERNEL; poll2_kernel++)
{
//行池化
for(poll2_row = 0; poll2_row < (CONV2_OUT_SIZE - POLL_KERNEL_SIZE)/POLL_STRIDE + 1; poll2_row++)
{
//列池化
for(poll2_col = 0; poll2_col < (CONV2_OUT_SIZE - POLL_KERNEL_SIZE)/POLL_STRIDE + 1; poll2_col++)
{
temp = 0.0;
//单点池化计算
for(poll2_i = 0; poll2_i < POLL_KERNEL_SIZE; poll2_i++)
{
for(poll2_j = 0; poll2_j < POLL_KERNEL_SIZE; poll2_j++)
{
temp = (out2[poll2_kernel * CONV2_OUT_SIZE * CONV2_OUT_SIZE +
(poll2_i + poll2_row * POLL_STRIDE) * CONV2_OUT_SIZE +
poll2_j + poll2_col * POLL_STRIDE] > temp) ?
out2[poll2_kernel * CONV2_OUT_SIZE * CONV2_OUT_SIZE +
(poll2_i + poll2_row * POLL_STRIDE) * CONV2_OUT_SIZE +
poll2_j + poll2_col * POLL_STRIDE] : temp;
}
}
out1[poll2_kernel * POLL2_OUT_SIZE * POLL2_OUT_SIZE +
poll2_row * POLL2_OUT_SIZE +
poll2_col] = temp;
}
}
}
//---------------------------第一层全连接---------------------------//
//in img size : 400
//out img size : 40
for(i = 0; i < CONV1_OUT_KERNEL * CONV1_OUT_SIZE * CONV1_OUT_SIZE; i++)
{
out2[i] = 0;
}
//printf("\n------------------------------------ FC1_OUT ------------------------------------\n");
int fc1_i, fc1_j;
for(fc1_i = 0; fc1_i < FC1_OUT; fc1_i++)
{
temp = 0.0;
for(fc1_j = 0; fc1_j < FC_X; fc1_j++)
{
temp += fc1_w[fc1_j * FC1_OUT + fc1_i] * out1[fc1_j];
}
//加偏置
temp += fc1_b[fc1_i];
out2[fc1_i] = temp;
}
//---------------------------第二层全连接---------------------------//
//in img size : 40
//out img size : 4
for(i = 0; i < CONV1_IN_KERNEL * IMG_SIZE * IMG_SIZE; i++)
{
out1[i] = 0;
}
//printf("\n------------------------------------ FC2_OUT ------------------------------------\n");
int fc2_i, fc2_j;
for(fc2_i = 0; fc2_i < FC2_OUT; fc2_i++)
{
temp = 0.0;
for(fc2_j = 0; fc2_j < FC1_OUT; fc2_j++)
{
temp += fc2_w[fc2_j * FC2_OUT + fc2_i] * out2[fc2_j];
}
//加偏置
temp += fc2_b[fc2_i];
out1[fc2_i] = temp;
}
//---------------------------找出概率最大值的索引---------------------------//
temp = 0.0;
int ret;
for(i = 0; i < FC2_OUT; i++ )
{
if(out1[i] > temp)
{
temp = out1[i];
ret = i;
}
}
//0: Hazardous;
//1: Kitchen;
//2: Others;
//3: Recycled;
return ret;
}
int main()
{
int ret;
#if 1
ret = one_dim_rubbish(Others1,conv1_weight,conv1_bias,conv2_weight,conv2_bias,fc1_weight,fc1_bias,fc2_weight,fc2_bias);
char const *input_img[] = {"Hazardous1","Hazardous2","Kitchen1","Kitchen2","Others1","Others2","Recycled1","Recycled2"};
char const *classes[] = {"Hazardous","Kitchen","Others","Recycled"};
printf("\n input %s \t predict is: %s\n",input_img[4],classes[ret]);
#else
float *imgx[] = {Hazardous1, Hazardous2, Kitchen1, Kitchen2,
Others1, Others2, Recycled1, Recycled2};
char const *input_img[] = {"Hazardous1","Hazardous2","Kitchen1","Kitchen2","Others1","Others2","Recycled1","Recycled2"};
for(int i = 0; i < 8; i++)
{
ret = conv_connect(imgx[i],conv1_weight,conv1_bias,conv2_weight,conv2_bias,fc1_weight,fc1_bias,fc2_weight,fc2_bias);
char const *classes[] = {"Hazardous","Kitchen","Others","Recycled"};
printf("\n input %s \t predict is: %s\n",input_img[i],classes[ret]);
}
#endif
return 0;
}
FPGA编译代码main.c
/*
* main.c
*
* Created on: 2021年10月21日
* Author: eye
*/
//gcc标准头文件
#include
#include
#include
#include
#include
#include
#include
//HPS厂家提供的底层定义头文件
#define soc_cv_av //开发端Cyclone V系列
#include "hwlib.h"
#include "socal/socal.h"
#include "socal/hps.h"
//与用户具体的HPS 应用系统相关的硬件描述头文件
#include "hps_0.h"
#include "conv.h"
//接口定义区
#define HW_REGS_BASE (ALT_STM_OFST) //HPS外设地址段基地址
#define HW_REGS_SPAN (0x04000000) //HPS外设地址段地址空间64MB大小
#define HW_REGS_MASK (HW_REGS_SPAN - 1) //HPS外设地址段地址掩码
//接口定义(结构体的方式)
typedef struct{
volatile float *img;
volatile float *c1_w;
volatile float *c1_b;
volatile float *c2_w;
volatile float *c2_b;
volatile float *f1_w;
volatile float *f1_b;
volatile float *f2_w;
volatile float *f2_b;
}fc_port_def;
fc_port_def my_fc_port;
typedef struct{
volatile long long busy;
volatile long long start;
volatile long long irq_en;
volatile long long done;
volatile long long result;
}fc_ctrl_def;
fc_ctrl_def *my_fc_ctrl;
const float *imgx[8] = {Hazardous1, Hazardous2, Kitchen1, Kitchen2,
Others1, Others2, Recycled1, Recycled2};
const char *input_img[] = {"Hazardous1","Hazardous2","Kitchen1","Kitchen2","Others1","Others2","Recycled1","Recycled2"};
const char *classes[] = {"Hazardous","Kitchen","Others","Recycled"};
int fc_init(void *virtual_base)
{
void *fc_ctrl_addr;
fc_ctrl_addr = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_CRA_BASE) & (unsigned long)(HW_REGS_MASK));
//接口映射
my_fc_ctrl = (fc_port_def *)fc_ctrl_addr;
my_fc_ctrl->start = 0x0;
my_fc_port.img = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_IN_IMG_BASE) & (unsigned long)(HW_REGS_MASK));
my_fc_port.c1_w = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_CONV1_W_BASE) & (unsigned long)(HW_REGS_MASK));
my_fc_port.c1_b = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_CONV1_B_BASE) & (unsigned long)(HW_REGS_MASK));
my_fc_port.c2_w = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_CONV2_W_BASE) & (unsigned long)(HW_REGS_MASK));
my_fc_port.c2_b = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_CONV2_B_BASE) & (unsigned long)(HW_REGS_MASK));
my_fc_port.f1_w = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_FC1_W_BASE) & (unsigned long)(HW_REGS_MASK));
my_fc_port.f1_b = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_FC1_B_BASE) & (unsigned long)(HW_REGS_MASK));
my_fc_port.f2_w = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_FC2_W_BASE) & (unsigned long)(HW_REGS_MASK));
my_fc_port.f2_b = virtual_base + ((unsigned long)(ALT_LWFPGASLVS_OFST + RUBBISH_0_ONE_DIM_RUBBISH_INTERNAL_INST_AVS_FC2_B_BASE) & (unsigned long)(HW_REGS_MASK));
//加载权重参数、偏置参数
memcpy(my_fc_port.c1_w,conv1_weight,8*3*3*3*sizeof(float));
memcpy(my_fc_port.c1_b,conv1_bias,8*sizeof(float));
memcpy(my_fc_port.c2_w,conv2_weight,16*8*3*3*sizeof(float));
memcpy(my_fc_port.c2_b,conv2_bias,16*sizeof(float));
memcpy(my_fc_port.f1_w,fc1_weight,400*40*sizeof(float));
memcpy(my_fc_port.f1_b,fc1_bias,40*sizeof(float));
memcpy(my_fc_port.f2_w,fc2_weight,40*4*sizeof(float));
memcpy(my_fc_port.f2_b,fc2_bias,4*sizeof(float));
return 0;
}
//主函数
int main()
{
int fd,ret;
int i;
void *virtual_base;
float time_s,time_ns,time_ms;
struct timespec ts1,ts2;
clock_t start,finish;
float win_runtime;
//1.打开mmu open()
fd = open("/dev/mem",(O_RDWR | O_SYNC));
if(fd == (-1))
{
printf("Error:could not open\"/dev/mem\"...\n");
return 1;
}
//2.虚拟地址映射 mmap()
virtual_base = mmap(NULL,HW_REGS_SPAN,(PROT_READ | PROT_WRITE),MAP_SHARED,fd,HW_REGS_BASE);
//3.定义初始化函数()
fc_init(virtual_base);
//4.操作阶段
while(1)
{
for(i = 0; i < 8; i++)
{
start = clock();//windows 运行时间
ret = conv(imgx[i],conv1_weight,conv1_bias,conv2_weight,conv2_bias,fc1_weight,fc1_bias,fc2_weight,fc2_bias);
finish = clock();
win_runtime = (float)(finish - start)*1000/CLOCKS_PER_SEC;
memcpy(my_fc_port.img,imgx[i],3*28*28*sizeof(float));
clock_gettime(CLOCK_MONOTONIC,&ts1); //执行推理开始的时间
my_fc_ctrl->start = 0x01;
while((my_fc_ctrl->done & 0x02) == 0 );
my_fc_ctrl->start = 0x0;
clock_gettime(CLOCK_MONOTONIC,&ts2); //执行推理结束的时间
time_ns = ts2.tv_nsec - ts1.tv_nsec;
time_s = ts2.tv_sec - ts1.tv_sec;
time_ms = time_ns / 1000000 + time_s*1000;
//0: Hazardous
//1: Kitchen
//2: Others
//3: Recycled
printf("\n windows: running time:%.6f \t FPGA: running time:%.6f \n ",win_runtime, time_ms);
printf("\n input is: %s \n",input_img[i]);
printf("\n Windows predict is: %s \t FPGA predict is: %s \n\n",classes[ret], classes[my_fc_ctrl->result]);
}
break;
}
//5.取消虚拟地址映射,munmap()
if(munmap(virtual_base, HW_REGS_SPAN) != 0)
{
printf("Error:munmap is failed...\n");
close(fd);
return 1;
}
//6.关闭 close()
close(fd);
return 0;
}