Skip to content

Commit

Permalink
finish contraction, need A100 test
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLeeJSL committed Dec 4, 2023
1 parent 96b5e09 commit 6554d67
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 55 deletions.
54 changes: 51 additions & 3 deletions d4ft/native/obara_saika/eri_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,28 @@ HEMI_DEV_CALLABLE void triu_ij_from_index(int n, int index, int *i,
*j = j_;
}

HEMI_DEV_CALLABLE void get_symmetry_count(int i, int j, int k, int l, int *count) {
int count_ = 1;
if(i == k & j == l){
if(i != j){
count_ *= 4;
}
} else{
count_ *= 2;
if(i == j){
if(k != l){
count_ *= 2;
}
} else{
count_ *= 2;
if(k != l){
count_ *= 2;
}
}
}
*count = count_;
}

// template <typename FLOAT>
void Hartree_32::Gpu(cudaStream_t stream,
Array<const int>& N,
Expand Down Expand Up @@ -74,7 +96,12 @@ void Hartree_64::Gpu(cudaStream_t stream,
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
Array<int>& output) {
Array<const double>& pgto_coeff,
Array<const double>& pgto_normalization_factor,
Array<const int>& pgto_idx_to_cgto_idx,
Array<const double>& rdm1,
Array<const int>& n_cgto,
Array<double>& output) {
// Prescreening
int* idx_4c;
int idx_length;
Expand All @@ -92,8 +119,6 @@ void Hartree_64::Gpu(cudaStream_t stream,
loc = screened_idx_offset.ptr[index] + i - screened_cd_idx_start.ptr[index];
idx_4c[loc] = sorted_ab_idx.ptr[index]; // ab
idx_4c[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
output.ptr[loc] = sorted_ab_idx.ptr[index]; // ab
output.ptr[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
}
__syncthreads();
});
Expand All @@ -104,8 +129,28 @@ void Hartree_64::Gpu(cudaStream_t stream,
int a, b, c, d; // pgto 4c idx
int i, j, k, l; // cgto 4c idx
double eri_result;
double Na, Nb, Nc, Nd;
double Ca, Cb, Cc, Cd;
double Mab, Mcd;
int count;
triu_ij_from_index(N.ptr[0], idx_4c[index], &a, &b);
triu_ij_from_index(N.ptr[0], idx_4c[index + screened_length.ptr[0]], &c, &d);
get_symmetry_count(a, b, c, d, &count);
double dcount = static_cast<double>(count);
Ca = pgto_coeff.ptr[a];
Cb = pgto_coeff.ptr[b];
Cc = pgto_coeff.ptr[c];
Cd = pgto_coeff.ptr[d];
Na = pgto_normalization_factor.ptr[a];
Nb = pgto_normalization_factor.ptr[b];
Nc = pgto_normalization_factor.ptr[c];
Nd = pgto_normalization_factor.ptr[d];
i = pgto_idx_to_cgto_idx.ptr[a];
j = pgto_idx_to_cgto_idx.ptr[b];
k = pgto_idx_to_cgto_idx.ptr[c];
l = pgto_idx_to_cgto_idx.ptr[d];
Mab = rdm1.ptr[i*n_cgto.ptr[0] + j];
Mcd = rdm1.ptr[k*n_cgto.ptr[0] + k];
eri_result = eri<double>(n.ptr[0 * N.ptr[0] + a], n.ptr[1 * N.ptr[0] + a], n.ptr[2 * N.ptr[0] + a], // a
n.ptr[0 * N.ptr[0] + b], n.ptr[1 * N.ptr[0] + b], n.ptr[2 * N.ptr[0] + b], // b
n.ptr[0 * N.ptr[0] + c], n.ptr[1 * N.ptr[0] + c], n.ptr[2 * N.ptr[0] + c], // c
Expand All @@ -116,6 +161,9 @@ void Hartree_64::Gpu(cudaStream_t stream,
r.ptr[0 * N.ptr[0] + d], r.ptr[1 * N.ptr[0] + d], r.ptr[2 * N.ptr[0] + d], // d
z.ptr[a], z.ptr[b], z.ptr[c], z.ptr[d], // z
min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr);
eri_result = eri_result * dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd;
// prod result from rdm1
atomicAdd(output.ptr, eri_result);
});

// std::cout<<index_4c.spec->shape[0]<<std::endl;
Expand Down
27 changes: 21 additions & 6 deletions d4ft/native/obara_saika/eri_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,18 @@ class Hartree_64 {
const Spec<int>& shape11,
const Spec<int>& shape12,
const Spec<int>& shape13,
const Spec<int>& shape14) {
const Spec<int>& shape14,
const Spec<double>& shape15,
const Spec<double>& shape16,
const Spec<int>& shape17,
const Spec<double>& shape18,
const Spec<int>& shape19) {
// double n2 = shape4.shape[0]*(shape4.shape[0]+1)/2;
// double n4 = n2*(n2+1)/2;
// int n4_int = static_cast<int>(n4);
std::vector<int> outshape={2*shape11.shape[0]*shape12.shape[0]};
// std::vector<int> outshape={shape1.shape[0]};
Spec<int> out(outshape);
// std::vector<int> outshape={2*shape11.shape[0]*shape12.shape[0]};
std::vector<int> outshape={shape1.shape[0]};
Spec<double> out(outshape);
return std::make_tuple(out);
}
// static void Cpu(Array<const float>& arg1, Array<const int>& arg2,
Expand All @@ -128,7 +133,12 @@ class Hartree_64 {
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
Array<int>& output){
Array<const double>& pgto_coeff,
Array<const double>& pgto_normalization_factor,
Array<const int>& pgto_idx_to_cgto_idx,
Array<const double>& rdm1,
Array<const int>& n_cgto,
Array<double>& output){
// std::memcpy(output.ptr, outshape.ptr, sizeof(float) * outshape.spec->Size());
}

Expand All @@ -148,7 +158,12 @@ class Hartree_64 {
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
Array<int>& output);
Array<const double>& pgto_coeff,
Array<const double>& pgto_normalization_factor,
Array<const int>& pgto_idx_to_cgto_idx,
Array<const double>& rdm1,
Array<const int>& n_cgto,
Array<double>& output);
};

class Hartree_32_uncontracted {
Expand Down
Binary file modified d4ft/native/obara_saika/eri_kernel.so
Binary file not shown.
125 changes: 79 additions & 46 deletions tests/native/xla/eri_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from d4ft.integral.gto import symmetry, tensorization
from copy import deepcopy
from d4ft.types import AngularStats, CGTOSymTensorIncore, Tensor2C, Tensor4C
from d4ft.utils import get_rdm1
# from obsa.obara_saika import get_coulomb, get_kinetic, get_nuclear, get_overlap

# from jax.interpreters import ad, batching, mlir, xla
Expand Down Expand Up @@ -70,23 +71,26 @@ def num_unique_ijkl(n):

# To support higher angular, first adjust constants in eri.h: MAX_XYZ, MAX_YZ..
# pyscf_mol = get_pyscf_mol("C180-0", "sto-3g")
# pyscf_mol = get_pyscf_mol("C60-Ih", "sto-3g")
pyscf_mol = get_pyscf_mol("C60-Ih", "sto-3g")
# pyscf_mol = get_pyscf_mol("O2", "6-31G")
pyscf_mol = get_pyscf_mol("O2", "sto-3g")
# pyscf_mol = get_pyscf_mol("O2", "sto-3g")
mol = Mol.from_pyscf_mol(pyscf_mol)
cgto = CGTO.from_mol(mol)
self.s = angular_stats.angular_static_args(*[cgto.pgto.angular] * 4)
self.cgto = cgto
self.ab_idx_counts = symmetry.get_2c_sym_idx(cgto.n_pgtos)
n_2c_idx = len(self.ab_idx_counts)

key = jax.random.PRNGKey(42)
self.Mo_coeff = jax.random.normal(key,(2, self.cgto.n_cgtos, self.cgto.n_cgtos))

# 4c tensors
ab_idx, counts_ab = self.ab_idx_counts[:, :2], self.ab_idx_counts[:, 2]
self.abab_idx_count = jnp.hstack([ab_idx, ab_idx,
counts_ab[:, None]*counts_ab[:, None]]).astype(int)
# ab_idx, counts_ab = self.ab_idx_counts[:, :2], self.ab_idx_counts[:, 2]
# self.abab_idx_count = jnp.hstack([ab_idx, ab_idx,
# counts_ab[:, None]*counts_ab[:, None]]).astype(int)

num_4c_idx = symmetry.num_unique_ij(n_2c_idx)
self.num_4c_idx = num_4c_idx
# num_4c_idx = symmetry.num_unique_ij(n_2c_idx)
# self.num_4c_idx = num_4c_idx
# self.num_4c_idx = num_4c_idx
# batch_size: int = 2**23
# i = 0
Expand Down Expand Up @@ -164,7 +168,10 @@ def num_unique_ijkl(n):
# np.testing.assert_array_equal(self.outshape, out)

def test_abab(self) -> None:
compute_hartree_test(self.cgto, self.s)
# print(len(self.abcd_idx_counts))
# for e in self.abcd_idx_counts[:10000]:
# np.testing.assert_equal(jnp.array(get_symmetry_count(e[0],e[1],e[2],e[3])), e[4])
compute_hartree_test(self.cgto, self.s, self.Mo_coeff)
# pgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(self.s, cgto=False)
# pgto_4c_fn_gt = tensorization.tensorize_4c_cgto(electron_repulsion_integral, self.s, cgto=False)
# cgto_4c_fn = tensorization.tensorize_4c_cgto_range(eri_fn, s4)
Expand Down Expand Up @@ -223,8 +230,23 @@ def test_abab(self) -> None:

# logging.info(f"block diag (ab|ab) computed, size: {eri_abab.shape}")

def compute_hartree_test(cgto: CGTO, static_args: AngularStats):
pass
def get_symmetry_count(i,j,k,l):
ret = 1
if i==k and j==l:
if i != j:
ret *= 4
else:
ret *= 2
if i == j:
if k != l:
ret *= 2
else:
ret *= 2
if k != l:
ret *= 2
return ret

def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin):
l_xyz = jnp.sum(cgto.pgto.angular, 1)
orig_idx = jnp.argsort(l_xyz)

Expand All @@ -245,6 +267,15 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats):
max_cd = jnp.array(static_args.max_cd, dtype=jnp.int32)
Ms = jnp.array([static_args.max_xyz+1, static_args.max_yz+1, static_args.max_z+1], dtype=jnp.int32)

pgto_coeff = jnp.array(cgto.coeff[orig_idx])
pgto_normalization_factor = jnp.array(cgto.N[orig_idx])

cgto_seg_idx = jnp.cumsum(jnp.array(cgto.cgto_splits))
pgto_idx_to_cgto_idx = jnp.array(jnp.argmax(orig_idx[:, None] < cgto_seg_idx, axis=-1),dtype=jnp.int32)

rdm1 = get_rdm1(Mo_coeff_spin).sum(0).flatten()


ab_idx_counts = symmetry.get_2c_sym_idx(cgto.n_pgtos)
rank_ab_idx = jnp.arange(ab_idx_counts.shape[0])
ss_mask = (ab_idx_counts[:, 1] < s_num)
Expand All @@ -260,30 +291,6 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats):
pp_idx = rank_ab_idx[pp_mask]
pd_idx = rank_ab_idx[pd_mask]
dd_idx = rank_ab_idx[dd_mask]

# ab_idx_counts = jnp.vstack([ab_idx_counts[ss_mask], ab_idx_counts[sp_mask], ab_idx_counts[sd_mask],
# ab_idx_counts[pp_mask], ab_idx_counts[pd_mask], ab_idx_counts[dd_mask]])

# ss_num = jnp.count_nonzero(ss_mask)
# sp_num = jnp.count_nonzero(sp_mask)
# sd_num = jnp.count_nonzero(sd_mask)
# pp_num = jnp.count_nonzero(pp_mask)
# pd_num = jnp.count_nonzero(pd_mask)
# dd_num = jnp.count_nonzero(dd_mask)
# ss_start = 0
# ss_end = ss_start + ss_num
# sp_start = ss_end
# sp_end = sp_start + sp_num
# sd_start = sp_end
# sd_end = sd_start + sd_num
# pp_start = sd_end
# pp_end = pp_start + pp_num
# pd_start = pp_end
# pd_end = pd_start + pd_num
# dd_start = pd_end
# dd_end = dd_start + dd_num
# ab_range = jnp.array([[ss_start, sp_start, sd_start, pp_start, pd_start, dd_start],
# [ss_end, sp_end, sd_end, pp_end, pd_end, dd_end]],dtype=jnp.int32)

ab_idx, counts_ab = ab_idx_counts[:, :2], ab_idx_counts[:, 2]
abab_idx_counts = jnp.hstack([ab_idx, ab_idx,
Expand All @@ -306,32 +313,58 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats):
eri_abab[sorted_idx[4]],
eri_abab[sorted_idx[5]]]



# ss,ss
# for (ss, ss) (pp, pp) (dd, dd), (sp, sp) ... need ensure idx > cnt. For anyone else, no need
import time
i = 1
j = 1
eps = 1e-10
sorted_ab_idx = sorted_idx[0]
sorted_cd_idx = sorted_idx[0]
sorted_eri_abab = sorted_eri[0]
sorted_eri_cdcd = sorted_eri[0]
sorted_ab_idx = sorted_idx[i]
sorted_cd_idx = sorted_idx[j]

# if len(sorted_ab_idx) == 0 or len(sorted_cd_idx) == 0:
# continue

sorted_eri_abab = sorted_eri[i]
sorted_eri_cdcd = sorted_eri[j]
sorted_ab_thres = (eps / jnp.sqrt(sorted_eri_abab))**2
cnt = jnp.array([e for e in range(len(sorted_eri_abab))])
cd_idx = jnp.searchsorted(sorted_eri_cdcd, sorted_ab_thres)
cd_idx = jnp.maximum(cnt, cd_idx)
if i == j:
cd_idx = jnp.maximum(cnt, cd_idx)
cdcd_len = len(sorted_eri_cdcd)
start_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(cdcd_len-cd_idx)[:-1]), dtype=jnp.int32)
screened_cnt = jnp.sum(cdcd_len-cd_idx)
output = hartree(jnp.array([N], dtype=jnp.int32), jnp.array([screened_cnt], dtype=jnp.int32),

# if screened_cnt <= 0:
# continue

t1 = time.time()
output = hartree(N, jnp.array([screened_cnt], dtype=jnp.int32),
n, r, z, min_a, min_c, max_ab, max_cd, Ms,
jnp.array(sorted_ab_idx, dtype=jnp.int32),
jnp.array(sorted_cd_idx, dtype=jnp.int32),
jnp.array(cd_idx, dtype=jnp.int32),
jnp.array(start_offset, dtype=jnp.int32))

# print(s_num, p_num)
jnp.array(start_offset, dtype=jnp.int32),
pgto_coeff,
pgto_normalization_factor,
pgto_idx_to_cgto_idx,
rdm1,
jnp.array([cgto.n_cgtos], dtype=jnp.int32))
jax.block_until_ready(output)
t2 = time.time()
print(t2 - t1)
print(len(sorted_ab_idx)*(len(sorted_ab_idx)+1)/2)
print(screened_cnt)
print(output)
# # print(s_num, p_num)
# abcd_idx = output[:2*screened_cnt].reshape((2,screened_cnt))
# print(abcd_idx[:,-100:])


# # print(abcd_idx[:,-100:])
# for i in range(100):
# print(ab_idx[abcd_idx[0,i]], symmetry.get_triu_ij_from_idx(N,abcd_idx[0,i]))
# print(ab_idx[abcd_idx[1,i]], symmetry.get_triu_ij_from_idx(N,abcd_idx[1,i]))



Expand Down

0 comments on commit 6554d67

Please sign in to comment.