Skip to content

Commit

Permalink
fix inv_softplus
Browse files Browse the repository at this point in the history
  • Loading branch information
Szkered committed Aug 29, 2023
1 parent 910777b commit 10bc431
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 26 deletions.
14 changes: 0 additions & 14 deletions d4ft/integral/gto/cgto.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,20 +540,6 @@ def to_hk(
center = self.pgto.center

if "exponent" in optimizable_params or "coeff" in optimizable_params:
# NOTE: we want to have some activation function here to make sure
# that exponent > 0. However softplus is not good as inv_softplus
# makes some exponent goes inf
# exponent = jax.nn.softplus(
# hk.get_parameter(
# "exponent",
# self.pgto.exponent.shape,
# init=make_constant_fn(inv_softplus(self.pgto.exponent))
# )
# )
# coeff = hk.get_parameter(
# "coeff", self.coeff.shape, init=make_constant_fn(self.coeff)
# )

coeff, exponent = reparameterize(
self,
optim_exp="exponent" in optimizable_params,
Expand Down
10 changes: 2 additions & 8 deletions d4ft/solver/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,6 @@ def cgto_direct_opt(

cgto_tensor_fns, pyscf_mol, cgto = build_mf_cgto(cfg)

# if cfg.intor_cfg.incore:
# cgto_e_tensors = cgto_tensor_fns.get_incore_tensors(cgto)
# else:
# cgto_e_tensors = None

if cfg.method_cfg.name == "KS":
dg = DifferentiableGrids(pyscf_mol)
dg.level = cfg.intor_cfg.quad_level
Expand All @@ -181,17 +176,16 @@ def H_factory() -> Tuple[Callable, Hamiltonian]:
else:
cgto_hk = cgto
if cfg.intor_cfg.incore:
ovlp = get_ovlp_incore(cgto, cgto_e_tensors)
cgto_intor = get_cgto_intor(
cgto_hk,
cgto_tensor_fns,
cgto_e_tensors=cgto_e_tensors,
intor=cfg.intor_cfg.intor,
)
ovlp = get_ovlp_incore(cgto, cgto_e_tensors)
else:
cgto_intor = get_cgto_intor(
cgto_hk,
cgto_tensor_fns,
cgto_tensor_fns=cgto_tensor_fns,
intor=cfg.intor_cfg.intor,
)
ovlp = get_ovlp(cgto_hk, cgto_tensor_fns)
Expand Down
12 changes: 8 additions & 4 deletions d4ft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jaxtyping import Array, Num
from ml_collections import ConfigDict

Expand All @@ -40,10 +41,13 @@ def compose(f: Callable, g: Callable) -> Callable:


def inv_softplus(y: Num[Array, "*s"]) -> Num[Array, "*s"]:
if y < 20: # This threshold is arbitrary and can be adjusted.
return jnp.log(jnp.exp(y) - 1.)
else: # For large y, softplus(y) ~= y
return y
r"""Inverse of softplus.
We want to have some activation function here to make sure
that exponent > 0. However softplus is not good as inv_softplus
makes some exponent goes inf. For large y, softplus(y) ~= y.
"""
return lax.cond(y < 20, lambda y: jnp.log(jnp.exp(y) - 1.), lambda y: y, y)


def save_cfg(cfg: ConfigDict, save_path: Union[str, Path]) -> None:
Expand Down

0 comments on commit 10bc431

Please sign in to comment.