diff --git a/d4ft/native/obara_saika/eri_kernel.cu b/d4ft/native/obara_saika/eri_kernel.cu index 9b02dea..15f57c2 100644 --- a/d4ft/native/obara_saika/eri_kernel.cu +++ b/d4ft/native/obara_saika/eri_kernel.cu @@ -19,6 +19,28 @@ HEMI_DEV_CALLABLE void triu_ij_from_index(int n, int index, int *i, *j = j_; } +HEMI_DEV_CALLABLE void get_symmetry_count(int i, int j, int k, int l, int *count) { + int count_ = 1; + if(i == k & j == l){ + if(i != j){ + count_ *= 4; + } + } else{ + count_ *= 2; + if(i == j){ + if(k != l){ + count_ *= 2; + } + } else{ + count_ *= 2; + if(k != l){ + count_ *= 2; + } + } + } + *count = count_; +} + // template void Hartree_32::Gpu(cudaStream_t stream, Array& N, @@ -74,7 +96,12 @@ void Hartree_64::Gpu(cudaStream_t stream, Array& sorted_cd_idx, Array& screened_cd_idx_start, Array& screened_idx_offset, - Array& output) { + Array& pgto_coeff, + Array& pgto_normalization_factor, + Array& pgto_idx_to_cgto_idx, + Array& rdm1, + Array& n_cgto, + Array& output) { // Prescreening int* idx_4c; int idx_length; @@ -92,8 +119,6 @@ void Hartree_64::Gpu(cudaStream_t stream, loc = screened_idx_offset.ptr[index] + i - screened_cd_idx_start.ptr[index]; idx_4c[loc] = sorted_ab_idx.ptr[index]; // ab idx_4c[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd - output.ptr[loc] = sorted_ab_idx.ptr[index]; // ab - output.ptr[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd } __syncthreads(); }); @@ -104,8 +129,28 @@ void Hartree_64::Gpu(cudaStream_t stream, int a, b, c, d; // pgto 4c idx int i, j, k, l; // cgto 4c idx double eri_result; + double Na, Nb, Nc, Nd; + double Ca, Cb, Cc, Cd; + double Mab, Mcd; + int count; triu_ij_from_index(N.ptr[0], idx_4c[index], &a, &b); triu_ij_from_index(N.ptr[0], idx_4c[index + screened_length.ptr[0]], &c, &d); + get_symmetry_count(a, b, c, d, &count); + double dcount = static_cast(count); + Ca = pgto_coeff.ptr[a]; + Cb = pgto_coeff.ptr[b]; + Cc = pgto_coeff.ptr[c]; + Cd = pgto_coeff.ptr[d]; + Na = pgto_normalization_factor.ptr[a]; + Nb = pgto_normalization_factor.ptr[b]; + Nc = pgto_normalization_factor.ptr[c]; + Nd = pgto_normalization_factor.ptr[d]; + i = pgto_idx_to_cgto_idx.ptr[a]; + j = pgto_idx_to_cgto_idx.ptr[b]; + k = pgto_idx_to_cgto_idx.ptr[c]; + l = pgto_idx_to_cgto_idx.ptr[d]; + Mab = rdm1.ptr[i*n_cgto.ptr[0] + j]; + Mcd = rdm1.ptr[k*n_cgto.ptr[0] + k]; eri_result = eri(n.ptr[0 * N.ptr[0] + a], n.ptr[1 * N.ptr[0] + a], n.ptr[2 * N.ptr[0] + a], // a n.ptr[0 * N.ptr[0] + b], n.ptr[1 * N.ptr[0] + b], n.ptr[2 * N.ptr[0] + b], // b n.ptr[0 * N.ptr[0] + c], n.ptr[1 * N.ptr[0] + c], n.ptr[2 * N.ptr[0] + c], // c @@ -116,6 +161,9 @@ void Hartree_64::Gpu(cudaStream_t stream, r.ptr[0 * N.ptr[0] + d], r.ptr[1 * N.ptr[0] + d], r.ptr[2 * N.ptr[0] + d], // d z.ptr[a], z.ptr[b], z.ptr[c], z.ptr[d], // z min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr); + eri_result = eri_result * dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd; + // prod result from rdm1 + atomicAdd(output.ptr, eri_result); }); // std::cout<shape[0]<& shape11, const Spec& shape12, const Spec& shape13, - const Spec& shape14) { + const Spec& shape14, + const Spec& shape15, + const Spec& shape16, + const Spec& shape17, + const Spec& shape18, + const Spec& shape19) { // double n2 = shape4.shape[0]*(shape4.shape[0]+1)/2; // double n4 = n2*(n2+1)/2; // int n4_int = static_cast(n4); - std::vector outshape={2*shape11.shape[0]*shape12.shape[0]}; - // std::vector outshape={shape1.shape[0]}; - Spec out(outshape); + // std::vector outshape={2*shape11.shape[0]*shape12.shape[0]}; + std::vector outshape={shape1.shape[0]}; + Spec out(outshape); return std::make_tuple(out); } // static void Cpu(Array& arg1, Array& arg2, @@ -128,7 +133,12 @@ class Hartree_64 { Array& sorted_cd_idx, Array& screened_cd_idx_start, Array& screened_idx_offset, - Array& output){ + Array& pgto_coeff, + Array& pgto_normalization_factor, + Array& pgto_idx_to_cgto_idx, + Array& rdm1, + Array& n_cgto, + Array& output){ // std::memcpy(output.ptr, outshape.ptr, sizeof(float) * outshape.spec->Size()); } @@ -148,7 +158,12 @@ class Hartree_64 { Array& sorted_cd_idx, Array& screened_cd_idx_start, Array& screened_idx_offset, - Array& output); + Array& pgto_coeff, + Array& pgto_normalization_factor, + Array& pgto_idx_to_cgto_idx, + Array& rdm1, + Array& n_cgto, + Array& output); }; class Hartree_32_uncontracted { diff --git a/d4ft/native/obara_saika/eri_kernel.so b/d4ft/native/obara_saika/eri_kernel.so index 06cfcfb..2462842 100755 Binary files a/d4ft/native/obara_saika/eri_kernel.so and b/d4ft/native/obara_saika/eri_kernel.so differ diff --git a/tests/native/xla/eri_kernel_test.py b/tests/native/xla/eri_kernel_test.py index 9570958..3d3b971 100644 --- a/tests/native/xla/eri_kernel_test.py +++ b/tests/native/xla/eri_kernel_test.py @@ -22,6 +22,7 @@ from d4ft.integral.gto import symmetry, tensorization from copy import deepcopy from d4ft.types import AngularStats, CGTOSymTensorIncore, Tensor2C, Tensor4C +from d4ft.utils import get_rdm1 # from obsa.obara_saika import get_coulomb, get_kinetic, get_nuclear, get_overlap # from jax.interpreters import ad, batching, mlir, xla @@ -70,9 +71,9 @@ def num_unique_ijkl(n): # To support higher angular, first adjust constants in eri.h: MAX_XYZ, MAX_YZ.. # pyscf_mol = get_pyscf_mol("C180-0", "sto-3g") - # pyscf_mol = get_pyscf_mol("C60-Ih", "sto-3g") + pyscf_mol = get_pyscf_mol("C60-Ih", "sto-3g") # pyscf_mol = get_pyscf_mol("O2", "6-31G") - pyscf_mol = get_pyscf_mol("O2", "sto-3g") + # pyscf_mol = get_pyscf_mol("O2", "sto-3g") mol = Mol.from_pyscf_mol(pyscf_mol) cgto = CGTO.from_mol(mol) self.s = angular_stats.angular_static_args(*[cgto.pgto.angular] * 4) @@ -80,13 +81,16 @@ def num_unique_ijkl(n): self.ab_idx_counts = symmetry.get_2c_sym_idx(cgto.n_pgtos) n_2c_idx = len(self.ab_idx_counts) + key = jax.random.PRNGKey(42) + self.Mo_coeff = jax.random.normal(key,(2, self.cgto.n_cgtos, self.cgto.n_cgtos)) + # 4c tensors - ab_idx, counts_ab = self.ab_idx_counts[:, :2], self.ab_idx_counts[:, 2] - self.abab_idx_count = jnp.hstack([ab_idx, ab_idx, - counts_ab[:, None]*counts_ab[:, None]]).astype(int) + # ab_idx, counts_ab = self.ab_idx_counts[:, :2], self.ab_idx_counts[:, 2] + # self.abab_idx_count = jnp.hstack([ab_idx, ab_idx, + # counts_ab[:, None]*counts_ab[:, None]]).astype(int) - num_4c_idx = symmetry.num_unique_ij(n_2c_idx) - self.num_4c_idx = num_4c_idx + # num_4c_idx = symmetry.num_unique_ij(n_2c_idx) + # self.num_4c_idx = num_4c_idx # self.num_4c_idx = num_4c_idx # batch_size: int = 2**23 # i = 0 @@ -164,7 +168,10 @@ def num_unique_ijkl(n): # np.testing.assert_array_equal(self.outshape, out) def test_abab(self) -> None: - compute_hartree_test(self.cgto, self.s) + # print(len(self.abcd_idx_counts)) + # for e in self.abcd_idx_counts[:10000]: + # np.testing.assert_equal(jnp.array(get_symmetry_count(e[0],e[1],e[2],e[3])), e[4]) + compute_hartree_test(self.cgto, self.s, self.Mo_coeff) # pgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(self.s, cgto=False) # pgto_4c_fn_gt = tensorization.tensorize_4c_cgto(electron_repulsion_integral, self.s, cgto=False) # cgto_4c_fn = tensorization.tensorize_4c_cgto_range(eri_fn, s4) @@ -223,8 +230,23 @@ def test_abab(self) -> None: # logging.info(f"block diag (ab|ab) computed, size: {eri_abab.shape}") -def compute_hartree_test(cgto: CGTO, static_args: AngularStats): - pass +def get_symmetry_count(i,j,k,l): + ret = 1 + if i==k and j==l: + if i != j: + ret *= 4 + else: + ret *= 2 + if i == j: + if k != l: + ret *= 2 + else: + ret *= 2 + if k != l: + ret *= 2 + return ret + +def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin): l_xyz = jnp.sum(cgto.pgto.angular, 1) orig_idx = jnp.argsort(l_xyz) @@ -245,6 +267,15 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats): max_cd = jnp.array(static_args.max_cd, dtype=jnp.int32) Ms = jnp.array([static_args.max_xyz+1, static_args.max_yz+1, static_args.max_z+1], dtype=jnp.int32) + pgto_coeff = jnp.array(cgto.coeff[orig_idx]) + pgto_normalization_factor = jnp.array(cgto.N[orig_idx]) + + cgto_seg_idx = jnp.cumsum(jnp.array(cgto.cgto_splits)) + pgto_idx_to_cgto_idx = jnp.array(jnp.argmax(orig_idx[:, None] < cgto_seg_idx, axis=-1),dtype=jnp.int32) + + rdm1 = get_rdm1(Mo_coeff_spin).sum(0).flatten() + + ab_idx_counts = symmetry.get_2c_sym_idx(cgto.n_pgtos) rank_ab_idx = jnp.arange(ab_idx_counts.shape[0]) ss_mask = (ab_idx_counts[:, 1] < s_num) @@ -260,30 +291,6 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats): pp_idx = rank_ab_idx[pp_mask] pd_idx = rank_ab_idx[pd_mask] dd_idx = rank_ab_idx[dd_mask] - - # ab_idx_counts = jnp.vstack([ab_idx_counts[ss_mask], ab_idx_counts[sp_mask], ab_idx_counts[sd_mask], - # ab_idx_counts[pp_mask], ab_idx_counts[pd_mask], ab_idx_counts[dd_mask]]) - - # ss_num = jnp.count_nonzero(ss_mask) - # sp_num = jnp.count_nonzero(sp_mask) - # sd_num = jnp.count_nonzero(sd_mask) - # pp_num = jnp.count_nonzero(pp_mask) - # pd_num = jnp.count_nonzero(pd_mask) - # dd_num = jnp.count_nonzero(dd_mask) - # ss_start = 0 - # ss_end = ss_start + ss_num - # sp_start = ss_end - # sp_end = sp_start + sp_num - # sd_start = sp_end - # sd_end = sd_start + sd_num - # pp_start = sd_end - # pp_end = pp_start + pp_num - # pd_start = pp_end - # pd_end = pd_start + pd_num - # dd_start = pd_end - # dd_end = dd_start + dd_num - # ab_range = jnp.array([[ss_start, sp_start, sd_start, pp_start, pd_start, dd_start], - # [ss_end, sp_end, sd_end, pp_end, pd_end, dd_end]],dtype=jnp.int32) ab_idx, counts_ab = ab_idx_counts[:, :2], ab_idx_counts[:, 2] abab_idx_counts = jnp.hstack([ab_idx, ab_idx, @@ -306,32 +313,58 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats): eri_abab[sorted_idx[4]], eri_abab[sorted_idx[5]]] + + # ss,ss # for (ss, ss) (pp, pp) (dd, dd), (sp, sp) ... need ensure idx > cnt. For anyone else, no need + import time + i = 1 + j = 1 eps = 1e-10 - sorted_ab_idx = sorted_idx[0] - sorted_cd_idx = sorted_idx[0] - sorted_eri_abab = sorted_eri[0] - sorted_eri_cdcd = sorted_eri[0] + sorted_ab_idx = sorted_idx[i] + sorted_cd_idx = sorted_idx[j] + + # if len(sorted_ab_idx) == 0 or len(sorted_cd_idx) == 0: + # continue + + sorted_eri_abab = sorted_eri[i] + sorted_eri_cdcd = sorted_eri[j] sorted_ab_thres = (eps / jnp.sqrt(sorted_eri_abab))**2 cnt = jnp.array([e for e in range(len(sorted_eri_abab))]) cd_idx = jnp.searchsorted(sorted_eri_cdcd, sorted_ab_thres) - cd_idx = jnp.maximum(cnt, cd_idx) + if i == j: + cd_idx = jnp.maximum(cnt, cd_idx) cdcd_len = len(sorted_eri_cdcd) start_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(cdcd_len-cd_idx)[:-1]), dtype=jnp.int32) screened_cnt = jnp.sum(cdcd_len-cd_idx) - output = hartree(jnp.array([N], dtype=jnp.int32), jnp.array([screened_cnt], dtype=jnp.int32), + + # if screened_cnt <= 0: + # continue + + t1 = time.time() + output = hartree(N, jnp.array([screened_cnt], dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms, jnp.array(sorted_ab_idx, dtype=jnp.int32), jnp.array(sorted_cd_idx, dtype=jnp.int32), jnp.array(cd_idx, dtype=jnp.int32), - jnp.array(start_offset, dtype=jnp.int32)) - - # print(s_num, p_num) + jnp.array(start_offset, dtype=jnp.int32), + pgto_coeff, + pgto_normalization_factor, + pgto_idx_to_cgto_idx, + rdm1, + jnp.array([cgto.n_cgtos], dtype=jnp.int32)) + jax.block_until_ready(output) + t2 = time.time() + print(t2 - t1) + print(len(sorted_ab_idx)*(len(sorted_ab_idx)+1)/2) + print(screened_cnt) + print(output) + # # print(s_num, p_num) # abcd_idx = output[:2*screened_cnt].reshape((2,screened_cnt)) - # print(abcd_idx[:,-100:]) - - + # # print(abcd_idx[:,-100:]) + # for i in range(100): + # print(ab_idx[abcd_idx[0,i]], symmetry.get_triu_ij_from_idx(N,abcd_idx[0,i])) + # print(ab_idx[abcd_idx[1,i]], symmetry.get_triu_ij_from_idx(N,abcd_idx[1,i]))