diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs index 7be01a9cb8..9df0f2be12 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs @@ -41,9 +41,16 @@ pub fn add_avx_a_sc(a_sc: &__m256i, b: &__m256i) -> __m256i { #[inline(always)] pub fn add_avx(a: &__m256i, b: &__m256i) -> __m256i { - let a_sc = shift_avx(a); - // let a_sc = toCanonical_avx_s(&a_s); - add_avx_a_sc(&a_sc, b) + unsafe { + let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1); + let a_sc = _mm256_xor_si256(*a, msb); + let c0_s = _mm256_add_epi64(a_sc, *b); + let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1); + let mask_ = _mm256_cmpgt_epi64(a_sc, c0_s); + let corr_ = _mm256_and_si256(mask_, p_n); + let c_s = _mm256_add_epi64(c0_s, corr_); + _mm256_xor_si256(c_s, msb) + } } #[inline(always)] diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs index ce86ce67de..dd305d5c8e 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs @@ -6,19 +6,22 @@ use crate::hash::hash_types::RichField; const MSB_: i64 = 0x8000000000000000u64 as i64; const P8_: i64 = 0xFFFFFFFF00000001u64 as i64; const P8_N_: i64 = 0xFFFFFFFF; +const ONE_: i64 = 1; #[allow(non_snake_case)] #[repr(align(64))] -struct FieldConstants { - MSB_V: [i64; 8], - P8_V: [i64; 8], - P8_N_V: [i64; 8], +pub(crate) struct FieldConstants { + pub(crate) MSB_V: [i64; 8], + pub(crate) P8_V: [i64; 8], + pub(crate) P8_N_V: [i64; 8], + pub(crate) ONE_V: [i64; 8], } -const FC: FieldConstants = FieldConstants { +pub(crate) const FC: FieldConstants = FieldConstants { MSB_V: [MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_], P8_V: [P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_], P8_N_V: [P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_], + ONE_V: [ONE_, ONE_, ONE_, ONE_, ONE_, ONE_, ONE_, ONE_], }; #[allow(dead_code)] @@ -46,13 +49,34 @@ pub fn to_canonical_avx512(a: &__m512i) -> __m512i { #[inline(always)] pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i { + /* unsafe { // let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); - let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); + let p8_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let c0 = _mm512_add_epi64(*a, *b); let result_mask = _mm512_cmpgt_epu64_mask(*a, c0); _mm512_mask_add_epi64(c0, result_mask, c0, p8_n) } + */ + unsafe { + let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); + let a_sc = _mm512_xor_si512(*a, msb); + let c0_s = _mm512_add_epi64(a_sc, *b); + let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); + let mask_ = _mm512_cmpgt_epi64_mask(a_sc, c0_s); + let c_s = _mm512_mask_add_epi64(c0_s, mask_, c0_s, p_n); + _mm512_xor_si512(c_s, msb) + } +} + +#[inline(always)] +pub fn add_avx512_s_b_small(a_s: &__m512i, b_small: &__m512i) -> __m512i { + unsafe { + let corr = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); + let c0_s = _mm512_add_epi64(*a_s, *b_small); + let mask_ = _mm512_cmpgt_epi64_mask(*a_s, c0_s); + _mm512_mask_add_epi64(c0_s, mask_, c0_s, corr) + } } #[inline(always)] @@ -82,11 +106,12 @@ pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i { #[inline(always)] pub fn reduce_avx512_96_64(c_h: &__m512i, c_l: &__m512i) -> __m512i { unsafe { - let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::()); - let p_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); + let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); + let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let c_ls = _mm512_xor_si512(*c_l, msb); let c2 = _mm512_mul_epu32(*c_h, p_n); - let c_s = add_avx512(&c_ls, &c2); + let c_s = add_avx512_s_b_small(&c_ls, &c2); + // let c_s = add_avx512(&c_ls, &c2); _mm512_xor_si512(c_s, msb) } } diff --git a/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs index 04ce1262bf..2f7039f147 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs @@ -84,19 +84,6 @@ unsafe fn sub64(a: &__m256i, b: &__m256i, bin: &__m256i) -> (__m256i, __m256i) { let zeros = _mm256_set_epi64x(0, 0, 0, 0); let (r1, b1) = sub64_no_borrow(a, b); - // TODO - delete - /* - let mut v = [0i64; 4]; - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), *a); - println!("a: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), *b); - println!("b: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), r1); - println!("r: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), b1); - println!("b: {:?}", v); - */ - let m1 = _mm256_cmpeq_epi64(*bin, ones); let m2 = _mm256_cmpeq_epi64(r1, zeros); let m = _mm256_and_si256(m1, m2); @@ -104,14 +91,6 @@ unsafe fn sub64(a: &__m256i, b: &__m256i, bin: &__m256i) -> (__m256i, __m256i) { let r = _mm256_sub_epi64(r1, *bin); let bo = _mm256_or_si256(bo, b1); - // TODO - delete - /* - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), r); - println!("r: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), bo); - println!("b: {:?}", v); - */ - (r, bo) } diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs index db86c76b98..301e043446 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs @@ -1411,6 +1411,7 @@ where let ss0 = add_avx(&s0, &rc0); let ss1 = add_avx(&s1, &rc1); let ss2 = add_avx(&s2, &rc2); + (s0, s1, s2) = sbox_avx_m256i(&ss0, &ss1, &ss2); mds_layer_avx(&mut s0, &mut s1, &mut s2); round_ctr += 1; diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 9dc7c70898..282c19a44e 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1134,8 +1134,8 @@ where let mut result = [F::ZERO; SPONGE_WIDTH]; let res0 = state[0]; unsafe { - let mut r0 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r1 = _mm512_loadu_si512((&mut result[4..12]).as_mut_ptr().cast::()); + let mut r0 = _mm512_loadu_epi64((&mut result[0..8]).as_mut_ptr().cast::()); + let mut r1 = _mm512_loadu_epi64((&mut result[4..12]).as_mut_ptr().cast::()); for r in 1..12 { let sr512 = _mm512_set_epi64( @@ -1148,23 +1148,23 @@ where state[r].to_canonical_u64() as i64, state[r].to_canonical_u64() as i64, ); - let t0 = _mm512_loadu_si512( + let t0 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let t1 = _mm512_loadu_si512( + let t1 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][4..12]) .as_ptr() - .cast::(), + .cast::(), ); let m0 = mult_avx512(&sr512, &t0); let m1 = mult_avx512(&sr512, &t1); r0 = add_avx512(&r0, &m0); r1 = add_avx512(&r1, &m1); } - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); state[0] = res0; } } @@ -1177,22 +1177,22 @@ where F: PrimeField64, { unsafe { - let c0 = _mm512_loadu_si512( + let c0 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let c1 = _mm512_loadu_si512( + let c1 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[4..12]) .as_ptr() - .cast::(), + .cast::(), ); - let mut s0 = _mm512_loadu_si512((state[0..8]).as_ptr().cast::()); - let mut s1 = _mm512_loadu_si512((state[4..12]).as_ptr().cast::()); + let mut s0 = _mm512_loadu_epi64((state[0..8]).as_ptr().cast::()); + let mut s1 = _mm512_loadu_epi64((state[4..12]).as_ptr().cast::()); s0 = add_avx512(&s0, &c0); s1 = add_avx512(&s1, &c1); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), s1); } } @@ -1262,19 +1262,34 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 * - (test 3): if a + b < 2^64 (this means a + b is negative in signed representation) => no overflow so cout = 0 * - (test 3): if a + b >= 2^64 (this means a + b becomes positive in signed representation, that is, a + b >= 0) => there is overflow so cout = 1 */ - let ones = _mm512_set_epi64(1, 1, 1, 1, 1, 1, 1, 1); + let ones = _mm512_load_epi64(FC.ONE_V.as_ptr().cast::()); let zeros = _mm512_xor_si512(*a, *a); // faster 0 let r = _mm512_add_epi64(*a, *b); let ma = _mm512_cmpgt_epi64_mask(zeros, *a); let mb = _mm512_cmpgt_epi64_mask(zeros, *b); let mc = _mm512_cmpgt_epi64_mask(zeros, r); - let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); + // let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); + let m = (ma & mb) | (!mc & (ma ^ mb)); let co = _mm512_mask_blend_epi64(m, zeros, ones); (r, co) } #[inline] pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { + /* + // long version + let r = _mm512_mul_epu32(*a, *b); + let ah = _mm512_srli_epi64(*a, 32); + let bh = _mm512_srli_epi64(*b, 32); + let r1 = _mm512_mul_epu32(*a, bh); + let r1 = _mm512_slli_epi64(r1, 32); + let r = _mm512_add_epi64(r, r1); + let r1 = _mm512_mul_epu32(ah, *b); + let r1 = _mm512_slli_epi64(r1, 32); + let r = _mm512_add_epi64(r, r1); + r + */ + // short version _mm512_mullo_epi64(*a, *b) } @@ -1301,16 +1316,16 @@ unsafe fn block1_avx512(x: &__m512i, y: [i64; 3]) -> __m512i { unsafe fn block2_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> (__m512i, __m512i) { let mut vxr: [i64; 8] = [0; 8]; let mut vxi: [i64; 8] = [0; 8]; - _mm512_storeu_si512(vxr.as_mut_ptr().cast::(), *xr); - _mm512_storeu_si512(vxi.as_mut_ptr().cast::(), *xi); + _mm512_storeu_epi64(vxr.as_mut_ptr().cast::(), *xr); + _mm512_storeu_epi64(vxi.as_mut_ptr().cast::(), *xi); let x1: [(i64, i64); 3] = [(vxr[0], vxi[0]), (vxr[1], vxi[1]), (vxr[2], vxi[2])]; let x2: [(i64, i64); 3] = [(vxr[4], vxi[4]), (vxr[5], vxi[5]), (vxr[6], vxi[6])]; let b1 = block2(x1, y); let b2 = block2(x2, y); vxr = [b1[0].0, b1[1].0, b1[2].0, 0, b2[0].0, b2[1].0, b2[2].0, 0]; vxi = [b1[0].1, b1[1].1, b1[2].1, 0, b2[0].1, b2[1].1, b2[2].1, 0]; - let rr = _mm512_loadu_si512(vxr.as_ptr().cast::()); - let ri = _mm512_loadu_si512(vxi.as_ptr().cast::()); + let rr = _mm512_loadu_epi64(vxr.as_ptr().cast::()); + let ri = _mm512_loadu_epi64(vxi.as_ptr().cast::()); (rr, ri) } @@ -1343,18 +1358,9 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif1perm2 = _mm512_permutex_epi64(dif1, 0x2); let z0i = _mm512_add_epi64(dif3, dif1perm1); let z0i = _mm512_add_epi64(z0i, dif1perm2); - let mask = _mm512_set_epi64( - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - ); - let z0r = _mm512_and_si512(z0r, mask); - let z0i = _mm512_and_si512(z0i, mask); + let zeros = _mm512_xor_si512(z0r, z0r); + let z0r = _mm512_mask_blend_epi64(0x11, zeros, z0r); + let z0i = _mm512_mask_blend_epi64(0x11, zeros, z0i); // z1 // z1r = dif2[0] + dif2[1] + prod[2] - sum[2]; @@ -1377,18 +1383,8 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif1perm = _mm512_permutex_epi64(dif1, 0x8); let z1i = _mm512_add_epi64(dif3, dif3perm); let z1i = _mm512_add_epi64(z1i, dif1perm); - let mask = _mm512_set_epi64( - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - ); - let z1r = _mm512_and_si512(z1r, mask); - let z1i = _mm512_and_si512(z1i, mask); + let z1r = _mm512_mask_blend_epi64(0x22, zeros, z1r); + let z1i = _mm512_mask_blend_epi64(0x22, zeros, z1i); // z2 // z2r = dif2[0] + dif2[1] + dif2[2]; @@ -1410,18 +1406,8 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif3perm2 = _mm512_permutex_epi64(dif3, 0x10); let z2i = _mm512_add_epi64(dif3, dif3perm1); let z2i = _mm512_add_epi64(z2i, dif3perm2); - let mask = _mm512_set_epi64( - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - ); - let z2r = _mm512_and_si512(z2r, mask); - let z2i = _mm512_and_si512(z2i, mask); + let z2r = _mm512_mask_blend_epi64(0x44, zeros, z2r); + let z2i = _mm512_mask_blend_epi64(0x44, zeros, z2i); let zr = _mm512_or_si512(z0r, z1r); let zr = _mm512_or_si512(zr, z2r); @@ -1500,10 +1486,7 @@ unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut #[inline(always)] #[unroll_for_loops] unsafe fn mds_layer_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut __m512i) { - let mask = _mm512_set_epi64( - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, - ); + let mask = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let mut sl0 = _mm512_and_si512(*s0, mask); let mut sl1 = _mm512_and_si512(*s1, mask); let mut sl2 = _mm512_and_si512(*s2, mask); @@ -1545,48 +1528,48 @@ unsafe fn mds_partial_layer_init_avx512(s0: &mut __m512i, s1: &mut __m512i, s where F: PrimeField64, { - let mut result = [F::ZERO; 2 * SPONGE_WIDTH]; let res0 = *s0; - - let mut r0 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r1 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r2 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); + let mut r0 = _mm512_xor_epi64(res0, res0); + let mut r1 = r0; + let mut r2 = r0; for r in 1..12 { - let sr = match r { - 1 => _mm512_permutex_epi64(*s0, 0x55), - 2 => _mm512_permutex_epi64(*s0, 0xAA), - 3 => _mm512_permutex_epi64(*s0, 0xFF), - 4 => _mm512_permutex_epi64(*s1, 0x0), - 5 => _mm512_permutex_epi64(*s1, 0x55), - 6 => _mm512_permutex_epi64(*s1, 0xAA), - 7 => _mm512_permutex_epi64(*s1, 0xFF), - 8 => _mm512_permutex_epi64(*s2, 0x0), - 9 => _mm512_permutex_epi64(*s2, 0x55), - 10 => _mm512_permutex_epi64(*s2, 0xAA), - 11 => _mm512_permutex_epi64(*s2, 0xFF), - _ => _mm512_permutex_epi64(*s0, 0x55), - }; - let t0 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][0..8]) - .as_ptr() - .cast::(), - ); - let t1 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][8..16]) - .as_ptr() - .cast::(), - ); - let t2 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][16..24]) - .as_ptr() - .cast::(), - ); - let m0 = mult_avx512(&sr, &t0); - let m1 = mult_avx512(&sr, &t1); - let m2 = mult_avx512(&sr, &t2); - r0 = add_avx512(&r0, &m0); - r1 = add_avx512(&r1, &m1); - r2 = add_avx512(&r2, &m2); + if r < 12 { + let sr = match r { + 1 => _mm512_permutex_epi64(*s0, 0x55), + 2 => _mm512_permutex_epi64(*s0, 0xAA), + 3 => _mm512_permutex_epi64(*s0, 0xFF), + 4 => _mm512_permutex_epi64(*s1, 0x0), + 5 => _mm512_permutex_epi64(*s1, 0x55), + 6 => _mm512_permutex_epi64(*s1, 0xAA), + 7 => _mm512_permutex_epi64(*s1, 0xFF), + 8 => _mm512_permutex_epi64(*s2, 0x0), + 9 => _mm512_permutex_epi64(*s2, 0x55), + 10 => _mm512_permutex_epi64(*s2, 0xAA), + 11 => _mm512_permutex_epi64(*s2, 0xFF), + _ => _mm512_permutex_epi64(*s0, 0x55), + }; + let t0 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][0..8]) + .as_ptr() + .cast::(), + ); + let t1 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][8..16]) + .as_ptr() + .cast::(), + ); + let t2 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][16..24]) + .as_ptr() + .cast::(), + ); + let m0 = mult_avx512(&sr, &t0); + let m1 = mult_avx512(&sr, &t1); + let m2 = mult_avx512(&sr, &t2); + r0 = add_avx512(&r0, &m0); + r1 = add_avx512(&r1, &m1); + r2 = add_avx512(&r2, &m2); + } } *s0 = _mm512_mask_blend_epi64(0x11, r0, res0); *s1 = r1; @@ -1649,20 +1632,20 @@ unsafe fn mds_partial_layer_fast_avx512( state[0].to_noncanonical_u64() as i64, state[0].to_noncanonical_u64() as i64, ); - let rc0 = _mm512_loadu_si512( + let rc0 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let rc1 = _mm512_loadu_si512( + let rc1 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][8..16]) .as_ptr() - .cast::(), + .cast::(), ); - let rc2 = _mm512_loadu_si512( + let rc2 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][16..24]) .as_ptr() - .cast::(), + .cast::(), ); let (mh, ml) = mult_avx512_128(&ss0, &rc0); let m = reduce_avx512_128_64(&mh, &ml); @@ -1687,9 +1670,9 @@ unsafe fn mds_partial_layer_fast_avx512( let m = reduce_avx512_128_64(&mh, &ml); *s2 = add_avx512(s2, &m); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), *s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), *s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), *s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), *s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), *s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), *s2); } #[allow(unused)] @@ -1704,22 +1687,22 @@ where // Self::full_rounds(&mut state, &mut round_ctr); for _ in 0..HALF_N_FULL_ROUNDS { // load state - let s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let s1 = _mm512_loadu_si512((&state[4..12]).as_ptr().cast::()); + let s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let s1 = _mm512_loadu_epi64((&state[4..12]).as_ptr().cast::()); let rc: &[u64; 12] = &ALL_ROUND_CONSTANTS[SPONGE_WIDTH * round_ctr..][..SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[4..12]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[4..12]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let r0 = sbox_avx512_one(&ss0); let r1 = sbox_avx512_one(&ss1); // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); *state = ::mds_layer(&state); round_ctr += 1; @@ -1737,22 +1720,22 @@ where // Self::full_rounds(&mut state, &mut round_ctr); for _ in 0..HALF_N_FULL_ROUNDS { // load state - let s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let s1 = _mm512_loadu_si512((&state[4..12]).as_ptr().cast::()); + let s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let s1 = _mm512_loadu_epi64((&state[4..12]).as_ptr().cast::()); let rc: &[u64; 12] = &ALL_ROUND_CONSTANTS[SPONGE_WIDTH * round_ctr..][..SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[4..12]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[4..12]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let r0 = sbox_avx512_one(&ss0); let r1 = sbox_avx512_one(&ss1); // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); *state = ::mds_layer(&state); // mds_layer_avx::(&mut s0, &mut s1, &mut s2); @@ -1780,18 +1763,18 @@ where unsafe { // load state - let mut s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let mut s1 = _mm512_loadu_si512((&state[8..16]).as_ptr().cast::()); - let mut s2 = _mm512_loadu_si512((&state[16..24]).as_ptr().cast::()); + let mut s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let mut s1 = _mm512_loadu_epi64((&state[8..16]).as_ptr().cast::()); + let mut s2 = _mm512_loadu_epi64((&state[16..24]).as_ptr().cast::()); for _ in 0..HALF_N_FULL_ROUNDS { let rc: &[u64; 24] = &ALL_ROUND_CONSTANTS_AVX512[2 * SPONGE_WIDTH * round_ctr..] [..2 * SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[8..16]).as_ptr().cast::()); - let rc2 = _mm512_loadu_si512((&rc[16..24]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[8..16]).as_ptr().cast::()); + let rc2 = _mm512_loadu_epi64((&rc[16..24]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let ss2 = add_avx512(&s2, &rc2); @@ -1803,20 +1786,20 @@ where } // this does partial_first_constant_layer_avx(&mut state); - let c0 = _mm512_loadu_si512( + let c0 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let c1 = _mm512_loadu_si512( + let c1 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[8..16]) .as_ptr() - .cast::(), + .cast::(), ); - let c2 = _mm512_loadu_si512( + let c2 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[16..24]) .as_ptr() - .cast::(), + .cast::(), ); s0 = add_avx512(&s0, &c0); s1 = add_avx512(&s1, &c1); @@ -1824,9 +1807,9 @@ where mds_partial_layer_init_avx512::(&mut s0, &mut s1, &mut s2); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), s2); for i in 0..N_PARTIAL_ROUNDS { state[0] = sbox_monomial(state[0]); @@ -1844,9 +1827,9 @@ where [..2 * SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[8..16]).as_ptr().cast::()); - let rc2 = _mm512_loadu_si512((&rc[16..24]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[8..16]).as_ptr().cast::()); + let rc2 = _mm512_loadu_epi64((&rc[16..24]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let ss2 = add_avx512(&s2, &rc2); @@ -1858,9 +1841,9 @@ where } // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), s2); debug_assert_eq!(round_ctr, N_ROUNDS); }; @@ -1900,13 +1883,13 @@ where leaf_size / SPONGE_RATE + 1 }; for _ in 0..loops { - let end1 = if idx1 + SPONGE_RATE > leaf_size { + let end1 = if idx1 + SPONGE_RATE >= leaf_size { leaf_size } else { idx1 + SPONGE_RATE }; - let end2 = if idx2 + SPONGE_RATE > inputs.len() { - inputs.len() + let end2 = if idx2 + SPONGE_RATE >= 2 * leaf_size { + 2 * leaf_size } else { idx2 + SPONGE_RATE }; diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index f78c3f4a24..fa3752629d 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -450,6 +450,8 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField as F; use crate::field::types::{Field, PrimeField64}; #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + use crate::hash::arch::x86_64::poseidon_goldilocks_avx512::hash_leaf_avx512; + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] use crate::hash::poseidon::test_helpers::check_test_vectors_avx512; use crate::hash::poseidon::test_helpers::{check_consistency, check_test_vectors}; use crate::hash::poseidon::{Poseidon, PoseidonHash}; @@ -488,6 +490,12 @@ mod tests { [0xa89280105650c4ec, 0xab542d53860d12ed, 0x5704148e9ccab94f, 0xd3a826d4b62da9f5, 0x8a7a6ca87892574f, 0xc7017e1cad1a674e, 0x1f06668922318e34, 0xa3b203bc8102676f, 0xfcc781b0ce382bf2, 0x934c69ff3ed14ba5, 0x504688a5996e8f13, 0x401f3f2ed524a2ba, ]), + ([0xf2cc0ce426e7eddd, 0x91ad40f14cfdcb78, 0xc516c642346aabc, 0xa79a0411d96de0, + 0xf256c881b6167069, 0x5c767aa6354a647b, 0x79a821313415b9dc, 0xf083bc2f276b99e1, + 0x9aa0ac0171df5ac7, 0xc3c705daf69d66e0, 0x3b0468abe66c5ed, 0xdcf835c4d4cffd73, ], + [0x96d91d333e5e038d, 0x114395c7cfb7e18f, 0x19b1ea99556391ff, 0xd53855a776b4582a, + 0x378d8ea4ffbb7545, 0x168319892eff226a, 0x5f09f06508283bd, 0xb92d599c947cc2f1, + 0xf078fc732200e4d4, 0xcaf95e4285f3099d, 0x8532be1f10f23cd0, 0xc3260991186909ff, ]) ]; check_test_vectors::(test_vectors12.clone()); @@ -576,40 +584,144 @@ mod tests { #[test] fn test_hash_no_pad_gl() { - let inputs: [u64; 32] = [ - 9972144316416239374, - 7195869958086994472, - 12805395537960412263, - 6755149769410714396, - 16592921959755212957, - 1370750654791741308, - 11186995120529280354, - 288690570896506034, - 2896720011649362435, - 13870686984275550055, - 12288026009924247278, - 15608864109019511973, - 15690944173815210604, - 17535150735055770942, - 4265223756233917229, - 17236464151311603291, - 15180455466814482598, - 12377438429067983442, - 11274960245127600167, - 5684300978461808754, - 1918159483831849502, - 15340265949423289730, - 181633163915570313, - 12684059848091546996, - 10060377187090493210, - 13523019938818230572, - 16846214147461656883, - 13560222746484567233, - 2150999602305437005, - 9103462636082953981, - 16341057499572706412, - 842265247111451937, + let inputs = [ + 0xb8f463d7cb4f24f6, + 0xe94ad9aba668af65, + 0x4a31c8cee787786a, + 0x7f8ed7050aeadcf9, + 0x516c34f52a5c8b14, + 0x542c22306722b175, + 0x6feba1eb9030ecb9, + 0xe103d491fa784080, + 0x31d9a62ea39f4ec9, + 0xbf0ccc95d9b4c697, + 0x5a9d230167523b2e, + 0x7ff277e12091d2f2, + 0xf2af521b9537abf3, + 0xe39e815313da5c12, + 0xe5feaa1e4f46b87b, + 0x76b772a9e6eda11c, + 0x9005e1c8fbf27eed, + 0x78ea9242b53108ac, + 0x5561d33040b6affb, + 0x61ded48ffee1f243, + 0xebbe0c4034afb9e5, + 0x7973d462ab14d331, + 0x76a23e459a0849b, + 0x9fa93d23d8b84515, + 0x1e19bba2ce8042dd, + 0xb1159302625b71a3, + 0x792e2e4171fd7e83, + 0xc9088b032be7eff0, + 0x6540b29fbec19cb2, + 0x8c4f849dd68f4cdc, + 0xb91969b7cfcd1ec8, + 0x4d450eff6a3b0c7c, + 0xcace16a8345de56e, + 0xe5bac07b93e1f0e2, + 0x35088bde4f1bd3a9, + 0x2e0bd8e257386e40, + 0xed67fe1bd44680f0, + 0x887a32a6049105f, + 0x3ae86d4d60b87a67, + 0x665a656a217edacf, + 0x2eb451b933acbd2d, + 0x63876760e9570fb4, + 0x2b11da28eb95d7d6, + 0x138ea36659579c0a, + 0x457f674d92cfcd72, + 0xba4b8ffc7287142d, + 0x2b9bd3cd64e65cb6, + 0x2780e8b0e66848e8, + 0xe18303c5010835a4, + 0x6c4e379aba35e21e, + 0xf9c3f2f33320d9cd, + 0x82429ba2d6263c9a, + 0x11e81115fa995e88, + 0x75a7fb5681cd15e4, + 0xa54b2a0b6d57e340, + 0x884b3d9cc9b7f720, + 0xdac1b985f5b0ff19, + 0x5938c0405a01dbd4, + 0x13fb2d9399c3ef2e, + 0xeaed82d3706dccec, + 0xf8d853012e56f7fb, + 0xa4c639bbaf484525, + 0xe3b35501c21797ba, + 0x1a645013fcb5e3a0, + 0xf2eb2337ba169178, + 0xcc94fd9269c7d33, + 0x82a9aaa398b13f1, + 0xe9b5ecbe6576234, + 0x252287d7ed9ec792, + 0x30629bee322f17cc, + 0x9ae26078f44e8afb, + 0xabdc35ac8f527136, + 0x4b2a3be4ef4c231f, + 0x23074d5363eeba58, + 0x75cfe940f6967c16, + 0xfb185a23f6225406, + 0xda8a21bd2ba64cc3, + 0xd623bde11eb8c989, + 0x76201928e4523ba3, + 0x1c20cb194495b643, + 0x3e70ce2fddc52451, + 0x86c698ca61fdae8e, + 0x9855dd30ad0c1309, + 0x271541a781755737, + 0x209b4ccf7db16277, + 0xff27cae2771d1d8c, + 0xd7795488a7bfe6ee, + 0x9cf1875ec535778e, + 0x9fad94c126427390, + 0x199b482c029f3d9d, + 0x92ae2055bb3f6d6, + 0x29d6100b44167374, + 0x88e8c8ffdefe0f33, + 0xa3d8d929ea748a62, + 0xd5dbe1a3d99e113d, + 0x438639f8f0e3ff25, + 0xf2cc0ce426e7eddd, + 0x91ad40f14cfdcb78, + 0xc516c642346aabc, + 0xa79a0411d96de0, + 0xf256c881b6167069, + 0x5c767aa6354a647b, + 0x79a821313415b9dc, + 0xf083bc2f276b99e1, + 0x9d47fc86eb2de7c2, + 0x3370a8711a678a03, + 0x1572c8a8bf872b26, + 0xdbb7de1fc45360a1, + 0x5f87c0fe24bafdd4, + 0x2f6a5784207d118a, + 0x640c588afcf0cc14, + 0xe609f3cbb7cb015, + 0x8e4907544019be80, + 0xde2f553ac4ab68c3, + 0x29cd0d2800262365, + 0x3bf736a6fbc14ce2, + 0xab059c3c3cba4912, + 0xe609e14997bd2f5c, + 0x694189d934ff1f8d, + 0x54570348f45e3a9, + 0x90ef5b98b0a08a34, + 0x1b09b93749616de8, + 0x89be3144389d48c1, + 0xdaa7e268d0fd82d8, + 0xc46956b67fa89c61, + 0xec88a7133e4fefc, + 0xe41596ca682069f4, + 0x297f55e46472431b, + 0x33ada14fd813218d, + 0x22c57ca5e77249ad, + 0x4e2f2c7cc99f2d47, + 0x78d11ba2efc7556f, + 0xdfc98976b6e3ad0d, + 0x59d88f72bf5ad1d8, + 0x19ca05690b8e1ad9, ]; + let inputs = inputs .iter() .map(|x| F::from_canonical_u64(*x)) @@ -617,15 +729,25 @@ mod tests { let output = PoseidonHash::hash_no_pad(&inputs); let expected_out: [u64; 4] = [ - 8197835875512527937, - 7109417654116018994, - 18237163116575285904, - 17017896878738047012, + 0xc19dccf6ec4f3df3, + 0x1bf0d65af6925451, + 0xee9dbf2c8dcad9a2, + 0xae46323715f528a1, ]; let expected_out = expected_out .iter() .map(|x| F::from_canonical_u64(*x)) .collect::>(); assert_eq!(output.elements.to_vec(), expected_out); + + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + { + let mut dleaf: Vec = vec![F::from_canonical_u64(0); 2 * inputs.len()]; + dleaf[0..inputs.len()].copy_from_slice(&inputs); + dleaf[inputs.len()..2 * inputs.len()].copy_from_slice(&inputs); + let (h1, h2) = hash_leaf_avx512(dleaf.as_slice(), inputs.len()); + assert_eq!(h1, expected_out); + assert_eq!(h2, expected_out); + } } }