设输入矩阵为NixNi,卷积核大小为KxK,卷积步长为S,则可以选用长度为Ni,高度为K的shift_ram实现,具体如下。
记输入shift_ram的时间为clk0,则易知经过NixK个周期后shift_ram被充满,也即产生第一列有效输出 X 0 , 0 , X 1 , 0 , X 2 , 0 X_{0,0},X_{1,0},X_{2,0} X0,0,X1,0,X2,0,再经过Nix(Ni-K+1)-1个周期,产生最后一列有效输出 X N i − 1 , N i − 1 , X N i − 2 , N i − 1 , X N i − 3 , N i − 1 X_{Ni-1,Ni-1},X_{Ni-2,Ni-1},X_{Ni-3,Ni-1} XNi−1,Ni−1,XNi−2,Ni−1,XNi−3,Ni−1,设数据读出延时为2+1个clk,计算卷积需要乘法1个周期,加法需要2个周期(每列相加,再把每列的和相加),则从输入第一个数据的读地址开始,需要经过NixK+(2+1)+(1+2)+K-1(需等待余下K-1列读出)个clk才会产生第一个卷积输出,经过NixK+(2+1)+(Ni-K+1)xNi-1+(1+2)个周期产生最后一个卷积输出,当然,这之间的输出并不都有效,应视步长S有选择的进行读取结果。
这种方式计算一个矩阵的大约需要 O ( N i 2 ) O(N_i^2) O(Ni2)个clk周期。
代码如下,已通过仿真,卷积核大小为3,其他卷积核大小需要修改部分代码,但思路一样。
conv.v
`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////
// Company:
// Engineer:
//
// Create Date: 2020/02/13 20:54:56
// Design Name:
// Module Name: top
// Project Name:
// Target Devices:
// Tool Versions:
// Description:
//
// Dependencies:
//
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
//
//////////////////////////////////////////////////////////////////////////////////
module conv
# (
parameter Ni=5,
parameter K=3,
parameter S=2,
parameter No=(Ni-K)/S+1,
parameter F=8) //FΪС��λ��
(
input clk,
input rst,
input start,
input signed[15:0]weight,
output reg[7:0]weight_addr,
input signed [15:0]rd_data,
output signed [15:0]wr_data,
output reg[15:0]rd_addr,
output reg[15:0]wr_addr,
output wren,
output done
);
wire [47:0]taps;
wire signed[15:0]dout,dout0,dout1,dout2;
reg read_weight;
reg read;
reg sum_valid;
reg sum_valid_d1;
reg wr_data_valid;
reg signed[15:0]din;
reg signed[15:0]dout0_d1,dout0_d2,dout1_d1,dout1_d2,dout2_d1,dout2_d2;
reg signed [31:0]mul1,mul2,mul0,mul1_d1,mul2_d1,mul0_d1,mul2_d2,mul1_d2,mul0_d2;
reg signed [31:0]sum0,sum1,sum2;
reg signed [15:0]sum;
reg [7:0]weight_addr_d1,weight_addr_d2;
reg [9:0]cnt1,cnt2;
reg [9:0]cnt_s1,cnt_s2;
reg busy;
wire signed [31:0]p11,p12,p13,p21,p22,p23,p31,p32,p33;
reg signed[15:0]k11,k12,k13,k21,k22,k23,k31,k32,k33;
assign p13=dout0*k13;
assign p23=dout1*k23;
assign p33=dout2*k33;
assign p12=dout0_d1*k12;
assign p22=dout1_d1*k22;
assign p32=dout2_d1*k32;
assign p11=dout0_d2*k11;
assign p21=dout1_d2*k21;
assign p31=dout2_d2*k31;
assign {dout0,dout1,dout2}=taps;
always@(posedge clk,posedge rst)
if(rst)
busy<=1'b0;
else if(start)
busy<=1'b1;
else if(done)
busy<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
weight_addr<=8'd0;
else if(read_weight)
if(weight_addr==K*K-1)
weight_addr<=8'd0;
else
weight_addr<=8'd1+weight_addr;
else
weight_addr<=8'd0;
always@(posedge clk,posedge rst)
if(rst)
read_weight<=1'b0;
else if(start)
read_weight<=1'b1;
else if(weight_addr==K*K-1)
read_weight<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
begin
weight_addr_d1<=8'd0;
weight_addr_d2<=8'd0;
end
else
begin
weight_addr_d1<=weight_addr;
weight_addr_d2<=weight_addr_d1;
end
always@(posedge clk,posedge rst)
if(rst)
begin
k11<=16'd0;
k12<=16'd0;
k13<=16'd0;
k21<=16'd0;
k22<=16'd0;
k23<=16'd0;
k31<=16'd0;
k32<=16'd0;
k33<=16'd0;
end
else
case(weight_addr_d2)
8'd0:k11<=weight;
8'd1:k12<=weight;
8'd2:k13<=weight;
8'd3:k21<=weight;
8'd4:k22<=weight;
8'd5:k23<=weight;
8'd6:k31<=weight;
8'd7:k32<=weight;
8'd8:k33<=weight;
default:begin
k11<=k11;
k12<=k12;
k13<=k13;
k21<=k21;
k22<=k22;
k23<=k23;
k31<=k31;
k32<=k32;
k33<=k33;
end
endcase
always@(posedge clk,posedge rst)
if(rst)
rd_addr<=16'd0;
else if(read)
if(rd_addr==(Ni*Ni-1))
rd_addr<=16'd0;
else
rd_addr<=rd_addr+16'd1;
else
rd_addr<=16'd0;
always@(posedge clk,posedge rst)
if(rst)
din<=16'd0;
else
din<=rd_data;
always@(posedge clk,posedge rst)
if(rst)
read<=1'b0;
else if(start)
read<=1'b1;
else if(busy&&rd_addr==(Ni*Ni-1))
read<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
cnt1<=10'd0;
else if(busy)
cnt1<=cnt1+1'b1;
else
cnt1<=10'd0;
always@(posedge clk,posedge rst)
if(rst)
begin
{dout0_d2,dout0_d1}<=32'd0;
{dout1_d2,dout1_d1}<=32'd0;
{dout2_d2,dout2_d1}<=32'd0;
end
else
begin
{dout0_d2,dout0_d1}<={dout0_d1,dout0};
{dout1_d2,dout1_d1}<={dout1_d1,dout1};
{dout2_d2,dout2_d1}<={dout2_d1,dout2};
end
always@(posedge clk,posedge rst)
if(rst)
begin
mul0<=32'd0;
mul1<=32'd0;
mul2<=32'd0;
mul0_d1<=32'd0;
mul1_d1<=32'd0;
mul2_d1<=32'd0;
mul0_d2<=32'd0;
mul1_d2<=32'd0;
mul2_d2<=32'd0;
end
else
begin
mul0<=p13;
mul1<=p23;
mul2<=p33;
mul0_d1<=p12;
mul1_d1<=p22;
mul2_d1<=p32;
mul0_d2<=p11;
mul1_d2<=p21;
mul2_d2<=p31;
end
always@(posedge clk,posedge rst)
if(rst)
begin
sum0<=32'd0;
sum1<=32'd0;
sum2<=32'd0;
end
else
begin
sum0<=mul0+mul1+mul2;
sum1<=mul0_d1+mul1_d1+mul2_d1;
sum2<=mul0_d2+mul1_d2+mul2_d2;
end
always@(posedge clk,posedge rst)
if(rst)
sum<=16'd0;
else
sum<=(sum1+sum2+sum0)>>>F;
always@(posedge clk,posedge rst)
if(rst)
sum_valid<=1'b0;
else if(cnt1==Ni*K+2+1+5-1-1) //�������ӳ�(2+1)+Ni*K+�����ӳ�5(��2��,��,��,��)-1(��ǰһ������)
sum_valid<=1'b1;
else if(cnt1==2+1+Ni*K+Ni*(Ni-K+1)-1+3-1) //�������ӳ�(2+1)+Ni*K+��һ�������һ��(Ni*(Ni-K+1)-1)+�����ӳ�3(�ˣ��ӣ���)
sum_valid<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
cnt2<=10'd0;
else if(sum_valid)
if(cnt2==Ni-1)
cnt2<=10'd0;
else
cnt2<=cnt2+1'd1;
else
cnt2<=10'd0;
always@(posedge clk,posedge rst)
if(rst)
cnt_s1<=10'd0;
else if(sum_valid)
if(cnt_s1==S-1||cnt2==Ni-1)
cnt_s1<=10'd0;
else
cnt_s1<=cnt_s1+1'd1;
else
cnt_s1<=10'd0;
always@(posedge clk,posedge rst)
if(rst)
cnt_s2<=10'd0;
else if(sum_valid)
if(cnt2==Ni-1)
if(cnt_s2==S-1)
cnt_s2<=10'd0;
else
cnt_s2<=cnt_s2+1'd1;
else
cnt_s2<=cnt_s2;
else
cnt_s2<=10'd0;
always@(posedge clk,posedge rst)
if(rst)
wr_data_valid<=1'b0;
else if(sum_valid&&cnt2<(Ni-K+1)&&(cnt_s1==1'b0)&&(cnt_s2==1'b0))
wr_data_valid<=1'b1;
else
wr_data_valid<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
wr_addr<=16'd0;
else if(wr_data_valid)
if(wr_addr==No*No-1)
wr_addr<=16'd0;
else
wr_addr<=16'd1+wr_addr;
else
wr_addr<=wr_addr;
always@(posedge clk,posedge rst)
if(rst)
sum_valid_d1<=1'b0;
else
sum_valid_d1<=sum_valid;
assign done=sum_valid_d1&(~sum_valid);
assign wren=wr_data_valid;
assign wr_data=sum;
SHIFT_RAM
#(.TAP_NUM(K),
.TAP_LENGTH(Ni))
U (
.clk(clk), // input wire clk
.din(din), // input wire [15 : 0] din
.dout(dout), // output wire [15 : 0] dout
.taps(taps) // output wire [47 : 0] taps
);
endmodule
SHIFT_RAM.v
`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////
// Company:
// Engineer:
//
// Create Date: 2020/02/13 19:08:21
// Design Name:
// Module Name: SHIFT_RAM
// Project Name:
// Target Devices:
// Tool Versions:
// Description:
//
// Dependencies:
//
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
//
//////////////////////////////////////////////////////////////////////////////////
module SHIFT_RAM
# (parameter TAP_NUM=3,
parameter TAP_LENGTH=5,
parameter N=TAP_NUM*TAP_LENGTH
)
(
input clk,
input [15:0]din,
output [15:0]dout,
output reg [TAP_NUM*16-1:0]taps
);
reg [15:0] data [0:N-1];
integer i;
always@(posedge clk)
begin
data[0]<=din;
for(i=0;i<N-1;i=i+1)
data[i+1]<=data[i];
end
always@(*)
begin
for(i=1;i<=TAP_NUM;i=i+1)
taps[(16*i-1)-:16]=data[i*TAP_LENGTH-1];
end
assign dout=data[N-1];
endmodule
卷积核大小为5x5的代码
module convolution
#(parameter Ni=9,
parameter S=2,
parameter F=8,
parameter DATA_WIDTH=16,
parameter MAP_ADDR_WIDTH=16,
parameter WEIGHT_ADDR_WIDTH=8)
(
input clk,
input start,
input signed[DATA_WIDTH-1:0]rd_data,
input signed[DATA_WIDTH-1:0]weight,
output reg signed[DATA_WIDTH-1:0]wr_data,
output reg[MAP_ADDR_WIDTH-1:0]rd_addr,
output reg[MAP_ADDR_WIDTH-1:0]wr_addr,
output reg[WEIGHT_ADDR_WIDTH-1:0]weight_addr,
output done,
output reg wren
);
parameter K=5;
reg read_map;
reg read_weight;
reg busy;
reg [19:0]cnt1;
reg [9:0]cnt2; //计数列
reg [9:0]cnt2s; //模S计数器
reg [9:0]cnt3s; //模S计数器
reg sum_valid;
reg sum_valid_ff;
reg [WEIGHT_ADDR_WIDTH-1:0] weight_addr_ff1,weight_addr_ff2;
wire [79:0]taps;
reg signed[DATA_WIDTH-1:0]din;
wire signed[DATA_WIDTH-1:0]m04,m14,m24,m34,m44;
reg signed[DATA_WIDTH-1:0]k00,k01,k02,k03,k04,k10,k11,k12,k13,k14,k20,k21,k22,k23,k24,k30,k31,k32,k33,k34,k40,k41,k42,k43,k44;
reg signed[DATA_WIDTH-1:0]m00,m01,m02,m03,m10,m11,m12,m13,m20,m21,m22,m23,m30,m31,m32,m33,m40,m41,m42,m43;
reg signed[DATA_WIDTH*2-1:0]p00,p01,p02,p03,p04,p10,p11,p12,p13,p14,p20,p21,p22,p23,p24,p30,p31,p32,p33,p34,p40,p41,p42,p43,p44;
reg signed[DATA_WIDTH*2-1:0]sum0,sum1,sum2,sum3,sum4;
assign m04=taps[79:64];
assign m14=taps[63:48];
assign m24=taps[47:32];
assign m34=taps[31:16];
assign m44=taps[15:0];
always@(posedge clk)
begin
{m00,m01,m02,m03}<={m01,m02,m03,m04};
{m10,m11,m12,m13}<={m11,m12,m13,m14};
{m20,m21,m22,m23}<={m21,m22,m23,m24};
{m30,m31,m32,m33}<={m31,m32,m33,m34};
{m40,m41,m42,m43}<={m41,m42,m43,m44};
end
always@(posedge clk)
begin
p00<=k00*m00;
p01<=k01*m01;
p02<=k02*m02;
p03<=k03*m03;
p04<=k04*m04;
p10<=k10*m10;
p11<=k11*m11;
p12<=k12*m12;
p13<=k13*m13;
p14<=k14*m14;
p20<=k20*m20;
p21<=k21*m21;
p22<=k22*m22;
p23<=k23*m23;
p24<=k24*m24;
p30<=k30*m30;
p31<=k31*m31;
p32<=k32*m32;
p33<=k33*m33;
p34<=k34*m34;
p40<=k40*m40;
p41<=k41*m41;
p42<=k42*m42;
p43<=k43*m43;
p44<=k44*m44;
end
always@(posedge clk)
begin
sum0<=p00+p10+p20+p30+p40;
sum1<=p01+p11+p21+p31+p41;
sum2<=p02+p12+p22+p32+p42;
sum3<=p03+p13+p23+p33+p43;
sum4<=p04+p14+p24+p34+p44;
end
always@(posedge clk)
wr_data<=(sum0+sum1+sum2+sum3+sum4)>>>F;
always@(posedge clk)
if(start)
read_map<=1'b1;
else if(rd_addr==Ni*Ni-1)
read_map<=1'b0;
always@(posedge clk)
if(start)
read_weight<=1'b1;
else if(weight_addr==K*K-1)
read_weight<=1'b0;
always@(posedge clk)
if(start)
busy<=1'b1;
else if(done)
busy<=1'b0;
always@(posedge clk)
if(busy)
cnt1<=cnt1+1'd1;
else
cnt1<=19'd0;
always@(posedge clk)
if(sum_valid)
if(cnt2==Ni-1)
cnt2<=10'd0;
else
cnt2<=cnt2+10'd1;
else
cnt2<=10'd0;
always@(posedge clk)
if(sum_valid)
if(cnt2==Ni-1)
if(cnt3s==S-1)
cnt3s<=10'd0;
else
cnt3s<=cnt3s+1'd1;
else
cnt3s<=cnt3s;
else
cnt3s<=10'd0;
always@(posedge clk)
if(sum_valid)
if(cnt2==Ni-1||cnt2s==S-1)
cnt2s<=10'd0;
else
cnt2s<=cnt2s+10'd1;
else
cnt2s<=10'd0;
always@(posedge clk)
if(sum_valid&&cnt2<Ni-K+1&&cnt2s==0&&cnt3s==0)
wren<=1'b1;
else
wren<=1'b0;
always@(posedge clk)
if(wren)
wr_addr<=wr_addr+1'd1;
else if(start)
wr_addr<=16'd0;
always@(posedge clk)
if(cnt1==Ni*K+2+1+K-1+3-1-1)
sum_valid<=1'b1;
else if(cnt1==Ni*K+2+1+Ni*(Ni-K+1)-1+3-1)
sum_valid<=1'b0;
else if(start)
sum_valid<=1'b0;
always@(posedge clk)
sum_valid_ff<=sum_valid;
always@(posedge clk)
if(read_map)
rd_addr<=rd_addr+1'd1;
else
rd_addr<=16'd0;
always@(posedge clk)
if(read_weight)
weight_addr<=weight_addr+1'd1;
else
weight_addr<=8'd0;
always@(posedge clk)
din<=rd_data;
always@(posedge clk)
begin
weight_addr_ff1<=weight_addr;
weight_addr_ff2<=weight_addr_ff1;
end
always@(posedge clk)
case(weight_addr_ff2)
8'd0:k00<=weight;
8'd1:k01<=weight;
8'd2:k02<=weight;
8'd3:k03<=weight;
8'd4:k04<=weight;
8'd5:k10<=weight;
8'd6:k11<=weight;
8'd7:k12<=weight;
8'd8:k13<=weight;
8'd9:k14<=weight;
8'd10:k20<=weight;
8'd11:k21<=weight;
8'd12:k22<=weight;
8'd13:k23<=weight;
8'd14:k24<=weight;
8'd15:k30<=weight;
8'd16:k31<=weight;
8'd17:k32<=weight;
8'd18:k33<=weight;
8'd19:k34<=weight;
8'd20:k40<=weight;
8'd21:k41<=weight;
8'd22:k42<=weight;
8'd23:k43<=weight;
8'd24:k44<=weight;
default:;
endcase
assign done=~sum_valid&&sum_valid_ff;
SHIFT_RAM
# (.TAP_NUM(K),
.TAP_LENGTH(Ni)
)
U
(
.clk(clk),
.din(din),
.dout(),
.taps(taps)
);
endmodule