Skip to content

Commit

Permalink
fix: dataclass mutable default
Browse files Browse the repository at this point in the history
Use `field(default_factory=...)` for function calls in dataclasses.
  • Loading branch information
LoganAMorrison committed Apr 17, 2024
1 parent 1308b14 commit e78f980
Show file tree
Hide file tree
Showing 21 changed files with 476 additions and 550 deletions.
23 changes: 16 additions & 7 deletions examples/vector_mediator_gev/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from hazma.vector_mediator import KineticMixingGeV


def photon_spectra(dm_mass: float, vm_mass: float, n: int):
def photon_spectra(dm_mass: float, vm_mass: float, n: int) -> None:
"""Generate the photon spectrum."""

model = KineticMixingGeV(
mx=dm_mass,
mv=vm_mass,
Expand Down Expand Up @@ -40,9 +39,8 @@ def photon_spectra(dm_mass: float, vm_mass: float, n: int):
plt.show()


def positron_spectra(dm_mass: float, vm_mass: float, n: int):
def positron_spectra(dm_mass: float, vm_mass: float, n: int) -> None:
"""Generate the photon spectrum."""

model = KineticMixingGeV(
mx=dm_mass,
mv=vm_mass,
Expand Down Expand Up @@ -72,9 +70,20 @@ def positron_spectra(dm_mass: float, vm_mass: float, n: int):
plt.show()


def neutrino_spectra(dm_mass: float, vm_mass: float, n: int, flavor: str):
"""Generate the neutrino spectrum."""

def neutrino_spectra(dm_mass: float, vm_mass: float, n: int, flavor: str) -> None:
"""Generate the neutrino spectrum.
Parameters
----------
dm_mass: float
Mass of the dark matter in MeV.
vm_mass: float
Mass of the vector mediator in MeV.
n: int
Number of photon energies between the minimum and maximum energies.
flavor: str
The neutrino flavor.
"""
model = KineticMixingGeV(
mx=dm_mass,
mv=vm_mass,
Expand Down
84 changes: 44 additions & 40 deletions hazma/form_factors/vector/_eta_gamma.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""
Module for computing the form factor V-eta-gamma.
"""
"""Module for computing the form factor V-eta-gamma."""

# pylint: disable=invalid-name

from dataclasses import InitVar, dataclass, field
from typing import Tuple, Union, overload
from typing import overload

import numpy as np

Expand All @@ -17,23 +15,31 @@
from ._utils import MPI_GEV, ComplexArray, RealArray


@dataclass(frozen=True)
@dataclass
class VectorFormFactorEtaGammaFitData:
r"""Storage class for the eta-photon form-factor."""

masses: RealArray = np.array([0.77526, 0.78284, 1.01952, 1.465, 1.70])
widths: RealArray = np.array([0.1491, 0.00868, 0.00421, 0.40, 0.30])
amps: RealArray = np.array([0.0861, 0.00824, 0.0158, 0.0147, 0.0])
phases: RealArray = np.array([0.0, 11.3, 170.0, 61.0, 0.0]) * np.pi / 180.0
masses: RealArray = field(
default_factory=lambda: np.array([0.77526, 0.78284, 1.01952, 1.465, 1.70])
)
widths: RealArray = field(
default_factory=lambda: np.array([0.1491, 0.00868, 0.00421, 0.40, 0.30])
)
amps: RealArray = field(
default_factory=lambda: np.array([0.0861, 0.00824, 0.0158, 0.0147, 0.0])
)
phases: RealArray = field(
default_factory=lambda: np.array([0.0, 11.3, 170.0, 61.0, 0.0]) * np.pi / 180.0
)


@dataclass
class VectorFormFactorEtaGamma(VectorFormFactorPA):
r""" "Class for computing the eta-photon form factor.
"""Class for computing the eta-photon form factor.
Attributes
----------
fsp_masses: Tuple[float]
fsp_masses: tuple[float, float]
Final state particle masses (only eta in this case.)
fit_data: VectorFormFactorEtaGammaFitData
Fit information used to compute the form-factor.
Expand All @@ -50,25 +56,31 @@ class VectorFormFactorEtaGamma(VectorFormFactorPA):
cross_section
Compute the dark matter annihilation cross section into an eta and
photon.
"""
fsp_masses: Tuple[float, float] = (parameters.eta_mass, 0.0)

fsp_masses: tuple[float, float] = (parameters.eta_mass, 0.0)
fit_data: VectorFormFactorEtaGammaFitData = field(init=False)

masses: InitVar[RealArray] = np.array([0.77526, 0.78284, 1.01952, 1.465, 1.70])
widths: InitVar[RealArray] = np.array([0.1491, 0.00868, 0.00421, 0.40, 0.30])
amps: InitVar[RealArray] = np.array([0.0861, 0.00824, 0.0158, 0.0147, 0.0])
phases: InitVar[RealArray] = np.array([0.0, 11.3, 170.0, 61.0, 0.0]) * np.pi / 180.0
masses: InitVar[RealArray] = field(
default=np.array([0.77526, 0.78284, 1.01952, 1.465, 1.70])
)
widths: InitVar[RealArray] = field(
default=np.array([0.1491, 0.00868, 0.00421, 0.40, 0.30])
)
amps: InitVar[RealArray] = field(
default=np.array([0.0861, 0.00824, 0.0158, 0.0147, 0.0])
)
phases: InitVar[RealArray] = field(
default=np.array([0.0, 11.3, 170.0, 61.0, 0.0]) * np.pi / 180.0
)

def __post_init__(self, masses, widths, amps, phases):
self.fit_data = VectorFormFactorEtaGammaFitData(
masses=masses, widths=widths, amps=amps, phases=phases
)

def __form_factor(self, s: RealArray, couplings: Couplings) -> ComplexArray:
"""
Compute the form factor for V-eta-gamma at given squared center of mass
energ(ies).
"""Compute the form factor for V-eta-gamma at given squared center of mass energ(ies).
Parameters
----------
Expand Down Expand Up @@ -121,18 +133,16 @@ def __form_factor(self, s: RealArray, couplings: Couplings) -> ComplexArray:
@overload
def form_factor( # pylint: disable=arguments-differ
self, *, q: float, couplings: Couplings
) -> complex:
...
) -> complex: ...

@overload
def form_factor( # pylint: disable=arguments-differ
self, *, q: RealArray, couplings: Couplings
) -> ComplexArray:
...
) -> ComplexArray: ...

def form_factor( # pylint: disable=arguments-differ
self, *, q: Union[float, RealArray], couplings: Couplings
) -> Union[complex, ComplexArray]:
self, *, q: float | RealArray, couplings: Couplings
) -> complex | ComplexArray:
r"""Compute the eta-photon form factor.
Parameters
Expand Down Expand Up @@ -166,17 +176,15 @@ def form_factor( # pylint: disable=arguments-differ
@overload
def integrated_form_factor( # pylint: disable=arguments-differ
self, q: float, couplings: Couplings
) -> float:
...
) -> float: ...

@overload
def integrated_form_factor( # pylint: disable=arguments-differ
self, q: RealArray, couplings: Couplings
) -> RealArray:
...
) -> RealArray: ...

def integrated_form_factor( # pylint: disable=arguments-differ
self, q: Union[float, RealArray], couplings: Couplings
self, q: float | RealArray, couplings: Couplings
) -> RealOrRealArray:
r"""Compute the eta-photon form-factor integrated over phase-space.
Expand All @@ -201,17 +209,15 @@ def integrated_form_factor( # pylint: disable=arguments-differ
@overload
def width( # pylint: disable=arguments-differ
self, mv: float, couplings: Couplings
) -> float:
...
) -> float: ...

@overload
def width( # pylint: disable=arguments-differ
self, mv: RealArray, couplings: Couplings
) -> RealArray:
...
) -> RealArray: ...

def width( # pylint: disable=arguments-differ
self, mv: Union[float, RealArray], couplings: Couplings
self, mv: float | RealArray, couplings: Couplings
) -> RealOrRealArray:
r"""Compute the partial decay width of a massive vector into an eta and
photon.
Expand Down Expand Up @@ -243,8 +249,7 @@ def cross_section( # pylint: disable=arguments-differ,too-many-arguments
gvxx: float,
wv: float,
couplings: Couplings,
) -> float:
...
) -> float: ...

@overload
def cross_section( # pylint: disable=arguments-differ,too-many-arguments
Expand All @@ -255,8 +260,7 @@ def cross_section( # pylint: disable=arguments-differ,too-many-arguments
gvxx: float,
wv: float,
couplings: Couplings,
) -> RealArray:
...
) -> RealArray: ...

def cross_section( # pylint: disable=arguments-differ,too-many-arguments
self,
Expand Down
49 changes: 21 additions & 28 deletions hazma/form_factors/vector/_eta_omega.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,26 @@
"""

from dataclasses import InitVar, dataclass, field
from typing import overload, Tuple
from typing import overload

import numpy as np

from hazma import parameters
from hazma.utils import RealOrRealArray, ComplexOrComplexArray, ComplexArray, RealArray
from hazma.utils import ComplexArray, ComplexOrComplexArray, RealArray, RealOrRealArray

from ._utils import breit_wigner_fw
from ._two_body import VectorFormFactorPV, Couplings
from ._base import vector_couplings_to_isospin
from ._two_body import Couplings, VectorFormFactorPV
from ._utils import breit_wigner_fw

META = parameters.eta_mass
MOMEGA = parameters.omega_mass


@dataclass(frozen=True)
@dataclass
class VectorFormFactorEtaOmegaFitData:
r"""Storage class for the eta-omega vector form-factor. See arXiv:1911.11147
for details on the default values.
"""Storage class for the eta-omega vector form-factor.
See arXiv:1911.11147 for details on the default values.
"""

# w', w''' parameters
Expand Down Expand Up @@ -56,14 +57,14 @@ class VectorFormFactorEtaOmega(VectorFormFactorPV):
omega.
"""

fsp_masses: Tuple[float, float] = field(init=False, default=(META, MOMEGA))
fsp_masses: tuple[float, float] = field(init=False, default=(META, MOMEGA))
fit_data: VectorFormFactorEtaOmegaFitData = field(init=False)

# w', w''' parameters
masses: InitVar[RealArray] = np.array([1.43, 1.67])
widths: InitVar[RealArray] = np.array([0.215, 0.113])
amps: InitVar[RealArray] = np.array([0.0862, 0.0648])
phases: InitVar[RealArray] = np.exp(1j * np.array([0.0, np.pi]))
masses: InitVar[RealArray] = field(default=np.array([1.43, 1.67]))
widths: InitVar[RealArray] = field(default=np.array([0.215, 0.113]))
amps: InitVar[RealArray] = field(default=np.array([0.0862, 0.0648]))
phases: InitVar[RealArray] = field(default=np.exp(1j * np.array([0.0, np.pi])))

def __post_init__(
self,
Expand Down Expand Up @@ -111,14 +112,12 @@ def __form_factor(self, *, s: RealArray, couplings: Couplings):
@overload
def form_factor( # pylint: disable=arguments-differ
self, q: float, couplings: Couplings
) -> complex:
...
) -> complex: ...

@overload
def form_factor( # pylint: disable=arguments-differ
self, q: RealArray, couplings: Couplings
) -> ComplexArray:
...
) -> ComplexArray: ...

def form_factor( # pylint: disable=arguments-differ
self, q: RealOrRealArray, couplings: Couplings
Expand Down Expand Up @@ -154,14 +153,12 @@ def form_factor( # pylint: disable=arguments-differ
@overload
def integrated_form_factor( # pylint: disable=arguments-differ
self, q: float, couplings: Couplings
) -> float:
...
) -> float: ...

@overload
def integrated_form_factor( # pylint: disable=arguments-differ
self, q: RealArray, couplings: Couplings
) -> RealArray:
...
) -> RealArray: ...

def integrated_form_factor( # pylint: disable=arguments-differ
self, q: RealOrRealArray, couplings: Couplings
Expand All @@ -185,14 +182,12 @@ def integrated_form_factor( # pylint: disable=arguments-differ
@overload
def width( # pylint: disable=arguments-differ
self, mv: float, couplings: Couplings
) -> float:
...
) -> float: ...

@overload
def width( # pylint: disable=arguments-differ
self, mv: RealArray, couplings: Couplings
) -> RealArray:
...
) -> RealArray: ...

def width( # pylint: disable=arguments-differ
self, mv: RealOrRealArray, couplings: Couplings
Expand Down Expand Up @@ -223,8 +218,7 @@ def cross_section( # pylint: disable=arguments-differ,too-many-arguments
gvxx: float,
wv: float,
couplings: Couplings,
) -> float:
...
) -> float: ...

@overload
def cross_section( # pylint: disable=arguments-differ,too-many-arguments
Expand All @@ -235,8 +229,7 @@ def cross_section( # pylint: disable=arguments-differ,too-many-arguments
gvxx: float,
wv: float,
couplings: Couplings,
) -> RealArray:
...
) -> RealArray: ...

def cross_section( # pylint: disable=arguments-differ,too-many-arguments
self,
Expand Down
Loading

0 comments on commit e78f980

Please sign in to comment.