Skip to content

Commit

Permalink
refactor config system
Browse files Browse the repository at this point in the history
  • Loading branch information
Szkered committed Aug 18, 2023
1 parent d1b1ccc commit 41ab795
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 93 deletions.
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ D4FT also provides examples for standard algorithms, similar to the "train" scri
## Calculating the ground state energy of Oxygen molecule
Let's calculate the ground state energy of Oxygen molecule with direct minimization DFT:
``` shell
python main.py --run direct --config.mol_cfg.mol O2
python main.py --run direct --config.sys_cfg.mol O2
```

and you should see the following log after the calculation has converged:
Expand All @@ -48,7 +48,7 @@ where each component of the ground state energy is printed.
## Benchmarking against PySCF
Now let's test the accuracy of the calculated ground state energy against well-established open-source QC library PySCF. D4FT provides a thin wrapper around PySCF's API: to run the same calculation above of the Oxygen molecule but with PySCF, run:
``` shell
python main.py --run pyscf --config.mol_cfg.mol O2
python main.py --run pyscf --config.sys_cfg.mol O2
```
This will call PySCF to perform SCF calculation with the setting stated in `d4ft/config.py`. Two sets of energy will be printed:
1. the energy calculated with PySCF's integral engine `libcint`, which uses Rys quadrature:
Expand Down Expand Up @@ -77,7 +77,7 @@ where `1e energy` is the sum of kinetic and external potential energy. We see th
## Calculate energy barrier for reaction

``` shell
python main.py --run reaction --reaction hf_h_hfhts --config.gd_cfg.gd_cfg.lr_decay cosine
python main.py --run reaction --reaction hf_h_hfhts --config.solver_cfg.solver_cfg.lr_decay cosine
```
This calculate the ground state energy for each system, then compute the energy barrier:
``` shell
Expand All @@ -104,12 +104,12 @@ D4FT uses [ml_collections](https://github.com/google/ml_collections) to manage c

The configuration used for the calculation will be printed to the console at the start of the run. For example when you run the calculation for Oxygen above using the default configuration, you should see the following:
``` shell
algo_cfg: !!python/object:config_config.AlgoConfig
method_cfg: !!python/object:config_config.MethodConfig
__pydantic_initialised__: true
restricted: false
rng_seed: 137
xc_type: lda_x
gd_cfg: !!python/object:config_config.GDConfig
solver_cfg: !!python/object:config_config.GDConfig
__pydantic_initialised__: true
converge_threshold: 0.0001
epochs: 4000
Expand All @@ -124,32 +124,32 @@ intor_cfg: !!python/object:config_config.IntorConfig
incore: true
intor: obsa
quad_level: 1
mol_cfg: !!python/object:config_config.MoleculeConfig
sys_cfg: !!python/object:config_config.MoleculeConfig
__pydantic_initialised__: true
basis: sto-3g
charge: 0
geometry_source: cccdbd
mol: o2
spin: -1
scf_cfg: !!python/object:config_config.SCFConfig
solver_cfg: !!python/object:config_config.SCFConfig
__pydantic_initialised__: true
epochs: 100
momentum: 0.5
```

All configuration stated in `d4ft/config.py` can be overridden by providing an appropriate flag (of the form `--config.<cfg_field>`). For example, to change the basis set to `6-31g`, use the flag `--config.mol_cfg.basis 6-31g`. You can directly change the
All configuration stated in `d4ft/config.py` can be overridden by providing an appropriate flag (of the form `--config.<cfg_field>`). For example, to change the basis set to `6-31g`, use the flag `--config.sys_cfg.basis 6-31g`. You can directly change the
`d4ft/config.py` file, or specify a custom config file by supplying the flag `--config <your config file path>`.

## Specifying spin multiplicity
By default all electrons are maximally paired, so the spin is 0 or 1. To specify the spin multiplicity, use the flag `--config.mol_cfg.spin`, for example
By default all electrons are maximally paired, so the spin is 0 or 1. To specify the spin multiplicity, use the flag `--config.sys_cfg.spin`, for example
``` shell
python main.py --run direct --config.mol_cfg.mol O2 --config.mol_cfg.spin 2
python main.py --run direct --config.sys_cfg.mol O2 --config.sys_cfg.spin 2
```

## Specifying XC functional
D4FT uses [`jax-xc`](https://github.com/sail-sg/jax_xc) for XC functional. Use the flag `--config.algo_cfg.xc_type` to specify XC functional to use, for example:
D4FT uses [`jax-xc`](https://github.com/sail-sg/jax_xc) for XC functional. Use the flag `--config.method_cfg.xc_type` to specify XC functional to use, for example:
``` shell
python main.py --run direct --config.mol_cfg.mol O2 --config.algo_cfg.xc_type lda_x
python main.py --run direct --config.sys_cfg.mol O2 --config.method_cfg.xc_type lda_x
```


Expand All @@ -161,7 +161,7 @@ O 0.0000 0.0000 0.0000;
O 0.0000 0.0000 1.2075;
"""
```
For geometries not cached in the above file, D4FT will query the `cccdbd` website, and you shall see the following logs (using `--config.mol_cfg.mol ch4` in this example):
For geometries not cached in the above file, D4FT will query the `cccdbd` website, and you shall see the following logs (using `--config.sys_cfg.mol ch4` in this example):
``` shell
I0630 11:12:49.016396 140705043318592 cccdbd.py:108] **** Posting formula
I0630 11:12:50.397949 140705043318592 cccdbd.py:116] **** Fetching data
Expand All @@ -179,7 +179,7 @@ H 0.0000 0.0000 0.7414;
then pass it through the config flag as follows

``` shell
--config.mol_cfg.mol <path_to_geometry_file>
--config.sys_cfg.mol <path_to_geometry_file>
```

# Using the D4FT API directly
Expand All @@ -200,8 +200,8 @@ logging.set_verbosity(logging.INFO)

# load the default configuration, then override it
cfg = get_config()
cfg.mol_cfg.mol = 'H2'
cfg.mol_cfg.basis = '6-31g'
cfg.sys_cfg.mol = 'H2'
cfg.sys_cfg.basis = '6-31g'

# Calculation
e_total, _, _ = incore_cgto_direct_opt(cfg)
Expand All @@ -216,7 +216,7 @@ We have benchmarked the calculation against well known open-sourced quantum chem
To run systems from `refdata` benchmark sets,

``` shell
python main.py --benchmark bh76 --use_f64 --config.mol_cfg.basis <basis> --config.algo_cfg.xc_type <xc> --save --config.mol_cfg.geometry_source refdata --pyscf --config.save_dir <path>
python main.py --benchmark bh76 --use_f64 --config.sys_cfg.basis <basis> --config.method_cfg.xc_type <xc> --save --config.sys_cfg.geometry_source refdata --pyscf --config.save_dir <path>
```

To visualize the run:
Expand Down
91 changes: 72 additions & 19 deletions d4ft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from pathlib import Path
from typing import Literal
from typing import Literal, Union

from ml_collections import ConfigDict
from pydantic.config import ConfigDict as PydanticConfigDict
Expand All @@ -25,6 +25,7 @@
@dataclass(config=pydantic_config)
class GDConfig:
"""Config for direct minimization with gradient descent solver."""
name: Literal["GD"] = "GD"
lr: float = 1e-2
"""learning rate"""
lr_decay: Literal["none", "piecewise", "cosine"] = "none"
Expand All @@ -47,6 +48,7 @@ class GDConfig:
@dataclass(config=pydantic_config)
class SCFConfig:
"""Config for self-consistent field solver."""
name: Literal["SCF"] = "SCF"
momentum: float = 0.5
"""fock matrix update momentum"""
epochs: int = 100
Expand All @@ -68,6 +70,7 @@ class IntorConfig:
@dataclass(config=pydantic_config)
class MoleculeConfig:
"""Config for molecule"""
name: Literal["MOL"] = "MOL"
mol: str = "O2"
"""name of the molecule, or the path to the geometry file, which
specifies the geometry in the format
Expand All @@ -87,52 +90,102 @@ class MoleculeConfig:


@dataclass(config=pydantic_config)
class AlgoConfig:
"""Config for Algorithms."""
algo: Literal["HF", "KS"] = "KS"
"""Which algorithm to use. HF for Hartree-Fock, KS for Kohn-Sham DFT."""
class CrystalConfig:
"""Config for crystal"""
name: Literal["CRY"] = "CRY"
direct_lattice_dim: str = "1x1x1"
"""Dimension of the direct lattice, i.e. the number of k points
(crystal momenta) in each spatial direction. Format is N1xN2xN3."""
reciprocal_lattice_dim: str = "1x1x1"
"""Dimension of the reciprocal lattice, i.e. the number of reciprocal lattice
vectors in each spatial direction. Format is N1xN2xN3."""
energy_cutoff: float = 300.
"""kinetic energy (of G points) cutoff for the plane wave basis set.
Unit is Hartree"""


@dataclass(config=pydantic_config)
class HFConfig:
"""Config for Hartree-Fock theory."""
name: Literal["HF"] = "HF"
restricted: bool = False
"""Whether to run restricted calculation, i.e. enforcing symmetry by using the
same coefficients for both spins"""
rng_seed: int = 137
"""PRNG seed"""


@dataclass(config=pydantic_config)
class KSDFTConfig:
"""Config for Kohn-Sham Density functional theory."""
name: Literal["KS"] = "KS"
xc_type: str = "lda_x"
"""Name of the xc functional to use. To mix two XC functional, use the
syntax a*xc_name_1+b*xc_name_2 where a, b are numbers."""
restricted: bool = False
"""Whether to run restricted calculation, i.e. enforcing symmetry by using the
same coefficients for both spins"""
rng_seed: int = 137
"""PRNG seed"""


class D4FTConfig(ConfigDict):
algo_cfg: AlgoConfig
method_cfg: Union[HFConfig, KSDFTConfig]
"""which QC method to use"""
solver_cfg: Union[GDConfig, SCFConfig]
"""which solver to use"""
intor_cfg: IntorConfig
mol_cfg: MoleculeConfig
gd_cfg: GDConfig
scf_cfg: SCFConfig
"""integration engine config"""
sys_cfg: Union[MoleculeConfig, CrystalConfig]
"""config for the system to simulate"""
uuid: str
save_dir: str

def __init__(self, config_string: str) -> None:
method, solver, sys = config_string.split("-")

if method.lower() == "hf":
method_cls = HFConfig
elif method.lower() == "ks":
method_cls = KSDFTConfig
else:
raise ValueError(f"Unknown method {method}")

if solver.lower() == "gd":
solver_cls = GDConfig
elif solver.lower() == "scf":
solver_cls = SCFConfig
else:
raise ValueError(f"Unknown solver {solver}")

if sys.lower() == "mol":
sys_cls = MoleculeConfig
elif sys.lower() == "crystal":
sys_cls = CrystalConfig
else:
raise ValueError(f"Unknown system {sys}")

super().__init__(
{
"algo_cfg": AlgoConfig(),
"method_cfg": method_cls(),
"solver_cfg": solver_cls(),
"intor_cfg": IntorConfig(),
"mol_cfg": MoleculeConfig(),
"gd_cfg": GDConfig(),
"scf_cfg": SCFConfig(),
"sys_cfg": sys_cls(),
"uuid": "",
"save_dir": "_exp",
}
)

def validate(self, spin: int, charge: int) -> None:
if self.algo_cfg.restricted and self.mol_cfg.mol not in ["bh76_h", "h"]:
if self.method_cfg.restricted and self.sys_cfg.mol not in ["bh76_h", "h"]:
assert spin == 0 and charge == 0, \
"RESTRICTED only supports closed-shell molecules"

def get_save_dir(self) -> Path:
return Path(f"{self.save_dir}/{self.uuid}/{self.mol_cfg.mol}")
return Path(f"{self.save_dir}/{self.uuid}/{self.sys_cfg.mol}")

def get_core_cfg_str(self) -> str:
return "+".join([self.mol_cfg.basis, self.algo_cfg.xc_type])
return "+".join([self.sys_cfg.basis, self.method_cfg.xc_type])

def save(self):
save_path = self.get_save_dir().parent
Expand All @@ -141,12 +194,12 @@ def save(self):
f.write(str(self))


def get_config(config_string: str = "") -> D4FTConfig:
def get_config(config_string: str = "KS-GD-MOL") -> D4FTConfig:
"""Return the default configurations.
Args:
config_string: currently only set the type of algorithm. Available values:
"gd", "scf".
config_string: set the method, solver and sys for the D4FTConfig. Format is
method-solver-sys, and the default is KS-GD-MOL.
NOTE: for distributed setup, might need to move the dataclass definition
into this function.
Expand Down
4 changes: 2 additions & 2 deletions d4ft/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
dft_cgto,
incore_cgto_direct_opt,
incore_cgto_scf,
incore_hf_cgto,
incore_mf_cgto,
)

__all__ = [
"dft_cgto",
"incore_cgto_direct_opt",
"incore_cgto_scf",
"incore_hf_cgto",
"incore_mf_cgto",
]
Loading

0 comments on commit 41ab795

Please sign in to comment.