上一期文章利用Shift RAM和卷积模块实现了卷积层电路,本期文章将讲解池化(pooling)电路的实现。池化和卷积都是窗口运算,因此均可以利用Shift RAM进行窗口构造,窗口有效性判断也与卷积层类似。
池化又称下采样,即对窗口数据取平均或最大值,达到扩大感受野、减少参数等目的。LeNet-5网络结构中的池化层为平均池化,本工程则采用最大池化,取窗口最大值,如下图所示:
从图中可以看到池化运算与卷积运算窗口数据构造的相似之处,不同之处在于窗口内数据的具体运算。因此,池化层也使用Shift RAM进行窗口构造。
module MAXPOOLING
#(
parameter DATA_WIDTH = 16,
parameter INPUT_SIZE = 28, //输入特征图尺寸
parameter SCALE = 2 //池化规模
)
(
input clk,
input rst_n,
input clear,
input ena,
input [DATA_WIDTH*2-1:0]tap, //池化Shift RAM输出
output [DATA_WIDTH-1:0]pooling_out, //池化结果输出
output valid, //池化结果有效性标志
output done //池化完成标志
);
由于大部分池化层的窗口为2×2,这里没有做通用设计。除此之外,池化层的格式池化层的数据输入与卷积层相同,不再展开分析:
//tap input
assign m01 = tap[DATA_WIDTH*2-1:DATA_WIDTH];
assign m11 = tap[DATA_WIDTH-1:0];
//shift in
always@(posedge clk or negedge rst_n)
if(!rst_n) begin
m00 <= 0;
m10 <= 0;
end
else if(clear) begin
m00 <= 0;
m10 <= 0;
end
else if(ena) begin
m00 <= m01;
m10 <= m11;
end
比较窗口内的4个数据,找到最大值并输出,就能实现窗口的最大池化,如下:
assign maxval = (m00 > m01) ? ((m00 > m10) ? ((m00 > m11) ? m00 : m11) : ((m10 > m11) ? m10 : m11)) : ((m01 > m10) ? ((m01 > m11) ? m01 : m11) : ((m10 > m11) ? m10 : m11));
与卷积电路相同,这里使用了四个计数变量cnt、cnt_col、stride_valid_col、stride_valid_row。前两个用于输入像素位置统计,后两个用于stride间隔判断,代码如下:
always @(posedge clk or negedge rst_n)
if(!rst_n)
cnt <= 0;
else if(clear)
cnt <= 0;
else if(!ena)
cnt <= cnt;
else
cnt <= cnt + 1;
always @(posedge clk or negedge rst_n)
if(!rst_n)
cnt_col <= 0;
else if(clear)
cnt_col <= 0;
else if(!ena)
cnt_col <= cnt_col;
else if(shift) begin
if(cnt_col == INPUT_SIZE - 1)
cnt_col <= 0;
else
cnt_col <= cnt_col + 1;
end
always @(posedge clk or negedge rst_n)
if(!rst_n)
stride_valid_col <= 0;
else if(clear)
stride_valid_col <= 0;
else if(ena) begin
if(shift) begin
if(stride_valid_col < SCALE - 1)
stride_valid_col <= stride_valid_col + 1;
else
stride_valid_col = 0;
end
end
always @(posedge clk or negedge rst_n)
if(!rst_n)
stride_valid_row <= 0;
else if(clear)
stride_valid_row <= 0;
else if(ena) begin
if(cnt_col == INPUT_SIZE - 1) begin
if(stride_valid_row < SCALE - 1)
stride_valid_row <= stride_valid_row + 1;
else
stride_valid_row = 0;
end
end
与卷积电路一样,由于Shift RAM的无效输出,池化电路的输出也并非每个周期都有效。由于池化运算逻辑由纯组合电路构成,窗口数据的有效性即为池化电路输出的有效性。此外,还需注意池化窗口的stride为2,利用stride_valid_col和stride_valid_row进行判断。由于池化窗口为2×2,stride为2,Shift RAM的数据不会出现错行情况,只需要判断当前窗口是否在有效移动周期和是否在步长间隔。有效性判断代码如下:
assign valid = (shift & stride_valid_col == 0 & stride_valid_row == 0) ? 1 : 0;
always @(posedge clk or negedge rst_n)
if(!rst_n)
shift <= 0;
else if(clear)
shift <= 0;
else if(!ena)
shift <= 0;
else if(cnt == INPUT_SIZE * SCALE + (SCALE - 1) - 1)
shift <= 1;
else if(cnt == INPUT_SIZE * SCALE + INPUT_SIZE * (INPUT_SIZE - SCALE + 1) - 1)
shift <= 0;
module MAXPOOLING
#(
parameter DATA_WIDTH = 16,
parameter INPUT_SIZE = 28,
parameter SCALE = 2
)
(
input clk,
input rst_n,
input clear,
input ena,
input [DATA_WIDTH*2-1:0]tap,
output [DATA_WIDTH-1:0]pooling_out,
output valid,
output done
);
function integer clogb2 (input integer bit_depth);
begin
for(clogb2=0; bit_depth>0; clogb2=clogb2+1)
bit_depth = bit_depth >> 1;
end
endfunction
localparam CNT_BIT_NUM = clogb2((INPUT_SIZE * (INPUT_SIZE + 1)));
localparam CNT_LINE_BIT_NUM = clogb2(INPUT_SIZE);
wire [DATA_WIDTH-1:0]m01,m11;
reg [DATA_WIDTH-1:0]m00, m10;
wire [DATA_WIDTH-1:0]maxval;
reg shift;
reg[CNT_BIT_NUM-1:0] cnt;
reg[CNT_LINE_BIT_NUM-1:0] cnt_col;
reg[CNT_LINE_BIT_NUM-1:0] stride_valid_col;
reg[CNT_LINE_BIT_NUM-1:0] stride_valid_row;
assign m01 = tap[DATA_WIDTH*2-1:DATA_WIDTH];
assign m11 = tap[DATA_WIDTH-1:0];
//find max value in 2×2 window
assign maxval = (m00 > m01) ? ((m00 > m10) ? ((m00 > m11) ? m00 : m11) : ((m10 > m11) ? m10 : m11)) : ((m01 > m10) ? ((m01 > m11) ? m01 : m11) : ((m10 > m11) ? m10 : m11));
assign valid = (shift & stride_valid_col == 0 & stride_valid_row == 0) ? 1 : 0;
assign done = (cnt == INPUT_SIZE * 2 + INPUT_SIZE * (INPUT_SIZE - SCALE + 1) - 1) ? 1 : 0;
assign pooling_out = (valid) ? ((maxval > 0) ? maxval : 0) : 0;
always@(posedge clk or negedge rst_n)
if(!rst_n) begin
m00 <= 0;
m10 <= 0;
end
else if(clear) begin
m00 <= 0;
m10 <= 0;
end
else if(ena) begin
m00 <= m01;
m10 <= m11;
end
always @(posedge clk or negedge rst_n)
if(!rst_n)
cnt <= 0;
else if(clear)
cnt <= 0;
else if(!ena)
cnt <= cnt;
else
cnt <= cnt + 1;
always @(posedge clk or negedge rst_n)
if(!rst_n)
cnt_col <= 0;
else if(clear)
cnt_col <= 0;
else if(!ena)
cnt_col <= cnt_col;
else if(shift) begin
if(cnt_col == INPUT_SIZE - 1)
cnt_col <= 0;
else
cnt_col <= cnt_col + 1;
end
always @(posedge clk or negedge rst_n)
if(!rst_n)
stride_valid_col <= 0;
else if(clear)
stride_valid_col <= 0;
else if(ena) begin
if(shift) begin
if(stride_valid_col < SCALE - 1)
stride_valid_col <= stride_valid_col + 1;
else
stride_valid_col = 0;
end
end
always @(posedge clk or negedge rst_n)
if(!rst_n)
stride_valid_row <= 0;
else if(clear)
stride_valid_row <= 0;
else if(ena) begin
if(cnt_col == INPUT_SIZE - 1) begin
if(stride_valid_row < SCALE - 1)
stride_valid_row <= stride_valid_row + 1;
else
stride_valid_row = 0;
end
end
always @(posedge clk or negedge rst_n)
if(!rst_n)
shift <= 0;
else if(clear)
shift <= 0;
else if(!ena)
shift <= 0;
else if(cnt == INPUT_SIZE * SCALE + (SCALE - 1) - 1)
shift <= 1;
else if(cnt == INPUT_SIZE * SCALE + INPUT_SIZE * (INPUT_SIZE - SCALE + 1) - 1)
shift <= 0;
endmodule
`timescale 1ns / 1ns
module tb_maxpooling();
parameter DATA_WIDTH = 16,
FMAP_SIZE = 28,
SCALE = 2;
reg clk, rst_n;
reg ena;
reg[DATA_WIDTH-1:0] shift_in;
wire[DATA_WIDTH-1:0] shift_out;
wire[SCALE*DATA_WIDTH-1:0] tap;
wire[DATA_WIDTH-1:0] pooling_out;
wire pooling_done;
wire pooling_valid;
initial begin
clk = 0;
rst_n = 0;
ena = 0;
shift_in = 0;
#100;
rst_n = 1;
ena = 1;
repeat(FMAP_SIZE * FMAP_SIZE) begin
#100
clk = ~clk;
#100
clk = ~clk;
shift_in = shift_in + 1;
end
forever begin
#100
clk = ~clk;
#100
clk = ~clk;
end
end
SHIFT_RAM
# (
.DATA_WIDTH(DATA_WIDTH),
.TAP_NUM(SCALE),
.TAP_LENGTH(FMAP_SIZE)
)
u_shift_ram_1(
.clk(clk),
.rst_n(rst_n),
.clear(),
.ena(ena),
.shift_in(shift_in),
.shift_out(shift_out),
.taps(tap)
);
MAXPOOLING
#(
.DATA_WIDTH(DATA_WIDTH),
.FMAP_SIZE(FMAP_SIZE),
.KERNEL_SIZE(SCALE),
.STRIDE(1)
)
u_pooling
(
.clk(clk),
.rst_n(rst_n),
.ena(ena),
.clear(),
.tap(tap),
.pooling_out(pooling_out),
.valid(pooling_valid),
.done(pooling_done)
);
endmodule
从波形图可以看到,池化电路的第一个有效输出为29,最后一个有效输出为783,分别对应第一个和最后一个窗口的数据最大值,在最后一个有效输出到来时将done信号拉高一个周期。
本期文章介绍了池化电路的实现。池化电路与卷积电路对数据窗口的处理类似,具体处理细节可参考上期文章。下期文章将介绍全连接电路的实现。