From 8a33fa14226aef1f9364abdb1c1c14cb830c3f19 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 5 Oct 2023 18:49:18 -0400 Subject: [PATCH] adding to methods to all classes --- lighthouse/SSP/basic_ssp.py | 5 +++++ .../initial_mass_function/initial_mass_function.py | 6 +++++- lighthouse/initial_mass_function/kroupa.py | 4 ++++ lighthouse/isochrone/MIST_Isochrone.py | 5 +++++ lighthouse/isochrone/isochrone.py | 6 +++++- .../polynomial_evaluator.py | 12 ++++++++++-- .../stellar_atmosphere_spectrum.py | 5 ++++- 7 files changed, 38 insertions(+), 5 deletions(-) diff --git a/lighthouse/SSP/basic_ssp.py b/lighthouse/SSP/basic_ssp.py index c71df06..7422d3c 100644 --- a/lighthouse/SSP/basic_ssp.py +++ b/lighthouse/SSP/basic_ssp.py @@ -61,6 +61,11 @@ def forward(self, metalicity, Tage, alpha) -> torch.Tensor: return spectrum + def to(self, dtype=None, device=None): + self.isochrone.to(dtype=dtype, device=device) + self.imf.to(dtype=dtype, device=device) + self.sas.to(dtype=dtype, device=device) + if __name__ == "__main__": from isochrone import MIST from initial_mass_function import Kroupa diff --git a/lighthouse/initial_mass_function/initial_mass_function.py b/lighthouse/initial_mass_function/initial_mass_function.py index 6ce1552..16dc14d 100644 --- a/lighthouse/initial_mass_function/initial_mass_function.py +++ b/lighthouse/initial_mass_function/initial_mass_function.py @@ -7,4 +7,8 @@ class Initial_Mass_Function(ABC): @abstractmethod def get_weight(self, mass) -> Tensor: - pass + ... + + @abstractmethod + def to(self, dtype=None, device=None): + ... diff --git a/lighthouse/initial_mass_function/kroupa.py b/lighthouse/initial_mass_function/kroupa.py index b0cbd72..5ada94c 100644 --- a/lighthouse/initial_mass_function/kroupa.py +++ b/lighthouse/initial_mass_function/kroupa.py @@ -20,6 +20,10 @@ def get_weight(self, mass, alpha) -> torch.Tensor: return weight + def to(self, dtype=None, device=None): + pass + + if __name__ == "__main__": import matplotlib.pyplot as plt diff --git a/lighthouse/isochrone/MIST_Isochrone.py b/lighthouse/isochrone/MIST_Isochrone.py index dbf29a7..c722489 100644 --- a/lighthouse/isochrone/MIST_Isochrone.py +++ b/lighthouse/isochrone/MIST_Isochrone.py @@ -34,6 +34,11 @@ def get_isochrone(self, metallicity, age, *args, low_m_limit = 0.08, high_m_limi return dict((p, isochrone[i]) for i, p in enumerate(self.param_order)) + def to(self, dtype=None, device=None): + self.isochrone_grid.to(dtype=dtype, device=device) + self.metallicities.to(dtype=dtype, device=device) + self.ages.to(dtype=dtype, device=device) + if __name__=='__main__': test = MIST() diff --git a/lighthouse/isochrone/isochrone.py b/lighthouse/isochrone/isochrone.py index 7baf7b8..b434fb4 100644 --- a/lighthouse/isochrone/isochrone.py +++ b/lighthouse/isochrone/isochrone.py @@ -6,4 +6,8 @@ class Isochrone(ABC): @abstractmethod def get_isochrone(self, metalicity, Tage, *args, low_m_limit = 0.08, high_m_limit = 100) -> dict: - pass #phase, stellar_mass, Teff, logg, logL + ... #phase, stellar_mass, Teff, logg, logL + + @abstractmethod + def to(self, dtype=None, device=None): + ... diff --git a/lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py b/lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py index 3fc30da..6530ac6 100644 --- a/lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py +++ b/lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py @@ -38,9 +38,8 @@ def __init__(self): self.wavelength = torch.tensor(coeffs.to_numpy()[:,0], dtype = torch.float64) self.reference[name] = torch.tensor(coeffs.to_numpy()[:,1], dtype = torch.float64) self.coefficients[name] = torch.tensor(coeffs.to_numpy()[:,2:], dtype = torch.float64) - + def get_spectrum(self, teff, logg, feh) -> torch.Tensor: - """ Setting up some boundaries """ @@ -76,6 +75,15 @@ def get_spectrum(self, teff, logg, feh) -> torch.Tensor: return flux + def to(self, dtype=None, device=None): + self.wavelength.to(dtype=dtype, device=device) + for key in self.polynomial_powers: + self.polynomial_powers[key].to(dtype=dtype, device=device) + self.bounds.to(dtype=dtype, device=device) + for name in self.reference: + self.reference[name].to(dtype=dtype, device=device) + self.coefficients[name].to(dtype=dtype, device=device) + if __name__ == "__main__": diff --git a/lighthouse/stellar_atmosphere_spectrum/stellar_atmosphere_spectrum.py b/lighthouse/stellar_atmosphere_spectrum/stellar_atmosphere_spectrum.py index ad3937c..d6aaca3 100644 --- a/lighthouse/stellar_atmosphere_spectrum/stellar_atmosphere_spectrum.py +++ b/lighthouse/stellar_atmosphere_spectrum/stellar_atmosphere_spectrum.py @@ -6,6 +6,9 @@ class Stellar_Atmosphere_Spectrum(ABC): @abstractmethod def get_spectrum(self, logg, Z, Teff) -> Tensor: - pass + ... + @abstractmethod + def to(self, dtype=None, device=None): + ...