Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Szkered committed Aug 15, 2023
1 parent 425c146 commit a9e26d4
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 31 deletions.
211 changes: 182 additions & 29 deletions d4ft/integral/gto/cgto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@
from __future__ import annotations # forward declaration

import math
from typing import Callable, NamedTuple, Optional, Sequence, Tuple, Union
from typing import (
Callable,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
Literal,
)

import haiku as hk
import jax
Expand Down Expand Up @@ -101,6 +110,94 @@ def get_cgto_segment_id(cgto_splits: tuple) -> Int[Array, "n_pgtos"]:
return seg_id


def reparameterize(
cgto: CGTO, optim_exp: bool, optim_coeff: bool
) -> Tuple[Float[Array, "*n_pgtos"], Float[Array, "*n_pgtos"]]:
pgtos = []
coeffs = []

# iter atoms
for i, element_basis in enumerate(cgto.basis):
coord = cgto.atom_coords[i]
# iter shells
for cgto_i in element_basis:
total_angular = cgto_i[0] # l/shell
assert not isinstance(
cgto_i[1], float
), "basis with kappa is not supported yet"
pgtos_i = cgto_i[1:] # [[exp, c_1, c_2, ..], [exp, c_1, c_2, ..], ... ]]
n_coeffs = len(pgtos_i[0][1:])
# TODO: do not create separate PGTO for each c_1, c_2, ...
# iter contractions
for cid in range(1, 1 + n_coeffs): # 0-idx is exponent
# iter cartesian monomials for the given shell
pgtos_monomials = []
coeffs_monomials = []
pgtos_sph = []
coeffs_sph = []
for angular in get_cart_angular_vec(total_angular):
# iter PGTOs in the given CGTO. If n_coeffs > 1, we have
# the general contraction, i.e. each PGTO appears in multiple
# CGTOs.
pgto_i_ = []
coeffs_i = []
for pgto_i in pgtos_i:
# NOTE: we want to have some activation function here to make sure
# that exponent > 0. However softplus is not good as inv_softplus
# makes some exponent goes inf
exponent = pgto_i[0]
if optim_exp:
exponent = jax.nn.softplus(
hk.get_parameter(
"exponent",
shape=(),
init=make_constant_fn(inv_softplus(jnp.array(exponent)))
)
)
coeff = pgto_i[cid]
if optim_coeff:
coeff = hk.get_parameter(
"coeff", shape=(), init=make_constant_fn(jnp.array(coeff))
)
pgto_i_.append(PGTO(angular, coord, exponent))
coeffs_i.append(coeff)

# pgto_i = PGTO.apply(np.stack, pgto_i_)
pgto_i = PGTO.stack(pgto_i_)
pgtos_monomials.append(pgto_i)

# normalize each PGTO in cartesian coordinate
N = pgto_i.norm_inv()
normalized_coeffs_i = normalize_cgto_coeff(
pgto_i,
jnp.array(coeffs_i) * N
) / N
coeffs_monomials.append(normalized_coeffs_i)

# convert to spherical
r_l_inv = 1 / racah_norm(total_angular)
shell = Shell(total_angular)
prefacs = REAL_SOLID_HARMONICS_PREFAC[shell]
for mpl, sph in enumerate(MONOMIALS_TO_REAL_SOLID_HARMONICS[shell]):
m = mpl - total_angular # magnetic quantum number
p_m = prefacs[abs(m)]
for monomial_idx, m_prefac in sph:
pgtos_sph.append(pgtos_monomials[monomial_idx])
coeffs_sph.append(
coeffs_monomials[monomial_idx] * m_prefac * p_m * r_l_inv
)
# pgto_sph_ = PGTO.apply(np.concatenate, pgtos_sph)
pgto_sph_ = PGTO.concat(pgtos_sph)
pgtos.append(pgto_sph_)
coeffs.append(jnp.concatenate(coeffs_sph))

# pgto = PGTO.apply(np.concatenate, pgtos)
pgto = PGTO.concat(pgtos)
coeff = jnp.concatenate(coeffs)

return coeff, pgto.exponent


def build_cgto_from_mol(mol: Mol) -> CGTO:
"""Build CGTO from the basis information in Mol.
Expand Down Expand Up @@ -183,7 +280,7 @@ def build_cgto_from_mol(mol: Mol) -> CGTO:
pgto_i_.append(PGTO(angular, coord, exponent))
coeffs_i.append(coeff)

pgto_i = PGTO.apply(np.stack, pgto_i_)
pgto_i = PGTO.stack(pgto_i_)
pgtos_monomials.append(pgto_i)

# normalize each PGTO in cartesian coordinate
Expand Down Expand Up @@ -229,9 +326,18 @@ def build_cgto_from_mol(mol: Mol) -> CGTO:
f"there are {sum(cgto_splits)} (non-unique) PGTOs in spherical form"
)

# map basis into d4ft format

basis = []
for i, element in enumerate(mol.elements):
element_basis = []
for cgto_i in mol.basis[element]:
element_basis.append(cgto_i)
basis.append(element_basis)

cgto = CGTO(
pgto, pgto.norm_inv(), jnp.array(coeffs), cgto_splits, cgto_seg_id,
jnp.array(atom_splits), mol.atom_charges, mol.nocc
jnp.array(atom_splits), mol.atom_charges, mol.nocc, basis
)

return cgto
Expand Down Expand Up @@ -270,6 +376,25 @@ def at(self, i: Union[slice, int]) -> PGTO:
"""get one PGTO out of the batch"""
return PGTO(self.angular[i], self.center[i], self.exponent[i])

@staticmethod
def stack(pgtos: Sequence[PGTO]) -> PGTO:
"""Return the stack of a sequence of PGTO singletons.
Note that the angular needs to be a numpy array to avoid tracing error.
"""
angular, center, exponent = zip(*pgtos)
return PGTO(np.stack(angular), jnp.stack(center), jnp.stack(exponent))

@staticmethod
def concat(pgtos: Sequence[PGTO]) -> PGTO:
"""Return the concatenation of a sequence of PGTO singletons.
Note that the angular needs to be a numpy array to avoid tracing error.
"""
angular, center, exponent = zip(*pgtos)
return PGTO(
np.concatenate(angular), jnp.concatenate(center),
jnp.concatenate(exponent)
)

@staticmethod
def apply(f: Callable, pgtos: Sequence[PGTO]) -> PGTO:
"""Return the concatenation of a sequence of PGTO singletons.
Expand Down Expand Up @@ -324,9 +449,15 @@ class CGTO(NamedTuple):
"""Atom segment lengths. e.g. [15, 15] for O2 in sto-3g.
Useful for copying atom centers to each GTO when doing basis optimization."""
charge: Int[Array, "*n_atoms"]
"""charges of the atoms"""
"""Charges of the atoms"""
nocc: Int[Array, "2 nao"]
"""occupation mask for alpha and beta spin"""
"""Cccupation mask for alpha and beta spin"""
# elements: Sequence[str]
# """List of atoms in the systems"""
# atom_coords: Float[np.ndarray, "n_atoms 3"]
# """List of atoms coordinates in the systems"""
basis: Sequence[Sequence[Tuple[int, Sequence[Sequence[float]]]]]
"""basis in PySCF format"""

@property
def n_pgtos(self) -> int:
Expand Down Expand Up @@ -383,34 +514,56 @@ def from_mol(mol: Mol) -> CGTO:
def from_cart(cgto_cart: CGTO) -> CGTO:
return build_cgto_sph_from_mol(cgto_cart)

def to_hk(self) -> CGTO:
def to_hk(
self,
optimizable_params: Sequence[Literal[
"center",
"exponent",
"coeff",
]] = ["coeff"],
) -> CGTO:
"""Convert optimizable parameters to hk.Params. Must be haiku transformed.
Can be used for basis optimization.
"""
center_init = self.atom_coords
center = hk.get_parameter(
"center", center_init.shape, init=make_constant_fn(center_init)
)
center_rep = jnp.repeat(
center,
jnp.array(self.atom_splits),
axis=0,
total_repeat_length=self.n_pgtos
)
# NOTE: we want to have some activation function here to make sure
# that exponent > 0. However softplus is not good as inv_softplus
# makes some exponent goes inf
exponent = jax.nn.softplus(
hk.get_parameter(
"exponent",
self.pgto.exponent.shape,
init=make_constant_fn(inv_softplus(self.pgto.exponent))
if "center" in optimizable_params:
center_init = self.atom_coords
center_param = hk.get_parameter(
"center", center_init.shape, init=make_constant_fn(center_init)
)
)
coeff = hk.get_parameter(
"coeff", self.coeff.shape, init=make_constant_fn(self.coeff)
)
pgto = PGTO(self.pgto.angular, center_rep, exponent)
center = jnp.repeat(
center_param,
jnp.array(self.atom_splits),
axis=0,
total_repeat_length=self.n_pgtos
)
else:
center = self.pgto.center

if "exponent" in optimizable_params or "coeff" in optimizable_params:
# NOTE: we want to have some activation function here to make sure
# that exponent > 0. However softplus is not good as inv_softplus
# makes some exponent goes inf
# exponent = jax.nn.softplus(
# hk.get_parameter(
# "exponent",
# self.pgto.exponent.shape,
# init=make_constant_fn(inv_softplus(self.pgto.exponent))
# )
# )
# coeff = hk.get_parameter(
# "coeff", self.coeff.shape, init=make_constant_fn(self.coeff)
# )

coeff, exponent = reparameterize(
self,
optim_exp="exponent" in optimizable_params,
optim_coeff="coeff" in optimizable_params,
)
else:
coeff = self.coeff
exponent = self.pgto.exponent

pgto = PGTO(self.pgto.angular, center, exponent)
return self._replace(pgto=pgto, coeff=coeff)

# TODO: instead of using occupation mask, we can orthogonalize a non-square
Expand Down
4 changes: 2 additions & 2 deletions d4ft/solver/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def incore_cgto_direct_opt_dft(
def H_factory() -> Tuple[Callable, Hamiltonian]:
"""Auto-grad scope"""
# TODO: out-of-core + basis optimization
# cgto_hk = cgto.to_hk()
cgto_hk = cgto
cgto_hk = cgto.to_hk(["coeff"])
# cgto_hk = cgto
cgto_intor = get_cgto_intor(
cgto_hk, intor="obsa", incore_energy_tensors=incore_e_tensors
)
Expand Down

0 comments on commit a9e26d4

Please sign in to comment.