Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Szkered committed Aug 10, 2023
1 parent 769d940 commit d983ff3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
2 changes: 1 addition & 1 deletion d4ft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MoleculeConfig:
For example H2:
H 0.0000 0.0000 0.0000;
H 0.0000 0.0000 0.7414;"""
basis: str = "sto-3g" #"sto-3g"
basis: str = "sto-3g"
"""name of the atomic basis set"""
spin: int = -1
"""number of unpaired electrons. -1 means all electrons are
Expand Down
23 changes: 12 additions & 11 deletions d4ft/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from enum import Enum

import numpy as np


Expand Down Expand Up @@ -53,24 +54,24 @@ def racah_normalization(l: int):
:math:`R(l)=\sqrt{\frac{4\pi}{2l+1}}`
The solid harmonic :math:`\mathcal{Y}_{lm}=r^l*Y_{lm}` uses the same
normalization, and also the real solid harmonics s_{lm}, since they are obtained via
unitary tranformation from the solid harmonic.
normalization, and also the real solid harmonics s_{lm}, since they are
obtained via unitary tranformation from the solid harmonic.
"""
return np.sqrt((4 * np.pi) / (2 * l + 1))


# TODO: change this to a function that takes l and m as input
# get the R(l) part from racah_normalization, and the other prefactor
# can be tabulated against l and m.
REAL_SOLID_SPH_CART_PREFAC = [ # lm
0.282094791773878143, #0 00: 1/R(0)
0.488602511902919921, #1 1{1,2,3}: 1/R(1)
1.092548430592079070, #2 2{1,2}: 1/R(2) * np.sqrt(3)
0.315391565252520002, #3 2{0}: 1/R(2) * 0.5
0.746352665180230782 / 2, #4 3{0}: 1/R(3) * 0.5
0.590043589926643510, #5 3{3}: 1/R(3) * 0.5 * np.sqrt(5/2)
0.457045799464465739, #6 3{1}: 1/R(3) * 0.5 * np.sqrt(3/2)
1.445305721320277020, #7 3{2}: 1/R(3) * 0.5 * np.sqrt(15)
REAL_SOLID_SPH_CART_PREFAC = [ # i lm
0.282094791773878143, # 0 00: 1/R(0)
0.488602511902919921, # 1 1{1,2,3}: 1/R(1)
1.092548430592079070, # 2 2{1,2}: 1/R(2) * np.sqrt(3)
0.315391565252520002, # 3 2{0}: 1/R(2) * 0.5
0.746352665180230782 / 2, # 4 3{0}: 1/R(3) * 0.5
0.590043589926643510, # 5 3{3}: 1/R(3) * 0.5 * np.sqrt(5/2)
0.457045799464465739, # 6 3{1}: 1/R(3) * 0.5 * np.sqrt(3/2)
1.445305721320277020, # 7 3{2}: 1/R(3) * 0.5 * np.sqrt(15)
]
r"""The prefactor of the real solid harmonics under Cartesian coordinate,
mulitplied by inverse of Racah's normalization for total angular momentum of
Expand Down
16 changes: 9 additions & 7 deletions d4ft/integral/gto/cgto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations # forward declaration

import math
from typing import Callable, NamedTuple, Optional, Tuple, Union, Sequence
from typing import Callable, NamedTuple, Optional, Sequence, Tuple, Union

import haiku as hk
import jax
Expand All @@ -25,8 +25,8 @@
from jaxtyping import Array, Float, Int

from d4ft.constants import (
SHELL_TO_ANGULAR_VEC,
REAL_SOLID_SPH_CART_PREFAC,
SHELL_TO_ANGULAR_VEC,
Shell,
)
from d4ft.system.mol import Mol
Expand Down Expand Up @@ -123,9 +123,9 @@ def build_cgto_from_mol(mol: Mol) -> CGTO:
Reference:
- https://pyscf.org/user/gto.html#basis-set
- https://theochem.github.io/horton/2.0.1/tech_ref_gaussian_basis.html,
- https://onlinelibrary.wiley.com/iucr/itc/Bb/ch1o2v0001/table1o2o7o1/,
- https://github.com/sunqm/libcint/blob/747d6c0dd838d20abdc9a4c9e4c62d196a855bc0/src/cart2sph.c
- https://theochem.github.io/horton/2.0.1/tech_ref_gaussian_basis.html
- https://onlinelibrary.wiley.com/iucr/itc/Bb/ch1o2v0001/table1o2o7o1/
- https://github.com/sunqm/libcint/blob/master/src/cart2sph.c
Returns:
all translated GTOs.
Expand Down Expand Up @@ -211,7 +211,8 @@ def build_cgto_sph_from_mol(cgto_cart: CGTO) -> CGTO:

# s shell: same as cartesian
if shell == 0:
# TODO: replace this with cgto_cart.pgto.at(slice(cgto_ptr, cgto_ptr+n_pgtos))
# TODO: replace this with
# cgto_cart.pgto.at(slice(cgto_ptr, cgto_ptr+n_pgtos))
for j in range(n_pgtos):
pgto.append(
(
Expand Down Expand Up @@ -635,7 +636,8 @@ class PGTO(NamedTuple):
.. math::
PGTO_{nlm}(\vb{r})
=N_n(r_x-c_x)^{n_x} (r_y-c_y)^{n_y} (r_z-c_z)^{n_z} \exp(-\alpha \norm{\vb{r}-\vb{c}}^2)
=N_n(r_x-c_x)^{n_x} (r_y-c_y)^{n_y} (r_z-c_z)^{n_z}
\exp(-\alpha \norm{\vb{r}-\vb{c}}^2)
where N is the normalization factor
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,6 @@

# -- Options for reading type annotation
napoleon_google_docstring = True
napoleon_numpy_docstring = False # set this to True if you also use numpy-style docstrings
# set this to True if you also use numpy-style docstrings
napoleon_numpy_docstring = False
napoleon_include_init_with_doc = True

0 comments on commit d983ff3

Please sign in to comment.