Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Suggestion] : Remove strict requirement on backends. #229

Open
ASKabalan opened this issue Sep 27, 2024 · 7 comments
Open

[Suggestion] : Remove strict requirement on backends. #229

ASKabalan opened this issue Sep 27, 2024 · 7 comments
Labels

Comments

@ASKabalan
Copy link
Collaborator

Currently, we have to install pytorch to use the JAX backend of s2fft.

I think it would be a nice to be able to conditionally activate backends depending on the availability of the package

For example I user can use s2FFT using only numpy if he does not have JAX or pytorch
Otherwise he gets a runtime error instead of an import time error

Same for pyssht.
Don't know the best practice to do this, but a try catch around the import + global boolean should do the trick

@jasonmcewen
Copy link
Contributor

Thanks for the suggestion @ASKabalan! At present we only have PyTorch support for the precompute approach but we do plan to add PyTorch support for the on-the-fly transforms as well. So it would be good to have a nice setup like you're suggesting.

@matt-graham do you have any thoughts on this or know of best practices?

@matt-graham
Copy link
Collaborator

Hi @jasonmcewen. This is related to what I suggested in #224 - we can easily make all of PyTorch, pyssht and healpy optional dependencies and guard the imports from these packages within try: ... except ImportError: ... logic, either at a module level or within the relevant functions themselves, and raise an informative error if a user tries to use a function for which the relevant optional dependencies are not available.

I would say dropping hard requirements on all external packages which are not 'core' to the package would be good practice for a library like s2fft as it avoids forcing users who only want to rely on a subset of the functionality to install unnecessary dependencies, and as I mentioned in #224 extends the range of systems that users can install the package on. The overhead for those users who wish to use the optional features is minimal as we can add relevant optional dependencies groups so that for example they could just do pip install s2fft[all] to install all optional dependencies, or pip install s2fft[torch] to just install extra dependencies required for PyTorch support and so on.

@matt-graham
Copy link
Collaborator

matt-graham commented Oct 1, 2024

Just noticed suggestion to also apply this to JAX - while we could do this for JAX too, as a lot of the modules (perhaps the majority?) make multiple imports from jax package namespace, this would potentially get a bit unwieldy to guard all these imports. What might help with reducing our explicit imports from JAX and PyTorch APIs (and duplication of functions) is to use their support for the array API, by doing something like

import array_api_compat

...

def spectral_periodic_extension(fm, L: int):
    xp = array_api_compat.array_namespace(fm)
    nphi = fm.shape[0]
    return xnp.concatenate(
        (
            fm[-xnp.arange(L - nphi // 2, 0, -1) % nphi],
            fm,
            fm[xnp.arange(L - (nphi + 1) // 2) % nphi],
        )
    )

This would make the function compatible with all of NumPy, JAX and PyTorch arrays without requiring an explicit import from any of them and also avoid having multiple _jax, _torch variants of functions.

@jasonmcewen
Copy link
Contributor

Hi @jasonmcewen. This is related to what I suggested in #224 - we can easily make all of PyTorch, pyssht and healpy optional dependencies and guard the imports from these packages within try: ... except ImportError: ... logic, either at a module level or within the relevant functions themselves, and raise an informative error if a user tries to use a function for which the relevant optional dependencies are not available.

I would say dropping hard requirements on all external packages which are not 'core' to the package would be good practice for a library like s2fft as it avoids forcing users who only want to rely on a subset of the functionality to install unnecessary dependencies, and as I mentioned in #224 extends the range of systems that users can install the package on. The overhead for those users who wish to use the optional features is minimal as we can add relevant optional dependencies groups so that for example they could just do pip install s2fft[all] to install all optional dependencies, or pip install s2fft[torch] to just install extra dependencies required for PyTorch support and so on.

Ok, fair enough @matt-graham ! Does this mean users would need to install from source rather than PyPi tho? Perhaps that is not a big issue anyway.

@jasonmcewen
Copy link
Contributor

Just noticed suggestion to also apply this to JAX - while we could do this for JAX too, as a lot of the modules (perhaps the majority?) make multiple imports from jax package namespace, this would potentially get a bit unwieldy to guard all these imports. What might help with reducing our explicit imports from JAX and PyTorch APIs (and duplication of functions) is to use their support for the array API, by doing something like

import array_api_compat

...

def spectral_periodic_extension(fm, L: int):
    xp = array_api_compat.array_namespace(fm)
    nphi = fm.shape[0]
    return xnp.concatenate(
        (
            fm[-xnp.arange(L - nphi // 2, 0, -1) % nphi],
            fm,
            fm[xnp.arange(L - (nphi + 1) // 2) % nphi],
        )
    )

This would make the function compatible with all of NumPy, JAX and PyTorch arrays without requiring an explicit import from any of them and also avoid having multiple _jax, _torch variants of functions.

Maybe. This seems like quite a big revision tho. Let's discuss further...

@matt-graham
Copy link
Collaborator

Ok, fair enough @matt-graham ! Does this mean users would need to install from source rather than PyPi tho? Perhaps that is not a big issue anyway.

No, the extras / optional dependencies syntax works fine for both packages installed from a local directory and a packaging index like PyPI. It's the same syntax that JAX uses to install from PyPI with CUDA support:

pip install jax[cuda12]

@jasonmcewen
Copy link
Contributor

Ok, sounds like it makes a lot of sense then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants