Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

verify_zero5 and reduce_3 optimisations #90

Merged
merged 1 commit into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ protostar test-cairo0 --max-steps 10000000 tests/protostar_tests/bls12_381/test_
## Benchmarks
| Operation on curve BN254 | Cairo steps or estimation |
|---------|---------------|
| miller_loop | 770 159 |
| multi_miller_loop (N points) | ~ N * 489677 + 280482 |
| miller_loop | 749 693 |
| multi_miller_loop (N points) | ~ N * 478602 + 271091 |
| final_exponentiation | 610 997 |
| Groth16 circuit example | 2 974 539|
| Groth16 circuit example | 2 875 612|

| Operation on curve BLS12-381| Cairo steps (number) (OBSOLETE: Wait for optimisation) |
|---------|---------------|
Expand Down
203 changes: 137 additions & 66 deletions src/bn254/fq.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -740,62 +740,130 @@ func verify_zero5{range_check_ptr}(val: UnreducedBigInt5) {
setattr(ids, 'q'+str(i), carries[i])
%}

// This ensure q_i * BASE or -q_i * BASE doesn't overlfow PRIME.
assert [range_check_ptr + 0] = q0;
assert [range_check_ptr + 1] = q1;
assert [range_check_ptr + 2] = q2;
assert [range_check_ptr + 3] = q3;
assert [range_check_ptr + 4] = BASE_MIN_1 - q.d0;
assert [range_check_ptr + 5] = BASE_MIN_1 - q.d1;
assert [range_check_ptr + 6] = BASE_MIN_1 - q.d2;

// This ensure all (q*P +r) limbs don't overlfow by restricting q limbs in [-2**127, 2**127).

assert [range_check_ptr + 4] = 2 ** 127 + q.d0;
assert [range_check_ptr + 5] = 2 ** 127 + q.d1;
assert [range_check_ptr + 6] = 2 ** 127 + q.d2;

// diff = q*p - val
// diff(base) = 0

tempvar diff_d0 = q.d0 * P0 - val.d0;
tempvar diff_d1 = q.d0 * P1 + q.d1 * P0 - val.d1;
tempvar diff_d2 = q.d0 * P2 + q.d1 * P1 + q.d2 * P0 - val.d2;
tempvar diff_d3 = q.d1 * P2 + q.d2 * P1 - val.d3;
tempvar diff_d4 = q.d2 * P2 - val.d4;

local carry0: felt;
local carry1: felt;
local carry2: felt;
local carry3: felt;
// tempvar diff_d0 = q.d0 * P0 - val.d0;
// tempvar diff_d1 = q.d0 * P1 + q.d1 * P0 - val.d1;
// tempvar diff_d2 = q.d0 * P2 + q.d1 * P1 + q.d2 * P0 - val.d2;
// tempvar diff_d3 = q.d1 * P2 + q.d2 * P1 - val.d3;
// tempvar diff_d4 = q.d2 * P2 - val.d4;

// Checks that diff(base) = 0 depending on q limbs signs
// Since diff(base) = 0, diff_i has the form diff_i = k * BASE + 0
// See reduce_5 for more details.
if (flag0 != 0) {
assert diff_d0 = q0 * BASE;
assert carry0 = q0;
} else {
assert carry0 = (-1) * q0;
assert diff_d0 = carry0 * BASE;
}

if (flag1 != 0) {
assert diff_d1 + carry0 = q1 * BASE;
assert carry1 = q1;
} else {
assert carry1 = (-1) * q1;
assert diff_d1 + carry0 = carry1 * BASE;
}

if (flag2 != 0) {
assert diff_d2 + carry1 = q2 * BASE;
assert carry2 = q2;
} else {
assert carry2 = (-1) * q2;
assert diff_d2 + carry1 = carry2 * BASE;
}

if (flag3 != 0) {
assert diff_d3 + carry2 = q3 * BASE;
assert carry3 = q3;
assert q.d0 * P0 - val.d0 = q0 * BASE;
if (flag1 != 0) {
assert q.d0 * P1 + q.d1 * P0 - val.d1 + q0 = q1 * BASE;
if (flag2 != 0) {
assert q.d0 * P2 + q.d1 * P1 + q.d2 * P0 - val.d2 + q1 = q2 * BASE;
if (flag3 != 0) {
assert q.d1 * P2 + q.d2 * P1 - val.d3 + q2 = q3 * BASE;
assert q.d2 * P2 = val.d4 - q3;
} else {
// let q3 = (-1) * q3;
assert q.d1 * P2 + q.d2 * P1 + q2 + q3 * BASE = val.d3;
assert q.d2 * P2 = val.d4 + q3;
}
} else {
assert q.d0 * P2 + q.d1 * P1 + q.d2 * P0 + q1 + q2 * BASE = val.d2;
if (flag3 != 0) {
assert q.d1 * P2 + q.d2 * P1 - val.d3 - q2 = q3 * BASE;
assert q.d2 * P2 = val.d4 - q3;
} else {
// let q3 = (-1) * q3;
assert q.d1 * P2 + q.d2 * P1 - val.d3 + q3 * BASE = q2;
assert q.d2 * P2 = val.d4 + q3;
}
}
} else {
assert q.d0 * P1 + q.d1 * P0 + q0 + q1 * BASE = val.d1;
if (flag2 != 0) {
assert q.d0 * P2 + q.d1 * P1 + q.d2 * P0 - val.d2 - q1 = q2 * BASE;
if (flag3 != 0) {
assert q.d1 * P2 + q.d2 * P1 - val.d3 + q2 = q3 * BASE;
assert q.d2 * P2 = val.d4 - q3;
} else {
// let q3 = (-1) * q3;
assert q.d1 * P2 + q.d2 * P1 + q2 + q3 * BASE = val.d3;
assert q.d2 * P2 = val.d4 + q3;
}
} else {
assert q.d0 * P2 + q.d1 * P1 + q.d2 * P0 - val.d2 + q2 * BASE = q1;
if (flag3 != 0) {
assert q.d1 * P2 + q.d2 * P1 - val.d3 - q2 = q3 * BASE;
assert q.d2 * P2 = val.d4 - q3;
} else {
// let q3 = (-1) * q3;
assert q.d1 * P2 + q.d2 * P1 - val.d3 + q3 * BASE = q2;
assert q.d2 * P2 = val.d4 + q3;
}
}
}
} else {
assert carry3 = (-1) * q3;
assert diff_d3 + carry2 = carry3 * BASE;
assert q.d0 * P0 + q0 * BASE = val.d0;
if (flag1 != 0) {
assert q.d0 * P1 + q.d1 * P0 - val.d1 - q0 = q1 * BASE;
if (flag2 != 0) {
assert q.d0 * P2 + q.d1 * P1 + q.d2 * P0 - val.d2 + q1 = q2 * BASE;
if (flag3 != 0) {
assert q.d1 * P2 + q.d2 * P1 - val.d3 + q2 = q3 * BASE;
assert q.d2 * P2 = val.d4 - q3;
} else {
// let q3 = (-1) * q3;
assert q.d1 * P2 + q.d2 * P1 + q2 + q3 * BASE = val.d3;
assert q.d2 * P2 = val.d4 + q3;
}
} else {
assert q.d0 * P2 + q.d1 * P1 + q.d2 * P0 + q1 + q2 * BASE = val.d2;
if (flag3 != 0) {
assert q.d1 * P2 + q.d2 * P1 - val.d3 - q2 = q3 * BASE;
assert q.d2 * P2 = val.d4 - q3;
} else {
// let q3 = (-1) * q3;
assert q.d1 * P2 + q.d2 * P1 - val.d3 + q3 * BASE = q2;
assert q.d2 * P2 = val.d4 + q3;
}
}
} else {
assert q.d0 * P1 + q.d1 * P0 - q0 + q1 * BASE = val.d1;
if (flag2 != 0) {
assert q.d0 * P2 + q.d1 * P1 + q.d2 * P0 - val.d2 - q1 = q2 * BASE;
if (flag3 != 0) {
assert q.d1 * P2 + q.d2 * P1 - val.d3 + q2 = q3 * BASE;
assert q.d2 * P2 = val.d4 - q3;
} else {
// let q3 = (-1) * q3;
assert q.d1 * P2 + q.d2 * P1 + q2 + q3 * BASE = val.d3;
assert q.d2 * P2 = val.d4 + q3;
}
} else {
assert q.d0 * P2 + q.d1 * P1 + q.d2 * P0 - q1 + q2 * BASE = val.d2;
if (flag3 != 0) {
assert q.d1 * P2 + q.d2 * P1 - val.d3 - q2 = q3 * BASE;
assert q.d2 * P2 = val.d4 - q3;
} else {
// let q3 = (-1) * q3;
assert q.d1 * P2 + q.d2 * P1 - val.d3 + q3 * BASE = q2;
assert q.d2 * P2 = val.d4 + q3;
}
}
}
}

assert diff_d4 + carry3 = 0;

tempvar range_check_ptr = range_check_ptr + 7;
return ();
}
Expand Down Expand Up @@ -858,7 +926,7 @@ func reduce_5{range_check_ptr}(val: UnreducedBigInt5) -> BigInt3* {
assert [range_check_ptr + 2] = q2;
assert [range_check_ptr + 3] = q3;

// This ensure all (q*P +r) limbs don't overlfow.
// This ensure all (q*P +r) limbs don't overlfow by restricting q limbs in [-2**127, 2**127).
assert [range_check_ptr + 4] = 2 ** 127 + q.d0;
assert [range_check_ptr + 5] = 2 ** 127 + q.d1;
assert [range_check_ptr + 6] = 2 ** 127 + q.d2;
Expand Down Expand Up @@ -1041,45 +1109,48 @@ func reduce_3{range_check_ptr}(val: UnreducedBigInt3) -> BigInt3* {
// It is very important as we can assert diff_i has the form diff_i = k * BASE + 0.
// Since the euclidean division gives uniqueness and RC_BOUND * BASE = 2**214 < PRIME, it is enough.
// See https://github.com/starkware-libs/cairo-lang/blob/40404870166edc1e1fc5778fe39a29f981121ef9/src/starkware/cairo/common/math.cairo#L289-L312
let q_abs = abs_value(q);
// let q_abs = abs_value(q);

assert [range_check_ptr + 0] = q0;
assert [range_check_ptr + 1] = q1;
assert [range_check_ptr + 2] = 100 - q_abs;

// This ensure all (q*P +r) limbs don't overlfow by restricting q in [-2**127, 2**127).

assert [range_check_ptr + 2] = 2 ** 127 + q;

// diff = q*p + r - val
// diff(base) = 0

tempvar diff_d0 = q * P0 + r.d0 - val.d0;
tempvar diff_d1 = q * P1 + r.d1 - val.d1;
tempvar diff_d2 = q * P2 + r.d2 - val.d2;

local carry0: felt;
local carry1: felt;
// tempvar diff_d0 = q * P0 + r.d0 - val.d0;
// tempvar diff_d1 = q * P1 + r.d1 - val.d1;
// tempvar diff_d2 = q * P2 + r.d2 - val.d2;

// Since diff(base) = 0, diff_i has the form diff_i = k * BASE + 0
// When we reduce each limb % BASE and propagate the carries k=(limb//BASE), all coefficients should be 0.
// So for each i diff_i%BASE is 0 and we propagate the carry k to diff_(i+1), until the end,
// ensuring diff(base) is indeed 0.

if (flag0 != 0) {
assert diff_d0 = q0 * BASE;
assert carry0 = q0;
} else {
assert carry0 = (-1) * q0;
assert diff_d0 = carry0 * BASE;
}

if (flag1 != 0) {
assert diff_d1 + carry0 = q1 * BASE;
assert carry1 = q1;
assert q * P0 + r.d0 - val.d0 = q0 * BASE;
if (flag1 != 0) {
assert q * P1 + r.d1 - val.d1 + q0 = q1 * BASE;
assert q * P2 + r.d2 = val.d2 - q1;
} else {
assert q * P1 + r.d1 + q0 + q1 * BASE = val.d1;
assert q * P2 + r.d2 = val.d2 + q1;
}
} else {
assert carry1 = (-1) * q1;
assert diff_d1 + carry0 = carry1 * BASE;
assert q * P0 + r.d0 + q0 * BASE = val.d0;
if (flag1 != 0) {
assert q * P1 + r.d1 - val.d1 - q0 = q1 * BASE;
assert q * P2 + r.d2 = val.d2 - q1;
} else {
assert q * P1 + r.d1 - q0 + q1 * BASE = val.d1;
assert q * P2 + r.d2 = val.d2 + q1;
}
}

assert diff_d2 + carry1 = 0;

// ensure r is a reduced field element
assert [range_check_ptr + 3] = BASE_MIN_1 - r.d0;
assert [range_check_ptr + 4] = BASE_MIN_1 - r.d1;
assert [range_check_ptr + 5] = P2 - r.d2;
Expand Down