卷积神经网络LeNet-5的RTL实现(四)

卷积神经网络LeNet-5的RTL实现(四):池化

前文回顾

上一期文章利用Shift RAM和卷积模块实现了卷积层电路,本期文章将讲解池化(pooling)电路的实现。池化和卷积都是窗口运算,因此均可以利用Shift RAM进行窗口构造,窗口有效性判断也与卷积层类似。

池化简介

池化又称下采样,即对窗口数据取平均或最大值,达到扩大感受野、减少参数等目的。LeNet-5网络结构中的池化层为平均池化,本工程则采用最大池化,取窗口最大值,如下图所示:

卷积神经网络LeNet-5的RTL实现(四)_第1张图片

从图中可以看到池化运算与卷积运算窗口数据构造的相似之处,不同之处在于窗口内数据的具体运算。因此,池化层也使用Shift RAM进行窗口构造。

池化电路

接口

module MAXPOOLING
#(
	parameter DATA_WIDTH = 16,
	parameter INPUT_SIZE = 28,						//输入特征图尺寸
	parameter SCALE = 2								//池化规模
)
(
	input clk, 
	input rst_n,
	input clear,
	input ena,
	input [DATA_WIDTH*2-1:0]tap,					//池化Shift RAM输出
	output [DATA_WIDTH-1:0]pooling_out,				//池化结果输出
	output valid,									//池化结果有效性标志
	output done										//池化完成标志
);

数据输入

由于大部分池化层的窗口为2×2,这里没有做通用设计。除此之外,池化层的格式池化层的数据输入与卷积层相同,不再展开分析:

	//tap input
	assign m01 = tap[DATA_WIDTH*2-1:DATA_WIDTH];
	assign m11 = tap[DATA_WIDTH-1:0];
	
	//shift in
	always@(posedge clk or negedge rst_n)
		if(!rst_n) begin
			m00 <= 0;
			m10 <= 0;
		end
		else if(clear) begin
			m00 <= 0;
			m10 <= 0;			
		end
		else if(ena) begin
			m00 <= m01;
			m10 <= m11;
		end

池化运算逻辑

比较窗口内的4个数据,找到最大值并输出,就能实现窗口的最大池化,如下:

assign maxval = (m00 > m01) ? ((m00 > m10) ? ((m00 > m11) ? m00 : m11) : ((m10 > m11) ? m10 : m11)) : ((m01 > m10) ? ((m01 > m11) ? m01 : m11) : ((m10 > m11) ? m10 : m11));

计数信号

与卷积电路相同,这里使用了四个计数变量cntcnt_colstride_valid_colstride_valid_row。前两个用于输入像素位置统计,后两个用于stride间隔判断,代码如下:

	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt <= 0;
		else if(clear)
			cnt <= 0;
		else if(!ena)
			cnt <= cnt;
		else
			cnt <= cnt + 1;
			
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt_col <= 0;
		else if(clear)
			cnt_col <= 0;
		else if(!ena)
			cnt_col <= cnt_col;
		else if(shift) begin
			if(cnt_col == INPUT_SIZE - 1)
				cnt_col <= 0;
			else
				cnt_col <= cnt_col + 1;
		end
	
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			stride_valid_col <= 0;
		else if(clear)
			stride_valid_col <= 0;
		else if(ena) begin
			if(shift) begin
				if(stride_valid_col < SCALE - 1)
					stride_valid_col <= stride_valid_col + 1;
				else
					stride_valid_col = 0;
			end
		end

	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			stride_valid_row <= 0;
		else if(clear)
			stride_valid_row <= 0;
		else if(ena) begin
			if(cnt_col == INPUT_SIZE - 1) begin
				if(stride_valid_row < SCALE - 1)
					stride_valid_row <= stride_valid_row + 1;
				else
					stride_valid_row = 0;
			end
		end

池化数据有效性

与卷积电路一样,由于Shift RAM的无效输出,池化电路的输出也并非每个周期都有效。由于池化运算逻辑由纯组合电路构成,窗口数据的有效性即为池化电路输出的有效性。此外,还需注意池化窗口的stride为2,利用stride_valid_colstride_valid_row进行判断。由于池化窗口为2×2,stride为2,Shift RAM的数据不会出现错行情况,只需要判断当前窗口是否在有效移动周期和是否在步长间隔。有效性判断代码如下:

	assign valid = (shift & stride_valid_col == 0 & stride_valid_row == 0) ? 1 : 0;
	
	always @(posedge clk or negedge rst_n)
	    if(!rst_n)
	    	shift <= 0;
	    else if(clear)
	    	shift <= 0;
	    else if(!ena)
	    	shift <= 0;
	    else if(cnt == INPUT_SIZE * SCALE + (SCALE - 1) - 1)
	    	shift <= 1;
		else if(cnt == INPUT_SIZE * SCALE + INPUT_SIZE * (INPUT_SIZE - SCALE + 1) - 1)
	    	shift <= 0;

电路完整代码

module MAXPOOLING
#(
	parameter DATA_WIDTH = 16,
	parameter INPUT_SIZE = 28,
	parameter SCALE = 2
)
(
	input clk, 
	input rst_n,
	input clear,
	input ena,
	input [DATA_WIDTH*2-1:0]tap,
	output [DATA_WIDTH-1:0]pooling_out,
	output valid,
	output done
);
	
	function integer clogb2 (input integer bit_depth);
	begin
		for(clogb2=0; bit_depth>0; clogb2=clogb2+1)
		bit_depth = bit_depth >> 1;
	end
	endfunction
	
	localparam CNT_BIT_NUM = clogb2((INPUT_SIZE * (INPUT_SIZE + 1)));
	localparam CNT_LINE_BIT_NUM = clogb2(INPUT_SIZE);
	
	wire [DATA_WIDTH-1:0]m01,m11;
	reg [DATA_WIDTH-1:0]m00, m10;
	wire [DATA_WIDTH-1:0]maxval;
	reg shift;
	reg[CNT_BIT_NUM-1:0] cnt;
	reg[CNT_LINE_BIT_NUM-1:0] cnt_col;
	reg[CNT_LINE_BIT_NUM-1:0] stride_valid_col;
	reg[CNT_LINE_BIT_NUM-1:0] stride_valid_row;
	
	assign m01 = tap[DATA_WIDTH*2-1:DATA_WIDTH];
	assign m11 = tap[DATA_WIDTH-1:0];
	
	//find max value in 2×2 window
	assign maxval = (m00 > m01) ? ((m00 > m10) ? ((m00 > m11) ? m00 : m11) : ((m10 > m11) ? m10 : m11)) : ((m01 > m10) ? ((m01 > m11) ? m01 : m11) : ((m10 > m11) ? m10 : m11));
	assign valid = (shift & stride_valid_col == 0 & stride_valid_row == 0) ? 1 : 0;
	assign done = (cnt == INPUT_SIZE * 2 + INPUT_SIZE * (INPUT_SIZE - SCALE + 1) - 1) ? 1 : 0;
	assign pooling_out = (valid) ? ((maxval > 0) ? maxval : 0) : 0;
	
	always@(posedge clk or negedge rst_n)
		if(!rst_n) begin
			m00 <= 0;
			m10 <= 0;
		end
		else if(clear) begin
			m00 <= 0;
			m10 <= 0;			
		end
		else if(ena) begin
			m00 <= m01;
			m10 <= m11;
		end
	
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt <= 0;
		else if(clear)
			cnt <= 0;
		else if(!ena)
			cnt <= cnt;
		else
			cnt <= cnt + 1;
			
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt_col <= 0;
		else if(clear)
			cnt_col <= 0;
		else if(!ena)
			cnt_col <= cnt_col;
		else if(shift) begin
			if(cnt_col == INPUT_SIZE - 1)
				cnt_col <= 0;
			else
				cnt_col <= cnt_col + 1;
		end
			
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			stride_valid_col <= 0;
		else if(clear)
			stride_valid_col <= 0;
		else if(ena) begin
			if(shift) begin
				if(stride_valid_col < SCALE - 1)
					stride_valid_col <= stride_valid_col + 1;
				else
					stride_valid_col = 0;
			end
		end
		
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			stride_valid_row <= 0;
		else if(clear)
			stride_valid_row <= 0;
		else if(ena) begin
			if(cnt_col == INPUT_SIZE - 1) begin
				if(stride_valid_row < SCALE - 1)
					stride_valid_row <= stride_valid_row + 1;
				else
					stride_valid_row = 0;
			end
		end
	
	always @(posedge clk or negedge rst_n)
	    if(!rst_n)
	    	shift <= 0;
	    else if(clear)
	    	shift <= 0;
	    else if(!ena)
	    	shift <= 0;
	    else if(cnt == INPUT_SIZE * SCALE + (SCALE - 1) - 1)
	    	shift <= 1;
		else if(cnt == INPUT_SIZE * SCALE + INPUT_SIZE * (INPUT_SIZE - SCALE + 1) - 1)
	    	shift <= 0;
	
endmodule

TestBench

`timescale 1ns / 1ns

module tb_maxpooling();

	parameter 	DATA_WIDTH = 16,
				FMAP_SIZE = 28,
				SCALE = 2;
   	
   	reg clk, rst_n;
   	reg ena;
   	reg[DATA_WIDTH-1:0] shift_in;
   	wire[DATA_WIDTH-1:0] shift_out;
   	wire[SCALE*DATA_WIDTH-1:0] tap;
   	wire[DATA_WIDTH-1:0] pooling_out;
   	wire pooling_done;
   	wire pooling_valid;
   	
	initial begin
		clk = 0;
		rst_n = 0;
		ena = 0;
		shift_in = 0;

		#100;
		rst_n = 1;
		ena = 1;
		repeat(FMAP_SIZE * FMAP_SIZE) begin
			#100
			clk = ~clk;
			#100
			clk = ~clk;			
			shift_in = shift_in + 1;
		end
		
		forever begin
			#100
			clk = ~clk;
			#100
			clk = ~clk;			
		end	
	end
	
	SHIFT_RAM
	# (	
		.DATA_WIDTH(DATA_WIDTH),
		.TAP_NUM(SCALE),
	   	.TAP_LENGTH(FMAP_SIZE)
	)
	u_shift_ram_1(
		.clk(clk),
		.rst_n(rst_n),
		.clear(),
		.ena(ena),
		.shift_in(shift_in),
		.shift_out(shift_out),
		.taps(tap)
	);
	
	MAXPOOLING
	#(
		.DATA_WIDTH(DATA_WIDTH),
		.FMAP_SIZE(FMAP_SIZE),
		.KERNEL_SIZE(SCALE),
		.STRIDE(1)
	)
	u_pooling
	(
		.clk(clk), 
		.rst_n(rst_n),
		.ena(ena),
		.clear(),
		.tap(tap),
		.pooling_out(pooling_out),
		.valid(pooling_valid),
		.done(pooling_done)
	);
	
endmodule

仿真结果

卷积神经网络LeNet-5的RTL实现(四)_第2张图片
卷积神经网络LeNet-5的RTL实现(四)_第3张图片

从波形图可以看到,池化电路的第一个有效输出为29,最后一个有效输出为783,分别对应第一个和最后一个窗口的数据最大值,在最后一个有效输出到来时将done信号拉高一个周期。

总结

本期文章介绍了池化电路的实现。池化电路与卷积电路对数据窗口的处理类似,具体处理细节可参考上期文章。下期文章将介绍全连接电路的实现。

你可能感兴趣的:(硬件加速器和SoC设计)