Skip to content

Commit

Permalink
Merge pull request #33 from Ciela-Institute/tomethod
Browse files Browse the repository at this point in the history
adding "to" methods to all classes
  • Loading branch information
AlexaVillaume authored Nov 22, 2023
2 parents 20d4de0 + 8a33fa1 commit 334bc78
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 5 deletions.
5 changes: 5 additions & 0 deletions lighthouse/SSP/basic_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def forward(self, metalicity, Tage, alpha) -> torch.Tensor:

return spectrum

def to(self, dtype=None, device=None):
self.isochrone.to(dtype=dtype, device=device)
self.imf.to(dtype=dtype, device=device)
self.sas.to(dtype=dtype, device=device)

if __name__ == "__main__":
from isochrone import MIST
from initial_mass_function import Kroupa
Expand Down
6 changes: 5 additions & 1 deletion lighthouse/initial_mass_function/initial_mass_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ class Initial_Mass_Function(ABC):

@abstractmethod
def get_weight(self, mass) -> Tensor:
pass
...

@abstractmethod
def to(self, dtype=None, device=None):
...
4 changes: 4 additions & 0 deletions lighthouse/initial_mass_function/kroupa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def get_weight(self, mass, alpha) -> torch.Tensor:

return weight

def to(self, dtype=None, device=None):
pass


if __name__ == "__main__":

import matplotlib.pyplot as plt
Expand Down
5 changes: 5 additions & 0 deletions lighthouse/isochrone/MIST_Isochrone.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def get_isochrone(self, metallicity, age, *args, low_m_limit = 0.08, high_m_limi

return dict((p, isochrone[i]) for i, p in enumerate(self.param_order))

def to(self, dtype=None, device=None):
self.isochrone_grid.to(dtype=dtype, device=device)
self.metallicities.to(dtype=dtype, device=device)
self.ages.to(dtype=dtype, device=device)

if __name__=='__main__':
test = MIST()

Expand Down
6 changes: 5 additions & 1 deletion lighthouse/isochrone/isochrone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@ class Isochrone(ABC):

@abstractmethod
def get_isochrone(self, metalicity, Tage, *args, low_m_limit = 0.08, high_m_limit = 100) -> dict:
pass #phase, stellar_mass, Teff, logg, logL
... #phase, stellar_mass, Teff, logg, logL

@abstractmethod
def to(self, dtype=None, device=None):
...
12 changes: 10 additions & 2 deletions lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ def __init__(self):
self.wavelength = torch.tensor(coeffs.to_numpy()[:,0], dtype = torch.float64)
self.reference[name] = torch.tensor(coeffs.to_numpy()[:,1], dtype = torch.float64)
self.coefficients[name] = torch.tensor(coeffs.to_numpy()[:,2:], dtype = torch.float64)

def get_spectrum(self, teff, logg, feh) -> torch.Tensor:

"""
Setting up some boundaries
"""
Expand Down Expand Up @@ -76,6 +75,15 @@ def get_spectrum(self, teff, logg, feh) -> torch.Tensor:

return flux

def to(self, dtype=None, device=None):
self.wavelength.to(dtype=dtype, device=device)
for key in self.polynomial_powers:
self.polynomial_powers[key].to(dtype=dtype, device=device)
self.bounds.to(dtype=dtype, device=device)
for name in self.reference:
self.reference[name].to(dtype=dtype, device=device)
self.coefficients[name].to(dtype=dtype, device=device)



if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
class Stellar_Atmosphere_Spectrum(ABC):
@abstractmethod
def get_spectrum(self, logg, Z, Teff) -> Tensor:
pass
...

@abstractmethod
def to(self, dtype=None, device=None):
...

0 comments on commit 334bc78

Please sign in to comment.