fpga基于shift ram的卷积实现

verilog实现卷积运算

设输入矩阵为NixNi,卷积核大小为KxK,卷积步长为S,则可以选用长度为Ni,高度为K的shift_ram实现,具体如下。
记输入shift_ram的时间为clk0,则易知经过NixK个周期后shift_ram被充满,也即产生第一列有效输出 X 0 , 0 , X 1 , 0 , X 2 , 0 X_{0,0},X_{1,0},X_{2,0} X0,0,X1,0,X2,0,再经过Nix(Ni-K+1)-1个周期,产生最后一列有效输出 X N i − 1 , N i − 1 , X N i − 2 , N i − 1 , X N i − 3 , N i − 1 X_{Ni-1,Ni-1},X_{Ni-2,Ni-1},X_{Ni-3,Ni-1} XNi1,Ni1,XNi2,Ni1,XNi3,Ni1,设数据读出延时为2+1个clk,计算卷积需要乘法1个周期,加法需要2个周期(每列相加,再把每列的和相加),则从输入第一个数据的读地址开始,需要经过NixK+(2+1)+(1+2)+K-1(需等待余下K-1列读出)个clk才会产生第一个卷积输出,经过NixK+(2+1)+(Ni-K+1)xNi-1+(1+2)个周期产生最后一个卷积输出,当然,这之间的输出并不都有效,应视步长S有选择的进行读取结果。
这种方式计算一个矩阵的大约需要 O ( N i 2 ) O(N_i^2) O(Ni2)个clk周期。
代码如下,已通过仿真,卷积核大小为3,其他卷积核大小需要修改部分代码,但思路一样。
conv.v

`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////
// Company: 
// Engineer: 
// 
// Create Date: 2020/02/13 20:54:56
// Design Name: 
// Module Name: top
// Project Name: 
// Target Devices: 
// Tool Versions: 
// Description: 
// 
// Dependencies: 
// 
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
// 
//////////////////////////////////////////////////////////////////////////////////


module conv
# (
    parameter Ni=5,
    parameter K=3,
    parameter S=2,
    parameter No=(Ni-K)/S+1,
    parameter F=8)         //FΪС��λ��
(
input clk,
input rst,
input start,
input signed[15:0]weight,
output reg[7:0]weight_addr,
input signed [15:0]rd_data,
output signed [15:0]wr_data,
output reg[15:0]rd_addr,
output reg[15:0]wr_addr,
output wren,
output done
    );


wire [47:0]taps;
wire signed[15:0]dout,dout0,dout1,dout2;
reg read_weight;
reg read;
reg sum_valid;
reg sum_valid_d1;
reg wr_data_valid;
reg signed[15:0]din;
reg signed[15:0]dout0_d1,dout0_d2,dout1_d1,dout1_d2,dout2_d1,dout2_d2;
reg signed [31:0]mul1,mul2,mul0,mul1_d1,mul2_d1,mul0_d1,mul2_d2,mul1_d2,mul0_d2;
reg signed [31:0]sum0,sum1,sum2;
reg signed [15:0]sum;
reg [7:0]weight_addr_d1,weight_addr_d2;
reg [9:0]cnt1,cnt2;
reg [9:0]cnt_s1,cnt_s2;
reg busy;
wire signed [31:0]p11,p12,p13,p21,p22,p23,p31,p32,p33;
reg signed[15:0]k11,k12,k13,k21,k22,k23,k31,k32,k33;

assign        p13=dout0*k13;
assign        p23=dout1*k23;
assign        p33=dout2*k33;
assign        p12=dout0_d1*k12;
assign        p22=dout1_d1*k22;
assign        p32=dout2_d1*k32;
assign        p11=dout0_d2*k11;
assign        p21=dout1_d2*k21;
assign        p31=dout2_d2*k31;

assign {dout0,dout1,dout2}=taps;

always@(posedge clk,posedge rst)
if(rst)
    busy<=1'b0;
else if(start)
    busy<=1'b1;
else if(done)
    busy<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
    weight_addr<=8'd0;
else if(read_weight)
    if(weight_addr==K*K-1)
        weight_addr<=8'd0;
    else
        weight_addr<=8'd1+weight_addr;
else
    weight_addr<=8'd0;
always@(posedge clk,posedge rst)
if(rst)
    read_weight<=1'b0;
else if(start)
    read_weight<=1'b1;
else if(weight_addr==K*K-1)
    read_weight<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
begin
    weight_addr_d1<=8'd0;
    weight_addr_d2<=8'd0;
end
else
begin 
    weight_addr_d1<=weight_addr;
    weight_addr_d2<=weight_addr_d1;
end
always@(posedge clk,posedge rst)
if(rst)
begin 
     k11<=16'd0;
     k12<=16'd0;
     k13<=16'd0;
     k21<=16'd0;
     k22<=16'd0;
     k23<=16'd0;
     k31<=16'd0;
     k32<=16'd0;
     k33<=16'd0;
end
else 
     case(weight_addr_d2)
     8'd0:k11<=weight;
     8'd1:k12<=weight;
     8'd2:k13<=weight;
     8'd3:k21<=weight;
     8'd4:k22<=weight;
     8'd5:k23<=weight;
     8'd6:k31<=weight;
     8'd7:k32<=weight;
     8'd8:k33<=weight;
     default:begin
                k11<=k11;
                k12<=k12;
                k13<=k13;
                k21<=k21;
                k22<=k22;
                k23<=k23;
                k31<=k31;
                k32<=k32;
                k33<=k33;
            end
     endcase

always@(posedge clk,posedge rst)
if(rst)
    rd_addr<=16'd0;
else if(read)
    if(rd_addr==(Ni*Ni-1))
       rd_addr<=16'd0;
    else
       rd_addr<=rd_addr+16'd1;
else
    rd_addr<=16'd0;
always@(posedge clk,posedge rst)
if(rst)
    din<=16'd0;
else
    din<=rd_data;
always@(posedge clk,posedge rst)
if(rst)
    read<=1'b0;
else if(start)
    read<=1'b1;
else if(busy&&rd_addr==(Ni*Ni-1))
    read<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
    cnt1<=10'd0;
else if(busy)
    cnt1<=cnt1+1'b1;
else
    cnt1<=10'd0;
always@(posedge clk,posedge rst)
if(rst)
begin
        {dout0_d2,dout0_d1}<=32'd0;
        {dout1_d2,dout1_d1}<=32'd0;
        {dout2_d2,dout2_d1}<=32'd0;
end
else
begin
        {dout0_d2,dout0_d1}<={dout0_d1,dout0};
        {dout1_d2,dout1_d1}<={dout1_d1,dout1};
        {dout2_d2,dout2_d1}<={dout2_d1,dout2};
end
always@(posedge clk,posedge rst)
if(rst)
begin
         mul0<=32'd0;
         mul1<=32'd0;
         mul2<=32'd0;
         mul0_d1<=32'd0;
         mul1_d1<=32'd0;
         mul2_d1<=32'd0;
         mul0_d2<=32'd0;
         mul1_d2<=32'd0;
         mul2_d2<=32'd0;
end
else
begin
         mul0<=p13;
         mul1<=p23;
         mul2<=p33;
         mul0_d1<=p12;
         mul1_d1<=p22;
         mul2_d1<=p32;
         mul0_d2<=p11;
         mul1_d2<=p21;
         mul2_d2<=p31;
end
always@(posedge clk,posedge rst)
if(rst)
begin
   sum0<=32'd0;
   sum1<=32'd0;
   sum2<=32'd0;
end
else
begin
    sum0<=mul0+mul1+mul2;
    sum1<=mul0_d1+mul1_d1+mul2_d1;
    sum2<=mul0_d2+mul1_d2+mul2_d2;
end
always@(posedge clk,posedge rst)
if(rst)
    sum<=16'd0;
else
    sum<=(sum1+sum2+sum0)>>>F;
always@(posedge clk,posedge rst)
if(rst)
     sum_valid<=1'b0;
else if(cnt1==Ni*K+2+1+5-1-1)                          //�������ӳ�(2+1)+Ni*K+�����ӳ�5(��2��,��,��,��)-1(��ǰһ������)
     sum_valid<=1'b1;                           
else if(cnt1==2+1+Ni*K+Ni*(Ni-K+1)-1+3-1)              //�������ӳ�(2+1)+Ni*K+��һ�������һ��(Ni*(Ni-K+1)-1)+�����ӳ�3(�ˣ��ӣ���)
     sum_valid<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
    cnt2<=10'd0;
else if(sum_valid)
    if(cnt2==Ni-1)
        cnt2<=10'd0;
    else
        cnt2<=cnt2+1'd1;
else
    cnt2<=10'd0; 
always@(posedge clk,posedge rst)
if(rst)
    cnt_s1<=10'd0;
else if(sum_valid)
    if(cnt_s1==S-1||cnt2==Ni-1)
         cnt_s1<=10'd0;
    else
         cnt_s1<=cnt_s1+1'd1;
else
   cnt_s1<=10'd0;

always@(posedge clk,posedge rst)
if(rst)
     cnt_s2<=10'd0;
else if(sum_valid)
     if(cnt2==Ni-1)
        if(cnt_s2==S-1)
           cnt_s2<=10'd0;
        else
           cnt_s2<=cnt_s2+1'd1;
     else 
        cnt_s2<=cnt_s2;
else
     cnt_s2<=10'd0;
always@(posedge clk,posedge rst)
if(rst)
     wr_data_valid<=1'b0;
else if(sum_valid&&cnt2<(Ni-K+1)&&(cnt_s1==1'b0)&&(cnt_s2==1'b0))
     wr_data_valid<=1'b1;
else
     wr_data_valid<=1'b0;
always@(posedge clk,posedge rst)
if(rst)
    wr_addr<=16'd0;
else if(wr_data_valid)
    if(wr_addr==No*No-1)
        wr_addr<=16'd0;
    else
        wr_addr<=16'd1+wr_addr;
else 
    wr_addr<=wr_addr;
always@(posedge clk,posedge rst)
if(rst)
    sum_valid_d1<=1'b0;
else 
    sum_valid_d1<=sum_valid;
    
assign done=sum_valid_d1&(~sum_valid);
assign wren=wr_data_valid;
assign wr_data=sum;

SHIFT_RAM 
#(.TAP_NUM(K),
  .TAP_LENGTH(Ni))
U (
      .clk(clk),    // input wire clk
      .din(din),    // input wire [15 : 0] din
      .dout(dout),  // output wire [15 : 0] dout
      .taps(taps)  // output wire [47 : 0] taps
    );
endmodule

SHIFT_RAM.v

`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////
// Company: 
// Engineer: 
// 
// Create Date: 2020/02/13 19:08:21
// Design Name: 
// Module Name: SHIFT_RAM
// Project Name: 
// Target Devices: 
// Tool Versions: 
// Description: 
// 
// Dependencies: 
// 
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
// 
//////////////////////////////////////////////////////////////////////////////////


module SHIFT_RAM
# (parameter TAP_NUM=3,
   parameter TAP_LENGTH=5,
   parameter N=TAP_NUM*TAP_LENGTH
   )
(
input clk,
input [15:0]din,
output [15:0]dout,
output reg [TAP_NUM*16-1:0]taps
    );

reg [15:0] data [0:N-1];
integer i;
always@(posedge clk)
begin
    data[0]<=din;
    for(i=0;i<N-1;i=i+1)
        data[i+1]<=data[i];
end
always@(*)
begin
    for(i=1;i<=TAP_NUM;i=i+1)
        taps[(16*i-1)-:16]=data[i*TAP_LENGTH-1];
end
assign dout=data[N-1];
endmodule

更新

卷积核大小为5x5的代码

module convolution
#(parameter Ni=9,
  parameter S=2,
  parameter F=8,
  parameter DATA_WIDTH=16,
  parameter MAP_ADDR_WIDTH=16,
  parameter WEIGHT_ADDR_WIDTH=8)
(
input clk,
input start,
input signed[DATA_WIDTH-1:0]rd_data,
input signed[DATA_WIDTH-1:0]weight,
output reg signed[DATA_WIDTH-1:0]wr_data,
output reg[MAP_ADDR_WIDTH-1:0]rd_addr,
output reg[MAP_ADDR_WIDTH-1:0]wr_addr,
output reg[WEIGHT_ADDR_WIDTH-1:0]weight_addr,
output done,
output reg wren
);
parameter K=5;

reg read_map;
reg read_weight;
reg busy;
reg [19:0]cnt1;
reg [9:0]cnt2;               //计数列
reg [9:0]cnt2s;              //模S计数器
reg [9:0]cnt3s;              //模S计数器
reg sum_valid;
reg sum_valid_ff;
reg [WEIGHT_ADDR_WIDTH-1:0] weight_addr_ff1,weight_addr_ff2;
wire [79:0]taps;
reg signed[DATA_WIDTH-1:0]din;
wire signed[DATA_WIDTH-1:0]m04,m14,m24,m34,m44;
reg signed[DATA_WIDTH-1:0]k00,k01,k02,k03,k04,k10,k11,k12,k13,k14,k20,k21,k22,k23,k24,k30,k31,k32,k33,k34,k40,k41,k42,k43,k44;
reg signed[DATA_WIDTH-1:0]m00,m01,m02,m03,m10,m11,m12,m13,m20,m21,m22,m23,m30,m31,m32,m33,m40,m41,m42,m43;
reg signed[DATA_WIDTH*2-1:0]p00,p01,p02,p03,p04,p10,p11,p12,p13,p14,p20,p21,p22,p23,p24,p30,p31,p32,p33,p34,p40,p41,p42,p43,p44;
reg signed[DATA_WIDTH*2-1:0]sum0,sum1,sum2,sum3,sum4;

assign m04=taps[79:64];
assign m14=taps[63:48];
assign m24=taps[47:32];
assign m34=taps[31:16];
assign m44=taps[15:0];

always@(posedge clk)
begin
    {m00,m01,m02,m03}<={m01,m02,m03,m04};
    {m10,m11,m12,m13}<={m11,m12,m13,m14};
    {m20,m21,m22,m23}<={m21,m22,m23,m24};
    {m30,m31,m32,m33}<={m31,m32,m33,m34};
    {m40,m41,m42,m43}<={m41,m42,m43,m44};
end
always@(posedge clk)
begin
     p00<=k00*m00;
     p01<=k01*m01;
     p02<=k02*m02;
     p03<=k03*m03;
     p04<=k04*m04;
     p10<=k10*m10;
     p11<=k11*m11;
     p12<=k12*m12;
     p13<=k13*m13;
     p14<=k14*m14;
     p20<=k20*m20;
     p21<=k21*m21;
     p22<=k22*m22;
     p23<=k23*m23;
     p24<=k24*m24;
     p30<=k30*m30;
     p31<=k31*m31;
     p32<=k32*m32;
     p33<=k33*m33;
     p34<=k34*m34;
     p40<=k40*m40;
     p41<=k41*m41;
     p42<=k42*m42;
     p43<=k43*m43;
     p44<=k44*m44;
end
always@(posedge clk)
begin 
    sum0<=p00+p10+p20+p30+p40;
    sum1<=p01+p11+p21+p31+p41;
    sum2<=p02+p12+p22+p32+p42;
    sum3<=p03+p13+p23+p33+p43;
    sum4<=p04+p14+p24+p34+p44;
end
always@(posedge clk)
    wr_data<=(sum0+sum1+sum2+sum3+sum4)>>>F;
always@(posedge clk)
if(start)
    read_map<=1'b1;
else if(rd_addr==Ni*Ni-1)
    read_map<=1'b0;
always@(posedge clk)
if(start)
    read_weight<=1'b1;
else if(weight_addr==K*K-1)
    read_weight<=1'b0;
always@(posedge clk)
if(start)
    busy<=1'b1;
else if(done)
    busy<=1'b0;
always@(posedge clk)
if(busy)
    cnt1<=cnt1+1'd1;
else
    cnt1<=19'd0;
always@(posedge clk)
if(sum_valid)
    if(cnt2==Ni-1)
         cnt2<=10'd0;
    else
         cnt2<=cnt2+10'd1;
else
    cnt2<=10'd0;
always@(posedge clk)
if(sum_valid)
    if(cnt2==Ni-1)
         if(cnt3s==S-1)
             cnt3s<=10'd0;
         else
             cnt3s<=cnt3s+1'd1;
    else
       cnt3s<=cnt3s;
else
   cnt3s<=10'd0;
always@(posedge clk)
if(sum_valid)
    if(cnt2==Ni-1||cnt2s==S-1)
         cnt2s<=10'd0;
    else 
         cnt2s<=cnt2s+10'd1;
else
    cnt2s<=10'd0;
always@(posedge clk)
if(sum_valid&&cnt2<Ni-K+1&&cnt2s==0&&cnt3s==0)
    wren<=1'b1;
else
    wren<=1'b0;
always@(posedge clk)
if(wren)
    wr_addr<=wr_addr+1'd1;
else if(start)
    wr_addr<=16'd0;
    
always@(posedge clk)
if(cnt1==Ni*K+2+1+K-1+3-1-1)
    sum_valid<=1'b1;
else if(cnt1==Ni*K+2+1+Ni*(Ni-K+1)-1+3-1)
    sum_valid<=1'b0;
else if(start)
    sum_valid<=1'b0;
always@(posedge clk)
   sum_valid_ff<=sum_valid;

always@(posedge clk)
if(read_map)
    rd_addr<=rd_addr+1'd1;
else 
    rd_addr<=16'd0;
always@(posedge clk)
if(read_weight)
    weight_addr<=weight_addr+1'd1;
else 
    weight_addr<=8'd0;
always@(posedge clk)
    din<=rd_data;
always@(posedge clk)
begin
    weight_addr_ff1<=weight_addr;
    weight_addr_ff2<=weight_addr_ff1;
end
always@(posedge clk)
case(weight_addr_ff2)
     8'd0:k00<=weight;
     8'd1:k01<=weight;
     8'd2:k02<=weight;
     8'd3:k03<=weight;
     8'd4:k04<=weight;
     8'd5:k10<=weight;
     8'd6:k11<=weight;
     8'd7:k12<=weight;
     8'd8:k13<=weight;
     8'd9:k14<=weight;
     8'd10:k20<=weight;
     8'd11:k21<=weight;
     8'd12:k22<=weight;
     8'd13:k23<=weight;
     8'd14:k24<=weight;
     8'd15:k30<=weight;
     8'd16:k31<=weight;
     8'd17:k32<=weight;
     8'd18:k33<=weight;
     8'd19:k34<=weight;
     8'd20:k40<=weight;
     8'd21:k41<=weight;
     8'd22:k42<=weight;
     8'd23:k43<=weight;
     8'd24:k44<=weight;
     default:;
endcase

assign done=~sum_valid&&sum_valid_ff;

SHIFT_RAM 
# (.TAP_NUM(K),
   .TAP_LENGTH(Ni)
   )
U
(
.clk(clk),
.din(din),
.dout(),
.taps(taps)
    );  
    
endmodule

你可能感兴趣的:(fpga基于shift ram的卷积实现)