Skip to content

Commit

Permalink
Hartree kernel in cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLeeJSL committed Nov 24, 2023
1 parent fbe1c01 commit dbb3a52
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 91 deletions.
6 changes: 6 additions & 0 deletions d4ft/integral/gto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,15 @@ py_library(
py_library(
name = "tensorization",
srcs = ["tensorization.py"],
data = [
"//d4ft/native/obara_saika:eri_kernel.so",
"@cuda//:bin",
],
deps = [
":cgto",
"//d4ft:types",
"symmetry",
"//d4ft/native/xla:custom_call",
],
)

Expand Down
56 changes: 56 additions & 0 deletions d4ft/integral/gto/tensorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@
from d4ft.integral.gto import symmetry
from d4ft.integral.gto.cgto import CGTO, PGTO
from d4ft.types import IdxCount2C, IdxCount4C
from d4ft.native.obara_saika.eri_kernel import _Hartree_32, _Hartree_64
from d4ft.native.xla.custom_call import CustomCallMeta

Hartree_64 = CustomCallMeta("Hartree_64", (_Hartree_64,), {})
Hartree_32 = CustomCallMeta("Hartree_32", (_Hartree_32,), {})
if jax.config.jax_enable_x64:
hartree = Hartree_64()
else:
hartree = Hartree_32()

def tensorize_2c_cgto(f: Callable, static_args, cgto: bool = True):
"""2c centers tensorization with provided index set,
Expand Down Expand Up @@ -105,6 +113,54 @@ def tensorize(

return tensorize

def tensorize_4c_cgto_cuda(static_args, cgto: bool = True):
"""4c centers tensorization with provided index set.
where the tensor is contracted to cgto.
Used for incore/precompute.
Args:
cgto: if True, contract the tensor into cgto basis
"""

@partial(jax.jit, static_argnames=["n_segs"])
def tensorize(
gtos: CGTO,
idx_counts: IdxCount4C,
cgto_seg_id,
n_segs: int,
):
Ns = gtos.N
N = gtos.n_pgtos
# Why: Reshape n r z to 1D will significantly reduce computing time
n = jnp.array(gtos.pgto.angular.T, dtype=jnp.int32)
r = jnp.array(gtos.pgto.center.T)
z = jnp.array(gtos.pgto.exponent)
min_a = jnp.array(static_args.min_a, dtype=jnp.int32)
min_c = jnp.array(static_args.min_c, dtype=jnp.int32)
max_ab = jnp.array(static_args.max_ab, dtype=jnp.int32)
max_cd = jnp.array(static_args.max_cd, dtype=jnp.int32)
Ms = jnp.array([static_args.max_xyz+1, static_args.max_yz+1, static_args.max_z+1], dtype=jnp.int32)
abcd_idx = idx_counts[:, :4]
gtos_abcd, coeffs_abcd = zip(
*[
gtos.map_pgto_params(lambda gto_param, i=i: gto_param[abcd_idx[:, i]])
for i in range(4)
]
)
t_abcd = hartree(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
if not cgto:
return t_abcd
counts_abcd_i = idx_counts[:, 4]
N_abcd = Ns[abcd_idx].prod(-1) * counts_abcd_i
abcd = jnp.einsum("k,k,k,k,k,k->k", t_abcd, N_abcd, *coeffs_abcd)
cgto_abcd = jax.ops.segment_sum(abcd, cgto_seg_id, n_segs)
print(cgto_abcd.shape)
print(cgto_seg_id.shape)
print(n_segs)
return cgto_abcd

return tensorize


def tensorize_4c_cgto_range(f: Callable, static_args, cgto: bool = True):
"""Currently not used.
Expand Down
4 changes: 2 additions & 2 deletions d4ft/integral/obara_saika/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def ext_fn(a, b, static_args):
# counts_ab[:, None]]).astype(int)

# pgto_4c_fn = tensorization.tensorize_4c_cgto(eri_fn, s4, cgto=False)
cgto_4c_fn = tensorization.tensorize_4c_cgto(eri_fn, s4)

# cgto_4c_fn = tensorization.tensorize_4c_cgto(eri_fn, s4)
cgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(s4)
# cgto_4c_fn = tensorization.tensorize_4c_cgto_range(eri_fn, s4)

# NOTE: these are only needed for prescreening
Expand Down
50 changes: 40 additions & 10 deletions d4ft/native/obara_saika/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@pip_requirements//:requirements.bzl", "requirement")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_cuda//cuda:defs.bzl", "cuda_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(default_visibility = ["//visibility:public"])

Expand Down Expand Up @@ -42,8 +44,21 @@ cc_library(
],
)


cc_library(
name = "eri_kernel_h",
hdrs = ["eri_kernel.h"],
copts = [
"--std=c++17",
],
deps = [
"@hemi",
"//d4ft/native/xla:specs",
],
)

cuda_library(
name = "eri_kernel",
name = "eri_kernel_cu",
srcs = ["eri_kernel.cu"],
hdrs = ["eri_kernel.h"],
copts = [
Expand All @@ -56,18 +71,19 @@ cuda_library(
deps = [
":eri",
"@hemi",
"//d4ft/native/xla:specs",
],
)

cc_binary(
name = "eri_test",
srcs = ["eri_test.cc"],
deps = [
":eri_kernel",
"@cuda//:cudart_static",
"@hemi",
],
)
# cc_binary(
# name = "eri_test",
# srcs = ["eri_test.cc"],
# deps = [
# ":eri_kernel",
# "@cuda//:cudart_static",
# "@hemi",
# ],
# )

cc_binary(
name = "boys_test",
Expand All @@ -77,3 +93,17 @@ cc_binary(
"@cuda//:cudart_static",
],
)

pybind_extension(
name = "eri_kernel",
srcs = [
"eri_kernel.cc",
],
copts = [
"--std=c++17",
],
deps = [
":eri_kernel_cu",
"//d4ft/native/xla:custom_call_h",
],
)
70 changes: 35 additions & 35 deletions d4ft/native/obara_saika/eri.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@
#define MAX_AB 4

template <typename FLOAT>
HEMI_DEV_CALLABLE void vertical_0_0_c_0(size_t i, const FLOAT* pq,
HEMI_DEV_CALLABLE void vertical_0_0_c_0(int i, const FLOAT* pq,
const FLOAT* qc, FLOAT I[][MAX_XYZ + 1],
FLOAT rho, FLOAT eta,
const size_t* max_cd,
const size_t* Ms) {
const int* max_cd,
const int* Ms) {
FLOAT* I_0_cm2 = I[MAX_CD];
FLOAT* I_0_cm1 = I[0];
FLOAT wq_i = rho * pq[i] / eta;
for (size_t j = 0; j < max_cd[i]; ++j) {
for (int j = 0; j < max_cd[i]; ++j) {
FLOAT cm1 = (FLOAT)j;
FLOAT* I_0_c = I[j + 1];
for (size_t k = 0; k <= Ms[i]; ++k) {
for (int k = 0; k <= Ms[i]; ++k) {
FLOAT I_mp1_k =
wq_i * I_0_cm1[k] + (-cm1 / 2 / eta * rho / eta) * I_0_cm2[k];
I_0_c[k] = qc[i] * I_0_cm1[k] + (cm1 / 2 / eta) * I_0_cm2[k];
Expand All @@ -56,31 +56,31 @@ HEMI_DEV_CALLABLE void vertical_0_0_c_0(size_t i, const FLOAT* pq,

template <typename FLOAT>
HEMI_DEV_CALLABLE void vertical_a_0_c_0(
size_t i, FLOAT I[][MAX_XYZ + 1], const FLOAT* ab, const FLOAT* cd,
int i, FLOAT I[][MAX_XYZ + 1], const FLOAT* ab, const FLOAT* cd,
const FLOAT* pa, const FLOAT* pq, FLOAT rho, FLOAT zeta, FLOAT eta,
const size_t* na, const size_t* nb, const size_t* nc, const size_t* nd,
const size_t* Ms, const size_t* min_a, const size_t* min_c,
const size_t* max_ab, const size_t* max_cd, FLOAT* out) {
const int* na, const int* nb, const int* nc, const int* nd,
const int* Ms, const int* min_a, const int* min_c,
const int* max_ab, const int* max_cd, FLOAT* out) {
FLOAT cache[MAX_CD + 1][MAX_XYZ + 1] = {0};
FLOAT wa[MAX_AB + 1], wc[MAX_CD + 1];
for (size_t j = 0; j <= max_ab[i]; ++j) {
for (int j = 0; j <= max_ab[i]; ++j) {
FLOAT mask = (FLOAT)(j >= na[i] && j <= na[i] + nb[i]);
wa[j] = mask * HEMI_CONSTANT(comb)[nb[i]][j - na[i]] *
pow(ab[i], nb[i] - j + na[i]);
wa[j] = mask * HEMI_CONSTANT(comb)[nb[i]][std::max(j - na[i],0)] *
pow(ab[i], std::max(nb[i] - j + na[i],0));
}
for (size_t j = 0; j <= max_cd[i]; ++j) {
FLOAT mask = (FLOAT)(j >= nc[i] && j <= nc[i] + nd[i]);
wc[j] = mask * HEMI_CONSTANT(comb)[nd[i]][j - nc[i]] *
pow(cd[i], nd[i] - j + nc[i]);
for (int j = 0; j <= max_cd[i]; ++j) {
FLOAT mask = (FLOAT)(j >= nc[i] && j <= nc[i] + nd[i]); // here
wc[j] = mask * HEMI_CONSTANT(comb)[nd[i]][std::max(j - nc[i],0)] *
pow(cd[i], std::max(nd[i] - j + nc[i], 0));
}
FLOAT(*I_am2)[MAX_XYZ + 1] = cache;
FLOAT(*I_am1)[MAX_XYZ + 1] = I;
FLOAT(*I_a)[MAX_XYZ + 1] = I_am2;
for (size_t j = 0; j <= max_ab[i]; ++j) {
for (int j = 0; j <= max_ab[i]; ++j) {
FLOAT am1 = (FLOAT)j;
FLOAT wp_i = -rho * pq[i] / zeta;
for (size_t k = 0; k <= max_cd[i]; ++k) {
for (size_t l = 0; l <= Ms[i]; ++l) {
for (int k = 0; k <= max_cd[i]; ++k) {
for (int l = 0; l <= Ms[i]; ++l) {
FLOAT I_mp1_kl =
wp_i * I_am1[k][l] + (-am1 / 2 / zeta * rho / zeta) * I_am2[k][l];
// inplace write
Expand All @@ -90,15 +90,15 @@ HEMI_DEV_CALLABLE void vertical_a_0_c_0(
}
}
}
for (size_t k = 0; k < max_cd[i]; ++k) {
for (size_t l = 1; l <= Ms[i]; ++l) {
for (int k = 0; k < max_cd[i]; ++k) {
for (int l = 1; l <= Ms[i]; ++l) {
I_a[k + 1][l - 1] += (k + 1) * I_am1[k][l] / 2 / (zeta + eta);
}
}
if (j >= min_a[i]) {
FLOAT(*I_j)[MAX_XYZ + 1] = I_am1;
for (size_t k = min_c[i]; k <= max_cd[i]; ++k) {
for (size_t l = 0; l <= Ms[i + 1]; ++l) {
for (int k = min_c[i]; k <= max_cd[i]; ++k) {
for (int l = 0; l <= Ms[i]; ++l) { // Ms[i+1] -> Ms[i]
out[l] += wa[j] * I_j[k][l] * wc[k];
}
}
Expand Down Expand Up @@ -138,7 +138,7 @@ HEMI_DEV_CALLABLE FLOAT T(FLOAT rho, FLOAT* pq) {
template <typename FLOAT>
HEMI_DEV_CALLABLE FLOAT K(FLOAT z1, FLOAT z2, FLOAT* r1, FLOAT* r2) {
FLOAT d_squared = 0;
for (size_t i = 0; i < 3; ++i) {
for (int i = 0; i < 3; ++i) {
d_squared += (r1[i] - r2[i]) * (r1[i] - r2[i]);
}
return std::sqrt((FLOAT)2.) * std::pow(M_PI, (FLOAT)(5. / 4.)) / (z1 + z2) *
Expand All @@ -147,23 +147,23 @@ HEMI_DEV_CALLABLE FLOAT K(FLOAT z1, FLOAT z2, FLOAT* r1, FLOAT* r2) {

template <typename FLOAT>
HEMI_DEV_CALLABLE FLOAT
eri(size_t nax, size_t nay, size_t naz, size_t nbx, size_t nby, size_t nbz,
size_t ncx, size_t ncy, size_t ncz, size_t ndx, size_t ndy, size_t ndz,
eri(int nax, int nay, int naz, int nbx, int nby, int nbz,
int ncx, int ncy, int ncz, int ndx, int ndy, int ndz,
FLOAT rax, FLOAT ray, FLOAT raz, FLOAT rbx, FLOAT rby, FLOAT rbz, FLOAT rcx,
FLOAT rcy, FLOAT rcz, FLOAT rdx, FLOAT rdy, FLOAT rdz, FLOAT za, FLOAT zb,
FLOAT zc, FLOAT zd, const size_t* min_a, const size_t* min_c,
const size_t* max_ab, const size_t* max_cd, const size_t* Ms) {
size_t na[3] = {nax, nay, naz};
size_t nb[3] = {nbx, nby, nbz};
size_t nc[3] = {ncx, ncy, ncz};
size_t nd[3] = {ndx, ndy, ndz};
FLOAT zc, FLOAT zd, const int* min_a, const int* min_c,
const int* max_ab, const int* max_cd, const int* Ms) {
int na[3] = {nax, nay, naz};
int nb[3] = {nbx, nby, nbz};
int nc[3] = {ncx, ncy, ncz};
int nd[3] = {ndx, ndy, ndz};
FLOAT ra[3] = {rax, ray, raz};
FLOAT rb[3] = {rbx, rby, rbz};
FLOAT rc[3] = {rcx, rcy, rcz};
FLOAT rd[3] = {rdx, rdy, rdz};

FLOAT rp_[3], rq_[3], pa_[3], pb_[3], qc_[3], qd_[3], ab_[3], cd_[3], pq_[3];
for (size_t i = 0; i < 3; ++i) {
for (int i = 0; i < 3; ++i) {
rp_[i] = rp(ra[i], rb[i], za, zb);
rq_[i] = rp(rc[i], rd[i], zc, zd);
pa_[i] = rp_[i] - ra[i];
Expand Down Expand Up @@ -194,7 +194,7 @@ eri(size_t nax, size_t nay, size_t naz, size_t nbx, size_t nby, size_t nbz,
// reset I to zero
std::memset(I, 0, sizeof(I));
// set I[0] to out;
for (size_t i = 0; i <= Ms[1]; ++i) {
for (int i = 0; i <= Ms[1]; ++i) {
I[0][i] = out[i];
}
// set out[:MAX_Z+1] = {0};
Expand All @@ -205,7 +205,7 @@ eri(size_t nax, size_t nay, size_t naz, size_t nbx, size_t nby, size_t nbz,
// reset I to zero
std::memset(I, 0, sizeof(I));
// set I[0] to out;
for (size_t i = 0; i <= Ms[2]; ++i) {
for (int i = 0; i <= Ms[2]; ++i) {
I[0][i] = out[i];
}
// set out[:1] = {0};
Expand Down
11 changes: 11 additions & 0 deletions d4ft/native/obara_saika/eri_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "eri_kernel.h"

#include "d4ft/native/xla/custom_call.h"


PYBIND11_MODULE(eri_kernel, m) {
// py::class_<Parent>(m, "Parent").def(py::init<>());
REGISTER_XLA_FUNCTION(m, Hartree_32);
REGISTER_XLA_FUNCTION(m, Hartree_64);
// REGISTER_XLA_MEMBER(m, Parent, ExampleMember);
}
Loading

0 comments on commit dbb3a52

Please sign in to comment.