From 4c640df4ca2cbb2f881b384c8f74bc740e342b29 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Thu, 17 Oct 2024 15:05:01 +0800 Subject: [PATCH 1/4] fix avx512 poseidon issue --- .../src/hash/arch/x86_64/goldilocks_avx2.rs | 13 +- .../src/hash/arch/x86_64/goldilocks_avx512.rs | 31 ++- .../hash/arch/x86_64/poseidon_bn128_avx2.rs | 21 -- .../arch/x86_64/poseidon_goldilocks_avx2.rs | 13 +- .../arch/x86_64/poseidon_goldilocks_avx512.rs | 38 +++- plonky2/src/hash/poseidon_goldilocks.rs | 196 ++++++++++++++---- 6 files changed, 235 insertions(+), 77 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs index 7be01a9cb8..9df0f2be12 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs @@ -41,9 +41,16 @@ pub fn add_avx_a_sc(a_sc: &__m256i, b: &__m256i) -> __m256i { #[inline(always)] pub fn add_avx(a: &__m256i, b: &__m256i) -> __m256i { - let a_sc = shift_avx(a); - // let a_sc = toCanonical_avx_s(&a_s); - add_avx_a_sc(&a_sc, b) + unsafe { + let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1); + let a_sc = _mm256_xor_si256(*a, msb); + let c0_s = _mm256_add_epi64(a_sc, *b); + let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1); + let mask_ = _mm256_cmpgt_epi64(a_sc, c0_s); + let corr_ = _mm256_and_si256(mask_, p_n); + let c_s = _mm256_add_epi64(c0_s, corr_); + _mm256_xor_si256(c_s, msb) + } } #[inline(always)] diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs index ce86ce67de..adb3ca702b 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs @@ -46,15 +46,37 @@ pub fn to_canonical_avx512(a: &__m512i) -> __m512i { #[inline(always)] pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i { + /* unsafe { // let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); - let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); + let p8_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let c0 = _mm512_add_epi64(*a, *b); let result_mask = _mm512_cmpgt_epu64_mask(*a, c0); _mm512_mask_add_epi64(c0, result_mask, c0, p8_n) } + */ + unsafe { + let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); + let a_sc = _mm512_xor_si512(*a, msb); + let c0_s = _mm512_add_epi64(a_sc, *b); + let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); + let mask_ = _mm512_cmpgt_epi64_mask(a_sc, c0_s); + let c_s = _mm512_mask_add_epi64(c0_s, mask_, c0_s, p_n); + _mm512_xor_si512(c_s, msb) + } } +#[inline(always)] +pub fn add_avx512_s_b_small(a_s: &__m512i, b_small: &__m512i) -> __m512i { + unsafe { + let corr = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); + let c0_s = _mm512_add_epi64(*a_s, *b_small); + let mask_ = _mm512_cmpgt_epi64_mask(*a_s, c0_s); + _mm512_mask_add_epi64(c0_s, mask_, c0_s, corr) + } +} + + #[inline(always)] pub fn sub_avx512(a: &__m512i, b: &__m512i) -> __m512i { unsafe { @@ -82,11 +104,12 @@ pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i { #[inline(always)] pub fn reduce_avx512_96_64(c_h: &__m512i, c_l: &__m512i) -> __m512i { unsafe { - let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::()); - let p_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); + let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); + let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let c_ls = _mm512_xor_si512(*c_l, msb); let c2 = _mm512_mul_epu32(*c_h, p_n); - let c_s = add_avx512(&c_ls, &c2); + let c_s = add_avx512_s_b_small(&c_ls, &c2); + // let c_s = add_avx512(&c_ls, &c2); _mm512_xor_si512(c_s, msb) } } diff --git a/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs index 04ce1262bf..2f7039f147 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs @@ -84,19 +84,6 @@ unsafe fn sub64(a: &__m256i, b: &__m256i, bin: &__m256i) -> (__m256i, __m256i) { let zeros = _mm256_set_epi64x(0, 0, 0, 0); let (r1, b1) = sub64_no_borrow(a, b); - // TODO - delete - /* - let mut v = [0i64; 4]; - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), *a); - println!("a: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), *b); - println!("b: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), r1); - println!("r: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), b1); - println!("b: {:?}", v); - */ - let m1 = _mm256_cmpeq_epi64(*bin, ones); let m2 = _mm256_cmpeq_epi64(r1, zeros); let m = _mm256_and_si256(m1, m2); @@ -104,14 +91,6 @@ unsafe fn sub64(a: &__m256i, b: &__m256i, bin: &__m256i) -> (__m256i, __m256i) { let r = _mm256_sub_epi64(r1, *bin); let bo = _mm256_or_si256(bo, b1); - // TODO - delete - /* - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), r); - println!("r: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), bo); - println!("b: {:?}", v); - */ - (r, bo) } diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs index db86c76b98..d10fe828ce 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs @@ -1164,7 +1164,7 @@ unsafe fn mds_layer_avx(s0: &mut __m256i, s1: &mut __m256i, s2: &mut __m256i) { let (rl0, c0) = add64_no_carry(&sl0, &shl0); let (rh0, _) = add64_no_carry(&shh0, &c0); let r0 = reduce_avx_128_64(&rh0, &rl0); - + let (rl1, c1) = add64_no_carry(&sl1, &shl1); let (rh1, _) = add64_no_carry(&shh1, &c1); *s1 = reduce_avx_128_64(&rh1, &rl1); @@ -1393,7 +1393,7 @@ where F: PrimeField64 + Poseidon, { let mut state = &mut input.clone(); - let mut round_ctr = 0; + let mut round_ctr = 0; unsafe { // load state @@ -1410,12 +1410,13 @@ where let rc2 = _mm256_loadu_si256((&rc[8..12]).as_ptr().cast::<__m256i>()); let ss0 = add_avx(&s0, &rc0); let ss1 = add_avx(&s1, &rc1); - let ss2 = add_avx(&s2, &rc2); + let ss2 = add_avx(&s2, &rc2); + (s0, s1, s2) = sbox_avx_m256i(&ss0, &ss1, &ss2); mds_layer_avx(&mut s0, &mut s1, &mut s2); - round_ctr += 1; + round_ctr += 1; } - + // this does partial_first_constant_layer_avx(&mut state); let c0 = _mm256_loadu_si256( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..4]) @@ -1441,7 +1442,7 @@ where _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), s0); _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), s1); _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), s2); - + for i in 0..N_PARTIAL_ROUNDS { state[0] = sbox_monomial(state[0]); state[0] = state[0].add_canonical_u64(FAST_PARTIAL_ROUND_CONSTANTS[i]); diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 9dc7c70898..86bc98d60d 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1271,11 +1271,36 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); let co = _mm512_mask_blend_epi64(m, zeros, ones); (r, co) + /* + let mut va: [u64; 8] = [0; 8]; + let mut vb: [u64; 8] = [0; 8]; + let mut vr: [u64; 8] = [0; 8]; + let mut vc: [u64; 8] = [0; 8]; + _mm512_storeu_epi64(va.as_mut_ptr().cast::(), *a); + _mm512_storeu_epi64(vb.as_mut_ptr().cast::(), *b); + for i in 0..8 { + vr[i] = va[i].wrapping_add(vb[i]); + vc[i] = if vr[i] < va[i] { 1 } else { 0 }; + } + let r = _mm512_loadu_epi64(vr.as_ptr().cast::()); + let c = _mm512_loadu_epi64(vc.as_ptr().cast::()); + (r, c) + */ } #[inline] pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { - _mm512_mullo_epi64(*a, *b) + // _mm512_mullo_epi64(*a, *b) + let r = _mm512_mul_epu32(*a, *b); + let ah = _mm512_srli_epi64(*a, 32); + let bh = _mm512_srli_epi64(*b, 32); + let r1 = _mm512_mul_epu32(*a, bh); + let r1 = _mm512_slli_epi64(r1, 32); + let r = _mm512_add_epi64(r, r1); + let r1 = _mm512_mul_epu32(ah, *b); + let r1 = _mm512_slli_epi64(r1, 32); + let r = _mm512_add_epi64(r, r1); + r } #[inline(always)] @@ -1479,8 +1504,8 @@ unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut let f0 = block1_avx512(&u0, MDS_FREQ_BLOCK_ONE); // let [v1, v5, v9] = block2([(u[0], v[0]), (u[1], v[1]), (u[2], v[2])], MDS_FREQ_BLOCK_TWO); - // let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); - let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + // let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); // let [v2, v6, v10] = block3_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); // [u[0], u[1], u[2]] are all in u3 @@ -1795,6 +1820,7 @@ where let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let ss2 = add_avx512(&s2, &rc2); + s0 = sbox_avx512_one(&ss0); s1 = sbox_avx512_one(&ss1); s2 = sbox_avx512_one(&ss2); @@ -1900,13 +1926,13 @@ where leaf_size / SPONGE_RATE + 1 }; for _ in 0..loops { - let end1 = if idx1 + SPONGE_RATE > leaf_size { + let end1 = if idx1 + SPONGE_RATE >= leaf_size { leaf_size } else { idx1 + SPONGE_RATE }; - let end2 = if idx2 + SPONGE_RATE > inputs.len() { - inputs.len() + let end2 = if idx2 + SPONGE_RATE >= 2 * leaf_size { + 2 * leaf_size } else { idx2 + SPONGE_RATE }; diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index f78c3f4a24..fa3752629d 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -450,6 +450,8 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField as F; use crate::field::types::{Field, PrimeField64}; #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + use crate::hash::arch::x86_64::poseidon_goldilocks_avx512::hash_leaf_avx512; + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] use crate::hash::poseidon::test_helpers::check_test_vectors_avx512; use crate::hash::poseidon::test_helpers::{check_consistency, check_test_vectors}; use crate::hash::poseidon::{Poseidon, PoseidonHash}; @@ -488,6 +490,12 @@ mod tests { [0xa89280105650c4ec, 0xab542d53860d12ed, 0x5704148e9ccab94f, 0xd3a826d4b62da9f5, 0x8a7a6ca87892574f, 0xc7017e1cad1a674e, 0x1f06668922318e34, 0xa3b203bc8102676f, 0xfcc781b0ce382bf2, 0x934c69ff3ed14ba5, 0x504688a5996e8f13, 0x401f3f2ed524a2ba, ]), + ([0xf2cc0ce426e7eddd, 0x91ad40f14cfdcb78, 0xc516c642346aabc, 0xa79a0411d96de0, + 0xf256c881b6167069, 0x5c767aa6354a647b, 0x79a821313415b9dc, 0xf083bc2f276b99e1, + 0x9aa0ac0171df5ac7, 0xc3c705daf69d66e0, 0x3b0468abe66c5ed, 0xdcf835c4d4cffd73, ], + [0x96d91d333e5e038d, 0x114395c7cfb7e18f, 0x19b1ea99556391ff, 0xd53855a776b4582a, + 0x378d8ea4ffbb7545, 0x168319892eff226a, 0x5f09f06508283bd, 0xb92d599c947cc2f1, + 0xf078fc732200e4d4, 0xcaf95e4285f3099d, 0x8532be1f10f23cd0, 0xc3260991186909ff, ]) ]; check_test_vectors::(test_vectors12.clone()); @@ -576,40 +584,144 @@ mod tests { #[test] fn test_hash_no_pad_gl() { - let inputs: [u64; 32] = [ - 9972144316416239374, - 7195869958086994472, - 12805395537960412263, - 6755149769410714396, - 16592921959755212957, - 1370750654791741308, - 11186995120529280354, - 288690570896506034, - 2896720011649362435, - 13870686984275550055, - 12288026009924247278, - 15608864109019511973, - 15690944173815210604, - 17535150735055770942, - 4265223756233917229, - 17236464151311603291, - 15180455466814482598, - 12377438429067983442, - 11274960245127600167, - 5684300978461808754, - 1918159483831849502, - 15340265949423289730, - 181633163915570313, - 12684059848091546996, - 10060377187090493210, - 13523019938818230572, - 16846214147461656883, - 13560222746484567233, - 2150999602305437005, - 9103462636082953981, - 16341057499572706412, - 842265247111451937, + let inputs = [ + 0xb8f463d7cb4f24f6, + 0xe94ad9aba668af65, + 0x4a31c8cee787786a, + 0x7f8ed7050aeadcf9, + 0x516c34f52a5c8b14, + 0x542c22306722b175, + 0x6feba1eb9030ecb9, + 0xe103d491fa784080, + 0x31d9a62ea39f4ec9, + 0xbf0ccc95d9b4c697, + 0x5a9d230167523b2e, + 0x7ff277e12091d2f2, + 0xf2af521b9537abf3, + 0xe39e815313da5c12, + 0xe5feaa1e4f46b87b, + 0x76b772a9e6eda11c, + 0x9005e1c8fbf27eed, + 0x78ea9242b53108ac, + 0x5561d33040b6affb, + 0x61ded48ffee1f243, + 0xebbe0c4034afb9e5, + 0x7973d462ab14d331, + 0x76a23e459a0849b, + 0x9fa93d23d8b84515, + 0x1e19bba2ce8042dd, + 0xb1159302625b71a3, + 0x792e2e4171fd7e83, + 0xc9088b032be7eff0, + 0x6540b29fbec19cb2, + 0x8c4f849dd68f4cdc, + 0xb91969b7cfcd1ec8, + 0x4d450eff6a3b0c7c, + 0xcace16a8345de56e, + 0xe5bac07b93e1f0e2, + 0x35088bde4f1bd3a9, + 0x2e0bd8e257386e40, + 0xed67fe1bd44680f0, + 0x887a32a6049105f, + 0x3ae86d4d60b87a67, + 0x665a656a217edacf, + 0x2eb451b933acbd2d, + 0x63876760e9570fb4, + 0x2b11da28eb95d7d6, + 0x138ea36659579c0a, + 0x457f674d92cfcd72, + 0xba4b8ffc7287142d, + 0x2b9bd3cd64e65cb6, + 0x2780e8b0e66848e8, + 0xe18303c5010835a4, + 0x6c4e379aba35e21e, + 0xf9c3f2f33320d9cd, + 0x82429ba2d6263c9a, + 0x11e81115fa995e88, + 0x75a7fb5681cd15e4, + 0xa54b2a0b6d57e340, + 0x884b3d9cc9b7f720, + 0xdac1b985f5b0ff19, + 0x5938c0405a01dbd4, + 0x13fb2d9399c3ef2e, + 0xeaed82d3706dccec, + 0xf8d853012e56f7fb, + 0xa4c639bbaf484525, + 0xe3b35501c21797ba, + 0x1a645013fcb5e3a0, + 0xf2eb2337ba169178, + 0xcc94fd9269c7d33, + 0x82a9aaa398b13f1, + 0xe9b5ecbe6576234, + 0x252287d7ed9ec792, + 0x30629bee322f17cc, + 0x9ae26078f44e8afb, + 0xabdc35ac8f527136, + 0x4b2a3be4ef4c231f, + 0x23074d5363eeba58, + 0x75cfe940f6967c16, + 0xfb185a23f6225406, + 0xda8a21bd2ba64cc3, + 0xd623bde11eb8c989, + 0x76201928e4523ba3, + 0x1c20cb194495b643, + 0x3e70ce2fddc52451, + 0x86c698ca61fdae8e, + 0x9855dd30ad0c1309, + 0x271541a781755737, + 0x209b4ccf7db16277, + 0xff27cae2771d1d8c, + 0xd7795488a7bfe6ee, + 0x9cf1875ec535778e, + 0x9fad94c126427390, + 0x199b482c029f3d9d, + 0x92ae2055bb3f6d6, + 0x29d6100b44167374, + 0x88e8c8ffdefe0f33, + 0xa3d8d929ea748a62, + 0xd5dbe1a3d99e113d, + 0x438639f8f0e3ff25, + 0xf2cc0ce426e7eddd, + 0x91ad40f14cfdcb78, + 0xc516c642346aabc, + 0xa79a0411d96de0, + 0xf256c881b6167069, + 0x5c767aa6354a647b, + 0x79a821313415b9dc, + 0xf083bc2f276b99e1, + 0x9d47fc86eb2de7c2, + 0x3370a8711a678a03, + 0x1572c8a8bf872b26, + 0xdbb7de1fc45360a1, + 0x5f87c0fe24bafdd4, + 0x2f6a5784207d118a, + 0x640c588afcf0cc14, + 0xe609f3cbb7cb015, + 0x8e4907544019be80, + 0xde2f553ac4ab68c3, + 0x29cd0d2800262365, + 0x3bf736a6fbc14ce2, + 0xab059c3c3cba4912, + 0xe609e14997bd2f5c, + 0x694189d934ff1f8d, + 0x54570348f45e3a9, + 0x90ef5b98b0a08a34, + 0x1b09b93749616de8, + 0x89be3144389d48c1, + 0xdaa7e268d0fd82d8, + 0xc46956b67fa89c61, + 0xec88a7133e4fefc, + 0xe41596ca682069f4, + 0x297f55e46472431b, + 0x33ada14fd813218d, + 0x22c57ca5e77249ad, + 0x4e2f2c7cc99f2d47, + 0x78d11ba2efc7556f, + 0xdfc98976b6e3ad0d, + 0x59d88f72bf5ad1d8, + 0x19ca05690b8e1ad9, ]; + let inputs = inputs .iter() .map(|x| F::from_canonical_u64(*x)) @@ -617,15 +729,25 @@ mod tests { let output = PoseidonHash::hash_no_pad(&inputs); let expected_out: [u64; 4] = [ - 8197835875512527937, - 7109417654116018994, - 18237163116575285904, - 17017896878738047012, + 0xc19dccf6ec4f3df3, + 0x1bf0d65af6925451, + 0xee9dbf2c8dcad9a2, + 0xae46323715f528a1, ]; let expected_out = expected_out .iter() .map(|x| F::from_canonical_u64(*x)) .collect::>(); assert_eq!(output.elements.to_vec(), expected_out); + + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + { + let mut dleaf: Vec = vec![F::from_canonical_u64(0); 2 * inputs.len()]; + dleaf[0..inputs.len()].copy_from_slice(&inputs); + dleaf[inputs.len()..2 * inputs.len()].copy_from_slice(&inputs); + let (h1, h2) = hash_leaf_avx512(dleaf.as_slice(), inputs.len()); + assert_eq!(h1, expected_out); + assert_eq!(h2, expected_out); + } } } From 7b892617d7f6fb6b9ec84a9b0b6247d4ea35f33d Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Thu, 17 Oct 2024 18:08:30 +0800 Subject: [PATCH 2/4] fix mul issue in avx512 --- .../src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 86bc98d60d..4a20c49202 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1290,7 +1290,8 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 #[inline] pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { - // _mm512_mullo_epi64(*a, *b) + /* + // long version let r = _mm512_mul_epu32(*a, *b); let ah = _mm512_srli_epi64(*a, 32); let bh = _mm512_srli_epi64(*b, 32); @@ -1301,6 +1302,8 @@ pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { let r1 = _mm512_slli_epi64(r1, 32); let r = _mm512_add_epi64(r, r1); r + */ + _mm512_mullo_epi64(*a, *b) } #[inline(always)] @@ -1504,8 +1507,8 @@ unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut let f0 = block1_avx512(&u0, MDS_FREQ_BLOCK_ONE); // let [v1, v5, v9] = block2([(u[0], v[0]), (u[1], v[1]), (u[2], v[2])], MDS_FREQ_BLOCK_TWO); - let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); - // let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + // let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); // let [v2, v6, v10] = block3_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); // [u[0], u[1], u[2]] are all in u3 From 8cc4ca276fd18e51c696d6df2db20875994286c9 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Fri, 18 Oct 2024 10:25:54 +0800 Subject: [PATCH 3/4] optimize avx512 code --- .../src/hash/arch/x86_64/goldilocks_avx512.rs | 13 +- .../arch/x86_64/poseidon_goldilocks_avx512.rs | 274 ++++++++---------- 2 files changed, 122 insertions(+), 165 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs index adb3ca702b..e67818e102 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs @@ -6,19 +6,22 @@ use crate::hash::hash_types::RichField; const MSB_: i64 = 0x8000000000000000u64 as i64; const P8_: i64 = 0xFFFFFFFF00000001u64 as i64; const P8_N_: i64 = 0xFFFFFFFF; +const ONE_: i64 = 1; #[allow(non_snake_case)] #[repr(align(64))] -struct FieldConstants { - MSB_V: [i64; 8], - P8_V: [i64; 8], - P8_N_V: [i64; 8], +pub(crate) struct FieldConstants { + pub(crate) MSB_V: [i64; 8], + pub(crate) P8_V: [i64; 8], + pub(crate) P8_N_V: [i64; 8], + pub(crate) ONE_V: [i64; 8], } -const FC: FieldConstants = FieldConstants { +pub(crate) const FC: FieldConstants = FieldConstants { MSB_V: [MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_], P8_V: [P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_], P8_N_V: [P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_], + ONE_V: [ONE_, ONE_, ONE_, ONE_, ONE_, ONE_, ONE_, ONE_], }; #[allow(dead_code)] diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 4a20c49202..282c19a44e 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1134,8 +1134,8 @@ where let mut result = [F::ZERO; SPONGE_WIDTH]; let res0 = state[0]; unsafe { - let mut r0 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r1 = _mm512_loadu_si512((&mut result[4..12]).as_mut_ptr().cast::()); + let mut r0 = _mm512_loadu_epi64((&mut result[0..8]).as_mut_ptr().cast::()); + let mut r1 = _mm512_loadu_epi64((&mut result[4..12]).as_mut_ptr().cast::()); for r in 1..12 { let sr512 = _mm512_set_epi64( @@ -1148,23 +1148,23 @@ where state[r].to_canonical_u64() as i64, state[r].to_canonical_u64() as i64, ); - let t0 = _mm512_loadu_si512( + let t0 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let t1 = _mm512_loadu_si512( + let t1 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][4..12]) .as_ptr() - .cast::(), + .cast::(), ); let m0 = mult_avx512(&sr512, &t0); let m1 = mult_avx512(&sr512, &t1); r0 = add_avx512(&r0, &m0); r1 = add_avx512(&r1, &m1); } - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); state[0] = res0; } } @@ -1177,22 +1177,22 @@ where F: PrimeField64, { unsafe { - let c0 = _mm512_loadu_si512( + let c0 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let c1 = _mm512_loadu_si512( + let c1 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[4..12]) .as_ptr() - .cast::(), + .cast::(), ); - let mut s0 = _mm512_loadu_si512((state[0..8]).as_ptr().cast::()); - let mut s1 = _mm512_loadu_si512((state[4..12]).as_ptr().cast::()); + let mut s0 = _mm512_loadu_epi64((state[0..8]).as_ptr().cast::()); + let mut s1 = _mm512_loadu_epi64((state[4..12]).as_ptr().cast::()); s0 = add_avx512(&s0, &c0); s1 = add_avx512(&s1, &c1); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), s1); } } @@ -1262,30 +1262,16 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 * - (test 3): if a + b < 2^64 (this means a + b is negative in signed representation) => no overflow so cout = 0 * - (test 3): if a + b >= 2^64 (this means a + b becomes positive in signed representation, that is, a + b >= 0) => there is overflow so cout = 1 */ - let ones = _mm512_set_epi64(1, 1, 1, 1, 1, 1, 1, 1); + let ones = _mm512_load_epi64(FC.ONE_V.as_ptr().cast::()); let zeros = _mm512_xor_si512(*a, *a); // faster 0 let r = _mm512_add_epi64(*a, *b); let ma = _mm512_cmpgt_epi64_mask(zeros, *a); let mb = _mm512_cmpgt_epi64_mask(zeros, *b); let mc = _mm512_cmpgt_epi64_mask(zeros, r); - let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); + // let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); + let m = (ma & mb) | (!mc & (ma ^ mb)); let co = _mm512_mask_blend_epi64(m, zeros, ones); (r, co) - /* - let mut va: [u64; 8] = [0; 8]; - let mut vb: [u64; 8] = [0; 8]; - let mut vr: [u64; 8] = [0; 8]; - let mut vc: [u64; 8] = [0; 8]; - _mm512_storeu_epi64(va.as_mut_ptr().cast::(), *a); - _mm512_storeu_epi64(vb.as_mut_ptr().cast::(), *b); - for i in 0..8 { - vr[i] = va[i].wrapping_add(vb[i]); - vc[i] = if vr[i] < va[i] { 1 } else { 0 }; - } - let r = _mm512_loadu_epi64(vr.as_ptr().cast::()); - let c = _mm512_loadu_epi64(vc.as_ptr().cast::()); - (r, c) - */ } #[inline] @@ -1303,6 +1289,7 @@ pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { let r = _mm512_add_epi64(r, r1); r */ + // short version _mm512_mullo_epi64(*a, *b) } @@ -1329,16 +1316,16 @@ unsafe fn block1_avx512(x: &__m512i, y: [i64; 3]) -> __m512i { unsafe fn block2_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> (__m512i, __m512i) { let mut vxr: [i64; 8] = [0; 8]; let mut vxi: [i64; 8] = [0; 8]; - _mm512_storeu_si512(vxr.as_mut_ptr().cast::(), *xr); - _mm512_storeu_si512(vxi.as_mut_ptr().cast::(), *xi); + _mm512_storeu_epi64(vxr.as_mut_ptr().cast::(), *xr); + _mm512_storeu_epi64(vxi.as_mut_ptr().cast::(), *xi); let x1: [(i64, i64); 3] = [(vxr[0], vxi[0]), (vxr[1], vxi[1]), (vxr[2], vxi[2])]; let x2: [(i64, i64); 3] = [(vxr[4], vxi[4]), (vxr[5], vxi[5]), (vxr[6], vxi[6])]; let b1 = block2(x1, y); let b2 = block2(x2, y); vxr = [b1[0].0, b1[1].0, b1[2].0, 0, b2[0].0, b2[1].0, b2[2].0, 0]; vxi = [b1[0].1, b1[1].1, b1[2].1, 0, b2[0].1, b2[1].1, b2[2].1, 0]; - let rr = _mm512_loadu_si512(vxr.as_ptr().cast::()); - let ri = _mm512_loadu_si512(vxi.as_ptr().cast::()); + let rr = _mm512_loadu_epi64(vxr.as_ptr().cast::()); + let ri = _mm512_loadu_epi64(vxi.as_ptr().cast::()); (rr, ri) } @@ -1371,18 +1358,9 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif1perm2 = _mm512_permutex_epi64(dif1, 0x2); let z0i = _mm512_add_epi64(dif3, dif1perm1); let z0i = _mm512_add_epi64(z0i, dif1perm2); - let mask = _mm512_set_epi64( - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - ); - let z0r = _mm512_and_si512(z0r, mask); - let z0i = _mm512_and_si512(z0i, mask); + let zeros = _mm512_xor_si512(z0r, z0r); + let z0r = _mm512_mask_blend_epi64(0x11, zeros, z0r); + let z0i = _mm512_mask_blend_epi64(0x11, zeros, z0i); // z1 // z1r = dif2[0] + dif2[1] + prod[2] - sum[2]; @@ -1405,18 +1383,8 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif1perm = _mm512_permutex_epi64(dif1, 0x8); let z1i = _mm512_add_epi64(dif3, dif3perm); let z1i = _mm512_add_epi64(z1i, dif1perm); - let mask = _mm512_set_epi64( - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - ); - let z1r = _mm512_and_si512(z1r, mask); - let z1i = _mm512_and_si512(z1i, mask); + let z1r = _mm512_mask_blend_epi64(0x22, zeros, z1r); + let z1i = _mm512_mask_blend_epi64(0x22, zeros, z1i); // z2 // z2r = dif2[0] + dif2[1] + dif2[2]; @@ -1438,18 +1406,8 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif3perm2 = _mm512_permutex_epi64(dif3, 0x10); let z2i = _mm512_add_epi64(dif3, dif3perm1); let z2i = _mm512_add_epi64(z2i, dif3perm2); - let mask = _mm512_set_epi64( - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - ); - let z2r = _mm512_and_si512(z2r, mask); - let z2i = _mm512_and_si512(z2i, mask); + let z2r = _mm512_mask_blend_epi64(0x44, zeros, z2r); + let z2i = _mm512_mask_blend_epi64(0x44, zeros, z2i); let zr = _mm512_or_si512(z0r, z1r); let zr = _mm512_or_si512(zr, z2r); @@ -1528,10 +1486,7 @@ unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut #[inline(always)] #[unroll_for_loops] unsafe fn mds_layer_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut __m512i) { - let mask = _mm512_set_epi64( - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, - ); + let mask = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let mut sl0 = _mm512_and_si512(*s0, mask); let mut sl1 = _mm512_and_si512(*s1, mask); let mut sl2 = _mm512_and_si512(*s2, mask); @@ -1573,48 +1528,48 @@ unsafe fn mds_partial_layer_init_avx512(s0: &mut __m512i, s1: &mut __m512i, s where F: PrimeField64, { - let mut result = [F::ZERO; 2 * SPONGE_WIDTH]; let res0 = *s0; - - let mut r0 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r1 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r2 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); + let mut r0 = _mm512_xor_epi64(res0, res0); + let mut r1 = r0; + let mut r2 = r0; for r in 1..12 { - let sr = match r { - 1 => _mm512_permutex_epi64(*s0, 0x55), - 2 => _mm512_permutex_epi64(*s0, 0xAA), - 3 => _mm512_permutex_epi64(*s0, 0xFF), - 4 => _mm512_permutex_epi64(*s1, 0x0), - 5 => _mm512_permutex_epi64(*s1, 0x55), - 6 => _mm512_permutex_epi64(*s1, 0xAA), - 7 => _mm512_permutex_epi64(*s1, 0xFF), - 8 => _mm512_permutex_epi64(*s2, 0x0), - 9 => _mm512_permutex_epi64(*s2, 0x55), - 10 => _mm512_permutex_epi64(*s2, 0xAA), - 11 => _mm512_permutex_epi64(*s2, 0xFF), - _ => _mm512_permutex_epi64(*s0, 0x55), - }; - let t0 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][0..8]) - .as_ptr() - .cast::(), - ); - let t1 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][8..16]) - .as_ptr() - .cast::(), - ); - let t2 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][16..24]) - .as_ptr() - .cast::(), - ); - let m0 = mult_avx512(&sr, &t0); - let m1 = mult_avx512(&sr, &t1); - let m2 = mult_avx512(&sr, &t2); - r0 = add_avx512(&r0, &m0); - r1 = add_avx512(&r1, &m1); - r2 = add_avx512(&r2, &m2); + if r < 12 { + let sr = match r { + 1 => _mm512_permutex_epi64(*s0, 0x55), + 2 => _mm512_permutex_epi64(*s0, 0xAA), + 3 => _mm512_permutex_epi64(*s0, 0xFF), + 4 => _mm512_permutex_epi64(*s1, 0x0), + 5 => _mm512_permutex_epi64(*s1, 0x55), + 6 => _mm512_permutex_epi64(*s1, 0xAA), + 7 => _mm512_permutex_epi64(*s1, 0xFF), + 8 => _mm512_permutex_epi64(*s2, 0x0), + 9 => _mm512_permutex_epi64(*s2, 0x55), + 10 => _mm512_permutex_epi64(*s2, 0xAA), + 11 => _mm512_permutex_epi64(*s2, 0xFF), + _ => _mm512_permutex_epi64(*s0, 0x55), + }; + let t0 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][0..8]) + .as_ptr() + .cast::(), + ); + let t1 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][8..16]) + .as_ptr() + .cast::(), + ); + let t2 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][16..24]) + .as_ptr() + .cast::(), + ); + let m0 = mult_avx512(&sr, &t0); + let m1 = mult_avx512(&sr, &t1); + let m2 = mult_avx512(&sr, &t2); + r0 = add_avx512(&r0, &m0); + r1 = add_avx512(&r1, &m1); + r2 = add_avx512(&r2, &m2); + } } *s0 = _mm512_mask_blend_epi64(0x11, r0, res0); *s1 = r1; @@ -1677,20 +1632,20 @@ unsafe fn mds_partial_layer_fast_avx512( state[0].to_noncanonical_u64() as i64, state[0].to_noncanonical_u64() as i64, ); - let rc0 = _mm512_loadu_si512( + let rc0 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let rc1 = _mm512_loadu_si512( + let rc1 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][8..16]) .as_ptr() - .cast::(), + .cast::(), ); - let rc2 = _mm512_loadu_si512( + let rc2 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][16..24]) .as_ptr() - .cast::(), + .cast::(), ); let (mh, ml) = mult_avx512_128(&ss0, &rc0); let m = reduce_avx512_128_64(&mh, &ml); @@ -1715,9 +1670,9 @@ unsafe fn mds_partial_layer_fast_avx512( let m = reduce_avx512_128_64(&mh, &ml); *s2 = add_avx512(s2, &m); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), *s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), *s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), *s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), *s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), *s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), *s2); } #[allow(unused)] @@ -1732,22 +1687,22 @@ where // Self::full_rounds(&mut state, &mut round_ctr); for _ in 0..HALF_N_FULL_ROUNDS { // load state - let s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let s1 = _mm512_loadu_si512((&state[4..12]).as_ptr().cast::()); + let s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let s1 = _mm512_loadu_epi64((&state[4..12]).as_ptr().cast::()); let rc: &[u64; 12] = &ALL_ROUND_CONSTANTS[SPONGE_WIDTH * round_ctr..][..SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[4..12]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[4..12]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let r0 = sbox_avx512_one(&ss0); let r1 = sbox_avx512_one(&ss1); // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); *state = ::mds_layer(&state); round_ctr += 1; @@ -1765,22 +1720,22 @@ where // Self::full_rounds(&mut state, &mut round_ctr); for _ in 0..HALF_N_FULL_ROUNDS { // load state - let s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let s1 = _mm512_loadu_si512((&state[4..12]).as_ptr().cast::()); + let s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let s1 = _mm512_loadu_epi64((&state[4..12]).as_ptr().cast::()); let rc: &[u64; 12] = &ALL_ROUND_CONSTANTS[SPONGE_WIDTH * round_ctr..][..SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[4..12]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[4..12]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let r0 = sbox_avx512_one(&ss0); let r1 = sbox_avx512_one(&ss1); // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); *state = ::mds_layer(&state); // mds_layer_avx::(&mut s0, &mut s1, &mut s2); @@ -1808,22 +1763,21 @@ where unsafe { // load state - let mut s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let mut s1 = _mm512_loadu_si512((&state[8..16]).as_ptr().cast::()); - let mut s2 = _mm512_loadu_si512((&state[16..24]).as_ptr().cast::()); + let mut s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let mut s1 = _mm512_loadu_epi64((&state[8..16]).as_ptr().cast::()); + let mut s2 = _mm512_loadu_epi64((&state[16..24]).as_ptr().cast::()); for _ in 0..HALF_N_FULL_ROUNDS { let rc: &[u64; 24] = &ALL_ROUND_CONSTANTS_AVX512[2 * SPONGE_WIDTH * round_ctr..] [..2 * SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[8..16]).as_ptr().cast::()); - let rc2 = _mm512_loadu_si512((&rc[16..24]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[8..16]).as_ptr().cast::()); + let rc2 = _mm512_loadu_epi64((&rc[16..24]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let ss2 = add_avx512(&s2, &rc2); - s0 = sbox_avx512_one(&ss0); s1 = sbox_avx512_one(&ss1); s2 = sbox_avx512_one(&ss2); @@ -1832,20 +1786,20 @@ where } // this does partial_first_constant_layer_avx(&mut state); - let c0 = _mm512_loadu_si512( + let c0 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let c1 = _mm512_loadu_si512( + let c1 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[8..16]) .as_ptr() - .cast::(), + .cast::(), ); - let c2 = _mm512_loadu_si512( + let c2 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[16..24]) .as_ptr() - .cast::(), + .cast::(), ); s0 = add_avx512(&s0, &c0); s1 = add_avx512(&s1, &c1); @@ -1853,9 +1807,9 @@ where mds_partial_layer_init_avx512::(&mut s0, &mut s1, &mut s2); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), s2); for i in 0..N_PARTIAL_ROUNDS { state[0] = sbox_monomial(state[0]); @@ -1873,9 +1827,9 @@ where [..2 * SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[8..16]).as_ptr().cast::()); - let rc2 = _mm512_loadu_si512((&rc[16..24]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[8..16]).as_ptr().cast::()); + let rc2 = _mm512_loadu_epi64((&rc[16..24]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let ss2 = add_avx512(&s2, &rc2); @@ -1887,9 +1841,9 @@ where } // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), s2); debug_assert_eq!(round_ctr, N_ROUNDS); }; From 6b4bce20bce2aebcc9a1d143fc1fc3d3aa425f06 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Fri, 18 Oct 2024 10:26:19 +0800 Subject: [PATCH 4/4] cargo fmt --- plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs | 3 +-- .../src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs index e67818e102..dd305d5c8e 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs @@ -59,7 +59,7 @@ pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i { } */ unsafe { - let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); + let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); let a_sc = _mm512_xor_si512(*a, msb); let c0_s = _mm512_add_epi64(a_sc, *b); let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); @@ -79,7 +79,6 @@ pub fn add_avx512_s_b_small(a_s: &__m512i, b_small: &__m512i) -> __m512i { } } - #[inline(always)] pub fn sub_avx512(a: &__m512i, b: &__m512i) -> __m512i { unsafe { diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs index d10fe828ce..301e043446 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs @@ -1164,7 +1164,7 @@ unsafe fn mds_layer_avx(s0: &mut __m256i, s1: &mut __m256i, s2: &mut __m256i) { let (rl0, c0) = add64_no_carry(&sl0, &shl0); let (rh0, _) = add64_no_carry(&shh0, &c0); let r0 = reduce_avx_128_64(&rh0, &rl0); - + let (rl1, c1) = add64_no_carry(&sl1, &shl1); let (rh1, _) = add64_no_carry(&shh1, &c1); *s1 = reduce_avx_128_64(&rh1, &rl1); @@ -1393,7 +1393,7 @@ where F: PrimeField64 + Poseidon, { let mut state = &mut input.clone(); - let mut round_ctr = 0; + let mut round_ctr = 0; unsafe { // load state @@ -1410,13 +1410,13 @@ where let rc2 = _mm256_loadu_si256((&rc[8..12]).as_ptr().cast::<__m256i>()); let ss0 = add_avx(&s0, &rc0); let ss1 = add_avx(&s1, &rc1); - let ss2 = add_avx(&s2, &rc2); + let ss2 = add_avx(&s2, &rc2); (s0, s1, s2) = sbox_avx_m256i(&ss0, &ss1, &ss2); mds_layer_avx(&mut s0, &mut s1, &mut s2); - round_ctr += 1; + round_ctr += 1; } - + // this does partial_first_constant_layer_avx(&mut state); let c0 = _mm256_loadu_si256( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..4]) @@ -1442,7 +1442,7 @@ where _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), s0); _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), s1); _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), s2); - + for i in 0..N_PARTIAL_ROUNDS { state[0] = sbox_monomial(state[0]); state[0] = state[0].add_canonical_u64(FAST_PARTIAL_ROUND_CONSTANTS[i]);