Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

black format and add GPU precompute kernels #166

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion s2fft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import logging
from jax.config import config

if config.read("jax_enable_x64") is False:
logger = logging.getLogger("s2fft")
logger.warning("JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.")
logger.warning(
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L."
)
61 changes: 9 additions & 52 deletions s2fft/base_transforms/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,18 +297,15 @@ def _compute_inverse_direct(
f = np.zeros(samples.f_shape(L, sampling, nside), dtype=np.complex128)

for t, theta in enumerate(thetas):

if sampling.lower() == "healpix":
phis_ring = samples.phis_ring(t, nside)

for el in range(max(L_lower, abs(spin)), L):

dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)

elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))

for p, phi in enumerate(phis_ring):

if sampling.lower() != "healpix":
entry = (t, p)

Expand Down Expand Up @@ -459,22 +456,17 @@ def _compute_inverse_sov_fft(
m_offset = 1 if sampling in ["mwss", "healpix"] else 0

for t, theta in enumerate(thetas):

phi_ring_offset = (
samples.p2phi_ring(t, 0, nside)
if sampling.lower() == "healpix"
else 0
samples.p2phi_ring(t, 0, nside) if sampling.lower() == "healpix" else 0
)

for el in range(max(L_lower, abs(spin)), L):

dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)

elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))

m_start_ind = 0 if reality else -el
for m in range(m_start_ind, el + 1):

phase_shift = (
np.exp(1j * m * phi_ring_offset)
if sampling.lower() == "healpix"
Expand Down Expand Up @@ -506,9 +498,7 @@ def _compute_inverse_sov_fft(
norm="forward",
)
else:
f = np.fft.ifft(
np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward"
)
f = np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward")

return f

Expand Down Expand Up @@ -554,24 +544,17 @@ def _compute_inverse_sov_fft_vectorized(
m_offset = 1 if sampling in ["mwss", "healpix"] else 0

for t, theta in enumerate(thetas):

phase_shift = (
samples.ring_phase_shift_hp(L, t, nside, False, reality)
if sampling.lower() == "healpix"
else 1.0
)

for el in range(max(L_lower, abs(spin)), L):

dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)
elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
m_start_ind = L - 1 if reality else 0
val = (
elfactor
* dl[m_start_ind:]
* flm[el, m_start_ind:]
* phase_shift
)
val = elfactor * dl[m_start_ind:] * flm[el, m_start_ind:] * phase_shift
if reality and sampling.lower() == "healpix":
ftm[t, m_offset : L - 1 + m_offset] += np.flip(np.conj(val[1:]))

Expand All @@ -589,9 +572,7 @@ def _compute_inverse_sov_fft_vectorized(
norm="forward",
)
else:
f = np.fft.ifft(
np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward"
)
f = np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward")

return f

Expand Down Expand Up @@ -641,30 +622,23 @@ def _compute_forward_direct(
phis_ring = samples.phis_equiang(L, sampling)

for t, theta in enumerate(thetas):

if sampling.lower() == "healpix":
phis_ring = samples.phis_ring(t, nside)

for el in range(max(L_lower, abs(spin)), L):

dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)

elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))

for p, phi in enumerate(phis_ring):

if sampling.lower() != "healpix":
entry = (t, p)
else:
entry = samples.hp_ang2pix(nside, theta, phi)

if reality:
flm[el, L - 1] += (
weights[t]
* (-1) ** spin
* elfactor
* dl[L - 1]
* f[entry]
weights[t] * (-1) ** spin * elfactor * dl[L - 1] * f[entry]
) # m = 0
for m in range(1, el + 1):
val = (
Expand Down Expand Up @@ -738,7 +712,6 @@ def _compute_forward_sov(

ftm = np.zeros((len(thetas), 2 * L - 1), dtype=np.complex128)
for t, theta in enumerate(thetas):

if sampling.lower() == "healpix":
phis_ring = samples.phis_ring(t, nside)

Expand All @@ -755,20 +728,14 @@ def _compute_forward_sov(
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)

for t, theta in enumerate(thetas):

for el in range(max(L_lower, abs(spin)), L):

dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)

elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))

if reality:
flm[el, L - 1] += (
weights[t]
* (-1) ** spin
* elfactor
* dl[L - 1]
* ftm[t, L - 1]
weights[t] * (-1) ** spin * elfactor * dl[L - 1] * ftm[t, L - 1]
) # m = 0
for m in range(1, el + 1):
val = (
Expand Down Expand Up @@ -852,20 +819,14 @@ def _compute_forward_sov_fft(
ftm_temp = ftm_temp[:, :-1]
ftm[:, L - 1 + m_offset :] = ftm_temp
else:
ftm = np.fft.fftshift(
np.fft.fft(f, axis=1, norm="backward"), axes=1
)
ftm = np.fft.fftshift(np.fft.fft(f, axis=1, norm="backward"), axes=1)

for t, theta in enumerate(thetas):

phi_ring_offset = (
samples.p2phi_ring(t, 0, nside)
if sampling.lower() == "healpix"
else 0
samples.p2phi_ring(t, 0, nside) if sampling.lower() == "healpix" else 0
)

for el in range(max(L_lower, abs(spin)), L):

dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)

elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
Expand Down Expand Up @@ -974,20 +935,16 @@ def _compute_forward_sov_fft_vectorized(
t = t[:, :-1]
ftm[:, L - 1 + m_offset :] = t
else:
ftm = np.fft.fftshift(
np.fft.fft(f, axis=1, norm="backward"), axes=1
)
ftm = np.fft.fftshift(np.fft.fft(f, axis=1, norm="backward"), axes=1)

for t, theta in enumerate(thetas):

phase_shift = (
samples.ring_phase_shift_hp(L, t, nside, True, reality)
if sampling.lower() == "healpix"
else 1.0
)

for el in range(max(L_lower, abs(spin)), L):

dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)

elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
Expand Down
4 changes: 1 addition & 3 deletions s2fft/base_transforms/wigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def inverse(
if reality:
f = np.fft.irfft(fban[N - 1 :], 2 * N - 1, axis=ax, norm="forward")
else:
f = np.fft.ifft(
np.fft.ifftshift(fban, axes=ax), axis=ax, norm="forward"
)
f = np.fft.ifft(np.fft.ifftshift(fban, axes=ax), axis=ax, norm="forward")

return f

Expand Down
Loading
Loading