Skip to content

Commit

Permalink
Merge pull request #37 from okx/dev-dumi
Browse files Browse the repository at this point in the history
add LDE bench and salt support
  • Loading branch information
dloghin authored Oct 15, 2024
2 parents 9a917ba + 3e2cc3f commit 376d690
Show file tree
Hide file tree
Showing 23 changed files with 2,347 additions and 278 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = ["field", "maybe_rayon", "plonky2", "starky", "util", "gen", "u32", "e
resolver = "2"

[workspace.dependencies]
cryptography_cuda = { git = "ssh://git@github.com/okx/cryptography_cuda.git", rev = "2a7c42d29ee72d7c2c2da9378ae816384c43cdec" }
cryptography_cuda = { git = "ssh://git@github.com/okx/cryptography_cuda.git", rev = "547192b2ef42dc7519435059c86f88431b8de999" }
ahash = { version = "0.8.7", default-features = false, features = [
"compile-time-rng",
] } # NOTE: Be sure to keep this version the same as the dependency in `hashbrown`.
Expand Down
16 changes: 0 additions & 16 deletions field/src/fft.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use alloc::vec::Vec;
use core::cmp::{max, min};

#[cfg(feature = "cuda")]
use cryptography_cuda::{ntt, types::NTTInputOutputOrder};
use plonky2_util::{log2_strict, reverse_index_bits_in_place};
use unroll::unroll_for_loops;

Expand Down Expand Up @@ -34,20 +32,6 @@ pub fn fft_root_table<F: Field>(n: usize) -> FftRootTable<F> {
root_table
}

#[allow(dead_code)]
#[cfg(feature = "cuda")]
fn fft_dispatch_gpu<F: Field>(
input: &mut [F],
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) {
if F::CUDA_SUPPORT {
return ntt(0, input, NTTInputOutputOrder::NN);
} else {
return fft_dispatch_cpu(input, zero_factor, root_table);
}
}

fn fft_dispatch_cpu<F: Field>(
input: &mut [F],
zero_factor: Option<usize>,
Expand Down
6 changes: 5 additions & 1 deletion plonky2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ serde = { workspace = true, features = ["rc"] }
static_assertions = { workspace = true }
unroll = { workspace = true }
web-time = { version = "1.0.0", optional = true }
once_cell = { version = "1.18.0" }
once_cell = { version = "1.20.2" }
papi-bindings = { version = "0.5.2" }

# Local dependencies
Expand Down Expand Up @@ -80,6 +80,10 @@ harness = false
name = "ffts"
harness = false

[[bench]]
name = "lde"
harness = false

[[bench]]
name = "hashing"
harness = false
Expand Down
59 changes: 59 additions & 0 deletions plonky2/benches/lde.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
mod allocator;

use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
#[cfg(feature = "cuda")]
use cryptography_cuda::init_cuda_degree_rs;
use plonky2::field::extension::Extendable;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::polynomial::PolynomialCoeffs;
use plonky2::fri::oracle::PolynomialBatch;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::util::timing::TimingTree;
use tynm::type_name;

pub(crate) fn bench_batch_lde<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
c: &mut Criterion,
) {
const RATE_BITS: usize = 3;

let mut group = c.benchmark_group(&format!("lde<{}>", type_name::<F>()));

#[cfg(feature = "cuda")]
init_cuda_degree_rs(16);

for size_log in [13, 14, 15] {
let orig_size = 1 << (size_log - RATE_BITS);
let lde_size = 1 << size_log;
let batch_size = 1 << 4;

group.bench_with_input(BenchmarkId::from_parameter(lde_size), &lde_size, |b, _| {
let polynomials: Vec<PolynomialCoeffs<F>> = (0..batch_size)
.into_iter()
.map(|_i| PolynomialCoeffs::new(F::rand_vec(orig_size)))
.collect();
let mut timing = TimingTree::new("lde", log::Level::Error);
b.iter(|| {
PolynomialBatch::<F, C, D>::from_coeffs(
polynomials.clone(),
RATE_BITS,
false,
1,
&mut timing,
None,
)
});
});
}
}

fn criterion_benchmark(c: &mut Criterion) {
bench_batch_lde::<GoldilocksField, PoseidonGoldilocksConfig, 2>(c);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
56 changes: 44 additions & 12 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose};

#[cfg(all(feature = "cuda", any(test, doctest)))]
pub static GPU_INIT: once_cell::sync::Lazy<std::sync::Arc<std::sync::Mutex<u64>>> =
once_cell::sync::Lazy::new(|| std::sync::Arc::new(std::sync::Mutex::new(0)));

#[cfg(all(feature = "cuda", any(test, doctest)))]
fn init_gpu() {
use cryptography_cuda::init_cuda_rs;

let mut init = GPU_INIT.lock().unwrap();
if *init == 0 {
println!("Init GPU!");
init_cuda_rs();
*init = 1;
}
}

/// Four (~64 bit) field elements gives ~128 bit security.
pub const SALT_SIZE: usize = 4;

Expand Down Expand Up @@ -192,10 +208,17 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
timing: &mut TimingTree,
fft_root_table: Option<&FftRootTable<F>>,
) -> Self {
let pols = polynomials.len();
let degree = polynomials[0].len();
let log_n = log2_strict(degree);

if log_n + rate_bits > 1 && polynomials.len() > 0 {
#[cfg(any(test, doctest))]
init_gpu();

if log_n + rate_bits > 1
&& polynomials.len() > 0
&& pols * (1 << (log_n + rate_bits)) < (1 << 31)
{
let _num_gpus: usize = std::env::var("NUM_OF_GPUS")
.expect("NUM_OF_GPUS should be set")
.parse()
Expand Down Expand Up @@ -232,17 +255,17 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
}

#[cfg(feature = "cuda")]
pub fn from_coeffs_gpu(
fn from_coeffs_gpu(
polynomials: &[PolynomialCoeffs<F>],
rate_bits: usize,
_blinding: bool,
blinding: bool,
cap_height: usize,
timing: &mut TimingTree,
_fft_root_table: Option<&FftRootTable<F>>,
log_n: usize,
_degree: usize,
) -> MerkleTree<F, <C as GenericConfig<D>>::Hasher> {
// let salt_size = if blinding { SALT_SIZE } else { 0 };
let salt_size = if blinding { SALT_SIZE } else { 0 };
// println!("salt_size: {:?}", salt_size);
let output_domain_size = log_n + rate_bits;

Expand All @@ -255,8 +278,9 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
let total_num_of_fft = polynomials.len();
// println!("total_num_of_fft: {:?}", total_num_of_fft);

let num_of_cols = total_num_of_fft + salt_size; // if blinding, extend by salt_size
let total_num_input_elements = total_num_of_fft * (1 << log_n);
let total_num_output_elements = total_num_of_fft * (1 << output_domain_size);
let total_num_output_elements = num_of_cols * (1 << output_domain_size);

let mut gpu_input: Vec<F> = polynomials
.into_iter()
Expand All @@ -270,6 +294,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
cfg_lde.are_outputs_on_device = true;
cfg_lde.with_coset = true;
cfg_lde.is_multi_gpu = true;
cfg_lde.salt_size = salt_size as u32;

let mut device_output_data: HostOrDeviceSlice<'_, F> =
HostOrDeviceSlice::cuda_malloc(0 as i32, total_num_output_elements).unwrap();
Expand Down Expand Up @@ -302,7 +327,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
}

let mut cfg_trans = TransposeConfig::default();
cfg_trans.batches = total_num_of_fft as u32;
cfg_trans.batches = num_of_cols as u32;
cfg_trans.are_inputs_on_device = true;
cfg_trans.are_outputs_on_device = true;

Expand All @@ -327,10 +352,14 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
MerkleTree::new_from_gpu_leaves(
&device_transpose_data,
1 << output_domain_size,
total_num_of_fft,
num_of_cols,
cap_height
)
);

drop(device_transpose_data);
drop(device_output_data);

mt
}

Expand All @@ -340,6 +369,9 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
blinding: bool,
fft_root_table: Option<&FftRootTable<F>>,
) -> Vec<Vec<F>> {
#[cfg(all(feature = "cuda", any(test, doctest)))]
init_gpu();

let degree = polynomials[0].len();
#[cfg(all(feature = "cuda", feature = "batch"))]
let log_n = log2_strict(degree) + rate_bits;
Expand Down Expand Up @@ -443,11 +475,11 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
println!("collect data from gpu used: {:?}", start.elapsed());
r
})
// .chain(
// (0..salt_size)
// .into_par_iter()
// .map(|_| F::rand_vec(degree << rate_bits)),
// )
.chain(
(0..salt_size)
.into_par_iter()
.map(|_| F::rand_vec(degree << rate_bits)),
)
.collect();
println!("real lde elapsed: {:?}", start_lde.elapsed());
return ret;
Expand Down
2 changes: 1 addition & 1 deletion plonky2/src/gates/gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use core::ops::Range;
use std::sync::Arc;

use hashbrown::HashMap;
use serde::{ Serialize, Serializer};
use serde::{Serialize, Serializer};

use crate::field::batch_util::batch_multiply_inplace;
use crate::field::extension::{Extendable, FieldExtension};
Expand Down
6 changes: 1 addition & 5 deletions plonky2/src/gates/low_degree_interpolation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
fn id(&self) -> String {
format!("{self:?}<D={D}>")
}
fn serialize(
&self,
dst: &mut Vec<u8>,
_common_data: &CommonCircuitData<F, D>,
) -> IoResult<()> {
fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
dst.write_usize(self.subgroup_bits)?;
Ok(())
}
Expand Down
56 changes: 45 additions & 11 deletions plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,26 @@ const MSB_: i64 = 0x8000000000000000u64 as i64;
const P8_: i64 = 0xFFFFFFFF00000001u64 as i64;
const P8_N_: i64 = 0xFFFFFFFF;

#[allow(non_snake_case)]
#[repr(align(64))]
struct FieldConstants {
MSB_V: [i64; 8],
P8_V: [i64; 8],
P8_N_V: [i64; 8],
}

const FC: FieldConstants = FieldConstants {
MSB_V: [MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_],
P8_V: [P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_],
P8_N_V: [P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_],
};

#[allow(dead_code)]
#[inline(always)]
pub fn shift_avx512(a: &__m512i) -> __m512i {
unsafe {
let msb = _mm512_set_epi64(MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_);
// let msb = _mm512_set_epi64(MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_);
let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::<i32>());
_mm512_xor_si512(*a, msb)
}
}
Expand All @@ -20,27 +35,31 @@ pub fn shift_avx512(a: &__m512i) -> __m512i {
#[inline(always)]
pub fn to_canonical_avx512(a: &__m512i) -> __m512i {
unsafe {
let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
// let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8 = _mm512_load_si512(FC.P8_V.as_ptr().cast::<i32>());
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let result_mask = _mm512_cmpge_epu64_mask(*a, p8);
_mm512_mask_add_epi64(*a, result_mask, *a, p8_n)
}
}

#[inline(always)]
pub fn add_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i {
unsafe {
let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let c0 = _mm512_add_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*a, c0);
_mm512_mask_add_epi64(c0, result_mask, c0, p8_n)
}
}

#[inline(always)]
pub fn sub_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
pub fn sub_avx512(a: &__m512i, b: &__m512i) -> __m512i {
unsafe {
let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
// let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
let p8 = _mm512_load_si512(FC.P8_V.as_ptr().cast::<i32>());
let c0 = _mm512_sub_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*b, *a);
_mm512_mask_add_epi64(c0, result_mask, c0, p8)
Expand All @@ -50,11 +69,25 @@ pub fn sub_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
#[inline(always)]
pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i {
unsafe {
let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let c_hh = _mm512_srli_epi64(*c_h, 32);
let c1 = sub_avx512_b_c(c_l, &c_hh);
let c1 = sub_avx512(c_l, &c_hh);
let c2 = _mm512_mul_epu32(*c_h, p8_n);
add_avx512_b_c(&c1, &c2)
add_avx512(&c1, &c2)
}
}

// Here we suppose c_h < 2^32
#[inline(always)]
pub fn reduce_avx512_96_64(c_h: &__m512i, c_l: &__m512i) -> __m512i {
unsafe {
let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::<i32>());
let p_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let c_ls = _mm512_xor_si512(*c_l, msb);
let c2 = _mm512_mul_epu32(*c_h, p_n);
let c_s = add_avx512(&c_ls, &c2);
_mm512_xor_si512(c_s, msb)
}
}

Expand All @@ -69,7 +102,8 @@ pub fn mult_avx512_128(a: &__m512i, b: &__m512i) -> (__m512i, __m512i) {
let c_ll = _mm512_mul_epu32(*a, *b);
let c_ll_h = _mm512_srli_epi64(c_ll, 32);
let r0 = _mm512_add_epi64(c_hl, c_ll_h);
let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let r0_l = _mm512_and_si512(r0, p8_n);
let r0_h = _mm512_srli_epi64(r0, 32);
let r1 = _mm512_add_epi64(c_lh, r0_l);
Expand Down
2 changes: 1 addition & 1 deletion plonky2/src/hash/arch/x86_64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod goldilocks_avx512;
pub mod poseidon2_goldilocks_avx2;
#[cfg(target_feature = "avx2")]
pub mod poseidon_bn128_avx2;
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512dq")))]
#[cfg(target_feature = "avx2")]
pub mod poseidon_goldilocks_avx2;
#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))]
pub mod poseidon_goldilocks_avx512;
Loading

0 comments on commit 376d690

Please sign in to comment.