Skip to content

Commit

Permalink
integrate compute_hartree into d4ft
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLeeJSL committed Feb 20, 2024
1 parent 100288e commit 3213665
Show file tree
Hide file tree
Showing 4 changed files with 15,404 additions and 46 deletions.
11 changes: 6 additions & 5 deletions d4ft/hamiltonian/cgto_intors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pyscf
from jaxtyping import Array, Float

from d4ft.integral.gto import symmetry
from d4ft.integral.gto import symmetry, tensorization
from d4ft.integral.gto.cgto import CGTO
from d4ft.integral.obara_saika.driver import CGTOSymTensorFns
from d4ft.types import (
Expand Down Expand Up @@ -95,14 +95,15 @@ def ext_fn(mo_coeff: MoCoeff) -> Float[Array, ""]:
# rate = 0.5

def har_fn(mo_coeff: MoCoeff) -> Float[Array, ""]:
rdm1 = get_rdm1(mo_coeff).sum(0) # sum over spin
rdm1_ab = rdm1[mo_abcd_idx_counts[:, 0], mo_abcd_idx_counts[:, 1]]
rdm1_cd = rdm1[mo_abcd_idx_counts[:, 2], mo_abcd_idx_counts[:, 3]]
# rdm1 = get_rdm1(mo_coeff).sum(0) # sum over spin
# rdm1_ab = rdm1[mo_abcd_idx_counts[:, 0], mo_abcd_idx_counts[:, 1]]
# rdm1_cd = rdm1[mo_abcd_idx_counts[:, 2], mo_abcd_idx_counts[:, 3]]
# key = hk.next_rng_key()
# mask = jax.random.bernoulli(key, rate, shape=eri.shape)
# e_har = jnp.sum(eri * mask * rdm1_ab * rdm1_cd) / rate
# NOTE: 0.5 prefactor already included in the eri
e_har = jnp.sum(cgto_e_tensors.eri_abcd * rdm1_ab * rdm1_cd)
# e_har = jnp.sum(cgto_e_tensors.eri_abcd * rdm1_ab * rdm1_cd)
e_har = tensorization.compute_hartree(cgto, mo_coeff)
return e_har

def exc_fn(mo_coeff: MoCoeff) -> Float[Array, ""]:
Expand Down
228 changes: 189 additions & 39 deletions d4ft/integral/gto/tensorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@
from d4ft.types import IdxCount2C, IdxCount4C
from d4ft.native.obara_saika.eri_kernel import _Hartree_32, _Hartree_32_uncontracted, _Hartree_64, _Hartree_64_uncontracted
from d4ft.native.xla.custom_call import CustomCallMeta
from d4ft.utils import get_rdm1
from d4ft.integral import obara_saika as obsa

# from jax.interpreters import ad, batching, mlir, xla
Hartree_64 = CustomCallMeta("Hartree_64", (_Hartree_64,), {})
Hartree_64_uncontracted = CustomCallMeta("Hartree_64_uncontracted", (_Hartree_64_uncontracted,), {})
Hartree_32 = CustomCallMeta("Hartree_32", (_Hartree_32,), {})
Hartree_32_uncontracted = CustomCallMeta("Hartree_32_uncontracted", (_Hartree_32_uncontracted,), {})
if jax.config.jax_enable_x64:
hartree = Hartree_64()
hartree_uncontracted = Hartree_64_uncontracted()
else:
hartree = Hartree_32()
hartree_uncontracted = Hartree_32_uncontracted()

# if jax.config.jax_enable_x64:
# hartree = Hartree_64()
# hartree_uncontracted = Hartree_64_uncontracted()
# else:
# hartree = Hartree_32()
# hartree_uncontracted = Hartree_32_uncontracted()
hartree = Hartree_64()
hartree_uncontracted = Hartree_64_uncontracted()
def tensorize_2c_cgto(f: Callable, static_args, cgto: bool = True):
"""2c centers tensorization with provided index set,
where the tensor is contracted to cgto.
Expand Down Expand Up @@ -117,59 +121,101 @@ def tensorize(

return tensorize

def tensorize_4c_cgto_cuda(static_args, cgto: bool = True):
"""4c centers tensorization with provided index set.
where the tensor is contracted to cgto.
Used for incore/precompute.
def get_abab_fun(static_args):
"""Get the function to compute ERI for all abab to do pre
screen by cuda.
Args:
cgto: if True, contract the tensor into cgto basis
static_args: statis arguments for orbitals
"""

@partial(jax.jit, static_argnames=["n_segs"])
# @partial(jax.jit, static_argnames=["n_segs"])
def tensorize(
gtos: CGTO,
idx_counts: IdxCount4C,
cgto_seg_id,
n_segs: int,
idx_counts,
orig_idx,
):
N = gtos.n_pgtos
Ns = gtos.N

# Why: Reshape n r z to 1D will significantly reduce computing time
n = jnp.array(gtos.pgto.angular.T, dtype=jnp.int32)
r = jnp.array(gtos.pgto.center.T)
z = jnp.array(gtos.pgto.exponent)
n = jnp.array(gtos.pgto.angular[orig_idx].T, dtype=jnp.int32)
r = jnp.array(gtos.pgto.center[orig_idx].T)
z = jnp.array(gtos.pgto.exponent)[orig_idx]

min_a = jnp.array(static_args.min_a, dtype=jnp.int32)
min_c = jnp.array(static_args.min_c, dtype=jnp.int32)
max_ab = jnp.array(static_args.max_ab, dtype=jnp.int32)
max_cd = jnp.array(static_args.max_cd, dtype=jnp.int32)
Ms = jnp.array([static_args.max_xyz+1, static_args.max_yz+1, static_args.max_z+1], dtype=jnp.int32)
abcd_idx = idx_counts[:, :4]

gtos_abcd, coeffs_abcd = zip(
*[
gtos.map_pgto_params(lambda gto_param, i=i: gto_param[abcd_idx[:, i]])
for i in range(4)
]
)
import time
t1 = time.time()
t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)

har_jit = jax.jit(hartree_uncontracted)
t_abcd = har_jit(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
jax.block_until_ready(t_abcd)
t2 = time.time()
print("Current abab cuda time =",t2-t1)
if not cgto:
return t_abcd
counts_abcd_i = idx_counts[:, 4]
N_abcd = Ns[abcd_idx].prod(-1) * counts_abcd_i
abcd = jnp.einsum("k,k,k,k,k,k->k", t_abcd, N_abcd, *coeffs_abcd)
cgto_abcd = jax.ops.segment_sum(abcd, cgto_seg_id, n_segs)
return cgto_abcd
return t_abcd

return tensorize

def get_4c_contracted_hartree_fun(static_args):
"""Get the function to compute ERI (contracted)
Args:
static_args: statis arguments for orbitals
"""

# @partial(jax.jit)
def tensorize(
cgto: CGTO,
orig_idx,
sorted_ab_idx,
sorted_cd_idx,
screened_cd_idx_start,
# start_offset,
screened_cnt,
pgto_idx_to_cgto_idx,
rdm1,
thread_load,
thread_num,
ab_thread_num,
ab_thread_offset
):
N = jnp.array([cgto.n_pgtos], dtype=jnp.int32)
n = jnp.array(cgto.pgto.angular[orig_idx].T, dtype=jnp.int32)
r = jnp.array(cgto.pgto.center[orig_idx].T)
z = jnp.array(cgto.pgto.exponent)[orig_idx]

min_a = jnp.array(static_args.min_a, dtype=jnp.int32)
min_c = jnp.array(static_args.min_c, dtype=jnp.int32)
max_ab = jnp.array(static_args.max_ab, dtype=jnp.int32)
max_cd = jnp.array(static_args.max_cd, dtype=jnp.int32)
Ms = jnp.array([static_args.max_xyz+1, static_args.max_yz+1, static_args.max_z+1], dtype=jnp.int32)

pgto_coeff = jnp.array(cgto.coeff[orig_idx])
pgto_normalization_factor = jnp.array(cgto.N[orig_idx])

har_jit = jax.jit(hartree)

output = har_jit(N,
jnp.array([thread_load], dtype=jnp.int32),
jnp.array([thread_num], dtype=jnp.int64),
jnp.array([screened_cnt], dtype=jnp.int64),
n, r, z, min_a, min_c, max_ab, max_cd, Ms,
jnp.array(sorted_ab_idx, dtype=jnp.int32),
jnp.array(sorted_cd_idx, dtype=jnp.int32),
jnp.array(screened_cd_idx_start, dtype=jnp.int32),
# jnp.array(start_offset, dtype=jnp.int32),
jnp.array(ab_thread_num, dtype=jnp.int32),
jnp.array(ab_thread_offset, dtype=jnp.int32),
pgto_coeff,
pgto_normalization_factor,
pgto_idx_to_cgto_idx,
rdm1,
jnp.array([cgto.n_cgtos], dtype=jnp.int32),
jnp.array([cgto.n_pgtos], dtype=jnp.int32))
jax.block_until_ready(output)
return output

return tensorize

def tensorize_4c_cgto_range(f: Callable, static_args, cgto: bool = True):
"""Currently not used.
Expand Down Expand Up @@ -220,3 +266,107 @@ def tensorize(
return cgto_abcd

return tensorize

def compute_hartree(cgto: CGTO, Mo_coeff_spin, eps = 1e-10, thread_load = 2**10):
"""Compute contracted ERI
Args:
cgto: cgto of molecule
static_args: statis arguments for orbitals
Mo_coeff_spin: molecule coefficients with spin
"""
static_args = obsa.angular_static_args(*[cgto.pgto.angular] * 4)
l_xyz = jnp.sum(cgto.pgto.angular, 1)
orig_idx = jnp.argsort(l_xyz)

ab_idx_counts = symmetry.get_2c_sym_idx(cgto.n_pgtos)
ab_idx, counts_ab = ab_idx_counts[:, :2], ab_idx_counts[:, 2]
abab_idx_counts = jnp.hstack([ab_idx, ab_idx,
counts_ab[:, None]*counts_ab[:, None]]).astype(int)
abab_idx = jnp.array(abab_idx_counts[: ,:4], dtype=jnp.int32)

abab_eri_fun = get_abab_fun(static_args)
abcd_eri_fun = get_4c_contracted_hartree_fun(static_args)
# Compute eri abab
eri_abab = abab_eri_fun(cgto, abab_idx_counts, orig_idx)

eri_abab = jnp.array(eri_abab)

# current support s, p, d
s_num = jnp.count_nonzero(l_xyz == 0)
p_num = jnp.count_nonzero(l_xyz == 1)
d_num = jnp.count_nonzero(l_xyz == 2)

cgto_seg_idx = jnp.cumsum(jnp.array(cgto.cgto_splits))
pgto_idx_to_cgto_idx = jnp.array(jnp.argmax(orig_idx[:, None] < cgto_seg_idx, axis=-1),dtype=jnp.int32)

rdm1 = get_rdm1(Mo_coeff_spin).sum(0).flatten()

sorted_idx = jnp.argsort(eri_abab)
sorted_eri = eri_abab[sorted_idx]

rank_ab_idx = jnp.arange(ab_idx_counts.shape[0])
ss_mask = (ab_idx_counts[:, 1] < s_num)
sp_mask = (ab_idx_counts[:, 1] >= s_num) & (ab_idx_counts[:, 1] < s_num + p_num) & (ab_idx_counts[:, 0] < s_num)
sd_mask = (ab_idx_counts[:, 1] >= s_num + p_num) & (ab_idx_counts[:, 0] < s_num)
pp_mask = (ab_idx_counts[:, 1] < s_num + p_num) & (ab_idx_counts[:, 0] >= s_num)
pd_mask = (ab_idx_counts[:, 1] >= s_num + p_num) & (ab_idx_counts[:, 0] >= s_num) & (ab_idx_counts[:, 0] < s_num + p_num)
dd_mask = (ab_idx_counts[:, 0] >= s_num + p_num)

ss_idx = rank_ab_idx[ss_mask]
sp_idx = rank_ab_idx[sp_mask]
sd_idx = rank_ab_idx[sd_mask]
pp_idx = rank_ab_idx[pp_mask]
pd_idx = rank_ab_idx[pd_mask]
dd_idx = rank_ab_idx[dd_mask]

sorted_idx = [ss_idx[jnp.argsort(eri_abab[ss_idx])],
sp_idx[jnp.argsort(eri_abab[sp_idx])],
sd_idx[jnp.argsort(eri_abab[sd_idx])],
pp_idx[jnp.argsort(eri_abab[pp_idx])],
pd_idx[jnp.argsort(eri_abab[pd_idx])],
dd_idx[jnp.argsort(eri_abab[dd_idx])],]
sorted_eri = [eri_abab[sorted_idx[0]],
eri_abab[sorted_idx[1]],
eri_abab[sorted_idx[2]],
eri_abab[sorted_idx[3]],
eri_abab[sorted_idx[4]],
eri_abab[sorted_idx[5]]]


# for (ss, ss) (pp, pp) (dd, dd), (sp, sp) ... need to ensure idx > cnt. For anyone else, no need
output = 0
for i in range(6):
for j in range(i, 6):
sorted_ab_idx = sorted_idx[i]
sorted_cd_idx = sorted_idx[j]
if len(sorted_ab_idx) == 0 or len(sorted_cd_idx) == 0:
continue
sorted_eri_abab = sorted_eri[i]
sorted_eri_cdcd = sorted_eri[j]
sorted_ab_thres = (eps / jnp.sqrt(sorted_eri_abab))**2
screened_cd_idx_start = jnp.searchsorted(sorted_eri_cdcd, sorted_ab_thres)
if i == j:
screened_cd_idx_start = jnp.maximum(jnp.array([e for e in range(len(sorted_eri_abab))]), screened_cd_idx_start)
cdcd_len = len(sorted_eri_cdcd)
# start_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(cdcd_len-screened_cd_idx_start)[:-1]), dtype=jnp.int32)

screened_cnt = jnp.sum(cdcd_len-screened_cd_idx_start)
cdcd_len_list = cdcd_len - screened_cd_idx_start

ab_thread_num = jnp.ceil(cdcd_len_list/thread_load)
thread_num = jnp.sum(ab_thread_num)
ab_thread_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(ab_thread_num)[:-1]), dtype=jnp.int32)
print(i,j,screened_cnt)
output += abcd_eri_fun(cgto, orig_idx,sorted_ab_idx,
sorted_cd_idx,
screened_cd_idx_start,
# start_offset,
jnp.sum(cdcd_len-screened_cd_idx_start),
pgto_idx_to_cgto_idx,
rdm1,
thread_load,
thread_num,
ab_thread_num,
ab_thread_offset,)
return output
9 changes: 7 additions & 2 deletions d4ft/integral/obara_saika/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def ext_fn(a, b, static_args):
# counts_ab[:, None]]).astype(int)

# pgto_4c_fn = tensorization.tensorize_4c_cgto(eri_fn, s4, cgto=False)
# cgto_4c_fn = tensorization.tensorize_4c_cgto(eri_fn, s4)
cgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(s4)
cgto_4c_fn = tensorization.tensorize_4c_cgto(eri_fn, s4)
# cgto_4c_fn = tensorization.compute_hartree
# cgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(s4)
# cgto_4c_fn = tensorization.tensorize_4c_cgto_range(eri_fn, s4)

# NOTE: these are only needed for prescreening
Expand Down Expand Up @@ -142,4 +143,8 @@ def eri_abcd_fn(cgto: CGTO) -> Tensor4C:
eri_abcd_cgto += eri_abcd_i
return eri_abcd_cgto

# def eri_abcd_fn(cgto: CGTO) -> Tensor4C:
# eri_abcd_cgto = partial(tensorization.compute_hartree, cgto, s4)
# return eri_abcd_cgto

return CGTOSymTensorFns(ovlp_ab_fn, kin_ab_fn, ext_ab_fn, eri_abcd_fn)
Loading

0 comments on commit 3213665

Please sign in to comment.