diff --git a/d4ft/integral/gto/cgto.py b/d4ft/integral/gto/cgto.py index 3a5bf94..a5bea82 100644 --- a/d4ft/integral/gto/cgto.py +++ b/d4ft/integral/gto/cgto.py @@ -14,7 +14,16 @@ from __future__ import annotations # forward declaration import math -from typing import Callable, NamedTuple, Optional, Sequence, Tuple, Union +from typing import ( + Callable, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, + Literal, +) import haiku as hk import jax @@ -101,6 +110,94 @@ def get_cgto_segment_id(cgto_splits: tuple) -> Int[Array, "n_pgtos"]: return seg_id +def reparameterize( + cgto: CGTO, optim_exp: bool, optim_coeff: bool +) -> Tuple[Float[Array, "*n_pgtos"], Float[Array, "*n_pgtos"]]: + pgtos = [] + coeffs = [] + + # iter atoms + for i, element_basis in enumerate(cgto.basis): + coord = cgto.atom_coords[i] + # iter shells + for cgto_i in element_basis: + total_angular = cgto_i[0] # l/shell + assert not isinstance( + cgto_i[1], float + ), "basis with kappa is not supported yet" + pgtos_i = cgto_i[1:] # [[exp, c_1, c_2, ..], [exp, c_1, c_2, ..], ... ]] + n_coeffs = len(pgtos_i[0][1:]) + # TODO: do not create separate PGTO for each c_1, c_2, ... + # iter contractions + for cid in range(1, 1 + n_coeffs): # 0-idx is exponent + # iter cartesian monomials for the given shell + pgtos_monomials = [] + coeffs_monomials = [] + pgtos_sph = [] + coeffs_sph = [] + for angular in get_cart_angular_vec(total_angular): + # iter PGTOs in the given CGTO. If n_coeffs > 1, we have + # the general contraction, i.e. each PGTO appears in multiple + # CGTOs. + pgto_i_ = [] + coeffs_i = [] + for pgto_i in pgtos_i: + # NOTE: we want to have some activation function here to make sure + # that exponent > 0. However softplus is not good as inv_softplus + # makes some exponent goes inf + exponent = pgto_i[0] + if optim_exp: + exponent = jax.nn.softplus( + hk.get_parameter( + "exponent", + shape=(), + init=make_constant_fn(inv_softplus(jnp.array(exponent))) + ) + ) + coeff = pgto_i[cid] + if optim_coeff: + coeff = hk.get_parameter( + "coeff", shape=(), init=make_constant_fn(jnp.array(coeff)) + ) + pgto_i_.append(PGTO(angular, coord, exponent)) + coeffs_i.append(coeff) + + # pgto_i = PGTO.apply(np.stack, pgto_i_) + pgto_i = PGTO.stack(pgto_i_) + pgtos_monomials.append(pgto_i) + + # normalize each PGTO in cartesian coordinate + N = pgto_i.norm_inv() + normalized_coeffs_i = normalize_cgto_coeff( + pgto_i, + jnp.array(coeffs_i) * N + ) / N + coeffs_monomials.append(normalized_coeffs_i) + + # convert to spherical + r_l_inv = 1 / racah_norm(total_angular) + shell = Shell(total_angular) + prefacs = REAL_SOLID_HARMONICS_PREFAC[shell] + for mpl, sph in enumerate(MONOMIALS_TO_REAL_SOLID_HARMONICS[shell]): + m = mpl - total_angular # magnetic quantum number + p_m = prefacs[abs(m)] + for monomial_idx, m_prefac in sph: + pgtos_sph.append(pgtos_monomials[monomial_idx]) + coeffs_sph.append( + coeffs_monomials[monomial_idx] * m_prefac * p_m * r_l_inv + ) + # pgto_sph_ = PGTO.apply(np.concatenate, pgtos_sph) + pgto_sph_ = PGTO.concat(pgtos_sph) + pgtos.append(pgto_sph_) + coeffs.append(jnp.concatenate(coeffs_sph)) + + # pgto = PGTO.apply(np.concatenate, pgtos) + pgto = PGTO.concat(pgtos) + coeff = jnp.concatenate(coeffs) + + return coeff, pgto.exponent + + def build_cgto_from_mol(mol: Mol) -> CGTO: """Build CGTO from the basis information in Mol. @@ -183,7 +280,7 @@ def build_cgto_from_mol(mol: Mol) -> CGTO: pgto_i_.append(PGTO(angular, coord, exponent)) coeffs_i.append(coeff) - pgto_i = PGTO.apply(np.stack, pgto_i_) + pgto_i = PGTO.stack(pgto_i_) pgtos_monomials.append(pgto_i) # normalize each PGTO in cartesian coordinate @@ -229,9 +326,18 @@ def build_cgto_from_mol(mol: Mol) -> CGTO: f"there are {sum(cgto_splits)} (non-unique) PGTOs in spherical form" ) + # map basis into d4ft format + + basis = [] + for i, element in enumerate(mol.elements): + element_basis = [] + for cgto_i in mol.basis[element]: + element_basis.append(cgto_i) + basis.append(element_basis) + cgto = CGTO( pgto, pgto.norm_inv(), jnp.array(coeffs), cgto_splits, cgto_seg_id, - jnp.array(atom_splits), mol.atom_charges, mol.nocc + jnp.array(atom_splits), mol.atom_charges, mol.nocc, basis ) return cgto @@ -270,6 +376,25 @@ def at(self, i: Union[slice, int]) -> PGTO: """get one PGTO out of the batch""" return PGTO(self.angular[i], self.center[i], self.exponent[i]) + @staticmethod + def stack(pgtos: Sequence[PGTO]) -> PGTO: + """Return the stack of a sequence of PGTO singletons. + Note that the angular needs to be a numpy array to avoid tracing error. + """ + angular, center, exponent = zip(*pgtos) + return PGTO(np.stack(angular), jnp.stack(center), jnp.stack(exponent)) + + @staticmethod + def concat(pgtos: Sequence[PGTO]) -> PGTO: + """Return the concatenation of a sequence of PGTO singletons. + Note that the angular needs to be a numpy array to avoid tracing error. + """ + angular, center, exponent = zip(*pgtos) + return PGTO( + np.concatenate(angular), jnp.concatenate(center), + jnp.concatenate(exponent) + ) + @staticmethod def apply(f: Callable, pgtos: Sequence[PGTO]) -> PGTO: """Return the concatenation of a sequence of PGTO singletons. @@ -324,9 +449,15 @@ class CGTO(NamedTuple): """Atom segment lengths. e.g. [15, 15] for O2 in sto-3g. Useful for copying atom centers to each GTO when doing basis optimization.""" charge: Int[Array, "*n_atoms"] - """charges of the atoms""" + """Charges of the atoms""" nocc: Int[Array, "2 nao"] - """occupation mask for alpha and beta spin""" + """Cccupation mask for alpha and beta spin""" + # elements: Sequence[str] + # """List of atoms in the systems""" + # atom_coords: Float[np.ndarray, "n_atoms 3"] + # """List of atoms coordinates in the systems""" + basis: Sequence[Sequence[Tuple[int, Sequence[Sequence[float]]]]] + """basis in PySCF format""" @property def n_pgtos(self) -> int: @@ -383,34 +514,56 @@ def from_mol(mol: Mol) -> CGTO: def from_cart(cgto_cart: CGTO) -> CGTO: return build_cgto_sph_from_mol(cgto_cart) - def to_hk(self) -> CGTO: + def to_hk( + self, + optimizable_params: Sequence[Literal[ + "center", + "exponent", + "coeff", + ]] = ["coeff"], + ) -> CGTO: """Convert optimizable parameters to hk.Params. Must be haiku transformed. Can be used for basis optimization. """ - center_init = self.atom_coords - center = hk.get_parameter( - "center", center_init.shape, init=make_constant_fn(center_init) - ) - center_rep = jnp.repeat( - center, - jnp.array(self.atom_splits), - axis=0, - total_repeat_length=self.n_pgtos - ) - # NOTE: we want to have some activation function here to make sure - # that exponent > 0. However softplus is not good as inv_softplus - # makes some exponent goes inf - exponent = jax.nn.softplus( - hk.get_parameter( - "exponent", - self.pgto.exponent.shape, - init=make_constant_fn(inv_softplus(self.pgto.exponent)) + if "center" in optimizable_params: + center_init = self.atom_coords + center_param = hk.get_parameter( + "center", center_init.shape, init=make_constant_fn(center_init) ) - ) - coeff = hk.get_parameter( - "coeff", self.coeff.shape, init=make_constant_fn(self.coeff) - ) - pgto = PGTO(self.pgto.angular, center_rep, exponent) + center = jnp.repeat( + center_param, + jnp.array(self.atom_splits), + axis=0, + total_repeat_length=self.n_pgtos + ) + else: + center = self.pgto.center + + if "exponent" in optimizable_params or "coeff" in optimizable_params: + # NOTE: we want to have some activation function here to make sure + # that exponent > 0. However softplus is not good as inv_softplus + # makes some exponent goes inf + # exponent = jax.nn.softplus( + # hk.get_parameter( + # "exponent", + # self.pgto.exponent.shape, + # init=make_constant_fn(inv_softplus(self.pgto.exponent)) + # ) + # ) + # coeff = hk.get_parameter( + # "coeff", self.coeff.shape, init=make_constant_fn(self.coeff) + # ) + + coeff, exponent = reparameterize( + self, + optim_exp="exponent" in optimizable_params, + optim_coeff="coeff" in optimizable_params, + ) + else: + coeff = self.coeff + exponent = self.pgto.exponent + + pgto = PGTO(self.pgto.angular, center, exponent) return self._replace(pgto=pgto, coeff=coeff) # TODO: instead of using occupation mask, we can orthogonalize a non-square diff --git a/d4ft/solver/drivers.py b/d4ft/solver/drivers.py index ec1b45d..7e4dd00 100644 --- a/d4ft/solver/drivers.py +++ b/d4ft/solver/drivers.py @@ -143,8 +143,8 @@ def incore_cgto_direct_opt_dft( def H_factory() -> Tuple[Callable, Hamiltonian]: """Auto-grad scope""" # TODO: out-of-core + basis optimization - # cgto_hk = cgto.to_hk() - cgto_hk = cgto + cgto_hk = cgto.to_hk(["coeff"]) + # cgto_hk = cgto cgto_intor = get_cgto_intor( cgto_hk, intor="obsa", incore_energy_tensors=incore_e_tensors )