diff --git a/README.md b/README.md index 7a3b7b1..89484c5 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ D4FT also provides examples for standard algorithms, similar to the "train" scri ## Calculating the ground state energy of Oxygen molecule Let's calculate the ground state energy of Oxygen molecule with direct minimization DFT: ``` shell -python main.py --run direct --config.mol_cfg.mol O2 +python main.py --run direct --config.sys_cfg.mol O2 ``` and you should see the following log after the calculation has converged: @@ -48,7 +48,7 @@ where each component of the ground state energy is printed. ## Benchmarking against PySCF Now let's test the accuracy of the calculated ground state energy against well-established open-source QC library PySCF. D4FT provides a thin wrapper around PySCF's API: to run the same calculation above of the Oxygen molecule but with PySCF, run: ``` shell -python main.py --run pyscf --config.mol_cfg.mol O2 +python main.py --run pyscf --config.sys_cfg.mol O2 ``` This will call PySCF to perform SCF calculation with the setting stated in `d4ft/config.py`. Two sets of energy will be printed: 1. the energy calculated with PySCF's integral engine `libcint`, which uses Rys quadrature: @@ -77,7 +77,7 @@ where `1e energy` is the sum of kinetic and external potential energy. We see th ## Calculate energy barrier for reaction ``` shell -python main.py --run reaction --reaction hf_h_hfhts --config.gd_cfg.gd_cfg.lr_decay cosine +python main.py --run reaction --reaction hf_h_hfhts --config.solver_cfg.solver_cfg.lr_decay cosine ``` This calculate the ground state energy for each system, then compute the energy barrier: ``` shell @@ -104,12 +104,12 @@ D4FT uses [ml_collections](https://github.com/google/ml_collections) to manage c The configuration used for the calculation will be printed to the console at the start of the run. For example when you run the calculation for Oxygen above using the default configuration, you should see the following: ``` shell -algo_cfg: !!python/object:config_config.AlgoConfig +method_cfg: !!python/object:config_config.MethodConfig __pydantic_initialised__: true restricted: false rng_seed: 137 xc_type: lda_x -gd_cfg: !!python/object:config_config.GDConfig +solver_cfg: !!python/object:config_config.GDConfig __pydantic_initialised__: true converge_threshold: 0.0001 epochs: 4000 @@ -124,32 +124,32 @@ intor_cfg: !!python/object:config_config.IntorConfig incore: true intor: obsa quad_level: 1 -mol_cfg: !!python/object:config_config.MoleculeConfig +sys_cfg: !!python/object:config_config.MoleculeConfig __pydantic_initialised__: true basis: sto-3g charge: 0 geometry_source: cccdbd mol: o2 spin: -1 -scf_cfg: !!python/object:config_config.SCFConfig +solver_cfg: !!python/object:config_config.SCFConfig __pydantic_initialised__: true epochs: 100 momentum: 0.5 ``` -All configuration stated in `d4ft/config.py` can be overridden by providing an appropriate flag (of the form `--config.`). For example, to change the basis set to `6-31g`, use the flag `--config.mol_cfg.basis 6-31g`. You can directly change the +All configuration stated in `d4ft/config.py` can be overridden by providing an appropriate flag (of the form `--config.`). For example, to change the basis set to `6-31g`, use the flag `--config.sys_cfg.basis 6-31g`. You can directly change the `d4ft/config.py` file, or specify a custom config file by supplying the flag `--config `. ## Specifying spin multiplicity -By default all electrons are maximally paired, so the spin is 0 or 1. To specify the spin multiplicity, use the flag `--config.mol_cfg.spin`, for example +By default all electrons are maximally paired, so the spin is 0 or 1. To specify the spin multiplicity, use the flag `--config.sys_cfg.spin`, for example ``` shell -python main.py --run direct --config.mol_cfg.mol O2 --config.mol_cfg.spin 2 +python main.py --run direct --config.sys_cfg.mol O2 --config.sys_cfg.spin 2 ``` ## Specifying XC functional -D4FT uses [`jax-xc`](https://github.com/sail-sg/jax_xc) for XC functional. Use the flag `--config.algo_cfg.xc_type` to specify XC functional to use, for example: +D4FT uses [`jax-xc`](https://github.com/sail-sg/jax_xc) for XC functional. Use the flag `--config.method_cfg.xc_type` to specify XC functional to use, for example: ``` shell -python main.py --run direct --config.mol_cfg.mol O2 --config.algo_cfg.xc_type lda_x +python main.py --run direct --config.sys_cfg.mol O2 --config.method_cfg.xc_type lda_x ``` @@ -161,7 +161,7 @@ O 0.0000 0.0000 0.0000; O 0.0000 0.0000 1.2075; """ ``` -For geometries not cached in the above file, D4FT will query the `cccdbd` website, and you shall see the following logs (using `--config.mol_cfg.mol ch4` in this example): +For geometries not cached in the above file, D4FT will query the `cccdbd` website, and you shall see the following logs (using `--config.sys_cfg.mol ch4` in this example): ``` shell I0630 11:12:49.016396 140705043318592 cccdbd.py:108] **** Posting formula I0630 11:12:50.397949 140705043318592 cccdbd.py:116] **** Fetching data @@ -179,7 +179,7 @@ H 0.0000 0.0000 0.7414; then pass it through the config flag as follows ``` shell ---config.mol_cfg.mol +--config.sys_cfg.mol ``` # Using the D4FT API directly @@ -200,8 +200,8 @@ logging.set_verbosity(logging.INFO) # load the default configuration, then override it cfg = get_config() -cfg.mol_cfg.mol = 'H2' -cfg.mol_cfg.basis = '6-31g' +cfg.sys_cfg.mol = 'H2' +cfg.sys_cfg.basis = '6-31g' # Calculation e_total, _, _ = incore_cgto_direct_opt(cfg) @@ -216,7 +216,7 @@ We have benchmarked the calculation against well known open-sourced quantum chem To run systems from `refdata` benchmark sets, ``` shell -python main.py --benchmark bh76 --use_f64 --config.mol_cfg.basis --config.algo_cfg.xc_type --save --config.mol_cfg.geometry_source refdata --pyscf --config.save_dir +python main.py --benchmark bh76 --use_f64 --config.sys_cfg.basis --config.method_cfg.xc_type --save --config.sys_cfg.geometry_source refdata --pyscf --config.save_dir ``` To visualize the run: diff --git a/d4ft/config.py b/d4ft/config.py index 730dea7..d16b487 100644 --- a/d4ft/config.py +++ b/d4ft/config.py @@ -13,7 +13,7 @@ # limitations under the License. from pathlib import Path -from typing import Literal +from typing import Literal, Union from ml_collections import ConfigDict from pydantic.config import ConfigDict as PydanticConfigDict @@ -25,6 +25,7 @@ @dataclass(config=pydantic_config) class GDConfig: """Config for direct minimization with gradient descent solver.""" + name: Literal["GD"] = "GD" lr: float = 1e-2 """learning rate""" lr_decay: Literal["none", "piecewise", "cosine"] = "none" @@ -47,6 +48,7 @@ class GDConfig: @dataclass(config=pydantic_config) class SCFConfig: """Config for self-consistent field solver.""" + name: Literal["SCF"] = "SCF" momentum: float = 0.5 """fock matrix update momentum""" epochs: int = 100 @@ -68,6 +70,7 @@ class IntorConfig: @dataclass(config=pydantic_config) class MoleculeConfig: """Config for molecule""" + name: Literal["MOL"] = "MOL" mol: str = "O2" """name of the molecule, or the path to the geometry file, which specifies the geometry in the format @@ -87,52 +90,102 @@ class MoleculeConfig: @dataclass(config=pydantic_config) -class AlgoConfig: - """Config for Algorithms.""" - algo: Literal["HF", "KS"] = "KS" - """Which algorithm to use. HF for Hartree-Fock, KS for Kohn-Sham DFT.""" +class CrystalConfig: + """Config for crystal""" + name: Literal["CRY"] = "CRY" + direct_lattice_dim: str = "1x1x1" + """Dimension of the direct lattice, i.e. the number of k points + (crystal momenta) in each spatial direction. Format is N1xN2xN3.""" + reciprocal_lattice_dim: str = "1x1x1" + """Dimension of the reciprocal lattice, i.e. the number of reciprocal lattice + vectors in each spatial direction. Format is N1xN2xN3.""" + energy_cutoff: float = 300. + """kinetic energy (of G points) cutoff for the plane wave basis set. + Unit is Hartree""" + + +@dataclass(config=pydantic_config) +class HFConfig: + """Config for Hartree-Fock theory.""" + name: Literal["HF"] = "HF" restricted: bool = False """Whether to run restricted calculation, i.e. enforcing symmetry by using the same coefficients for both spins""" + rng_seed: int = 137 + """PRNG seed""" + + +@dataclass(config=pydantic_config) +class KSDFTConfig: + """Config for Kohn-Sham Density functional theory.""" + name: Literal["KS"] = "KS" xc_type: str = "lda_x" """Name of the xc functional to use. To mix two XC functional, use the syntax a*xc_name_1+b*xc_name_2 where a, b are numbers.""" + restricted: bool = False + """Whether to run restricted calculation, i.e. enforcing symmetry by using the + same coefficients for both spins""" rng_seed: int = 137 """PRNG seed""" class D4FTConfig(ConfigDict): - algo_cfg: AlgoConfig + method_cfg: Union[HFConfig, KSDFTConfig] + """which QC method to use""" + solver_cfg: Union[GDConfig, SCFConfig] + """which solver to use""" intor_cfg: IntorConfig - mol_cfg: MoleculeConfig - gd_cfg: GDConfig - scf_cfg: SCFConfig + """integration engine config""" + sys_cfg: Union[MoleculeConfig, CrystalConfig] + """config for the system to simulate""" uuid: str save_dir: str def __init__(self, config_string: str) -> None: + method, solver, sys = config_string.split("-") + + if method.lower() == "hf": + method_cls = HFConfig + elif method.lower() == "ks": + method_cls = KSDFTConfig + else: + raise ValueError(f"Unknown method {method}") + + if solver.lower() == "gd": + solver_cls = GDConfig + elif solver.lower() == "scf": + solver_cls = SCFConfig + else: + raise ValueError(f"Unknown solver {solver}") + + if sys.lower() == "mol": + sys_cls = MoleculeConfig + elif sys.lower() == "crystal": + sys_cls = CrystalConfig + else: + raise ValueError(f"Unknown system {sys}") + super().__init__( { - "algo_cfg": AlgoConfig(), + "method_cfg": method_cls(), + "solver_cfg": solver_cls(), "intor_cfg": IntorConfig(), - "mol_cfg": MoleculeConfig(), - "gd_cfg": GDConfig(), - "scf_cfg": SCFConfig(), + "sys_cfg": sys_cls(), "uuid": "", "save_dir": "_exp", } ) def validate(self, spin: int, charge: int) -> None: - if self.algo_cfg.restricted and self.mol_cfg.mol not in ["bh76_h", "h"]: + if self.method_cfg.restricted and self.sys_cfg.mol not in ["bh76_h", "h"]: assert spin == 0 and charge == 0, \ "RESTRICTED only supports closed-shell molecules" def get_save_dir(self) -> Path: - return Path(f"{self.save_dir}/{self.uuid}/{self.mol_cfg.mol}") + return Path(f"{self.save_dir}/{self.uuid}/{self.sys_cfg.mol}") def get_core_cfg_str(self) -> str: - return "+".join([self.mol_cfg.basis, self.algo_cfg.xc_type]) + return "+".join([self.sys_cfg.basis, self.method_cfg.xc_type]) def save(self): save_path = self.get_save_dir().parent @@ -141,12 +194,12 @@ def save(self): f.write(str(self)) -def get_config(config_string: str = "") -> D4FTConfig: +def get_config(config_string: str = "KS-GD-MOL") -> D4FTConfig: """Return the default configurations. Args: - config_string: currently only set the type of algorithm. Available values: - "gd", "scf". + config_string: set the method, solver and sys for the D4FTConfig. Format is + method-solver-sys, and the default is KS-GD-MOL. NOTE: for distributed setup, might need to move the dataclass definition into this function. diff --git a/d4ft/solver/__init__.py b/d4ft/solver/__init__.py index f64ad0b..d612a83 100644 --- a/d4ft/solver/__init__.py +++ b/d4ft/solver/__init__.py @@ -17,12 +17,12 @@ dft_cgto, incore_cgto_direct_opt, incore_cgto_scf, - incore_hf_cgto, + incore_mf_cgto, ) __all__ = [ "dft_cgto", "incore_cgto_direct_opt", "incore_cgto_scf", - "incore_hf_cgto", + "incore_mf_cgto", ] diff --git a/d4ft/solver/drivers.py b/d4ft/solver/drivers.py index 472b01e..a4d4f8d 100644 --- a/d4ft/solver/drivers.py +++ b/d4ft/solver/drivers.py @@ -47,10 +47,13 @@ from d4ft.xc import get_lda_vxc, get_xc_functional, get_xc_intor -def incore_hf_cgto(cfg: D4FTConfig): +def incore_mf_cgto(cfg: D4FTConfig): + """Build the CGTO basis with in-core intor for the mean-field calculations + (i.e. HF and KS-DFT). For KS-DFT we also need to build the grids for the + numerical integration of the XC functional""" pyscf_mol = get_pyscf_mol( - cfg.mol_cfg.mol, cfg.mol_cfg.basis, cfg.mol_cfg.spin, cfg.mol_cfg.charge, - cfg.mol_cfg.geometry_source + cfg.sys_cfg.mol, cfg.sys_cfg.basis, cfg.sys_cfg.spin, cfg.sys_cfg.charge, + cfg.sys_cfg.geometry_source ) mol = Mol.from_pyscf_mol(pyscf_mol) cfg.validate(mol.spin, mol.charge) @@ -61,10 +64,13 @@ def incore_hf_cgto(cfg: D4FTConfig): s4 = obsa.angular_static_args(*[cgto.pgto.angular] * 4) incore_e_tensors = incore_int_sym(cgto, s2, s4) - dg = DifferentiableGrids(pyscf_mol) - dg.level = cfg.intor_cfg.quad_level - # TODO: test geometry optimization - grids_and_weights = dg.build(pyscf_mol.atom_coords()) + if cfg.method_cfg.name == "KS": + dg = DifferentiableGrids(pyscf_mol) + dg.level = cfg.intor_cfg.quad_level + # TODO: test geometry optimization + grids_and_weights = dg.build(pyscf_mol.atom_coords()) + else: + grids_and_weights = None return incore_e_tensors, pyscf_mol, cgto, grids_and_weights @@ -75,25 +81,25 @@ def incore_cgto_scf(cfg: D4FTConfig) -> None: NOTE: since jax-xc doesn't have vxc yet the vxc here is fixed to LDA """ - key = jax.random.PRNGKey(cfg.algo_cfg.rng_seed) - incore_e_tensors, _, cgto, grids_and_weights = incore_hf_cgto(cfg) + key = jax.random.PRNGKey(cfg.method_cfg.rng_seed) + incore_e_tensors, _, cgto, grids_and_weights = incore_mf_cgto(cfg) ovlp = get_ovlp(cgto, incore_e_tensors) vxc_fn = get_lda_vxc( - grids_and_weights, cgto, polarized=not cfg.algo_cfg.restricted + grids_and_weights, cgto, polarized=not cfg.method_cfg.restricted ) cgto_fock_fn = get_cgto_fock_fn(cgto, incore_e_tensors, vxc_fn) cgto_fock_jit = jax.jit(cgto_fock_fn) # get initial mo_coeff - mo_coeff_fn = partial(cgto.get_mo_coeff, restricted=cfg.algo_cfg.restricted) + mo_coeff_fn = partial(cgto.get_mo_coeff, restricted=cfg.method_cfg.restricted) mo_coeff_fn = hk.without_apply_rng(hk.transform(mo_coeff_fn)) params = mo_coeff_fn.init(key) mo_coeff = mo_coeff_fn.apply(params, apply_spin_mask=False) - polarized = not cfg.algo_cfg.restricted - xc_func = get_xc_functional(cfg.algo_cfg.xc_type, polarized) + polarized = not cfg.method_cfg.restricted + xc_func = get_xc_functional(cfg.method_cfg.xc_type, polarized) xc_fn = get_xc_intor(grids_and_weights, cgto, xc_func, polarized) kin_fn, ext_fn, har_fn = get_cgto_intor( cgto, intor="obsa", incore_energy_tensors=incore_e_tensors @@ -110,7 +116,7 @@ def energy_fn(mo_coeff): energies = Energies(e_total, e_kin, e_ext, e_har, e_xc, e_nuc) return energies - transpose_axis = (1, 0) if cfg.algo_cfg.restricted else (0, 2, 1) + transpose_axis = (1, 0) if cfg.method_cfg.restricted else (0, 2, 1) @jax.jit def scf_iter(fock): @@ -120,9 +126,11 @@ def scf_iter(fock): fock = jnp.eye(cgto.nao) # initial guess logger = RunLogger() - for step in range(cfg.scf_cfg.epochs): + for step in range(cfg.solver_cfg.epochs): new_fock = cgto_fock_jit(mo_coeff) - fock = (1 - cfg.scf_cfg.momentum) * new_fock + cfg.scf_cfg.momentum * fock + fock = ( + 1 - cfg.solver_cfg.momentum + ) * new_fock + cfg.solver_cfg.momentum * fock e_orb, mo_coeff = scf_iter(fock) logging.info(f"{e_orb=}") residual = jnp.eye(cgto.nao) - mo_coeff[0].T @ ovlp @ mo_coeff[0] @@ -139,12 +147,9 @@ def incore_cgto_direct_opt( ) -> float: """Solve for ground state of a molecular system with direct optimization DFT, where CGTO basis are used and the energy tensors are precomputed/incore.""" - key = jax.random.PRNGKey(cfg.algo_cfg.rng_seed) + key = jax.random.PRNGKey(cfg.method_cfg.rng_seed) - polarized = not cfg.algo_cfg.restricted - xc_func = get_xc_functional(cfg.algo_cfg.xc_type, polarized) - - incore_e_tensors, pyscf_mol, cgto, grids_and_weights = incore_hf_cgto(cfg) + incore_e_tensors, pyscf_mol, cgto, grids_and_weights = incore_mf_cgto(cfg) def H_factory() -> Tuple[Callable, Hamiltonian]: """Auto-grad scope""" @@ -159,21 +164,25 @@ def H_factory() -> Tuple[Callable, Hamiltonian]: ) mo_coeff_fn = partial( cgto_hk.get_mo_coeff, - restricted=cfg.algo_cfg.restricted, + restricted=cfg.method_cfg.restricted, ortho_fn=qr_factor, ovlp_sqrt_inv=sqrt_inv(ovlp), ) - xc_fn = get_xc_intor(grids_and_weights, cgto_hk, xc_func, polarized) - return dft_cgto(cgto_hk, cgto_intor, xc_fn, mo_coeff_fn) - # e_total = scipy_opt(cfg.gd_cfg, H_factory, key) + if cfg.method_cfg.name == "KS": + polarized = not cfg.method_cfg.restricted + xc_func = get_xc_functional(cfg.method_cfg.xc_type, polarized) + xc_fn = get_xc_intor(grids_and_weights, cgto_hk, xc_func, polarized) + return dft_cgto(cgto_hk, cgto_intor, xc_fn, mo_coeff_fn) + + # e_total = scipy_opt(cfg.solver_cfg, H_factory, key) # breakpoint() H_transformed = hk.multi_transform(H_factory) params = H_transformed.init(key) H = Hamiltonian(*H_transformed.apply) - logger, traj = sgd(cfg.gd_cfg, H, params, key) + logger, traj = sgd(cfg.solver_cfg, H, params, key) min_e_step = logger.data_df.e_total.astype(float).idxmin() logging.info(f"lowest total energy: \n {logger.data_df.iloc[min_e_step]}") @@ -183,8 +192,8 @@ def H_factory() -> Tuple[Callable, Hamiltonian]: # rdm1 = get_rdm1(traj[-1].mo_coeff) # scf_mo_coeff = pyscf_wrapper( # pyscf_mol, - # cfg.algo_cfg.restricted, - # cfg.algo_cfg.xc_type, + # cfg.method_cfg.restricted, + # cfg.method_cfg.xc_type, # cfg.intor_cfg.quad_level, # algo="KS", # rdm1=rdm1, @@ -216,7 +225,7 @@ def H_factory() -> Tuple[Callable, Hamiltonian]: def incore_cgto_pyscf_benchmark(cfg: D4FTConfig) -> RunLogger: - incore_e_tensors, pyscf_mol, cgto, grids_and_weights = incore_hf_cgto(cfg) + incore_e_tensors, pyscf_mol, cgto, grids_and_weights = incore_mf_cgto(cfg) return pyscf_benchmark( cfg, pyscf_mol, cgto, incore_e_tensors, grids_and_weights ) @@ -235,17 +244,17 @@ def pyscf_benchmark( cgto_intor = get_cgto_intor( cgto, intor="obsa", incore_energy_tensors=incore_e_tensors ) - polarized = not cfg.algo_cfg.restricted - xc_func = get_xc_functional(cfg.algo_cfg.xc_type, polarized) + polarized = not cfg.method_cfg.restricted + xc_func = get_xc_functional(cfg.method_cfg.xc_type, polarized) xc_fn = get_xc_intor(grids_and_weights, cgto, xc_func, polarized) # solve for ground state with PySCF and get the mo_coeff atom_mf, mo_coeff = pyscf_wrapper( pyscf_mol, - cfg.algo_cfg.restricted, - cfg.algo_cfg.xc_type, + cfg.method_cfg.restricted, + cfg.method_cfg.xc_type, cfg.intor_cfg.quad_level, - algo=cfg.algo_cfg.algo + algo=cfg.method_cfg.algo ) # add spin and apply occupation mask diff --git a/d4ft/solver/sgd.py b/d4ft/solver/sgd.py index 3a339a8..b6d39fe 100644 --- a/d4ft/solver/sgd.py +++ b/d4ft/solver/sgd.py @@ -28,7 +28,8 @@ def scipy_opt( - gd_cfg: GDConfig, H: Hamiltonian, params: hk.Params, key: jax.random.KeyArray + solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, + key: jax.random.KeyArray ) -> float: energy_fn_jit = jax.jit(lambda mo_coeff: H.energy_fn(mo_coeff, key)[0]) import jaxopt @@ -38,7 +39,8 @@ def scipy_opt( def sgd( - gd_cfg: GDConfig, H: Hamiltonian, params: hk.Params, key: jax.random.KeyArray + solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, + key: jax.random.KeyArray ) -> Tuple[RunLogger, Trajectory]: @jax.jit @@ -77,9 +79,9 @@ def meta_step(state: TrainingState, meta_state: TrainingState): ), new_state, energies, mo_grads # init state - opt_states = get_optimizer(gd_cfg, params, key) + opt_states = get_optimizer(solver_cfg, params, key) optimizer, state = opt_states["main"] - if gd_cfg.meta_opt != "none": + if solver_cfg.meta_opt != "none": meta_opt, meta_state = opt_states["meta"] # GD loop @@ -87,9 +89,9 @@ def meta_step(state: TrainingState, meta_state: TrainingState): converged = False logger = RunLogger() e_total_std = 0. - for step in range(gd_cfg.epochs): + for step in range(solver_cfg.epochs): - if gd_cfg.meta_opt == "none": + if solver_cfg.meta_opt == "none": new_state, energies, mo_grads = update(state) else: meta_state, new_state, energies, mo_grads = meta_step(state, meta_state) @@ -105,14 +107,14 @@ def meta_step(state: TrainingState, meta_state: TrainingState): state = new_state - if step < gd_cfg.hist_len: # don't check for convergence + if step < solver_cfg.hist_len: # don't check for convergence continue # check convergence e_total_std = jnp.stack( - [t.energies.e_total for t in traj[-gd_cfg.hist_len:]] + [t.energies.e_total for t in traj[-solver_cfg.hist_len:]] ).std() - if e_total_std < gd_cfg.converge_threshold: + if e_total_std < solver_cfg.converge_threshold: converged = True break diff --git a/run_main.sh b/run_main.sh index e63649f..c40083c 100644 --- a/run_main.sh +++ b/run_main.sh @@ -1,5 +1,5 @@ mode="direct" mol="o2" basis="cc-pvdz" -python3 main.py --run $mode --config.mol_cfg.mol $mol \ - --config.mol_cfg.basis $basis --use_f64 --config.gd_cfg.optimizer rmsprop --config.gd_cfg.meta_opt adam \ No newline at end of file +python3 main.py --run $mode --config.sys_cfg.mol $mol \ + --config.sys_cfg.basis $basis --use_f64 --config.solver_cfg.optimizer rmsprop --config.solver_cfg.meta_opt adam \ No newline at end of file diff --git a/tests/solver_test.py b/tests/solver_test.py index 45a2b5e..9c7cbfe 100644 --- a/tests/solver_test.py +++ b/tests/solver_test.py @@ -31,8 +31,8 @@ def test_incore_sgd( self, system: str, energy_bounds: Tuple[float, float] ) -> None: cfg = get_config() - cfg.mol_cfg.mol = system - cfg.mol_cfg.basis = '6-31g' + cfg.sys_cfg.mol = system + cfg.sys_cfg.basis = '6-31g' e_total = incore_cgto_direct_opt(cfg, basis_optim=False) upper_bound, lower_bound = energy_bounds self.assertTrue(e_total < upper_bound and e_total > lower_bound) diff --git a/tests/xc_test.py b/tests/xc_test.py index c4b3861..df902a8 100644 --- a/tests/xc_test.py +++ b/tests/xc_test.py @@ -39,14 +39,14 @@ class XCTest(parameterized.TestCase): ) def test_xc_grad(self, xc_name: str) -> None: cfg = get_config() - key = jax.random.PRNGKey(cfg.algo_cfg.rng_seed) - cfg.mol_cfg.mol = "h2" - cfg.mol_cfg.basis = "sto-3g" + key = jax.random.PRNGKey(cfg.method_cfg.rng_seed) + cfg.sys_cfg.mol = "h2" + cfg.sys_cfg.basis = "sto-3g" # build system pyscf_mol = get_pyscf_mol( - cfg.mol_cfg.mol, cfg.mol_cfg.basis, cfg.mol_cfg.spin, cfg.mol_cfg.charge, - cfg.mol_cfg.geometry_source + cfg.sys_cfg.mol, cfg.sys_cfg.basis, cfg.sys_cfg.spin, cfg.sys_cfg.charge, + cfg.sys_cfg.geometry_source ) mol = Mol.from_pyscf_mol(pyscf_mol) cgto = CGTO.from_mol(mol) @@ -63,12 +63,12 @@ def test_xc_grad(self, xc_name: str) -> None: # function maps mo coefficients to xc energy mo_coeff_fn = partial( cgto.get_mo_coeff, - restricted=cfg.algo_cfg.restricted, + restricted=cfg.method_cfg.restricted, ortho_fn=qr_factor, ovlp_sqrt_inv=sqrt_inv(ovlp), ) - polarized = not cfg.algo_cfg.restricted - xc_func = get_xc_functional(cfg.algo_cfg.xc_type, polarized) + polarized = not cfg.method_cfg.restricted + xc_func = get_xc_functional(cfg.method_cfg.xc_type, polarized) xc_fn = get_xc_intor(grids_and_weights, cgto, xc_func, polarized) mo_xc_fn = hk.without_apply_rng(hk.transform(compose(xc_fn, mo_coeff_fn)))