ZK-Learning MOOC课程笔记
//Variable creation
cs.add_var(p, v) → id
//Linear Combination creation
cs.zero()
lc.add(c, id) → lc_
//lc_ := lc + c * id
//Adding constraints
cs.constrain(lcA, lcB, lcC)
//Adds a constraint lcA × lcB = lcC
// main.rs
use ark_ff::PrimeField;
use ark_r1cs_std::{
prelude::{Boolean, EqGadget, AllocVar},
uint8::UInt8
};
use ark_relations::r1cs::{SynthesisError, ConstraintSystem};
use cmp::CmpGadget;
mod cmp;
mod alloc;
pub struct Puzzle<const N: usize, ConstraintF: PrimeField>([[UInt8<ConstraintF>; N]; N]);
pub struct Solution<const N: usize, ConstraintF: PrimeField>([[UInt8<ConstraintF>; N]; N]);
fn check_rows<const N: usize, ConstraintF: PrimeField>(
solution: &Solution<N, ConstraintF>,
) -> Result<(), SynthesisError> {
for row in &solution.0 {
for (j, cell) in row.iter().enumerate() {
for prior_cell in &row[0..j] {
cell.is_neq(&prior_cell)?
.enforce_equal(&Boolean::TRUE)?;
}
}
}
Ok(())
}
fn check_puzzle_matches_solution<const N: usize, ConstraintF: PrimeField>(
puzzle: &Puzzle<N, ConstraintF>,
solution: &Solution<N, ConstraintF>,
) -> Result<(), SynthesisError> {
for (p_row, s_row) in puzzle.0.iter().zip(&solution.0) {
for (p, s) in p_row.iter().zip(s_row) {
// Ensure that the solution `s` is in the range [1, N]
s.is_leq(&UInt8::constant(N as u8))?
.and(&s.is_geq(&UInt8::constant(1))?)?
.enforce_equal(&Boolean::TRUE)?;
// Ensure that either the puzzle slot is 0, or that
// the slot matches equivalent slot in the solution
(p.is_eq(s)?.or(&p.is_eq(&UInt8::constant(0))?)?)
.enforce_equal(&Boolean::TRUE)?;
}
}
Ok(())
}
fn check_helper<const N: usize, ConstraintF: PrimeField>(
puzzle: &[[u8; N]; N],
solution: &[[u8; N]; N],
) {
let cs = ConstraintSystem::<ConstraintF>::new_ref();
let puzzle_var = Puzzle::new_input(cs.clone(), || Ok(puzzle)).unwrap();
let solution_var = Solution::new_witness(cs.clone(), || Ok(solution)).unwrap();
check_puzzle_matches_solution(&puzzle_var, &solution_var).unwrap();
check_rows(&solution_var).unwrap();
assert!(cs.is_satisfied().unwrap());
}
fn main() {
use ark_bls12_381::Fq as F;
// Check that it accepts a valid solution.
let puzzle = [
[1, 0],
[0, 2],
];
let solution = [
[1, 2],
[1, 2],
];
check_helper::<2, F>(&puzzle, &solution);
// Check that it rejects a solution with a repeated number in a row.
let puzzle = [
[1, 0],
[0, 2],
];
let solution = [
[1, 0],
[1, 2],
];
check_helper::<2, F>(&puzzle, &solution);
}
// cmp.rs
use ark_ff::PrimeField;
use ark_r1cs_std::{prelude::{Boolean, EqGadget}, R1CSVar, uint8::UInt8, ToBitsGadget};
use ark_relations::r1cs::SynthesisError;
pub trait CmpGadget<ConstraintF: PrimeField>: R1CSVar<ConstraintF> + EqGadget<ConstraintF> {
#[inline]
fn is_geq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
// self >= other => self == other || self > other
// => !(self < other)
self.is_lt(other).map(|b| b.not())
}
#[inline]
fn is_leq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
// self <= other => self == other || self < other
// => self == other || other > self
// => self >= other
other.is_geq(self)
}
#[inline]
fn is_gt(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
// self > other => !(self == other || self < other)
// => !(self <= other)
self.is_leq(other).map(|b| b.not())
}
fn is_lt(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError>;
}
impl<ConstraintF: PrimeField> CmpGadget<ConstraintF> for UInt8<ConstraintF> {
fn is_lt(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
// Determine the variable mode.
if self.is_constant() && other.is_constant() {
let self_value = self.value().unwrap();
let other_value = other.value().unwrap();
let result = Boolean::constant(self_value < other_value);
Ok(result)
} else {
let diff_bits = self.xor(other)?.to_bits_be()?.into_iter();
let mut result = Boolean::FALSE;
let mut a_and_b_equal_so_far = Boolean::TRUE;
let a_bits = self.to_bits_be()?;
let b_bits = other.to_bits_be()?;
for ((a_and_b_are_unequal, a), b) in diff_bits.zip(a_bits).zip(b_bits) {
let a_is_lt_b = a.not().and(&b)?;
let a_and_b_are_equal = a_and_b_are_unequal.not();
result = result.or(&a_is_lt_b.and(&a_and_b_equal_so_far)?)?;
a_and_b_equal_so_far = a_and_b_equal_so_far.and(&a_and_b_are_equal)?;
}
Ok(result)
}
}
}
#[cfg(test)]
mod test {
use ark_r1cs_std::{prelude::{AllocationMode, AllocVar, Boolean, EqGadget}, uint8::UInt8};
use ark_relations::r1cs::{ConstraintSystem, SynthesisMode};
use ark_bls12_381::Fr as Fp;
use itertools::Itertools;
use crate::cmp::CmpGadget;
#[test]
fn test_comparison_for_u8() {
let modes = [AllocationMode::Constant, AllocationMode::Input, AllocationMode::Witness];
for (a, a_mode) in (0..=u8::MAX).cartesian_product(modes) {
for (b, b_mode) in (0..=u8::MAX).cartesian_product(modes) {
let cs = ConstraintSystem::<Fp>::new_ref();
cs.set_mode(SynthesisMode::Prove { construct_matrices: true });
let a_var = UInt8::new_variable(cs.clone(), || Ok(a), a_mode).unwrap();
let b_var = UInt8::new_variable(cs.clone(), || Ok(b), b_mode).unwrap();
if a < b {
a_var.is_lt(&b_var).unwrap()
.enforce_equal(&Boolean::TRUE).unwrap();
a_var.is_leq(&b_var).unwrap().enforce_equal(&Boolean::TRUE).unwrap();
a_var.is_gt(&b_var).unwrap().enforce_equal(&Boolean::FALSE).unwrap();
a_var.is_geq(&b_var).unwrap().enforce_equal(&Boolean::FALSE).unwrap();
} else if a == b {
a_var.is_lt(&b_var).unwrap().enforce_equal(&Boolean::FALSE).unwrap();
a_var.is_leq(&b_var).unwrap().enforce_equal(&Boolean::TRUE).unwrap();
a_var.is_gt(&b_var).unwrap().enforce_equal(&Boolean::FALSE).unwrap();
a_var.is_geq(&b_var).unwrap().enforce_equal(&Boolean::TRUE).unwrap();
} else {
a_var.is_lt(&b_var).unwrap().enforce_equal(&Boolean::FALSE).unwrap();
a_var.is_leq(&b_var).unwrap().enforce_equal(&Boolean::FALSE).unwrap();
a_var.is_gt(&b_var).unwrap().enforce_equal(&Boolean::TRUE).unwrap();
a_var.is_geq(&b_var).unwrap().enforce_equal(&Boolean::TRUE).unwrap();
}
assert!(cs.is_satisfied().unwrap(), "a: {a}, b: {b}");
}
}
}
}
//alloc.rs
use std::borrow::Borrow;
use ark_ff::PrimeField;
use ark_r1cs_std::{prelude::{AllocVar, AllocationMode}, uint8::UInt8};
use ark_relations::r1cs::{Namespace, SynthesisError};
use crate::{Puzzle, Solution};
impl<const N: usize, F: PrimeField> AllocVar<[[u8; N]; N], F> for Puzzle<N, F> {
fn new_variable<T: Borrow<[[u8; N]; N]>>(
cs: impl Into<Namespace<F>>,
f: impl FnOnce() -> Result<T, SynthesisError>,
mode: AllocationMode,
) -> Result<Self, SynthesisError> {
let cs = cs.into();
let row = [(); N].map(|_| UInt8::constant(0));
let mut puzzle = Puzzle([(); N].map(|_| row.clone()));
let value = f().map_or([[0; N]; N], |f| *f.borrow());
for (i, row) in value.into_iter().enumerate() {
for (j, cell) in row.into_iter().enumerate() {
puzzle.0[i][j] = UInt8::new_variable(cs.clone(), || Ok(cell), mode)?;
}
}
Ok(puzzle)
}
}
impl<const N: usize, F: PrimeField> AllocVar<[[u8; N]; N], F> for Solution<N, F> {
fn new_variable<T: Borrow<[[u8; N]; N]>>(
cs: impl Into<Namespace<F>>,
f: impl FnOnce() -> Result<T, SynthesisError>,
mode: AllocationMode,
) -> Result<Self, SynthesisError> {
let cs = cs.into();
let row = [(); N].map(|_| UInt8::constant(0));
let mut solution = Solution([(); N].map(|_| row.clone()));
let value = f().map_or([[0; N]; N], |f| *f.borrow());
for (i, row) in value.into_iter().enumerate() {
for (j, cell) in row.into_iter().enumerate() {
solution.0[i][j] = UInt8::new_variable(cs.clone(), || Ok(cell), mode)?;
}
}
Ok(solution)
}
}
struct Puzzle {
u8[N][N] elems;
}
struct Solution {
u8[N][N] elems;
}
def check_rows(Solution sol) -> bool {
// for each row
for u32 i in 0..N {
// for each column
for u32 j in 0..N {
// Check that the (i, j)-th element is not equal to any of the
// the elements preceding it in the same row.
for u32 k in 0..j {
assert(sol.elems[i][j] != sol.elems[i][k]);
}
}
}
return true;
}
def check_puzzle_matches_solution(Solution sol, Puzzle puzzle) -> bool {
for u32 i in 0..N {
for u32 j in 0..N {
assert((sol.elems[i][j] > 0) && (sol.elems[i][j] < 10));
assert(\
(puzzle.elems[i][j] == 0) ||\
(puzzle.elems[i][j] == sol.elems[i][j])\
);
}
}
return true;
}
def main(public Puzzle<2> puzzle, private Solution<2> sol) {
assert(check_puzzle_matches_solution(sol, puzzle));
assert(check_rows(sol));
}