吃透Chisel语言.40.Chisel实战之单周期RISC-V处理器实现(下)——具体实现和最终测试

Chisel实战之单周期RISC-V处理器实现(下)——具体实现和最终测试

上一篇文章中我们对本项目的需求进行了分析,并得到了初步的设计,这一篇文章我们就可以基于该设计来实现我们的单周期RISC-V处理器了。实现之后也必须用实际代码来测试一下,至少也得能运行递归版本的斐波那契数列计算。完整项目代码可以在本人的Github仓库获取:github-3rr0r/RV32ISC: A RISC-V RV32I ISA Single Cycle CPU。接下来我们直接进入正题!

实现思路

再次放上我们的设计图:

吃透Chisel语言.40.Chisel实战之单周期RISC-V处理器实现(下)——具体实现和最终测试_第1张图片

根据上篇文章的分析,我们设计的CPU中应该至少需要包含以下组件:

  1. 指令内存(MemInst):接收一个32位的指令地址,读取出指令;
  2. PC寄存器(PCReg):为指令内存提供指令地址,每个时钟周期地址都会+4,当前指令为跳转时,下一条指令为跳转目的地址,当前指令为分支指令且分支成功时,下一条指令为分支目标地址;
  3. 通用寄存器堆(Registers):可读可写的寄存器,接收寄存器号,为运算单元提供操作数,接收运算结果或从数据内存读取到的值;
  4. 数据内存(MemData):根据加载/存储地址,加载或存储数据,加载或存储依赖于译码器的译码;
  5. 指令译码器(Decoder):对指令进行译码,解析得到立即数、操作码、寄存器号等信息;
  6. 运算单元(ALU):根据操作数和操作码进行运算,运算结果写到寄存器,分支指令时将比较结果发送给PC,加载存储指令时计算地址;
  7. 控制单元(Controller):根据译码结果,给出对数据通路进行控制的相关信号;

模块化思维必须要有,如果把一个复杂的系统全都塞到一个代码文件里面,不管是编写、测试还是调试、迭代,都会有巨大的麻烦在等着你。而模块化的设计可以让模块之间相互独立,可以分别编写、测试各个模块,修改时也不会影响其他模块。

由于我们在设计时已经将系统划分出各个模块,因此分别实现为一个Chisel的Module,然后用一个顶层的Module将它们连接起来就可以了。

全局配置文件

磨刀不误砍柴工,在实现各个模块之前,我们做一件事会对我们很有帮助,那就是创建一个全局配置文件。

为什么要这么做呢?比如,我们实现的是32位的处理器,那么32这个数字肯定在编写过程中经常使用。我们当然可以在所有文件中都使用32这个数字,但是扩展性就会有很大的问题。比如要把项目迁移到64位实现,我们只能一个个数字去修改,而如果有全局配置文件,我们就可以直接在里面用一个对象来存放32这个数,在需要用32的地方使用这个对象,如果需要改成64,那么直接修改该对象的值就行了。

上面的说法可能不太严谨,但适用场合很多,拿我们的项目来说,虽然只是个单周期32位处理器,但也还是有其他东西是可以配置的,比如内存初始化的地址等。复杂的项目可能还可以用于配置Cache的参数、内存的端口数等。总之,这是个好的做法,我们应该形成这种思维。

那么我们在src/main/scala/config文件夹下创建一个Configs.scala文件,内容暂时如下:

package config

import chisel3._

object Configs {
    val ADDR_WIDTH = 32 // 地址位宽
    val ADDR_BYTE_WIDTH = ADDR_WIDTH / 8    // 地址位宽按字节算
    val DATA_WIDTH = 32 // 数据位宽
    val DATA_WIDTH_H = 16   // 半字数据位宽
    val DATA_WIDTH_B = 8    // 字节数据位宽
}

注意,config文件夹是自己创建的,使用Configs.scala中的对象时,将文件夹名作为包名使用,导入所有参数即可:

import config.Configs._

后续如果有需要添加的配置,在此配置文件中添加即可。

接下来我们开始各个模块的实现,首先就是PC寄存器模块。

PC寄存器的实现

PC寄存器的功能如下:

  1. 32位的指令地址输出,为指令内存提供地址以获取指令;
  2. 每个时钟周期,PC寄存器中的指令地址都会自增4,以获取下一个地址;
  3. 如果当前指令为跳转指令,会接收到控制单元的信号ctrlJump,接收计算结果(跳转地址)作为下一个地址;
  4. 如果当前指令为分支指令,会接收到控制单元的信号ctrlBranch,以及运算单元的分支结果,如果分支成功,接收计算结果(跳转地址)作为下一个地址;
  5. 寄存器初始化时,输出的指令地址为0;

所以我们可以这么实现PCReg模块(src/main/scala/rv32isc/PCReg.scala):

package rv32isc

import chisel3._
import chisel3.util._

import config.Configs._

// PCReg的模块接口
class PCRegIO extends Bundle {
    val addrOut = Output(UInt(ADDR_WIDTH.W))      // 地址输出
    val ctrlJump = Input(Bool())                // 当前指令是否为跳转指令
    val ctrlBranch = Input(Bool())              // 当前指令是否为分支指令
    val resultBranch = Input(Bool())            // 分支结果是否为分支成功
    val addrTarget = Input(UInt(ADDR_WIDTH.W))    // 跳转/分支的目的地址
}

// PCReg模块
class PCReg extends Module {
    val io = IO(new PCRegIO())  // 输入输出接口

    val regPC = RegInit(UInt(ADDR_WIDTH.W), START_ADDR.U)   // PC寄存器,初始化时重置为START_ADDR

    when (io.ctrlJump || (io.ctrlBranch && io.resultBranch)) {  // 跳转或分支成功时,更新为目的地址
        regPC := io.addrTarget
    } .otherwise {  // 否则自增4
        regPC := regPC + ADDR_BYTE_WIDTH.U
    }

    io.addrOut := regPC // 每个时钟周期输出当前PC寄存内的地址
}

注意,这里我们用到了一个常量START_ADDR用于表示起始执行地址,应该在Configs对象中包含:

val START_ADDR: Long = 0x00000000  // 起始执行地址

接着,我们可以创建PCReg模块对应的测试(src/test/scala/rv32isc/PCRegTest.scala):

package rv32isc

import chisel3._
import chiseltest._
import chisel3.util._
import org.scalatest.flatspec.AnyFlatSpec

import config.Configs._

trait PCRegTestFunc {
    // 生成十个随机地址用于测试
    val target_list =
        Seq.fill(10)(scala.util.Random.nextInt().toLong & 0x00ffffffffL)

    def testFn(dut: PCReg): Unit = {
        // 初始化状态
        dut.io.ctrlBranch.poke(false.B)
        dut.io.ctrlJump.poke(false.B)
        dut.io.resultBranch.poke(false.B)
        dut.io.addrTarget.poke(START_ADDR)
        dut.io.addrOut.expect(START_ADDR)
        
        var addr: Long = START_ADDR

        // 正常自增功能
        for (target <- target_list) {
            dut.io.addrTarget.poke(target.U)
            addr += ADDR_BYTE_WIDTH
            dut.clock.step()
            dut.io.addrOut.expect(addr.U)
        }

        // 跳转功能测试
        dut.io.ctrlJump.poke(true.B)
        for (target <- target_list) {
            dut.io.addrTarget.poke(target.U)
            dut.clock.step()
            dut.io.addrOut.expect(target.U)
            addr = target
        }
        dut.io.ctrlJump.poke(false.B)

        // 分支功能测试
        // 分支指令,但分支不成功
        dut.io.ctrlBranch.poke(true.B)
        for (target <- target_list) {
            dut.io.addrTarget.poke(target.U)
            addr += ADDR_BYTE_WIDTH
            dut.clock.step()
            dut.io.addrOut.expect(addr.U)
        }

        // 分支指令,且分支成功
        dut.io.resultBranch.poke(true.B)
        for (target <- target_list) {
            dut.io.addrTarget.poke(target.U)
            addr += ADDR_BYTE_WIDTH
            dut.clock.step()
            dut.io.addrOut.expect(target.U)
            addr = target
        }
    }
}

class PCRegTest extends AnyFlatSpec with ChiselScalatestTester with PCRegTestFunc {
    "PCReg" should "pass" in {
        test(new PCReg) { dut =>
            testFn(dut)
        } 
    }
}

测试通过。

PC寄存器将指令地址给了指令内存,指令内存给出这一周期要执行的指令,那么接下来我们就实现指令内存。

指令内存的实现

指令内存的工作逻辑特别简单:

  1. 接收一个32位的指令地址;
  2. 根据指令地址输出对应的指令;

Mem生成的内存为异步读、同步写的内存,得到地址后输出指令不含时序,为组合电路,所以实现起来也特别简单:

package rv32isc

import chisel3._
import chisel3.util._

import config.Configs._

class MemInstIO extends Bundle {
    val addr = Input(UInt(ADDR_WIDTH.W))    // 指令地址
    val inst = Output(UInt(INST_WIDTH.W))   // 指令输出
}

class MemInst extends Module {
    val io = IO(new MemInstIO())    // 输入输出接口

    // 指令内存,能存放MEM_INST_SIZE条INST_WIDTH位的指令
    val mem = Mem(MEM_INST_SIZE, UInt(INST_WIDTH.W))

    io.inst := mem.read(io.addr >> INST_BYTE_WIDTH_LOG.U)    // 读取对应位置的指令并输出
}

同样,这里我们使用了两个常量需要添加到Configs对象里:

val INST_WIDTH = 32 // 指令位宽
val INST_BYTE_WIDTH = INST_WIDTH / 8 // 指令位宽按字节算
val INST_BYTE_WIDTH_LOG = ceil(log(INST_BYTE_WIDTH) / log(2)).toInt // 指令地址对齐的偏移量
val MEM_INST_SIZE = 1024 // 指令内存大小

我这里着重解释一下mem.read(io.addr >> INST_BYTE_WIDTH_LOG.U)和常量INST_BYTE_WIDTH_LOG = ceil(log(INST_BYTE_WIDTH) / log(2)).toInt:由于指令的宽度为32位,因此每次读取指令时指令地址都要对齐到4字节。而我们的指令内存的每个数据项都是四字节的,因此应该用输入的指令地址的高30位访问,即右移两位。

不过,在测试上会稍微麻烦一点,我们需要先向内存中写入一些模拟的指令填满内存,在内存初始化的时候用loadMemoryFromFile写入mem。我们在MemInst类中添加初始化相关代码:

loadMemoryFromFile(mem, "src/test/scala/rv32isc/MemInst.hex", MemoryLoadFileType.Hex)

注意

  1. loadMemoryFromFile需要导入包:import chisel3.util.experimental.loadMemoryFromFile
  2. MemoryLoadFileType.Hex需要导入包:import firrtl.annotations.MemoryLoadFileType
  3. src/test/scala/rv32isc/MemInst.hex文件为在测试代码中生成的随机文本文件,共1024行,每一行都是一串8字符的随机的16进制数,用于模拟32位的指令,具体生成代码见src/test/scala/rv32isc/MemInstTest.scala
  4. loadMemoryFromFile只会在测试的时候执行;

下面我们就可以写测试代码了:

package rv32isc

import chisel3._
import chiseltest._
import chisel3.util._
import org.scalatest.flatspec.AnyFlatSpec

import java.io.PrintWriter
import java.io.File

import config.Configs._

trait MemInstTestFunc {
    // 生成MEM_INST_SIZE条随机指令进行测试
    val inst_list =
        Seq.fill(MEM_INST_SIZE)(scala.util.Random.nextInt().toLong & 0x00ffffffffL)

    // 为随机指令生成hex文本文件
    def genMemInstHex(): Unit = {
        val memFile = new File(
          System.getProperty("user.dir") + "/src/test/scala/rv32isc/MemInst.hex"
        )
        memFile.createNewFile()
        val memPrintWriter = new PrintWriter(memFile)
        for (i <- 0 to MEM_INST_SIZE - 1) {
            memPrintWriter.println(inst_list(i).toHexString)
        }
        memPrintWriter.close()
    }

    def testFn(dut: MemInst): Unit = {
        // 依次读取所有的指令,与inst_list进行匹配
        for (i <- 0 to MEM_INST_SIZE - 1) {
            dut.io.addr.poke((i * INST_BYTE_WIDTH).U)   // 作为地址,应该左移两位,即乘以4
            dut.io.inst.expect(inst_list(i).U)
        }
    }
}

class MemInstTest extends AnyFlatSpec with ChiselScalatestTester with MemInstTestFunc {
    "MemInst" should "pass" in {
        // 先生成随机hex文件,再进行测试
        genMemInstHex()
        test(new MemInst) { dut =>
            testFn(dut)
        } 
    }
}

测试通过。

得到指令之后,我们就可以对指令进行译码,所以我们下一步要实现译码单元。

译码单元的实现

要对指令进行译码,首先我们需要对指令格式有清晰地了解,再次放上指令格式的图:

吃透Chisel语言.40.Chisel实战之单周期RISC-V处理器实现(下)——具体实现和最终测试_第2张图片

我们用inst表示指令,那么很显然有以下几个规律:

  1. inst[6:0]都是opcode
  2. rs1rs2rd分别固定在inst[19:15]inst[24:20]inst[11:7]上;
  3. 功能码funct3funct7分别位于inst[14:12]inst[31:25]
  4. 立即数的分布比较复杂;

我们的思路可以这么来:

  1. opcodefunct3funct7的解析结果交给控制模块(Controller),由控制模块控制ALU的行为;
  2. rs1rs2rd直接交给寄存器堆(Registers),由控制模块控制是否写寄存器,而寄存器时钟读寄存并交给ALU
  3. 译码单元自己计算立即数imm,根据指令格式分类选择输出正确的立即数给ALU

接下来我们分析RV32I的指令中,可以如何根据opcode对应到指令格式类型上:

  1. U类型只有两个条指令LUIAUIPCinst[6:2]分别为b01101b00101
  2. J类型只有一个JALinst[6:2]b11011
  3. I类型有JALR、LOAD类指令、立即数算术逻辑类指令,inst[6:2]b11001b00000b00100三种;
  4. B类型都是条件分支类指令,inst[6:2]b11000
  5. S类型都是STORE类指令,inst[6:2]b01000
  6. R类型都是算术逻辑类指令,inst[6:2]b01100

我们再分析指令在ALU上的行为:

  1. 算术运算:加法、减法
  2. 逻辑运算:与、或、异或
  3. 位运算:逻辑左移、逻辑右移、算术右移
  4. 比较运算:等于、不等于、小于、大于等于
  5. 无:空操作

可以用四位二进制数对上面的行为进行编码(空操作归为算术运算):

操作 类型 信号编码
NOP 算术运算 00_00
加法 算术运算 00_01
减法 算术运算 00_10
逻辑运算 01_00
逻辑运算 01_01
异或 逻辑运算 01_11
逻辑左移 位运算 10_00
逻辑右移 位运算 10_01
算术右移 位运算 10_11
等于 比较运算 11_00
不等于 比较运算 11_01
小于 比较运算 11_10
大于等于 比较运算 11_11

通过对opcodefunct3funct7可以得到上面的编码,在此不做赘述,后面看代码就行。

对于这种硬编码,我们不能每次都照着二进制数去写,有一个好方法就是专门用一个文件来存放这种硬编码,我们这里就创建一个src/main/scala/utils/HardCodes.scala文件,内容如下:

package utils

import chisel3._

object OP_TYPES {
    val OP_TYPES_WIDTH = 4
    val OP_NOP = "b0000".U
    val OP_ADD = "b0001".U
    val OP_SUB = "b0010".U
    val OP_AND = "b0100".U
    val OP_OR = "b0101".U
    val OP_XOR = "b0111".U
    val OP_SLL = "b1000".U
    val OP_SRL = "b1001".U
    val OP_SRA = "b1011".U
    val OP_EQ = "b1100".U
    val OP_NEQ = "b1101".U
    val OP_LT = "b1110".U
    val OP_GE = "b1111".U
}

另外,我们还注意到几条末尾是U的指令,它们要求将操作数作为无符号数来进行运算,因此ALU应该还有一个输入用于指示ALU进行无符号运算还是有符号运算,用一位信号ctrlSigned就行。

还有其他的是否分支、是否跳转、是否加载、是否存储、rs1是否为PC、rs2是否为立即数等,都比较简单,分析方法类似,直接放上代码吧。下面就是src/main/scala/rv32isc/Decoder.scala的具体实现:

package rv32isc

import chisel3._
import chisel3.util._

import config.Configs._
import utils.OP_TYPES._
import utils.LS_TYPES._
import utils._

class DecoderIO extends Bundle {
    val inst = Input(UInt(INST_WIDTH.W))
    val bundleCtrl = new BundleControl()
    val bundleReg = new BundleReg()
    val imm = Output(UInt(DATA_WIDTH.W))
}

class Decoder extends Module {
    val io = IO(new DecoderIO())

    // 三个寄存器号
    io.bundleReg.rs1 := io.inst(19, 15)
    io.bundleReg.rs2 := io.inst(24, 20)
    io.bundleReg.rd := io.inst(11, 7)

    // 五种立即数
    val imm_i = Cat(Fill(20, io.inst(31)), io.inst(31, 20))
    val imm_s = Cat(Fill(20, io.inst(31)), io.inst(31, 25), io.inst(11, 7))
    val imm_b = Cat(Fill(20, io.inst(31)), io.inst(7), io.inst(30, 25), io.inst(11, 8), 0.U(1.W))
    val imm_u = Cat(io.inst(31, 12), Fill(12, 0.U))
    val imm_j = Cat(Fill(12, io.inst(31)), io.inst(31), io.inst(19, 12), io.inst(20), io.inst(30, 21), Fill(1, 0.U))
    // 和用于移位的shamt
    val imm_shamt = Cat(Fill(27, 0.U), io.inst(24, 20))

    // 用于立即数输出
    val imm = WireDefault(0.U(32.W))

    // 用于控制信号
    val ctrlJump = WireDefault(false.B)
    val ctrlBranch = WireDefault(false.B)
    val ctrlRegWrite = WireDefault(true.B)
    val ctrlLoad = WireDefault(false.B)
    val ctrlStore = WireDefault(false.B)
    val ctrlALUSrc = WireDefault(false.B)
    val ctrlJAL = WireDefault(false.B)
    val ctrlOP = WireDefault(0.U(OP_TYPES_WIDTH.W))
    val ctrlSigned = WireDefault(true.B)

    // 根据opcode对控制信号赋值
    switch (io.inst(6, 2)) {
        // U: LUI, AUIPC
        is ("b01101".U, "b00101".U) {
            ctrlALUSrc := true.B
            ctrlOP := OP_ADD
            imm := imm_u
        }
        // J: JAL
        is ("b11011".U) {
            ctrlALUSrc := true.B
            ctrlJump := true.B
            ctrlOP := OP_ADD
            ctrlJAL := true.B
            imm := imm_j
        }
        // I: JALR, 
        // I: LB, LH, LW, LBU, LHU
        // I: ADDI, SLTI, SLTIU, XORI, ORI, ANDI, SLLI, SRLI, SRAI
        is ("b11001".U, "b00000".U, "b00100".U) {
            ctrlALUSrc := true.B
            // JALR
            when (io.inst(6, 2) === "b11001".U) {
                ctrlJump := true.B
                ctrlOP := OP_ADD
                imm := imm_i
            }
            // LOAD
            .elsewhen (io.inst(6, 2) === "b00000".U) {
                ctrlLoad := true.B
                ctrlOP := OP_ADD
                imm := imm_i
                when (io.inst(14, 12) === "b100".U | io.inst(14, 12) === "b101".U) {
                    ctrlSigned := false.B
                }
            }
            // AL
            .elsewhen (io.inst(6, 2) === "b00100".U && (io.inst(14, 12) === "b001".U || io.inst(14, 12) === "b101".U)) {
                imm := imm_shamt
                switch (Cat(io.inst(30), io.inst(14, 12))) {
                    // SLLI
                    is ("b0001".U) {
                        ctrlOP := OP_SLL
                    }
                    // SRLI
                    is ("b0101".U) {
                        ctrlOP := OP_SRL
                    }
                    // SRAI
                    is ("b1101".U) {
                        ctrlOP := OP_SRA
                    }
                }
            } .otherwise {
                imm := imm_i
                switch (io.inst(14, 12)) {
                    // ADDI
                    is ("b000".U) {
                        ctrlOP := OP_ADD
                    }
                    // SLTI
                    is ("b010".U) {
                        ctrlOP := OP_LT
                    }
                    // SLTIU
                    is ("b011".U) {
                        ctrlOP := OP_LT
                        ctrlSigned := false.B
                    }
                    // XORI
                    is ("b100".U) {
                        ctrlOP := OP_XOR
                    }
                    // ORI
                    is ("b110".U) {
                        ctrlOP := OP_OR
                    }
                    // ANDI
                    is ("b111".U) {
                        ctrlOP := OP_AND
                    }
                }
            }
        }
        // B: BEQ, BNE, BLT, BGE, BLTU, BGEU
        is ("b11000".U) {
            ctrlALUSrc := false.B
            ctrlBranch := true.B
            ctrlRegWrite := false.B
            imm := imm_b
            switch (io.inst(14, 12)) {
                // BEQ
                is ("b000".U) {
                    ctrlOP := OP_EQ
                }
                // BNE
                is ("b001".U) {
                    ctrlOP := OP_NEQ
                }
                // BLT
                is ("b100".U) {
                    ctrlOP := OP_LT
                }
                // BGE
                is ("b101".U) {
                    ctrlOP := OP_GE
                }
                // BLTU
                is ("b110".U) {
                    ctrlOP := OP_LT
                    ctrlSigned := false.B
                }
                // BGEU
                is ("b111".U) {
                    ctrlOP := OP_GE
                    ctrlSigned := false.B
                }
            }
        }
        // S: SB, SH, SW
        is ("b01000".U) {
            ctrlALUSrc := true.B
            ctrlStore := true.B
            ctrlRegWrite := false.B
            ctrlOP := OP_ADD
            imm := imm_s
            when (io.inst(14, 12) === "b000".U) {
                ctrlLSType := LS_B
            }
            when (io.inst(14, 12) === "b001".U) {
                ctrlLSType := LS_H
            }
        }
        // R: ADD, SUB, SLL, SLT, SLTU, XOR, SRL, SRA, OR, AND
        is ("b01100".U) {
            switch (io.inst(14, 12)) {
                // ADD, SUB
                is ("b000".U) {
                    when (io.inst(30)) {
                        ctrlOP := OP_SUB
                    } .otherwise {
                        ctrlOP := OP_ADD
                    }
                }
                // SLL
                is ("b001".U) {
                    ctrlOP := OP_SLL
                }
                // SLT
                is ("b010".U) {
                    ctrlOP := OP_LT
                }
                // SLTU
                is ("b011".U) {
                    ctrlOP := OP_LT
                    ctrlOP := false.B
                }
                // XOR
                is ("b100".U) {
                    ctrlOP := OP_XOR
                }
                // SRL, SRA
                is ("b101".U) {
                    when (io.inst(30)) {
                        ctrlOP := OP_SRA
                    } .otherwise {
                        ctrlOP := OP_SRL
                    }
                }
                // OR
                is ("b110".U) {
                    ctrlOP := OP_OR
                }
                // AND
                is ("b111".U) {
                    ctrlOP := OP_AND
                }
            }
        }
    }

    // 连接控制信号和立即数
    io.bundleCtrl.ctrlALUSrc := ctrlALUSrc
    io.bundleCtrl.ctrlBranch := ctrlBranch
    io.bundleCtrl.ctrlJAL := ctrlJAL
    io.bundleCtrl.ctrlJump := ctrlJump
    io.bundleCtrl.ctrlLoad := ctrlLoad
    io.bundleCtrl.ctrlOP := ctrlOP
    io.bundleCtrl.ctrlRegWrite := ctrlRegWrite
    io.bundleCtrl.ctrlSigned := ctrlSigned
    io.bundleCtrl.ctrlStore := ctrlStore
    io.imm := imm
}

其中,我们使用了两个Bundle,由于这个Bundle可以在模块之间复用,所以我们将其放到一个单独的Bundle文件中供使用(src/main/scala/utils/Bundles.scala):

package utils

import chisel3._

import config.Configs._
import utils.OP_TYPES._

// 用于连接控制模块的Bundle
class BundleControl extends Bundle {
    val ctrlJump = Output(Bool())
    val ctrlBranch = Output(Bool())
    val ctrlRegWrite = Output(Bool())
    val ctrlLoad = Output(Bool())
    val ctrlStore = Output(Bool())
    val ctrlALUSrc = Output(Bool())
    val ctrlJAL = Output(Bool())
    val ctrlOP = Output(UInt(OP_TYPES_WIDTH.W))
    val ctrlSigned = Output(Bool())
}

// 用于连接寄存器模块的Bundle
class BundleReg extends Bundle {
    val rs1 = Output(UInt(REG_NUMS_LOG.W))
    val rs2 = Output(UInt(REG_NUMS_LOG.W))
    val rd = Output(UInt(REG_NUMS_LOG.W))
}

这一部分其实硬写测试并不明智,这里还是放上一个随意的测试代码(src/test/scala/rv32isc/DecoderTest.scala):

package rv32isc

import chisel3._
import chiseltest._
import chisel3.util._
import org.scalatest.flatspec.AnyFlatSpec

import config.Configs._

trait DecoderTestFunc {
    def testFn(dut: Decoder): Unit = {
        // LUI x1, 0x1111
        // AUIPC x2, 0x2222
        // JAL x3, L0
        // L0:
        // JALR x4, x3, 4
        // L3:
        // BEQ x5, x6, L1
        // L1:
        // BNE x1, x2, L2
        // L2:
        // BLT x1, x2, L4
        // L4:
        // BGE x1, x2, L5
        // L5:
        // BLTU x1, x2, L6
        // L6:
        // BGEU x1, x2, L7
        // L7:
        // LB x1, 0x4, x2
        // LH x1, 0x8, x2
        // LW x1, 0x4, x2
        // LBU x1, 0x8, x2
        // LHU x1, 0xc, x2
        // SB x1, 0x4, x1
        // SH x1, 0x4, x1
        // SW x1, 0x4, x1
        // ADDI x1, x1, 0x4
        // SLTI x1, x1, 0x4
        // SLTIU x1, x1, 0x4
        // XORI x1, x1, 0x4
        // ORI x1, x1, 0x4
        // ANDI x1, x1, 0x4
        // SLLI x1, x1, 0x4
        // SRLI x1, x1, 0x4
        // SRAI x1, x1, 0x4
        // ADD x1, x1, x2
        // SUB x1, x1, x2
        // SLL x1, x1, x2
        // SLT x1, x1, x2
        // SLTU x1, x1, x2
        // XOR x1, x1, x2
        // SRL x1, x1, x2
        // SRA x1, x1, x2
        // OR x1, x1, x2
        // AND x1, x1, x2

        val inst_list = Seq(
            "h011110b7".U,
            "h02222117".U,
            "h004001ef".U,
            "h00418267".U,
            "h00628263".U,
            "h00209263".U,
            "h0020c263".U,
            "h0020d263".U,
            "h0020e263".U,
            "h0020f263".U,
            "h00410083".U,
            "h00811083".U,
            "h00412083".U,
            "h00814083".U,
            "h00c15083".U,
            "h00108223".U,
            "h00109223".U,
            "h0010a223".U,
            "h00408093".U,
            "h0040a093".U,
            "h0040b093".U,
            "h0040c093".U,
            "h0040e093".U,
            "h0040f093".U,
            "h00409093".U,
            "h0040d093".U,
            "h4040d093".U,
            "h002080b3".U,
            "h402080b3".U,
            "h002090b3".U,
            "h0020a0b3".U,
            "h0020b0b3".U,
            "h0020c0b3".U,
            "h0020d0b3".U,
            "h4020d0b3".U,
            "h0020e0b3".U,
            "h0020f0b3".U
        )

        def test_Decoder(dut: Decoder): Unit = {
            for (inst <- inst_list) {
                dut.io.inst.poke(inst)
                println(dut.io.bundleReg.rs1.peekInt())
                println(dut.io.bundleReg.rs2.peekInt())
                println(dut.io.imm.peek())
                println(dut.io.bundleReg.rd.peekInt())
                println(dut.io.bundleCtrl.peek())
            }
        }
    }
}

class DecoderTest
    extends AnyFlatSpec
    with ChiselScalatestTester
    with DecoderTestFunc {
    "Decoder" should "pass" in {
        test(new Decoder) { dut =>
            testFn(dut)
        }
    }
}

测试通过。

解码单元的信号输出一部分给到寄存器,用于读取寄存器中的数据,下面我们就实现寄存器。

寄存器组的实现

寄存器的实现比较简单,我们前面的文章中也有相应的例子,这里就不分析了,直接放上代码(src/main/scala/rv32isc/Registers.scala):

package rv32isc

import chisel3._
import chisel3.util._

import config.Configs._
import utils._

class RegistersIO extends Bundle {
    val ctrlRegWrite = Input(Bool())
    val dataWrite = Input(UInt(DATA_WIDTH.W))
    val bundleReg = Output(Flipped(new BundleReg))
    val dataRead1 = Output(UInt(DATA_WIDTH.W))
    val dataRead2 = Output(UInt(DATA_WIDTH.W))
}

class Registers extends Module {
    val io = IO(new RegistersIO())

    // 寄存器组,REG_NUMS个,位宽DATA_WIDTH
    val regs = Reg(Vec(REG_NUMS, UInt(DATA_WIDTH.W)))

    // 寄存器号为0时读到0
    when (io.bundleReg.rs1 === 0.U) {
        io.dataRead1 := 0.U
    }
    when (io.bundleReg.rs2 === 0.U) {
        io.dataRead2 := 0.U
    }
    // 否则给出数据
    io.dataRead1 := regs(io.bundleReg.rs1)
    io.dataRead2 := regs(io.bundleReg.rs2)
    // 给出写信号,且rd不为0时写寄存器
    when (io.ctrlRegWrite && io.bundleReg.rd =/= 0.U) {
        regs(io.bundleReg.rd) := io.dataWrite
    }
}

然后是测试代码:

package rv32isc

import chisel3._
import chiseltest._
import chisel3.util._
import org.scalatest.flatspec.AnyFlatSpec

import config.Configs._

trait RegistersTestFunc {
    // 随机填入数据
    val oprand_list = Seq.fill(REG_NUMS)(scala.util.Random.nextInt().toLong & 0x00ffffffffL)

    def testRegs(dut: Registers): Unit = {
        // 初始化状态
        for (i <- 0 to REG_NUMS - 1) {
            dut.io.bundleReg.rs1.poke(i.U)
            dut.io.dataRead1.expect(0.U)
            dut.io.bundleReg.rs2.poke(i.U)
            dut.io.dataRead2.expect(0.U)
        }
        // 写入
        for (i <- 0 to REG_NUMS - 1) {
            dut.io.ctrlRegWrite.poke(true.B)
            dut.io.bundleReg.rd.poke(i.U)
            dut.io.dataWrite.poke(oprand_list(i))
            dut.clock.step()
        }
        // 读取
        for (i <- 0 to REG_NUMS - 1) {
            dut.io.bundleReg.rs1.poke(i.U)
            if (i == 0) {
                dut.io.dataRead1.expect(0.U)
            } else {
                dut.io.dataRead1.expect(oprand_list(i))
            }
            
            dut.io.bundleReg.rs2.poke(i.U)
            if (i == 0) {
                dut.io.dataRead2.expect(0.U)
            } else {
                dut.io.dataRead2.expect(oprand_list(i))
            }
        }
        // 不能写时尝试写0
        for (i <- 0 to REG_NUMS - 1) {
            dut.io.ctrlRegWrite.poke(false.B)
            dut.io.bundleReg.rd.poke(i.U)
            dut.io.dataWrite.poke(0)
            dut.clock.step()
        }
        // 再次读
        for (i <- 0 to REG_NUMS - 1) {
            dut.io.bundleReg.rs1.poke(i.U)
            if (i == 0) {
                dut.io.dataRead1.expect(0.U)
            } else {
                dut.io.dataRead1.expect(oprand_list(i))
            }
            
            dut.io.bundleReg.rs2.poke(i.U)
            if (i == 0) {
                dut.io.dataRead2.expect(0.U)
            } else {
                dut.io.dataRead2.expect(oprand_list(i))
            }
        }
    }
}

class RegistersTest extends AnyFlatSpec with ChiselScalatestTester with RegistersTestFunc {
    "Registers" should "pass" in {
        test(new Registers) { dut =>
            testRegs(dut)
        }
    }
}

测试通过。

数据也有了,控制信号也有了,下面我们就来实现最关键的ALU部分吧。

ALU模块的实现

虽然说ALU模块很关键,但是由于前面打下了较好的基础,所以在实现Alu模块时轻松了很多。现在控制单元到Alu有四个控制信号,所以我们还是用Bundle的方式,在src/main/scala/utils/Bundles.scala中加入:

class BundleAluControl extends Bundle {
    val ctrlALUSrc = Input(Bool())
    val ctrlJAL = Input(Bool())
    val ctrlOP = Input(UInt(OP_TYPES_WIDTH.W))
    val ctrlSigned = Input(Bool())
    val ctrlBranch = Input(Bool())
}

需要注意到的是,JAL指令需要PC寄存器寄存器的值作为操作数1,所以需要增加一个输入接口。然后就可以轻松实现Alu了,具体代码(src/main/scala/rv32isc/Alu.scala)如下:

package rv32isc

import chisel3._
import chisel3.util._

import config.Configs._
import utils.OP_TYPES._
import utils._

class AluIO extends Bundle {
    val bundleAluControl = new BundleAluControl()
    val dataRead1 = Input(UInt(DATA_WIDTH.W))
    val dataRead2 = Input(UInt(DATA_WIDTH.W))
    val imm = Input(UInt(DATA_WIDTH.W))
    val pc = Input(UInt(ADDR_WIDTH.W))
    val resultBranch = Output(Bool())
    val resultAlu = Output(UInt(DATA_WIDTH.W))
}

class Alu extends Module {
    val io = IO(new AluIO())

    // 用于输出比较结果和计算结果
    val resultBranch = WireDefault(false.B)
    val resultAlu = WireDefault(0.U(DATA_WIDTH.W))

    // 用于得到操作数
    val oprand1 = WireDefault(0.U(DATA_WIDTH.W))
    val oprand2 = WireDefault(0.U(DATA_WIDTH.W))

    oprand1 := Mux(io.bundleAluControl.ctrlJAL, io.pc, io.dataRead1)
    oprand2 := Mux(io.bundleAluControl.ctrlALUSrc, io.imm, io.dataRead2)

    // 根据bundleAluControl中的信号进行选择
    switch(io.bundleAluControl.ctrlOP) {
        is(OP_NOP) { // 啥也不干
            resultAlu := 0.U
            resultBranch := false.B
        }
        is(OP_ADD) {
            resultAlu := oprand1 +& oprand2
        }
        is(OP_SUB) {
            resultAlu := oprand1 -& oprand2
        }
        is(OP_AND) {
            resultAlu := oprand1 & oprand2
        }
        is(OP_OR) {
            resultAlu := oprand1 | oprand2
        }
        is(OP_XOR) {
            resultAlu := oprand1 ^ oprand2
        }
        is(OP_SLL) {
            resultAlu := oprand1 << oprand2(4, 0)
        }
        is(OP_SRL) {
            resultAlu := oprand1 >> oprand2(4, 0)
        }
        is(OP_SRA) { // 需要注意算术右移的写法
            resultAlu := (oprand1.asSInt >> oprand2(4, 0)).asUInt
        }
        is(OP_EQ) {
            resultBranch := oprand1.asSInt === oprand2.asSInt
            resultAlu := io.pc +& io.imm
        }
        is(OP_NEQ) {
            resultBranch := oprand1.asSInt =/= oprand2.asSInt
            resultAlu := io.pc +& io.imm
        }
        is(OP_LT) { // 区分有符号比较和无符号比较、分支和SLT
            when(io.bundleAluControl.ctrlBranch) {
                when(io.bundleAluControl.ctrlSigned) {
                    resultBranch := oprand1.asSInt < oprand2.asSInt
                }.otherwise {
                    resultBranch := oprand1 < oprand2
                }
                resultAlu := io.pc +& io.imm
            }.otherwise {
                when(io.bundleAluControl.ctrlSigned) {
                    resultAlu := oprand1.asSInt < oprand2.asSInt
                }.otherwise {
                    resultAlu := oprand1 < oprand2
                }
            }
        }
        is(OP_GE) { // 区分有符号比较和无符号比较
            when(io.bundleAluControl.ctrlSigned) {
                resultBranch := oprand1.asSInt >= oprand2.asSInt
            }.otherwise {
                resultBranch := oprand1 >= oprand2
            }
            resultAlu := io.pc +& io.imm
        }
    }

    io.resultAlu := resultAlu
    io.resultBranch := resultBranch
}

测试代码如下(因为赶时间,这里写的是有问题的,大家可以自行认真编写):

package rv32isc

import chisel3._
import chiseltest._
import chisel3.util._
import org.scalatest.flatspec.AnyFlatSpec

import config.Configs._
import utils.OP_TYPES._

trait AluTestFunc {
    // 测试所有功能
    val OP_TYPES_LIST = Seq(
        OP_NOP,
        OP_ADD,
        OP_SUB,
        OP_AND,
        OP_OR,
        OP_XOR,
        OP_SLL,
        OP_SRL,
        OP_SRA,
        OP_EQ,
        OP_NEQ,
        OP_LT,
        // OP_GE
    )

    // 随机的操作数
    val oprand_list =
        Seq.fill(10)(scala.util.Random.nextInt().toLong & 0x00ffffffffL)

    // 用于比对的正确结果
    def alu(a: Long, b: Long, op: UInt, sign: Boolean): (Long, Boolean) = {
        op match {
            case OP_NOP => (0, false)
            case OP_ADD => (a + b, false)
            case OP_SUB => (a - b, false)
            case OP_AND => (a & b, false)
            case OP_OR  => (a | b, false)
            case OP_XOR => (a ^ b, false)
            case OP_SLL => (a << (b & 0x000000001f), false)
            case OP_SRL => (a >>> (b & 0x000000001f), false)
            case OP_SRA => (a.toInt >> (b & 0x000000001f).toInt, false)
            case OP_EQ  => (0, a == b)
            case OP_NEQ => (0, a != b)
            case OP_LT => {
                if (sign) {
                    (0, (a << 32) < (b << 32))
                } else {
                    (0, a < b)
                }
            }
            case OP_GE => {
                if (sign) {
                    (0, (a << 32) >= (b << 32))
                } else {
                    (0, a >= b)
                }
            }
            case _ => (0, false)
        }
    }

    def testOne(dut: Alu, a: Long, b: Long, op: UInt, sign: Boolean): Unit = {
        // 正常测试
        dut.io.bundleAluControl.ctrlALUSrc.poke(false.B)
        dut.io.bundleAluControl.ctrlJAL.poke(false.B)
        dut.io.bundleAluControl.ctrlOP.poke(op)
        dut.io.bundleAluControl.ctrlSigned.poke(sign)
        dut.io.pc.poke(0.U)
        dut.io.imm.poke(0.U)
        dut.io.dataRead1.poke(a.U)
        dut.io.dataRead2.poke(b.U)
        val (resultAlu, resultBranch) = alu(a, b, op, sign)
        dut.io.resultAlu.expect((resultAlu.toLong & 0x00ffffffffL).U)
        dut.io.resultBranch.expect(resultBranch.B)
        // JAL+IMM
        dut.io.bundleAluControl.ctrlALUSrc.poke(true.B)
        dut.io.bundleAluControl.ctrlJAL.poke(true.B)
        dut.io.pc.poke(a.U)
        dut.io.imm.poke(b.U)
        dut.io.dataRead1.poke(0.U)
        dut.io.dataRead2.poke(0.U)
        val (resultAluJAL, resultBranchJAL) = alu(a, b, op, sign)
        dut.io.resultAlu.expect((resultAluJAL.toLong & 0x00ffffffffL).U)
        dut.io.resultBranch.expect(resultBranchJAL.B)
        // IMM
        dut.io.bundleAluControl.ctrlALUSrc.poke(true.B)
        dut.io.bundleAluControl.ctrlJAL.poke(false.B)
        dut.io.pc.poke(0.U)
        dut.io.imm.poke(b.U)
        dut.io.dataRead1.poke(a.U)
        dut.io.dataRead2.poke(0.U)
        val (resultAluIMM, resultBranchIMM) = alu(a, b, op, sign)
        dut.io.resultAlu.expect((resultAluIMM.toLong & 0x00ffffffffL).U)
        dut.io.resultBranch.expect(resultBranchIMM.B)
    }

    // 遍历功能和操作数进行测试
    def testFn(dut: Alu): Unit = {
        for (a <- oprand_list) {
            for (b <- oprand_list) {
                for (op <- OP_TYPES_LIST) {
                    testOne(dut, a, b, op, true)
                    testOne(dut, a, b, op, false)
                }
            }
        }
    }
}

class AluTest extends AnyFlatSpec with ChiselScalatestTester with AluTestFunc {
    "ALU" should "pass" in {
        test(new Alu) { dut =>
            testFn(dut)
        }
    }
}

测试通过。

数据内存的实现

Alu的计算结果可以直接给寄存器,也可以先给数据内存。因为LOAD、STORE类指令需要使用ALU计算得到的内存地址,所以我们完全可以把计算结果给数据内存,如果是LOAD或STORE指令,那就以此为地址读写数据,否则将结果直接发送给寄存器。

另外一点在之前的设计中忽略了的是,LOAD、STORE类指令需要区分字、半字和字节,因此我们需要在译码阶段多给控制单元一个信号,来指示操作数。并且,LOAD类指令也区分有符号无符号,因此也需要提供这个信号。

还有,STORE的数据源从哪里来?是rs2,可是我们计算地址用的是imm作为操作数2,因此,这里我们还需要把rs2里面的值给出到数据内存。

最后数据内存的实现如下:

package rv32isc

import chisel3._
import chisel3.util._

import config.Configs._
import utils.OP_TYPES._
import utils.LS_TYPES._
import utils._

class MemDataIO extends Bundle {
    val bundleMemDataControl = new BundleMemDataControl()
    val resultALU = Input(UInt(DATA_WIDTH.W))
    val dataStore = Input(UInt(DATA_WIDTH.W))
    val result = Output(UInt(DATA_WIDTH.W))
}

class MemData extends Module {
    val io = IO(new MemDataIO)

    // 数据内存
    val mem = Mem(MEM_DATA_SIZE, UInt(DATA_WIDTH.W))

    // 用于输出的结果
    val result = WireDefault(0.U(DATA_WIDTH.W))

    // 从内存中读取的数
    val dataLoad = WireDefault(0.U(DATA_WIDTH.W))

    // 不论是STORE还是LOAD,都需要用到这个读数
    dataLoad := mem.read(io.resultALU >> DATA_BYTE_WIDTH_LOG.U)

    // STORE指令
    when(io.bundleMemDataControl.ctrlStore) {
        when(io.bundleMemDataControl.ctrlLSType === LS_W) { // 修改全部4字节
            mem.write(io.resultALU >> DATA_BYTE_WIDTH_LOG.U, io.dataStore)
        }.elsewhen(io.bundleMemDataControl.ctrlLSType === LS_H) {   // 修改低2字节
            mem.write(io.resultALU >> DATA_BYTE_WIDTH_LOG.U, Cat(dataLoad(31, 16), io.dataStore(15, 0)))
        }.otherwise {   // 修改最低一个字节
            mem.write(io.resultALU >> DATA_BYTE_WIDTH_LOG.U, Cat(dataLoad(31, 8), io.dataStore(7, 0)))
        }
    }
    // LOAD指令
    when (io.bundleMemDataControl.ctrlLoad) {
        when(io.bundleMemDataControl.ctrlLSType === LS_W) {
            result := dataLoad
        }.elsewhen(io.bundleMemDataControl.ctrlLSType === LS_H) {
            when (io.bundleMemDataControl.ctrlSigned) {
                result := Cat(Fill(16, dataLoad(15)), dataLoad(15, 0))
            } .otherwise {
                result := Cat(Fill(16, 0.U), dataLoad(15, 0))
            }
        }.otherwise {
            when (io.bundleMemDataControl.ctrlSigned) {
                result := Cat(Fill(24, dataLoad(7)), dataLoad(7, 0))
            } .otherwise {
                result := Cat(Fill(24, 0.U), dataLoad(7, 0))
            }
        } 
    // 非LOAD指令
    } .otherwise {
        result := io.resultALU
    }
    
    // 输出
    io.result := result
}

测试代码如下:

package rv32isc

import chisel3._
import chiseltest._
import chisel3.util._
import org.scalatest.flatspec.AnyFlatSpec

import config.Configs._
import utils.OP_TYPES._
import utils.LS_TYPES._

trait MemDataTestFunc {
    // 生成MEM_DATA_SIZE条随机数据进行测试
    val data_list =
        Seq.fill(MEM_DATA_SIZE)(
            scala.util.Random.nextInt().toLong & 0x00ffffffffL
        )

    def testFn(dut: MemData): Unit = {
        // 初始化状态
        dut.clock.setTimeout(0)
        dut.io.bundleMemDataControl.ctrlLoad.poke(false.B)
        dut.io.bundleMemDataControl.ctrlStore.poke(false.B)
        dut.io.bundleMemDataControl.ctrlLSType.poke(LS_W)
        dut.io.bundleMemDataControl.ctrlSigned.poke(false.B)
        dut.io.dataStore.poke(0.U(DATA_WIDTH.W))
        dut.io.resultALU.poke(0.U(DATA_WIDTH.W))
        // 非LS指令
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U)
            dut.clock.step(1)
            dut.io.result.expect((i * DATA_BYTE_WIDTH).U)
        }
        // SW指令
        dut.io.bundleMemDataControl.ctrlStore.poke(true.B)
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.dataStore.poke(data_list(i))
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U) // 作为地址,应该左移两位,即乘以4
            dut.io.result.expect((i * DATA_BYTE_WIDTH).U)
            dut.clock.step(1)
        }
        // LW指令
        dut.io.bundleMemDataControl.ctrlStore.poke(false.B)
        dut.io.bundleMemDataControl.ctrlLoad.poke(true.B)
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.dataStore.poke(0.U)
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U) // 作为地址,应该左移两位,即乘以4
            dut.io.result.expect(data_list(i).U)
        }
        // SH清零低16比特
        dut.io.bundleMemDataControl.ctrlStore.poke(true.B)
        dut.io.bundleMemDataControl.ctrlLoad.poke(false.B)
        dut.io.bundleMemDataControl.ctrlLSType.poke(LS_H)
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.dataStore.poke(0)
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U) // 作为地址,应该左移两位,即乘以4
            dut.io.result.expect((i * DATA_BYTE_WIDTH).U)
            dut.clock.step(1)
        }
        // LHU指令读低16比特
        dut.io.bundleMemDataControl.ctrlStore.poke(false.B)
        dut.io.bundleMemDataControl.ctrlLoad.poke(true.B)
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.dataStore.poke(0.U)
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U) // 作为地址,应该左移两位,即乘以4
            dut.io.result.expect(0.U)
        }
        // LH指令读低16比特
        dut.io.bundleMemDataControl.ctrlSigned.poke(true.B)
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.dataStore.poke(0.U)
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U) // 作为地址,应该左移两位,即乘以4
            dut.io.result.expect(0.U)
        }
        dut.io.bundleMemDataControl.ctrlSigned.poke(false.B)
        // SB存储低8比特
        dut.io.bundleMemDataControl.ctrlStore.poke(true.B)
        dut.io.bundleMemDataControl.ctrlLoad.poke(false.B)
        dut.io.bundleMemDataControl.ctrlLSType.poke(LS_B)
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.dataStore.poke(data_list(i))
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U) // 作为地址,应该左移两位,即乘以4
            dut.io.result.expect((i * DATA_BYTE_WIDTH).U)
            dut.clock.step(1)
        }
        // LBU指令读低8比特
        dut.io.bundleMemDataControl.ctrlStore.poke(false.B)
        dut.io.bundleMemDataControl.ctrlLoad.poke(true.B)
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.dataStore.poke(0.U)
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U) // 作为地址,应该左移两位,即乘以4
            dut.io.result.expect((data_list(i).toLong & 0x00000000ffL).U)
        }
        // LB指令读低8比特
        dut.io.bundleMemDataControl.ctrlSigned.poke(true.B)
        for (i <- 0 to MEM_DATA_SIZE - 1) {
            dut.io.dataStore.poke(0.U)
            dut.io.resultALU.poke((i * DATA_BYTE_WIDTH).U) // 作为地址,应该左移两位,即乘以4
            if ((data_list(i).toLong & 0x0000000080L) == 0) {
                dut.io.result.expect((data_list(i).toLong & 0x00000000ffL).U)
            } else {
                dut.io.result.expect(((data_list(i).toLong & 0x00000000ffL) | 0x00ffffff00L).U)
            }
        }
    }
}

class MemDataTest
    extends AnyFlatSpec
    with ChiselScalatestTester
    with MemDataTestFunc {
    "MemData" should "pass" in {
        test(new MemData) { dut =>
            testFn(dut)
        }
    }
}

测试通过。

控制单元的实现

最后控制单元,但在控制单元之前,我们需要注意到JALJALR这两条指令的特殊性,它们存放的寄存器的值是当前的指令的地址+4,所以我们需要对寄存器堆先做一些修改,一个方面是要从控制模块给一个ctrlJump信号,另一方面是要选择写入的数据是PC+4还是数据内存返回的计算结果/加载的数据。修改后如下:

when(io.ctrlRegWrite && io.bundleReg.rd =/= 0.U) {
    when(io.ctrlJump) {
        regs(io.bundleReg.rd) := io.pc + INST_BYTE_WIDTH.U
    }.otherwise {
        regs(io.bundleReg.rd) := io.dataWrite
    }
}

那我们就可以写控制单元了,都是些连线,没什么技术含量:

package rv32isc

import chisel3._
import chisel3.util._

import utils._

class ControllerIO extends Bundle {
    val bundleControlIn = Flipped(new BundleControl()) // 来自译码器
    val bundleAluControl = Flipped(new BundleAluControl())  // 到ALU
    val bundleMemDataControl = Flipped(new BundleMemDataControl())  // 到数据内存
    val bundleControlOut = new BundleControl()  // 到其他
}

class Controller extends Module {
    val io = IO(new ControllerIO)

    // alu
    io.bundleAluControl.ctrlALUSrc := io.bundleControlIn.ctrlALUSrc
    io.bundleAluControl.ctrlJAL := io.bundleControlIn.ctrlJAL
    io.bundleAluControl.ctrlOP := io.bundleControlIn.ctrlOP
    io.bundleAluControl.ctrlSigned := io.bundleControlIn.ctrlSigned
    io.bundleAluControl.ctrlBranch := io.bundleControlIn.ctrlBranch

    // 内存单元
    io.bundleMemDataControl.ctrlLSType := io.bundleControlIn.ctrlALUSrc
    io.bundleMemDataControl.ctrlLoad := io.bundleControlIn.ctrlLoad
    io.bundleMemDataControl.ctrlSigned := io.bundleControlIn.ctrlSigned
    io.bundleMemDataControl.ctrlStore := io.bundleControlIn.ctrlStore
    
    // 其他
    io.bundleControlOut <> io.bundleControlIn
}

因为只有连线,就不测试了,我们直接进入最后一步,把各模块连接成一个处理器。

把各模块连接成一个处理器!

这一步仍然没有什么技术含量,把各个模块连接到一起就好了,一定要注意连线不要错连、漏连:

package rv32isc

import chisel3._
import chisel3.util._

import config.Configs._
import utils._

// Top的模块接口,用于测试
class TopIO extends Bundle {
    val addr = Output(UInt(ADDR_WIDTH.W))
    val inst = Output(UInt(INST_WIDTH.W))
    val bundleCtrl = new BundleControl()
    val resultALU = Output(UInt(DATA_WIDTH.W))
    val rs1 = Output(UInt(DATA_WIDTH.W))
    val rs2 = Output(UInt(DATA_WIDTH.W))
    val imm = Output(UInt(DATA_WIDTH.W))
    val resultBranch = Output(Bool())
    val result = Output(UInt(DATA_WIDTH.W))
}

class Top extends Module {
    val io = IO(new TopIO())

    val pcReg = Module(new PCReg())
    val memInst = Module(new MemInst())
    val decoder = Module(new Decoder())
    val registers = Module(new Registers())
    val alu = Module(new Alu())
    val memData = Module(new MemData())
    val controller = Module(new Controller())

    // PCReg in
    pcReg.io.resultBranch <> alu.io.resultBranch
    pcReg.io.addrTarget <> memData.io.result
    pcReg.io.ctrlBranch <> controller.io.bundleControlOut.ctrlBranch
    pcReg.io.ctrlJump <> controller.io.bundleControlOut.ctrlJump
    
    // MemInst in
    memInst.io.addr <> pcReg.io.addrOut

    // Decoder in
    decoder.io.inst <> memInst.io.inst

    // Registers in
    registers.io.bundleReg <> decoder.io.bundleReg
    registers.io.ctrlRegWrite <> controller.io.bundleControlOut.ctrlRegWrite
    registers.io.ctrlJump <> controller.io.bundleControlOut.ctrlJump
    registers.io.dataWrite <> memData.io.result
    registers.io.pc <> pcReg.io.addrOut

    // ALU in
    alu.io.bundleAluControl <> controller.io.bundleAluControl
    alu.io.dataRead1 <> registers.io.dataRead1
    alu.io.dataRead2 <> registers.io.dataRead2
    alu.io.imm <> decoder.io.imm
    alu.io.pc <> pcReg.io.addrOut
    
    // MemData in
    memData.io.bundleMemDataControl <> controller.io.bundleMemDataControl
    memData.io.dataStore <> registers.io.dataRead2
    memData.io.resultALU <> alu.io.resultAlu

    // Controller in
    controller.io.bundleControlIn <> decoder.io.bundleCtrl
    
    // top
    io.addr <> pcReg.io.addrOut
    io.bundleCtrl <> decoder.io.bundleCtrl
    io.inst <> memInst.io.inst
    io.result <> memData.io.result
    io.resultALU <> alu.io.resultAlu
    io.resultBranch <> alu.io.resultBranch
    io.imm <> decoder.io.imm
    io.rs1 <> registers.io.dataRead1
    io.rs2 <> registers.io.dataRead2
}

object main extends App {
    println(getVerilogString(new Top()))
}

通过sbt run运行,可以得到最终的Verilog代码,限于篇幅,这里就不放上来了,至少可以说明,编译生成Verilog代码是看起来没问题的。

但是具体的处理器的功能验证还需要对Top模块进行测试,下面就着重说一下。

CPU的整体测试

既然是CPU,那就必须得能跑程序,也就是说我们至少能做到这样的事:

写一段C程序,然后用RISC-V的工具链编译出二进制代码,我们的CPU可以运行这样的代码。

我们可以通过loadMemoryFromFile向指令内存的内存中加载十六进制文本格式的代码,所以我们首先需要从C源文件生成这样可以加载到内存的代码。假设你已经装好了rv32i的工具链,源文件为test.c,那么生成过程如下:

  1. 生成汇编:

    riscv32-unknown-elf-gcc -march=rv32i -S test.c
    
  2. 生成目标文件:

    riscv32-unknown-elf-as -march=rv32i test.s -o test.o
    
  3. 转换出二进制文件:

    riscv32-unknown-elf-objcopy -O binary test.o test.bin
    
  4. 自己写一个脚本,将二进制文件转换为十六进制文本文件,以Python为例:

    import sys
    
    if __name__ == "__main__":
        f_bin = open(str(sys.argv[1]), "rb")
        f_hex = open(str(sys.argv[2]), "w")
        while True:
            buf = f_bin.read(4)
            buf_len = len(buf)
            if buf_len > 0:
                s_hex = ''
                s_hex += hex(buf[3])[2:].zfill(2) + hex(buf[2])[2:].zfill(2) + hex(buf[1])[2:].zfill(2) + hex(buf[0])[2:].zfill(2)
                f_hex.write(s_hex + '\n')
            else:
                break
    

    然后运行:

    python3 bin2hex.py test.bin test.hex
    

我们可以简单写个斐波那契数列计算的C程序:

int test(int n)
{
    if (n == 1 || n == 2) // 数列前两项
    {
        return 1;
    }
    else // 从第三项开始
    {
        return test(n - 1) + test(n - 2);
    }
    return 0;
}
int main()
{
    int n = 10;
    int ret = test(n); // 计算斐波那契数列
    return ret;
}

根据上面的步骤,可以生成如下的内容:

fe010113
00112e23
00812c23
00912a23
02010413
fea42623
fec42703
00100793
00f70863
fec42703
00200793
00f71663
00100793
0380006f
fec42783
fff78793
00078513
00000097
000080e7
00050493
fec42783
ffe78793
00078513
00000097
000080e7
00050793
00f487b3
00078513
01c12083
01812403
01412483
02010113
00008067
fe010113
00112e23
00812c23
02010413
00a00793
fef42623
fec42503
00000097
000080e7
fea42423
fe842783
00078513
01c12083
01812403
02010113
00008067

但是不是没法看?都是十六进制,也不知道跟汇编指令怎么对应。没关系,一句话,让代码更好懂:

riscv32-unknown-elf-objdump -d test.o > test.dump

生成的文件如下:

test.o:     file format elf32-littleriscv


Disassembly of section .text:

00000000 :
   0:	fe010113          	addi	sp,sp,-32
   4:	00112e23          	sw	ra,28(sp)
   8:	00812c23          	sw	s0,24(sp)
   c:	00912a23          	sw	s1,20(sp)
  10:	02010413          	addi	s0,sp,32
  14:	fea42623          	sw	a0,-20(s0)
  18:	fec42703          	lw	a4,-20(s0)
  1c:	00100793          	li	a5,1
  20:	00f70863          	beq	a4,a5,30 <.L2>
  24:	fec42703          	lw	a4,-20(s0)
  28:	00200793          	li	a5,2
  2c:	00f71663          	bne	a4,a5,38 <.L3>

00000030 <.L2>:
  30:	00100793          	li	a5,1
  34:	0380006f          	j	6c <.L4>

00000038 <.L3>:
  38:	fec42783          	lw	a5,-20(s0)
  3c:	fff78793          	addi	a5,a5,-1
  40:	00078513          	mv	a0,a5
  44:	00000097          	auipc	ra,0x0
  48:	000080e7          	jalr	ra # 44 <.L3+0xc>
  4c:	00050493          	mv	s1,a0
  50:	fec42783          	lw	a5,-20(s0)
  54:	ffe78793          	addi	a5,a5,-2
  58:	00078513          	mv	a0,a5
  5c:	00000097          	auipc	ra,0x0
  60:	000080e7          	jalr	ra # 5c <.L3+0x24>
  64:	00050793          	mv	a5,a0
  68:	00f487b3          	add	a5,s1,a5

0000006c <.L4>:
  6c:	00078513          	mv	a0,a5
  70:	01c12083          	lw	ra,28(sp)
  74:	01812403          	lw	s0,24(sp)
  78:	01412483          	lw	s1,20(sp)
  7c:	02010113          	addi	sp,sp,32
  80:	00008067          	ret

00000084 
: 84: fe010113 addi sp,sp,-32 88: 00112e23 sw ra,28(sp) 8c: 00812c23 sw s0,24(sp) 90: 02010413 addi s0,sp,32 94: 00a00793 li a5,10 98: fef42623 sw a5,-20(s0) 9c: fec42503 lw a0,-20(s0) a0: 00000097 auipc ra,0x0 a4: 000080e7 jalr ra # a0 a8: fea42423 sw a0,-24(s0) ac: fe842783 lw a5,-24(s0) b0: 00078513 mv a0,a5 b4: 01c12083 lw ra,28(sp) b8: 01812403 lw s0,24(sp) bc: 02010113 addi sp,sp,32 c0: 00008067 ret

这样我们就有十六进制指令和汇编指令的对应关系了,下面我们看代码。

一个程序在执行的时候是从main函数进入的,在这段程序中,程序入口就是00000084

。对应的,要想正确执行这段程序,我们需要让我们的CPU从0x00000084开始执行。这时候,用Configs.scala存放全局变量的好处就体现出来了,我们只需要将:

val START_ADDR: Long = 0x00000000  // 起始执行地址

修改为:

val START_ADDR: Long = 0x00000084  // 起始执行地址

就行了。

有程序入口就行了嘛?当然不是。我们看第一条语句:

addi	sp,sp,-32

这对sp寄存器减了32,这个sp就是栈指针,程序进行函数调用都需要保护现场,将一些数据压栈,方便调用返回的时候恢复现场。然而这个栈应该是提前分配好空间的,栈底的地址比栈顶的地址要大。sp寄存器其实对应x2寄存器,所以在我们的处理器中初始值是0,那么减32之后就成负数了,我们1024大小的数据内存在索引数据时就会有问题。所以,我们要手动完成栈空间分配这个工作,只需要在程序入口处添加一条指令就行:

addi	sp,sp,1024 # 对应十六进制指令40010113

这里的1024作为地址是字节,相当于我们在1024*4大小的数据内存上分配了1024字节的空间,即可存放256个32位数据的栈,如果觉得不够,可以放4条该指令,刚好完全用完数据内存。

另外,程序的结束处是个返回指令,如果让它返回,它就会返回到调用main函数的地方,但当时栈是空的,所以会返回到地址0处开始执行,陷入无限循环。因此,我们还需要将最后一条指令替换为00000000,用于提示测试模块程序已经结束了。

于是,修改之后的十六进制指令序列如下:

fe010113
00112e23
00812c23
00912a23
02010413
fea42623
fec42703
00100793
00f70863
fec42703
00200793
00f71663
00100793
0380006f
fec42783
fff78793
00078513
00000097
000080e7
00050493
fec42783
ffe78793
00078513
00000097
000080e7
00050793
00f487b3
00078513
01c12083
01812403
01412483
02010113
00008067
40010113
40010113
40010113
40010113
fe010113
00112e23
00812c23
02010413
00a00793
fef42623
fec42503
00000097
000080e7
fea42423
fe842783
00078513
01c12083
01812403
02010113
00000000

这些工作一般由loader完成,我们这里临时手动完成就行,不用编写程序加载器。

下面我们就可以将代码放到MemInst.hex文件中,然后编写测试程序。测试代码如下:

package rv32isc

import chisel3._
import chiseltest._
import chisel3.util._
import org.scalatest.flatspec.AnyFlatSpec

import java.io.PrintWriter
import java.io.File

import config.Configs._

trait TopTestFunc {

    def testFn(dut: Top): Unit = {
        dut.clock.setTimeout(0)
        while (dut.io.inst.peekInt() != 0) {	// 运行到程序结束处停止运行
                println("PC", dut.io.addr.peekInt().toLong.toHexString)
                println("INST", dut.io.inst.peekInt().toLong.toHexString)
            if (dut.io.addr.peekInt() == 0xb8) {	// 调用返回的下一条指令对应的地址是0xa8,加上4条添加的指令,0xa8+0x10=0xb8,此时的rs2就是计算结果
                println("RES", dut.io.result.peekInt())
                println("RESALU", dut.io.resultALU.peekInt())
                println("RESBRANCH", dut.io.resultBranch.peek())
                println("RESJUMP", dut.io.bundleCtrl.ctrlJump.peek())
                println("SRCCCCC", dut.io.bundleCtrl.ctrlALUSrc.peek())
                println("STORE", dut.io.bundleCtrl.ctrlStore.peek())
                println("LOAD", dut.io.bundleCtrl.ctrlLoad.peek())
                println("RESJAL", dut.io.bundleCtrl.ctrlJAL.peek())
                println("OP:\t", dut.io.bundleCtrl.ctrlOP.peek())
                println("isBranch:\t", dut.io.bundleCtrl.ctrlBranch.peek())
                println("IMM:\t", dut.io.imm.peekInt())
                println("RS1:\t", dut.io.rs1.peekInt())
                println("RS2:\t", dut.io.rs2.peekInt())
                println("PC", dut.io.addr.peekInt().toLong.toHexString)
                println("INST", dut.io.inst.peekInt().toLong.toHexString)
                println("++++++++++++++++++++")
            }
            dut.clock.step(1)
        }
    }
}

class TopTest extends AnyFlatSpec with ChiselScalatestTester with TopTestFunc {
    "Top" should "pass" in {
        test(new Top) { dut =>
            testFn(dut)
        }
    }
}

运行测试,最后一部分输出如下:

(PC,b8)
(INST,fea42423)
(RES,4040)
(RESALU,4040)
(RESBRANCH,Bool(false))
(RESJUMP,Bool(false))
(SRCCCCC,Bool(true))
(STORE,Bool(true))
(LOAD,Bool(false))
(RESJAL,Bool(false))
(OP:	,UInt<4>(1))
(isBranch:	,Bool(false))
(IMM:	,4294967272)
(RS1:	,4064)
(RS2:	,55)
(PC,b8)
(INST,fea42423)
++++++++++++++++++++
(PC,bc)
(INST,fe842783)
(PC,c0)
(INST,78513)
(PC,c4)
(INST,1c12083)
(PC,c8)
(INST,1812403)
(PC,cc)
(INST,2010113)

可以看到,第16行显示rs2值为55,确实是第十项斐波那契数,测试通过。

说明和结语

文中的代码并非最终版本的代码,一些在调试过程中的修改未体现在文中。

完整项目代码可以在本人的Github仓库获取:github-3rr0r/RV32ISC: A RISC-V RV32I ISA Single Cycle CPU。

由于写得很仓促,也没有使用什么复杂的Chisel语法,很多好用的特性也没用上,甚至很多地方风格跟屎山一样,又懒得改,所以希望有兴趣的读者可以帮忙维护一下这个仓库。虽然最后测试通过了,但并不严谨,没有覆盖所有指令和边界情况,如果有不对的地方欢迎大家提出修改意见或直接git commit。

虽然只是个单周期的CPU,但编写起来并没有看起来那么顺利,有些脑抽写出来的错误逻辑找了很久。不过好在用的是Chisel,有更加直观的调试方法,如果用波形图可能就没那么顺利了。

本系列的Chisel实战部分到这里就完结了,本系列后续可能看情况更新一些Chisel的高阶内容,但不承诺一定会有。

下一步的计划是开辟一个新的专栏,还是实现一个RISC-V处理器,但是会更全面、更深入。会包括一些现代处理器的基本特性,比如流水线、乱序、多发射、分支预测、Cache等等,还有外设啥的,届时欢迎大家关注。

你可能感兴趣的:(吃透Chisel语言!!!,risc-v,Chisel,单周期CPU实现,完整代码,RV32I)