diff --git a/s2fft/__init__.py b/s2fft/__init__.py index 37e0f1c2..423d3219 100644 --- a/s2fft/__init__.py +++ b/s2fft/__init__.py @@ -10,6 +10,9 @@ import logging from jax.config import config + if config.read("jax_enable_x64") is False: logger = logging.getLogger("s2fft") - logger.warning("JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.") \ No newline at end of file + logger.warning( + "JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L." + ) diff --git a/s2fft/base_transforms/spherical.py b/s2fft/base_transforms/spherical.py index 7b4a4720..4625e7bc 100644 --- a/s2fft/base_transforms/spherical.py +++ b/s2fft/base_transforms/spherical.py @@ -297,18 +297,15 @@ def _compute_inverse_direct( f = np.zeros(samples.f_shape(L, sampling, nside), dtype=np.complex128) for t, theta in enumerate(thetas): - if sampling.lower() == "healpix": phis_ring = samples.phis_ring(t, nside) for el in range(max(L_lower, abs(spin)), L): - dl = recursions.turok.compute_slice(theta, el, L, -spin, reality) elfactor = np.sqrt((2 * el + 1) / (4 * np.pi)) for p, phi in enumerate(phis_ring): - if sampling.lower() != "healpix": entry = (t, p) @@ -459,22 +456,17 @@ def _compute_inverse_sov_fft( m_offset = 1 if sampling in ["mwss", "healpix"] else 0 for t, theta in enumerate(thetas): - phi_ring_offset = ( - samples.p2phi_ring(t, 0, nside) - if sampling.lower() == "healpix" - else 0 + samples.p2phi_ring(t, 0, nside) if sampling.lower() == "healpix" else 0 ) for el in range(max(L_lower, abs(spin)), L): - dl = recursions.turok.compute_slice(theta, el, L, -spin, reality) elfactor = np.sqrt((2 * el + 1) / (4 * np.pi)) m_start_ind = 0 if reality else -el for m in range(m_start_ind, el + 1): - phase_shift = ( np.exp(1j * m * phi_ring_offset) if sampling.lower() == "healpix" @@ -506,9 +498,7 @@ def _compute_inverse_sov_fft( norm="forward", ) else: - f = np.fft.ifft( - np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward" - ) + f = np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward") return f @@ -554,7 +544,6 @@ def _compute_inverse_sov_fft_vectorized( m_offset = 1 if sampling in ["mwss", "healpix"] else 0 for t, theta in enumerate(thetas): - phase_shift = ( samples.ring_phase_shift_hp(L, t, nside, False, reality) if sampling.lower() == "healpix" @@ -562,16 +551,10 @@ def _compute_inverse_sov_fft_vectorized( ) for el in range(max(L_lower, abs(spin)), L): - dl = recursions.turok.compute_slice(theta, el, L, -spin, reality) elfactor = np.sqrt((2 * el + 1) / (4 * np.pi)) m_start_ind = L - 1 if reality else 0 - val = ( - elfactor - * dl[m_start_ind:] - * flm[el, m_start_ind:] - * phase_shift - ) + val = elfactor * dl[m_start_ind:] * flm[el, m_start_ind:] * phase_shift if reality and sampling.lower() == "healpix": ftm[t, m_offset : L - 1 + m_offset] += np.flip(np.conj(val[1:])) @@ -589,9 +572,7 @@ def _compute_inverse_sov_fft_vectorized( norm="forward", ) else: - f = np.fft.ifft( - np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward" - ) + f = np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward") return f @@ -641,18 +622,15 @@ def _compute_forward_direct( phis_ring = samples.phis_equiang(L, sampling) for t, theta in enumerate(thetas): - if sampling.lower() == "healpix": phis_ring = samples.phis_ring(t, nside) for el in range(max(L_lower, abs(spin)), L): - dl = recursions.turok.compute_slice(theta, el, L, -spin, reality) elfactor = np.sqrt((2 * el + 1) / (4 * np.pi)) for p, phi in enumerate(phis_ring): - if sampling.lower() != "healpix": entry = (t, p) else: @@ -660,11 +638,7 @@ def _compute_forward_direct( if reality: flm[el, L - 1] += ( - weights[t] - * (-1) ** spin - * elfactor - * dl[L - 1] - * f[entry] + weights[t] * (-1) ** spin * elfactor * dl[L - 1] * f[entry] ) # m = 0 for m in range(1, el + 1): val = ( @@ -738,7 +712,6 @@ def _compute_forward_sov( ftm = np.zeros((len(thetas), 2 * L - 1), dtype=np.complex128) for t, theta in enumerate(thetas): - if sampling.lower() == "healpix": phis_ring = samples.phis_ring(t, nside) @@ -755,20 +728,14 @@ def _compute_forward_sov( flm = np.zeros(samples.flm_shape(L), dtype=np.complex128) for t, theta in enumerate(thetas): - for el in range(max(L_lower, abs(spin)), L): - dl = recursions.turok.compute_slice(theta, el, L, -spin, reality) elfactor = np.sqrt((2 * el + 1) / (4 * np.pi)) if reality: flm[el, L - 1] += ( - weights[t] - * (-1) ** spin - * elfactor - * dl[L - 1] - * ftm[t, L - 1] + weights[t] * (-1) ** spin * elfactor * dl[L - 1] * ftm[t, L - 1] ) # m = 0 for m in range(1, el + 1): val = ( @@ -852,20 +819,14 @@ def _compute_forward_sov_fft( ftm_temp = ftm_temp[:, :-1] ftm[:, L - 1 + m_offset :] = ftm_temp else: - ftm = np.fft.fftshift( - np.fft.fft(f, axis=1, norm="backward"), axes=1 - ) + ftm = np.fft.fftshift(np.fft.fft(f, axis=1, norm="backward"), axes=1) for t, theta in enumerate(thetas): - phi_ring_offset = ( - samples.p2phi_ring(t, 0, nside) - if sampling.lower() == "healpix" - else 0 + samples.p2phi_ring(t, 0, nside) if sampling.lower() == "healpix" else 0 ) for el in range(max(L_lower, abs(spin)), L): - dl = recursions.turok.compute_slice(theta, el, L, -spin, reality) elfactor = np.sqrt((2 * el + 1) / (4 * np.pi)) @@ -974,12 +935,9 @@ def _compute_forward_sov_fft_vectorized( t = t[:, :-1] ftm[:, L - 1 + m_offset :] = t else: - ftm = np.fft.fftshift( - np.fft.fft(f, axis=1, norm="backward"), axes=1 - ) + ftm = np.fft.fftshift(np.fft.fft(f, axis=1, norm="backward"), axes=1) for t, theta in enumerate(thetas): - phase_shift = ( samples.ring_phase_shift_hp(L, t, nside, True, reality) if sampling.lower() == "healpix" @@ -987,7 +945,6 @@ def _compute_forward_sov_fft_vectorized( ) for el in range(max(L_lower, abs(spin)), L): - dl = recursions.turok.compute_slice(theta, el, L, -spin, reality) elfactor = np.sqrt((2 * el + 1) / (4 * np.pi)) diff --git a/s2fft/base_transforms/wigner.py b/s2fft/base_transforms/wigner.py index f6347790..ba5295bf 100644 --- a/s2fft/base_transforms/wigner.py +++ b/s2fft/base_transforms/wigner.py @@ -76,9 +76,7 @@ def inverse( if reality: f = np.fft.irfft(fban[N - 1 :], 2 * N - 1, axis=ax, norm="forward") else: - f = np.fft.ifft( - np.fft.ifftshift(fban, axes=ax), axis=ax, norm="forward" - ) + f = np.fft.ifft(np.fft.ifftshift(fban, axes=ax), axis=ax, norm="forward") return f diff --git a/s2fft/precompute_transforms/construct.py b/s2fft/precompute_transforms/construct.py index 95d2b136..aaa59fe1 100644 --- a/s2fft/precompute_transforms/construct.py +++ b/s2fft/precompute_transforms/construct.py @@ -1,6 +1,14 @@ +from jax import jit, config + +config.update("jax_enable_x64", True) + import numpy as np +import jax.numpy as jnp +from jax import jit +from functools import partial + from s2fft.sampling import s2_samples as samples -from s2fft.utils import quadrature +from s2fft.utils import quadrature, quadrature_jax from s2fft import recursions from warnings import warn @@ -56,9 +64,9 @@ def spin_spherical_kernel( dl = np.zeros((len(thetas), L, m_dim), dtype=np.float64) for t, theta in enumerate(thetas): for el in range(abs(spin), L): - dl[t, el] = recursions.turok.compute_slice( - theta, el, L, -spin, reality - )[m_start_ind:] + dl[t, el] = recursions.turok.compute_slice(theta, el, L, -spin, reality)[ + m_start_ind: + ] dl[t, el] *= np.sqrt((2 * el + 1) / (4 * np.pi)) if forward: @@ -75,6 +83,84 @@ def spin_spherical_kernel( return dl +def spin_spherical_kernel_jax( + L: int, + spin: int = 0, + reality: bool = False, + sampling: str = "mw", + nside: int = None, + forward: bool = False, +): + r"""Precompute the wigner-d kernel for spin-spherical transform. This can be + drastically faster but comes at a :math:`\mathcal{O}(L^3)` memory overhead, making + it infeasible for :math:`L\geq 512`. + + Args: + L (int): Harmonic band-limit. + + spin (int): Harmonic spin. + + reality (bool, optional): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + Defaults to False. + + sampling (str, optional): Sampling scheme. Supported sampling schemes include + {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + + nside (int): HEALPix Nside resolution parameter. Only required + if sampling="healpix". + + forward (bool, optional): Whether to provide forward or inverse shift. + Defaults to False. + + Returns: + jnp.ndarray: Transform kernel for spin-spherical harmonic transform. + """ + m_start_ind = L - 1 if reality else 0 + + if forward and sampling.lower() in ["mw", "mwss"]: + sampling = "mwss" + thetas = samples.thetas(2 * L, "mwss") + else: + thetas = samples.thetas(L, sampling, nside) + + dl = recursions.price_mcewen.compute_all_slices_jax( + thetas, L, spin, sampling, forward, nside + ) + dl = dl.at[jnp.where(dl != dl)].set(0) + dl = jnp.swapaxes(dl, 0, 2) + dl = jnp.swapaxes(dl, 0, 1) + + # North pole singularity + if sampling.lower() == "mwss": + dl = dl.at[0].set(0) + dl = dl = dl.at[0, :, L - 1 - spin].set(1) + + # South pole singularity + if sampling.lower() in ["mw", "mwss"]: + dl = dl.at[-1].set(0) + dl = dl.at[-1, :, L - 1 + spin].set((-1) ** (jnp.arange(L) - spin)) + + dl = dl[:, :, m_start_ind:] + + scaling = jnp.sqrt((2 * jnp.arange(L) + 1) / (4 * jnp.pi)) + dl = jnp.einsum("...tlm,...l->...tlm", dl, scaling, optimize=True) + + if forward: + weights = quadrature_jax.quad_weights_transform(L, sampling, nside) + dl = jnp.einsum("...tlm, ...t->...tlm", dl, weights, optimize=True) + + if sampling.lower() == "healpix": + dl = jnp.einsum( + "...tlm,...tm->...tlm", + dl, + healpix_phase_shifts(L, nside, forward)[:, m_start_ind:], + optimize=True, + ) + + return dl + + def wigner_kernel( L: int, N: int, @@ -122,9 +208,7 @@ def wigner_kernel( for t, theta in enumerate(thetas): for el in range(abs(n), L): ind = n if reality else N - 1 + n - dl[ind, t, el] = recursions.turok.compute_slice( - theta, el, L, n, False - ) + dl[ind, t, el] = recursions.turok.compute_slice(theta, el, L, n, False) if forward: weights = quadrature.quad_weights_transform(L, sampling, 0, nside) @@ -148,6 +232,95 @@ def wigner_kernel( return dl +def wigner_kernel_jax( + L: int, + N: int, + reality: bool = False, + sampling: str = "mw", + nside: int = None, + forward: bool = False, +): + r"""Precompute the wigner-d kernels required for a Wigner transform. This can be + drastically faster but comes at a :math:`\mathcal{O}(NL^3)` memory overhead, making + it infeasible for :math:`L \geq 512`. + + Args: + L (int): Harmonic band-limit. + + N (int): Directional band-limit. + + reality (bool, optional): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + Defaults to False. + + sampling (str, optional): Sampling scheme. Supported sampling schemes include + {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + + nside (int): HEALPix Nside resolution parameter. Only required + if sampling="healpix". + + forward (bool, optional): Whether to provide forward or inverse shift. + Defaults to False. + + Returns: + jnp.ndarray: Transform kernel for Wigner transform. + """ + n_start_ind = N - 1 if reality else 0 + n_dim = N if reality else 2 * N - 1 + + if forward and sampling.lower() in ["mw", "mwss"]: + sampling = "mwss" + thetas = samples.thetas(2 * L, "mwss") + else: + thetas = samples.thetas(L, sampling, nside) + + dl = jnp.zeros((n_dim, len(thetas), L, 2 * L - 1), dtype=np.float64) + for n in range(n_start_ind - N + 1, N): + ind = n if reality else N - 1 + n + dl_n = recursions.price_mcewen.compute_all_slices_jax( + thetas, L, -n, sampling, forward, nside + ) + dl_n = dl_n.at[jnp.where(dl_n != dl_n)].set(0) + dl_n = jnp.swapaxes(dl_n, 0, 2) + dl_n = jnp.swapaxes(dl_n, 0, 1) + + # North pole singularity + if sampling.lower() == "mwss": + dl_n = dl_n.at[0].set(0) + dl_n = dl_n = dl_n.at[0, :, L - 1 + n].set(1) + + # South pole singularity + if sampling.lower() in ["mw", "mwss"]: + dl_n = dl_n.at[-1].set(0) + dl_n = dl_n.at[-1, :, L - 1 - n].set((-1) ** (jnp.arange(L) + n)) + + # Remove l <= n + dl_n = dl_n.at[:, : abs(n), :].set(0) + dl = dl.at[ind].add(dl_n) + + if forward: + weights = quadrature_jax.quad_weights_transform(L, sampling, nside) + dl = jnp.einsum("...ntlm, ...t->...ntlm", dl, weights, optimize=True) + dl *= 2 * jnp.pi / (2 * N - 1) + + else: + dl = jnp.einsum( + "...ntlm,...l->...ntlm", + dl, + (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2), + optimize=True, + ) + + if sampling.lower() == "healpix": + dl = np.einsum( + "...ntlm,...tm->...ntlm", + dl, + healpix_phase_shifts(L, nside, forward), + ) + + return dl + + def healpix_phase_shifts( L: int, nside: int, diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index 2a83c0e4..056200eb 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -1,4 +1,4 @@ -from jax import jit +from jax import jit import numpy as np import jax.numpy as jnp @@ -60,9 +60,7 @@ def inverse( if method == "numpy": return inverse_transform(flm, kernel, L, sampling, reality, spin, nside) elif method == "jax": - return inverse_transform_jax( - flm, kernel, L, sampling, reality, spin, nside - ) + return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside) else: raise ValueError(f"Method {method} not recognised.") @@ -241,9 +239,7 @@ def forward( if method == "numpy": return forward_transform(f, kernel, L, sampling, reality, spin, nside) elif method == "jax": - return forward_transform_jax( - f, kernel, L, sampling, reality, spin, nside - ) + return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside) else: raise ValueError(f"Method {method} not recognised.") @@ -380,8 +376,7 @@ def forward_transform_jax( if reality: flm = flm.at[:, :m_start_ind].set( jnp.flip( - (-1) ** (jnp.arange(1, L) % 2) - * jnp.conj(flm[:, m_start_ind + 1 :]), + (-1) ** (jnp.arange(1, L) % 2) * jnp.conj(flm[:, m_start_ind + 1 :]), axis=-1, ) ) diff --git a/s2fft/precompute_transforms/wigner.py b/s2fft/precompute_transforms/wigner.py index 5542a538..5cb119ef 100644 --- a/s2fft/precompute_transforms/wigner.py +++ b/s2fft/precompute_transforms/wigner.py @@ -59,9 +59,7 @@ def inverse( if method == "numpy": return inverse_transform(flmn, kernel, L, N, sampling, reality, nside) elif method == "jax": - return inverse_transform_jax( - flmn, kernel, L, N, sampling, reality, nside - ) + return inverse_transform_jax(flmn, kernel, L, N, sampling, reality, nside) else: raise ValueError(f"Method {method} not recognised.") @@ -102,37 +100,25 @@ def inverse_transform( m_offset = 1 if sampling in ["mwss", "healpix"] else 0 n_start_ind = N - 1 if reality else 0 - fnab = np.zeros( - samples.fnab_shape(L, N, sampling, nside), dtype=np.complex128 - ) + fnab = np.zeros(samples.fnab_shape(L, N, sampling, nside), dtype=np.complex128) fnab[n_start_ind:, :, m_offset:] = np.einsum( "...ntlm, ...nlm -> ...ntm", kernel, flmn[n_start_ind:, :, :] ) if sampling.lower() in "healpix": - f = np.zeros( - samples.f_shape(L, N, sampling, nside), dtype=np.complex128 - ) + f = np.zeros(samples.f_shape(L, N, sampling, nside), dtype=np.complex128) for n in range(n_start_ind - N + 1, N): ind = N - 1 + n f[ind] = hp.healpix_ifft(fnab[ind], L, nside, "numpy") if reality: - return np.fft.irfft( - f[n_start_ind:], 2 * N - 1, axis=-2, norm="forward" - ) + return np.fft.irfft(f[n_start_ind:], 2 * N - 1, axis=-2, norm="forward") else: - return np.fft.ifft( - np.fft.ifftshift(f, axes=-2), axis=-2, norm="forward" - ) + return np.fft.ifft(np.fft.ifftshift(f, axes=-2), axis=-2, norm="forward") else: if reality: - fnab = np.fft.ifft( - np.fft.ifftshift(fnab, axes=-1), axis=-1, norm="forward" - ) - return np.fft.irfft( - fnab[n_start_ind:], 2 * N - 1, axis=-3, norm="forward" - ) + fnab = np.fft.ifft(np.fft.ifftshift(fnab, axes=-1), axis=-1, norm="forward") + return np.fft.irfft(fnab[n_start_ind:], 2 * N - 1, axis=-3, norm="forward") else: fnab = np.fft.ifftshift(fnab, axes=(-1, -3)) return np.fft.ifft2(fnab, axes=(-1, -3), norm="forward") @@ -175,9 +161,7 @@ def inverse_transform_jax( m_offset = 1 if sampling in ["mwss", "healpix"] else 0 n_start_ind = N - 1 if reality else 0 - fnab = jnp.zeros( - samples.fnab_shape(L, N, sampling, nside), dtype=jnp.complex128 - ) + fnab = jnp.zeros(samples.fnab_shape(L, N, sampling, nside), dtype=jnp.complex128) fnab = fnab.at[n_start_ind:, :, m_offset:].set( jnp.einsum( "...ntlm, ...nlm -> ...ntm", @@ -188,16 +172,12 @@ def inverse_transform_jax( ) if sampling.lower() in "healpix": - f = jnp.zeros( - samples.f_shape(L, N, sampling, nside), dtype=jnp.complex128 - ) + f = jnp.zeros(samples.f_shape(L, N, sampling, nside), dtype=jnp.complex128) for n in range(n_start_ind - N + 1, N): ind = N - 1 + n f = f.at[ind].set(hp.healpix_ifft(fnab[ind], L, nside, "jax")) if reality: - return jnp.fft.irfft( - f[n_start_ind:], 2 * N - 1, axis=-2, norm="forward" - ) + return jnp.fft.irfft(f[n_start_ind:], 2 * N - 1, axis=-2, norm="forward") else: return jnp.conj( jnp.fft.fft( @@ -216,9 +196,7 @@ def inverse_transform_jax( norm="backward", ) ) - return jnp.fft.irfft( - fnab[n_start_ind:], 2 * N - 1, axis=-3, norm="forward" - ) + return jnp.fft.irfft(fnab[n_start_ind:], 2 * N - 1, axis=-3, norm="forward") else: fnab = jnp.conj(jnp.fft.ifftshift(fnab, axes=(-1, -3))) return jnp.conj(jnp.fft.fft2(fnab, axes=(-1, -3), norm="backward")) @@ -331,9 +309,7 @@ def forward_transform( m_offset = 1 if sampling in ["mwss", "healpix"] else 0 if sampling.lower() in "healpix": - temp = np.zeros( - samples.fnab_shape(L, N, sampling, nside), dtype=np.complex128 - ) + temp = np.zeros(samples.fnab_shape(L, N, sampling, nside), dtype=np.complex128) for n in range(n_start_ind - N + 1, N): ind = n if reality else N - 1 + n temp[N - 1 + n] = hp.healpix_fft(fban[ind], L, nside, "numpy") @@ -346,9 +322,7 @@ def forward_transform( flmn = np.zeros(samples.flmn_shape(L, N), dtype=np.complex128) flmn[n_start_ind:] = np.einsum("...ntlm, ...ntm -> ...nlm", kernel, fban) if reality: - flmn[:n_start_ind] = np.conj( - np.flip(flmn[n_start_ind + 1 :], axis=(-1, -3)) - ) + flmn[:n_start_ind] = np.conj(np.flip(flmn[n_start_ind + 1 :], axis=(-1, -3))) flmn[:n_start_ind] = np.einsum( "...nlm,...m->...nlm", flmn[:n_start_ind], @@ -403,9 +377,7 @@ def forward_transform_jax( if reality: fban = jnp.fft.rfft(jnp.real(f), axis=ax, norm="backward") else: - fban = jnp.fft.fftshift( - jnp.fft.fft(f, axis=ax, norm="backward"), axes=ax - ) + fban = jnp.fft.fftshift(jnp.fft.fft(f, axis=ax, norm="backward"), axes=ax) spins = -jnp.arange(n_start_ind - N + 1, N) if sampling.lower() == "mw": @@ -423,9 +395,7 @@ def forward_transform_jax( ) for n in range(n_start_ind - N + 1, N): ind = n if reality else N - 1 + n - temp = temp.at[N - 1 + n].set( - hp.healpix_fft(fban[ind], L, nside, "jax") - ) + temp = temp.at[N - 1 + n].set(hp.healpix_fft(fban[ind], L, nside, "jax")) fban = temp[n_start_ind:, :, m_offset:] else: diff --git a/s2fft/recursions/price_mcewen.py b/s2fft/recursions/price_mcewen.py index ce5a16a8..a5b7da94 100644 --- a/s2fft/recursions/price_mcewen.py +++ b/s2fft/recursions/price_mcewen.py @@ -191,12 +191,8 @@ def cpi_cp2_loop(m, args): def cpi_cp2_roll_loop(m, args): cpi, cp2 = args - cpi = cpi.at[:, m - L0].set( - jnp.roll(cpi[:, m - L0], (L - m - 1), axis=-1) - ) - cp2 = cp2.at[:, m - L0].set( - jnp.roll(cp2[:, m - L0], (L - m - 1), axis=-1) - ) + cpi = cpi.at[:, m - L0].set(jnp.roll(cpi[:, m - L0], (L - m - 1), axis=-1)) + cp2 = cp2.at[:, m - L0].set(jnp.roll(cp2[:, m - L0], (L - m - 1), axis=-1)) return cpi, cp2 cpi, cp2 = lax.fori_loop(L0, L, cpi_cp2_roll_loop, (cpi, cp2)) @@ -235,9 +231,7 @@ def renorm_m_loop(i, args): log_first_row_iter = jnp.swapaxes(log_first_row_iter, 0, 1) for ind in range(2): lrenorm = lrenorm.at[ind].set( - jnp.where( - i == half_slices[ind], log_first_row_iter, lrenorm[ind] - ) + jnp.where(i == half_slices[ind], log_first_row_iter, lrenorm[ind]) ) return log_first_row_iter, lrenorm @@ -303,9 +297,7 @@ def generate_precomputes_wigner( precomps = [] n_start_ind = 0 if reality else -N + 1 for n in range(n_start_ind, N): - precomps.append( - generate_precomputes(L, -n, sampling, nside, forward, L_lower) - ) + precomps.append(generate_precomputes(L, -n, sampling, nside, forward, L_lower)) return precomps @@ -357,9 +349,7 @@ def generate_precomputes_wigner_jax( captured_repeats = False n_start_ind = 0 if reality else -N + 1 for n in range(n_start_ind, N): - precomps = generate_precomputes_jax( - L, -n, sampling, nside, forward, L_lower - ) + precomps = generate_precomputes_jax(L, -n, sampling, nside, forward, L_lower) lrenorm.append(precomps[0]) vsign.append(precomps[1]) if not captured_repeats: @@ -456,9 +446,7 @@ def compute_all_slices( ) dl_test[sind, :, lind:] = ( - dl_iter[0, :, lind:] - * vsign[sind, lind:] - * np.exp(lrenorm[i, :, lind:]) + dl_iter[0, :, lind:] * vsign[sind, lind:] * np.exp(lrenorm[i, :, lind:]) ) dl_test[sind + sgn, :, lind - 1 :] = ( dl_iter[1, :, lind - 1 :] @@ -501,9 +489,15 @@ def compute_all_slices( return dl_test -@partial(jit, static_argnums=(1, 2)) +@partial(jit, static_argnums=(1, 3, 4, 5)) def compute_all_slices_jax( - beta: jnp.ndarray, L: int, spin: int, precomps=None + beta: jnp.ndarray, + L: int, + spin: int, + sampling: str = "mw", + forward: bool = False, + nside: int = None, + precomps=None, ) -> jnp.ndarray: r"""Compute a particular slice :math:`m^{\prime}`, denoted `mm`, of the complete Wigner-d matrix for all sampled polar angles @@ -544,13 +538,22 @@ def compute_all_slices_jax( ntheta = len(beta) lims = [0, -1] + # Trigonometric constant adopted throughout + c = jnp.cos(beta) + s = jnp.sin(beta) + omc = 1.0 - c + el = jnp.arange(L) + + # Indexing boundaries + half_slices = [el + mm + 1, el - mm + 1] + dl_test = jnp.zeros((2 * L - 1, ntheta, L), dtype=jnp.float64) if precomps is None: - lrenorm, lamb, vsign, cpi, cp2, cs, indices = generate_precomputes( - beta, L, mm + lrenorm, vsign, cpi, cp2, indices = generate_precomputes_jax( + L, spin, sampling, nside, forward, 0, beta ) else: - lrenorm, lamb, vsign, cpi, cp2, cs, indices = precomps + lrenorm, vsign, cpi, cp2, indices = precomps for i in range(2): lind = L - 1 @@ -558,18 +561,23 @@ def compute_all_slices_jax( sgn = (-1) ** (i) dl_iter = jnp.ones((2, ntheta, L), dtype=jnp.float64) + lamb = ( + jnp.einsum("l,t->tl", el + 1, omc, optimize=True) + + jnp.einsum("l,t->tl", 2 - L + el, c, optimize=True) + - half_slices[i] + ) + lamb = jnp.einsum("tl,t->tl", lamb, 1 / s, optimize=True) + dl_iter = dl_iter.at[1, :, lind:].set( jnp.einsum( "l,tl->tl", cpi[0, lind:], - dl_iter[0, :, lind:] * lamb[i, :, lind:], + dl_iter[0, :, lind:] * lamb[:, lind:], ) ) dl_test = dl_test.at[sind, :, lind:].set( - dl_iter[0, :, lind:] - * vsign[sind, lind:] - * jnp.exp(lrenorm[i, :, lind:]) + dl_iter[0, :, lind:] * vsign[sind, lind:] * jnp.exp(lrenorm[i, :, lind:]) ) dl_test = dl_test.at[sind + sgn, :, lind - 1 :].set( @@ -581,13 +589,20 @@ def compute_all_slices_jax( dl_entry = jnp.zeros((ntheta, L), dtype=jnp.float64) def pm_recursion_step(m, args): - dl_test, dl_entry, dl_iter, lamb, lrenorm = args + dl_test, dl_entry, dl_iter, lrenorm, indices, omc, c, s = args index = indices >= L - m - 1 - lamb = lamb.at[i, :, jnp.arange(L)].add(cs) + + lamb = ( + jnp.einsum("l,t->tl", el + 1, omc, optimize=True) + + jnp.einsum("l,t->tl", m - L + el + 1, c, optimize=True) + - half_slices[i] + ) + lamb = jnp.einsum("tl,t->tl", lamb, 1 / s, optimize=True) + dl_entry = jnp.where( index, - jnp.einsum("l,tl->tl", cpi[m - 1], dl_iter[1] * lamb[i]) - - jnp.einsum("l,tl->tl", cp2[m - 1], dl_iter[0]), + jnp.einsum("l,tl->tl", cpi[m - 1], dl_iter[1] * lamb, optimize=True) + - jnp.einsum("l,tl->tl", cp2[m - 1], dl_iter[0], optimize=True), dl_entry, ) dl_entry = dl_entry.at[:, -(m + 1)].set(1) @@ -603,18 +618,15 @@ def pm_recursion_step(m, args): bigi = 1.0 / abs(dl_entry) lbig = jnp.log(abs(dl_entry)) - dl_iter = dl_iter.at[0].set( - jnp.where(index, bigi * dl_iter[1], dl_iter[0]) - ) - dl_iter = dl_iter.at[1].set( - jnp.where(index, bigi * dl_entry, dl_iter[1]) - ) - lrenorm = lrenorm.at[i].set( - jnp.where(index, lrenorm[i] + lbig, lrenorm[i]) - ) - return dl_test, dl_entry, dl_iter, lamb, lrenorm + dl_iter = dl_iter.at[0].set(jnp.where(index, bigi * dl_iter[1], dl_iter[0])) + dl_iter = dl_iter.at[1].set(jnp.where(index, bigi * dl_entry, dl_iter[1])) + lrenorm = lrenorm.at[i].set(jnp.where(index, lrenorm[i] + lbig, lrenorm[i])) + return dl_test, dl_entry, dl_iter, lrenorm, indices, omc, c, s - dl_test, dl_entry, dl_iter, lamb, lrenorm = lax.fori_loop( - 2, L, pm_recursion_step, (dl_test, dl_entry, dl_iter, lamb, lrenorm) + dl_test, dl_entry, dl_iter, lrenorm, indices, omc, c, s = lax.fori_loop( + 2, + L, + pm_recursion_step, + (dl_test, dl_entry, dl_iter, lrenorm, indices, omc, c, s), ) return dl_test diff --git a/s2fft/recursions/risbo.py b/s2fft/recursions/risbo.py index c2ffd13b..754d39da 100644 --- a/s2fft/recursions/risbo.py +++ b/s2fft/recursions/risbo.py @@ -25,12 +25,10 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray: _arg_checks(dl, beta, L, el) if el == 0: - el = 0 dl[el + L - 1, el + L - 1] = 1.0 elif el == 1: - cosb = np.cos(beta) sinb = np.sin(beta) @@ -51,7 +49,6 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray: dl[1 + L - 1, 1 + L - 1] = coshb**2 else: - coshb = -np.cos(beta / 2.0) sinhb = np.sin(beta / 2.0) @@ -61,12 +58,10 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray: j = 2 * el - 1 rj = float(j) # TODO: is this necessary? for k in range(0, j): - sqrt_jmk = np.sqrt(j - k) sqrt_kp1 = np.sqrt(k + 1) for i in range(0, j): - sqrt_jmi = np.sqrt(j - i) sqrt_ip1 = np.sqrt(i + 1) @@ -84,12 +79,10 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray: j = 2 * el rj = float(j) # TODO: is this necessary? for k in range(0, j): - sqrt_jmk = np.sqrt(j - k) sqrt_kp1 = np.sqrt(k + 1) for i in range(0, j): - sqrt_jmi = np.sqrt(j - i) sqrt_ip1 = np.sqrt(i + 1) diff --git a/s2fft/recursions/trapani.py b/s2fft/recursions/trapani.py index ef9c3100..3c9a2af0 100644 --- a/s2fft/recursions/trapani.py +++ b/s2fft/recursions/trapani.py @@ -24,19 +24,15 @@ def init(dl: np.ndarray, L: int, implementation: str = "vectorized") -> np.ndarr """ if implementation.lower() == "loop": - return init_nonjax(dl, L) elif implementation == "vectorized": - return init_nonjax(dl, L) elif implementation == "jax": - return init_jax(dl, L) else: - raise ValueError(f"Implementation {implementation} not supported") @@ -148,7 +144,6 @@ def compute_eighth(dl: np.ndarray, L: int, el: int) -> np.ndarray: # Equation (11) of T&N (2006). for mm in range(el + 1): # 0:el - # m = el-1 case (t2 = 0). m = el - 1 dl[m + (L - 1), mm + (L - 1)] = ( @@ -317,7 +312,10 @@ def compute_quarter_jax(dl: jnp.ndarray, L: int, el: int) -> jnp.ndarray: ) def compute_dl_submatrix_slice(dl_slice_1_dl_slice_2, t1_fact_i_t2_fact_i): - t1_fact_i, t2_fact_i, = t1_fact_i_t2_fact_i + ( + t1_fact_i, + t2_fact_i, + ) = t1_fact_i_t2_fact_i dl_slice_1, dl_slice_2 = dl_slice_1_dl_slice_2 t1 = 2 * mm / t1_fact_i * dl_slice_1 t2 = t2_fact_i * dl_slice_2 @@ -327,7 +325,7 @@ def compute_dl_submatrix_slice(dl_slice_1_dl_slice_2, t1_fact_i_t2_fact_i): _, dl_submatrix = lax.scan( compute_dl_submatrix_slice, (dl[el - 1 + (L - 1), mm + (L - 1)], dl[el + (L - 1), mm + (L - 1)]), - (t1_fact, t2_fact) + (t1_fact, t2_fact), ) dl = dl.at[ms[:, None] + (L - 1), mm[None] + (L - 1)].set(dl_submatrix) @@ -609,19 +607,15 @@ def compute_full( """ if implementation.lower() == "loop": - return compute_full_loop(dl, L, el) elif implementation == "vectorized": - return compute_full_vectorized(dl, L, el) elif implementation == "jax": - return compute_full_jax(dl, L, el) else: - raise ValueError(f"Implementation {implementation} not supported") diff --git a/s2fft/recursions/turok.py b/s2fft/recursions/turok.py index c667042a..7d9ef915 100644 --- a/s2fft/recursions/turok.py +++ b/s2fft/recursions/turok.py @@ -217,9 +217,7 @@ def compute_quarter_slice( if i == 1: for m in range(el + 1): dl[lims[i] + sgn * m] = ( - (-1) ** ((mm - m + el) % 2) - * dl[lims[i] + sgn * m] - * renorm + (-1) ** ((mm - m + el) % 2) * dl[lims[i] + sgn * m] * renorm ) s_ind = 0 if positive_m_only else -el @@ -293,9 +291,7 @@ def compute_quarter(dl: np.ndarray, beta: float, l: int, L: int) -> np.ndarray: for i in range(2, 2 * l + 2): m = l + 1 - i ratio = np.sqrt((m + l + 1) / (l - m)) - log_first_row[i - 1] = ( - log_first_row[i - 2] + np.log(ratio) + np.log(np.abs(t)) - ) + log_first_row[i - 1] = log_first_row[i - 2] + np.log(ratio) + np.log(np.abs(t)) sign[i - 1] = sign[i - 2] * t / np.abs(t) # Initialising coefficients cp(m)= cplus(l-m). @@ -326,9 +322,7 @@ def compute_quarter(dl: np.ndarray, beta: float, l: int, L: int) -> np.ndarray: if dl[index - lp1, m + 1 - lp1] > big: lrenorm[index - 1] = lrenorm[index - 1] - lbig for im in range(1, m + 2): - dl[index - lp1, im - lp1] = ( - dl[index - lp1, im - lp1] * bigi - ) + dl[index - lp1, im - lp1] = dl[index - lp1, im - lp1] * bigi # Use Turok & Bucher recursion to fill horizontal to anti-diagonal (upper left eight) for index in range(l + 2, 2 * l + 1): @@ -345,9 +339,7 @@ def compute_quarter(dl: np.ndarray, beta: float, l: int, L: int) -> np.ndarray: if dl[index - lp1, m + 1 - lp1] > big: lrenorm[index - 1] = lrenorm[index - 1] - lbig for im in range(1, m + 2): - dl[index - lp1, im - lp1] = ( - dl[index - lp1, im - lp1] * bigi - ) + dl[index - lp1, im - lp1] = dl[index - lp1, im - lp1] * bigi # Apply renormalisation for i in range(1, l + 2): diff --git a/s2fft/recursions/turok_jax.py b/s2fft/recursions/turok_jax.py index af094477..eec666b0 100644 --- a/s2fft/recursions/turok_jax.py +++ b/s2fft/recursions/turok_jax.py @@ -164,7 +164,7 @@ def renorm_iteration(m, dl_lrenorm): # an IndexError exception when used with lax.fori_loop lambda x: jnp.where((indices < (m + 1))[::sgn], bigi * x, x), lambda x: x, - dl + dl, ) return dl, lrenorm @@ -178,7 +178,6 @@ def renorm_iteration(m, dl_lrenorm): if i == 1: dl = dl.at[-em].multiply((-1) ** ((mm - em + el + 1) % 2) * renorm) - return jnp.nan_to_num(dl, neginf=0, posinf=0) @@ -279,7 +278,7 @@ def reindex(dl, el, L, mm) -> jnp.ndarray: dl = dl.at[: L - 1].set(jnp.roll(dl, L - el - 1)[: L - 1]) dl = dl.at[L - 1 :].set(jnp.roll(dl, -(L - el - 1))[L - 1 :]) - m = jnp.arange(-L+1, L+1) - dl = dl.at[L-1+m].multiply((-1)**((mm - m)%2)) + m = jnp.arange(-L + 1, L + 1) + dl = dl.at[L - 1 + m].multiply((-1) ** ((mm - m) % 2)) return dl diff --git a/s2fft/sampling/s2_samples.py b/s2fft/sampling/s2_samples.py index 89368e8f..dbc6bec3 100644 --- a/s2fft/sampling/s2_samples.py +++ b/s2fft/sampling/s2_samples.py @@ -32,19 +32,15 @@ def ntheta(L: int = None, sampling: str = "mw", nside: int = None) -> int: ) if sampling.lower() == "mw": - return L elif sampling.lower() == "mwss": - return L + 1 elif sampling.lower() == "dh": - return 2 * L elif sampling.lower() == "healpix": - if nside is None: raise ValueError( f"Sampling scheme sampling={sampling} with nside={nside} not supported" @@ -53,7 +49,6 @@ def ntheta(L: int = None, sampling: str = "mw", nside: int = None) -> int: return 4 * nside - 1 else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") @@ -76,15 +71,12 @@ def ntheta_extension(L: int, sampling: str = "mw") -> int: """ if sampling.lower() == "mw": - return 2 * L - 1 elif sampling.lower() == "mwss": - return 2 * L else: - raise ValueError( f"Sampling scheme sampling={sampling} does not support periodic extension" ) @@ -113,31 +105,24 @@ def nphi_equiang(L: int, sampling: str = "mw") -> int: """ if sampling.lower() == "mw": - return 2 * L - 1 elif sampling.lower() == "mwss": - return 2 * L elif sampling.lower() == "dh": - return 2 * L - 1 elif sampling.lower() == "healpix": - raise ValueError(f"Sampling scheme sampling={sampling} not supported") else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") return 1 -def ftm_shape( - L: int, sampling: str = "mw", nside: int = None -) -> Tuple[int, int]: +def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> Tuple[int, int]: r"""Shape of intermediate array, before/after latitudinal step. Args: @@ -157,15 +142,12 @@ def ftm_shape( """ if sampling.lower() in ["mwss", "healpix"]: - return ntheta(L, sampling, nside), 2 * L elif sampling.lower() in ["mw", "dh"]: - return ntheta(L, sampling, nside), 2 * L - 1 else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") return 1 @@ -213,9 +195,7 @@ def nphi_ring(t: int, nside: int = None) -> int: raise ValueError(f"Ring t={t} not contained by nside={nside}") -def thetas( - L: int = None, sampling: str = "mw", nside: int = None -) -> np.ndarray: +def thetas(L: int = None, sampling: str = "mw", nside: int = None) -> np.ndarray: r"""Compute :math:`\theta` samples for given sampling scheme. Args: @@ -231,9 +211,7 @@ def thetas( Returns: np.ndarray: Array of :math:`\theta` samples for given sampling scheme. """ - t = np.arange(0, ntheta(L=L, sampling=sampling, nside=nside)).astype( - np.float64 - ) + t = np.arange(0, ntheta(L=L, sampling=sampling, nside=nside)).astype(np.float64) return t2theta(t, L, sampling, nside) @@ -272,19 +250,15 @@ def t2theta( ) if sampling.lower() == "mw": - return (2 * t + 1) * np.pi / (2 * L - 1) elif sampling.lower() == "mwss": - return 2 * t * np.pi / (2 * L) elif sampling.lower() == "dh": - return (2 * t + 1) * np.pi / (4 * L) elif sampling.lower() == "healpix": - if nside is None: raise ValueError( f"Sampling scheme sampling={sampling} with nside={nside} not supported" @@ -293,7 +267,6 @@ def t2theta( return _t2theta_healpix(t, nside) else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") @@ -404,23 +377,18 @@ def p2phi_equiang(L: int, p: int, sampling: str = "mw") -> np.ndarray: """ if sampling.lower() == "mw": - return 2 * p * np.pi / (2 * L - 1) elif sampling.lower() == "mwss": - return 2 * p * np.pi / (2 * L) elif sampling.lower() == "dh": - return 2 * p * np.pi / (2 * L - 1) elif sampling.lower() == "healpix": - raise ValueError(f"Sampling scheme sampling={sampling} not supported") else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") @@ -456,9 +424,7 @@ def ring_phase_shift_hp( return np.exp(sign * 1j * np.arange(m_start_ind, L) * phi_offset) -def f_shape( - L: int = None, sampling: str = "mw", nside: int = None -) -> Tuple[int]: +def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int]: r"""Shape of spherical signal. Args: @@ -486,11 +452,9 @@ def f_shape( ) if sampling.lower() == "healpix": - return (12 * nside**2,) else: - return ntheta(L, sampling), nphi_equiang(L, sampling) diff --git a/s2fft/sampling/so3_samples.py b/s2fft/sampling/so3_samples.py index 7a18e69e..2a069ae0 100644 --- a/s2fft/sampling/so3_samples.py +++ b/s2fft/sampling/so3_samples.py @@ -37,19 +37,15 @@ def f_shape( :math:`SO(3)`. """ if sampling in ["mw", "mwss", "dh"]: - return _ngamma(N), _nbeta(L, sampling), _nalpha(L, sampling) elif sampling.lower() == "healpix": - return _ngamma(N), 12 * nside**2 elif sampling.lower() == "healpix": - return 12 * nside**2, _ngamma(N) else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") @@ -94,15 +90,12 @@ def fnab_shape( """ if sampling.lower() in ["mwss", "healpix"]: - return _ngamma(N), samples.ntheta(L, sampling, nside), 2 * L elif sampling.lower() in ["mw", "dh"]: - return _ngamma(N), samples.ntheta(L, sampling, nside), 2 * L - 1 else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") return 1 @@ -138,15 +131,12 @@ def _nalpha(L: int, sampling: str = "mw") -> int: int: Number of :math:`\alpha` samples. """ if sampling.lower() in ["mw", "dh"]: - return 2 * L - 1 elif sampling.lower() == "mwss": - return 2 * L else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") @@ -166,19 +156,15 @@ def _nbeta(L: int, sampling: str = "mw") -> int: int: Number of :math:`\beta` samples. """ if sampling.lower() == "mw": - return L elif sampling.lower() == "mwss": - return L + 1 elif sampling.lower() == "dh": - return 2 * L else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") @@ -247,9 +233,7 @@ def flmn_3d_to_1d(flmn_3d: np.ndarray, L: int, N: int) -> np.ndarray: for n in range(-N + 1, N): for el in range(L): for m in range(-el, el + 1): - flmn_1d[elmn2ind(el, m, n, L, N)] = flmn_3d[ - N - 1 + n, el, L - 1 + m - ] + flmn_1d[elmn2ind(el, m, n, L, N)] = flmn_3d[N - 1 + n, el, L - 1 + m] return flmn_1d @@ -285,8 +269,6 @@ def flmn_1d_to_3d(flmn_1d: np.ndarray, L: int, N: int) -> np.ndarray: for n in range(-N + 1, N): for el in range(L): for m in range(-el, el + 1): - flmn_3d[N - 1 + n, el, L - 1 + m] = flmn_1d[ - elmn2ind(el, m, n, L, N) - ] + flmn_3d[N - 1 + n, el, L - 1 + m] = flmn_1d[elmn2ind(el, m, n, L, N)] return flmn_3d diff --git a/s2fft/transforms/otf_recursions.py b/s2fft/transforms/otf_recursions.py index 279a67e7..4d6b73bb 100644 --- a/s2fft/transforms/otf_recursions.py +++ b/s2fft/transforms/otf_recursions.py @@ -152,9 +152,7 @@ def inverse_latitudinal_step( lbig = np.log(abs(dl_entry[:, L_lower:])) dl_iter[0] = np.where(index, bigi * dl_iter[1], dl_iter[0]) - dl_iter[1] = np.where( - index, bigi * dl_entry[:, L_lower:], dl_iter[1] - ) + dl_iter[1] = np.where(index, bigi * dl_entry[:, L_lower:], dl_iter[1]) lrenorm[i] = np.where(index, lrenorm[i] + lbig, lrenorm[i]) return ftm @@ -312,9 +310,7 @@ def pm_recursion_step(m, args): dl_iter[1] * lamb, optimize=True, ) - - jnp.einsum( - "l,tl->tl", cp2[m - 1], dl_iter[0], optimize=True - ), + - jnp.einsum("l,tl->tl", cp2[m - 1], dl_iter[0], optimize=True), dl_entry[:, L_lower:], ) ) @@ -371,9 +367,7 @@ def eval_recursion_step( ndevices = local_device_count() opsdevice = int(ntheta / ndevices) - ftm = pmap( - eval_recursion_step, in_axes=(0, 0, 1, 1, 0, 0, 0, 0) - )( + ftm = pmap(eval_recursion_step, in_axes=(0, 0, 1, 1, 0, 0, 0, 0))( ftm.reshape(ndevices, opsdevice, ftm.shape[-1]), dl_entry.reshape(ndevices, opsdevice, L), dl_iter.reshape(2, ndevices, opsdevice, L - L_lower), @@ -382,9 +376,7 @@ def eval_recursion_step( omc.reshape(ndevices, opsdevice), c.reshape(ndevices, opsdevice), s.reshape(ndevices, opsdevice), - ).reshape( - ntheta, ftm.shape[-1] - ) + ).reshape(ntheta, ftm.shape[-1]) else: ( @@ -409,8 +401,7 @@ def eval_recursion_step( ftm = ftm.at[-1].set(0) ftm = ftm.at[-1, L - 1 + spin + m_offset].set( jnp.nansum( - (-1) ** abs(jnp.arange(L_lower, L) - spin) - * flm[L_lower:, L - 1 + spin] + (-1) ** abs(jnp.arange(L_lower, L) - spin) * flm[L_lower:, L - 1 + spin] ) ) @@ -572,9 +563,7 @@ def forward_latitudinal_step( lbig = np.log(abs(dl_entry[:, L_lower:])) dl_iter[0] = np.where(index, bigi * dl_iter[1], dl_iter[0]) - dl_iter[1] = np.where( - index, bigi * dl_entry[:, L_lower:], dl_iter[1] - ) + dl_iter[1] = np.where(index, bigi * dl_entry[:, L_lower:], dl_iter[1]) lrenorm[i] = np.where(index, lrenorm[i] + lbig, lrenorm[i]) return flm @@ -751,12 +740,8 @@ def pm_recursion_step(m, args): dl_entry = jnp.where( index, - jnp.einsum( - "l,tl->tl", cpi[m - 1], dl_iter[1] * lamb, optimize=True - ) - - jnp.einsum( - "l,tl->tl", cp2[m - 1], dl_iter[0], optimize=True - ), + jnp.einsum("l,tl->tl", cpi[m - 1], dl_iter[1] * lamb, optimize=True) + - jnp.einsum("l,tl->tl", cp2[m - 1], dl_iter[0], optimize=True), dl_entry, ) dl_entry = jnp.where(indices == L - 1 - m, 1, dl_entry) @@ -766,9 +751,7 @@ def pm_recursion_step(m, args): jnp.nansum( jnp.einsum( "tl, t->tl", - dl_entry - * vsign[sind + sgn * m] - * jnp.exp(lrenorm[i]), + dl_entry * vsign[sind + sgn * m] * jnp.exp(lrenorm[i]), ftm[:, sind + sgn * m + m_offset], optimize=True, ), @@ -847,9 +830,7 @@ def eval_recursion_step( cp2.reshape(L + 1, ndevices, opsdevice), vsign.reshape(2 * L - 1, ndevices, opsdevice), indices.reshape(ntheta, ndevices, opsdevice), - ).reshape( - L - L_lower, 2 * L - 1 - ) + ).reshape(L - L_lower, 2 * L - 1) ) else: @@ -880,7 +861,5 @@ def eval_recursion_step( ) if sampling.lower() == "mwss": - flm = flm.at[L_lower:, L - 1 - spin].add( - ftm_in[0, L - 1 - spin + m_offset] - ) + flm = flm.at[L_lower:, L - 1 - spin].add(ftm_in[0, L - 1 - spin + m_offset]) return flm diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 67c670ff..3a86667e 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -72,9 +72,7 @@ def inverse( recover acceleration by the number of devices. """ if method == "numpy": - return inverse_numpy( - flm, L, spin, nside, sampling, reality, precomps, L_lower - ) + return inverse_numpy(flm, L, spin, nside, sampling, reality, precomps, L_lower) elif method == "jax": return inverse_jax( flm, L, spin, nside, sampling, reality, precomps, spmd, L_lower @@ -178,9 +176,7 @@ def inverse_numpy( norm="forward", ) else: - return np.fft.ifft( - np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward" - ) + return np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward") @partial(jit, static_argnums=(1, 3, 4, 5, 7, 8)) @@ -368,9 +364,7 @@ def forward( recover acceleration by the number of devices. """ if method == "numpy": - return forward_numpy( - f, L, spin, nside, sampling, reality, precomps, L_lower - ) + return forward_numpy(f, L, spin, nside, sampling, reality, precomps, L_lower) elif method == "jax": return forward_jax( f, L, spin, nside, sampling, reality, precomps, spmd, L_lower @@ -452,9 +446,7 @@ def forward_numpy( ftm = np.zeros_like(f).astype(np.complex128) ftm[:, L - 1 + m_offset :] = t else: - ftm = np.fft.fftshift( - np.fft.fft(f, axis=1, norm="backward"), axes=1 - ) + ftm = np.fft.fftshift(np.fft.fft(f, axis=1, norm="backward"), axes=1) # Apply quadrature weights ftm = np.einsum("tm,t->tm", ftm, weights) @@ -590,9 +582,7 @@ def forward_jax( ftm = jnp.zeros_like(f).astype(jnp.complex128) ftm = ftm.at[:, L - 1 + m_offset :].set(t) else: - ftm = jnp.fft.fftshift( - jnp.fft.fft(f, axis=1, norm="backward"), axes=1 - ) + ftm = jnp.fft.fftshift(jnp.fft.fft(f, axis=1, norm="backward"), axes=1) # Apply quadrature weights ftm = jnp.einsum("tm,t->tm", ftm, weights, optimize=True) @@ -654,8 +644,7 @@ def f_bwd(res, glm): if reality: flm = flm.at[..., :m_start_ind].set( jnp.flip( - (-1) ** (jnp.arange(1, L) % 2) - * jnp.conj(flm[..., m_start_ind + 1 :]), + (-1) ** (jnp.arange(1, L) % 2) * jnp.conj(flm[..., m_start_ind + 1 :]), axis=-1, ) ) diff --git a/s2fft/transforms/wigner.py b/s2fft/transforms/wigner.py index b48dee5a..97b780c2 100644 --- a/s2fft/transforms/wigner.py +++ b/s2fft/transforms/wigner.py @@ -77,9 +77,7 @@ def inverse( recover acceleration by the number of devices. """ if method == "numpy": - return inverse_numpy( - flmn, L, N, nside, sampling, reality, precomps, L_lower - ) + return inverse_numpy(flmn, L, N, nside, sampling, reality, precomps, L_lower) elif method == "jax": return inverse_jax( flmn, L, N, nside, sampling, reality, precomps, spmd, L_lower @@ -170,9 +168,7 @@ def inverse_numpy( if reality: f = np.fft.irfft(fban[N - 1 :], 2 * N - 1, axis=ax, norm="forward") else: - f = np.fft.ifft( - np.fft.ifftshift(fban, axes=ax), axis=ax, norm="forward" - ) + f = np.fft.ifft(np.fft.ifftshift(fban, axes=ax), axis=ax, norm="forward") return f @@ -244,9 +240,7 @@ def inverse_jax( precomps = s2fft.generate_precomputes_wigner_jax( L, N, sampling, nside, False, reality, L_lower ) - fban = jnp.zeros( - samples.f_shape(L, N, sampling, nside), dtype=jnp.complex128 - ) + fban = jnp.zeros(samples.f_shape(L, N, sampling, nside), dtype=jnp.complex128) flmn = flmn.at[:, L_lower:].set( jnp.einsum( @@ -285,12 +279,9 @@ def spherical_loop(n, args): return fban, flmn, lrenorm, vsign, spins if spmd: - # TODO: Generalise this to optional device counts. ndevices = local_device_count() - opsdevice = ( - int(N / ndevices) if reality else int((2 * N - 1) / ndevices) - ) + opsdevice = int(N / ndevices) if reality else int((2 * N - 1) / ndevices) def eval_spherical_loop(fban, flmn, lrenorm, vsign, spins): return lax.fori_loop( @@ -405,13 +396,9 @@ def forward( recover acceleration by the number of devices. """ if method == "numpy": - return forward_numpy( - f, L, N, nside, sampling, reality, precomps, L_lower - ) + return forward_numpy(f, L, N, nside, sampling, reality, precomps, L_lower) elif method == "jax": - return forward_jax( - f, L, N, nside, sampling, reality, precomps, spmd, L_lower - ) + return forward_jax(f, L, N, nside, sampling, reality, precomps, spmd, L_lower) else: raise ValueError( f"Implementation {method} not recognised. Should be either numpy or jax." @@ -617,9 +604,7 @@ def spherical_loop(n, args): if spmd: # TODO: Generalise this to optional device counts. ndevices = local_device_count() - opsdevice = ( - int(N / ndevices) if reality else int((2 * N - 1) / ndevices) - ) + opsdevice = int(N / ndevices) if reality else int((2 * N - 1) / ndevices) def eval_spherical_loop(fban, flmn, lrenorm, vsign, spins): return lax.fori_loop( diff --git a/s2fft/utils/healpix_ffts.py b/s2fft/utils/healpix_ffts.py index da2c1c66..03858c75 100644 --- a/s2fft/utils/healpix_ffts.py +++ b/s2fft/utils/healpix_ffts.py @@ -65,9 +65,7 @@ def spectral_folding_jax(fm: jnp.ndarray, nphi: int, L: int) -> jnp.ndarray: ) -def spectral_periodic_extension( - fm: np.ndarray, nphi: int, L: int -) -> np.ndarray: +def spectral_periodic_extension(fm: np.ndarray, nphi: int, L: int) -> np.ndarray: """Extends lower frequency Fourier coefficients onto higher frequency coefficients, i.e. imposed periodicity in Fourier space. @@ -161,9 +159,7 @@ def healpix_fft( raise ValueError(f"Method {method} not recognised.") -def healpix_fft_numpy( - f: np.ndarray, L: int, nside: int, reality: bool -) -> np.ndarray: +def healpix_fft_numpy(f: np.ndarray, L: int, nside: int, reality: bool) -> np.ndarray: """Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions to manually enforce Fourier periodicity. @@ -204,9 +200,7 @@ def healpix_fft_numpy( @partial(jit, static_argnums=(1, 2, 3)) -def healpix_fft_jax( - f: jnp.ndarray, L: int, nside: int, reality: bool -) -> jnp.ndarray: +def healpix_fft_jax(f: jnp.ndarray, L: int, nside: int, reality: bool) -> jnp.ndarray: """ Healpix FFT JAX implementation using jax.numpy/numpy stack Computes the Forward Fast Fourier Transform with spectral back-projection @@ -234,9 +228,7 @@ def healpix_fft_jax( if reality and nphi == 2 * L: fm_chunk = jnp.zeros(nphi, dtype=jnp.complex128) fm_chunk = fm_chunk.at[nphi // 2 :].set( - jnp.fft.rfft( - jnp.real(f[index : index + nphi]), norm="backward" - )[:-1] + jnp.fft.rfft(jnp.real(f[index : index + nphi]), norm="backward")[:-1] ) else: fm_chunk = jnp.fft.fftshift( @@ -305,16 +297,12 @@ def healpix_ifft_numpy( Returns: np.ndarray: HEALPix pixel-space array. """ - f = np.zeros( - samples.f_shape(sampling="healpix", nside=nside), dtype=np.complex128 - ) + f = np.zeros(samples.f_shape(sampling="healpix", nside=nside), dtype=np.complex128) ntheta = ftm.shape[0] index = 0 for t in range(ntheta): nphi = samples.nphi_ring(t, nside) - fm_chunk = ( - ftm[t] if nphi == 2 * L else spectral_folding(ftm[t], nphi, L) - ) + fm_chunk = ftm[t] if nphi == 2 * L else spectral_folding(ftm[t], nphi, L) if reality and nphi == 2 * L: f[index : index + nphi] = np.fft.irfft( fm_chunk[nphi // 2 :], nphi, norm="forward" @@ -356,9 +344,7 @@ def healpix_ifft_jax( for t in range(ntheta): nphi = samples.nphi_ring(t, nside) - fm_chunk = ( - ftm[t] if nphi == 2 * L else spectral_folding_jax(ftm[t], nphi, L) - ) + fm_chunk = ftm[t] if nphi == 2 * L else spectral_folding_jax(ftm[t], nphi, L) if reality and nphi == 2 * L: f = f.at[index : index + nphi].set( jnp.fft.irfft(fm_chunk[nphi // 2 :], nphi, norm="forward") @@ -366,9 +352,7 @@ def healpix_ifft_jax( else: f = f.at[index : index + nphi].set( jnp.conj( - jnp.fft.fft( - jnp.fft.ifftshift(jnp.conj(fm_chunk)), norm="backward" - ) + jnp.fft.fft(jnp.fft.ifftshift(jnp.conj(fm_chunk)), norm="backward") ) ) @@ -395,9 +379,7 @@ def p2phi_rings(t: np.ndarray, nside: int) -> np.ndarray: shift * ((t - nside + 2) % 2) * np.pi / (2 * nside), tt, ) - tt = np.where( - t + 1 > 3 * nside, shift * np.pi / (2 * (4 * nside - t - 1)), tt - ) + tt = np.where(t + 1 > 3 * nside, shift * np.pi / (2 * (4 * nside - t - 1)), tt) tt = np.where(t + 1 < nside, shift * np.pi / (2 * (t + 1)), tt) return tt @@ -422,9 +404,7 @@ def p2phi_rings_jax(t: jnp.ndarray, nside: int) -> jnp.ndarray: shift * ((t - nside + 2) % 2) * jnp.pi / (2 * nside), tt, ) - tt = jnp.where( - t + 1 > 3 * nside, shift * jnp.pi / (2 * (4 * nside - t - 1)), tt - ) + tt = jnp.where(t + 1 > 3 * nside, shift * jnp.pi / (2 * (4 * nside - t - 1)), tt) tt = jnp.where(t + 1 < nside, shift * jnp.pi / (2 * (t + 1)), tt) return tt diff --git a/s2fft/utils/quadrature_jax.py b/s2fft/utils/quadrature_jax.py index 492962bc..787c0946 100644 --- a/s2fft/utils/quadrature_jax.py +++ b/s2fft/utils/quadrature_jax.py @@ -48,9 +48,7 @@ def quad_weights_transform( @partial(jit, static_argnums=(0, 1, 2)) -def quad_weights( - L: int = None, sampling: str = "mw", nside: int = None -) -> jnp.ndarray: +def quad_weights(L: int = None, sampling: str = "mw", nside: int = None) -> jnp.ndarray: r"""Compute quadrature weights for :math:`\theta` and :math:`\phi` integration for various sampling schemes. JAX implementation of :func:`~s2fft.quadrature.quad_weights`. @@ -229,9 +227,7 @@ def quad_weights_mw_theta_only(L: int) -> jnp.ndarray: w = w.at[i + L - 1].set(mw_weights(i)) w *= jnp.exp(-1j * jnp.arange(-(L - 1), L) * jnp.pi / (2 * L - 1)) - wr = jnp.real(jnp.fft.fft(jnp.fft.ifftshift(w), norm="backward")) / ( - 2 * L - 1 - ) + wr = jnp.real(jnp.fft.fft(jnp.fft.ifftshift(w), norm="backward")) / (2 * L - 1) q = wr[:L] q = q.at[: L - 1].add(wr[-1 : L - 1 : -1]) diff --git a/s2fft/utils/resampling.py b/s2fft/utils/resampling.py index eb4d8e9a..606ea3a5 100644 --- a/s2fft/utils/resampling.py +++ b/s2fft/utils/resampling.py @@ -33,9 +33,7 @@ def periodic_extension( f_ext = np.zeros((f.shape[0], ntheta_ext, nphi), dtype=np.complex128) f_ext[:, 0:ntheta, 0:nphi] = f[:, 0:ntheta, 0:nphi] - f_ext = np.fft.fftshift( - np.fft.fft(f_ext, axis=-1, norm="backward"), axes=-1 - ) + f_ext = np.fft.fftshift(np.fft.fft(f_ext, axis=-1, norm="backward"), axes=-1) ind1l = L + m_offset ind1u = 2 * L - 1 + m_offset @@ -48,9 +46,7 @@ def periodic_extension( ], axis=-2, ) - f_ext[:, ind1l:ind1u, m_offset:ind2u] *= (-1) ** np.abs( - np.arange(-(L - 1), L) - ) + f_ext[:, ind1l:ind1u, m_offset:ind2u] *= (-1) ** np.abs(np.arange(-(L - 1), L)) if hasattr(spin, "__len__"): f_ext[:, ind1l:ind1u, m_offset:ind2u] = np.einsum( "nlm,n->nlm", @@ -60,14 +56,10 @@ def periodic_extension( else: f_ext[:, ind1l:ind1u, m_offset:ind2u] *= (-1) ** np.abs(spin) - return np.fft.ifft( - np.fft.ifftshift(f_ext, axes=-1), axis=-1, norm="backward" - ) + return np.fft.ifft(np.fft.ifftshift(f_ext, axes=-1), axis=-1, norm="backward") -def periodic_extension_spatial_mwss( - f: np.ndarray, L: int, spin: int = 0 -) -> np.ndarray: +def periodic_extension_spatial_mwss(f: np.ndarray, L: int, spin: int = 0) -> np.ndarray: r"""Perform period extension of MWSS signal on the sphere in spatial domain, extending :math:`\theta` domain from :math:`[0,\pi]` to :math:`[0,2\pi]`. @@ -95,9 +87,7 @@ def periodic_extension_spatial_mwss( if hasattr(spin, "__len__"): f_ext[:, ntheta:, 0 : 2 * L] = np.einsum( "btp,b->btp", - np.fft.fftshift( - np.flip(f[:, 1 : ntheta - 1, 0 : 2 * L], axis=-2), axes=-1 - ), + np.fft.fftshift(np.flip(f[:, 1 : ntheta - 1, 0 : 2 * L], axis=-2), axes=-1), (-1) ** np.abs(spin), ) else: @@ -157,13 +147,9 @@ def upsample_by_two_mwss_ext(f_ext: np.ndarray, L: int) -> np.ndarray: f_ext = np.fft.fftshift(np.fft.fft(f_ext, axis=-2, norm="forward"), axes=-2) ntheta_ext_up = 2 * ntheta_ext - f_ext_up = np.zeros( - (f_ext.shape[0], ntheta_ext_up, nphi), dtype=np.complex128 - ) + f_ext_up = np.zeros((f_ext.shape[0], ntheta_ext_up, nphi), dtype=np.complex128) f_ext_up[:, L : ntheta_ext + L, :nphi] = f_ext[:, 0:ntheta_ext, :nphi] - return np.fft.ifft( - np.fft.ifftshift(f_ext_up, axes=-2), axis=-2, norm="forward" - ) + return np.fft.ifft(np.fft.ifftshift(f_ext_up, axes=-2), axis=-2, norm="forward") def downsample_by_two_mwss(f_ext: np.ndarray, L: int) -> np.ndarray: @@ -235,11 +221,9 @@ def unextend(f_ext: np.ndarray, L: int, sampling: str = "mw") -> np.ndarray: ) if sampling.lower() == "mw": - f = f_ext[:, 0:L, :] elif sampling.lower() == "mwss": - f = f_ext[:, 0 : L + 1, :] else: @@ -282,9 +266,7 @@ def mw_to_mwss_phi(f_mw: np.ndarray, L: int) -> np.ndarray: np.fft.fft(f_mw, axis=-1, norm="forward"), axes=-1 ) - return np.fft.ifft( - np.fft.ifftshift(f_mwss, axes=-1), axis=-1, norm="forward" - ) + return np.fft.ifft(np.fft.ifftshift(f_mwss, axes=-1), axis=-1, norm="forward") def mw_to_mwss_theta(f_mw: np.ndarray, L: int, spin: int = 0) -> np.ndarray: @@ -311,9 +293,7 @@ def mw_to_mwss_theta(f_mw: np.ndarray, L: int, spin: int = 0) -> np.ndarray: sampling in :math:`\phi`. """ f_mw_ext = periodic_extension(f_mw, L, spin=spin, sampling="mw") - fmp_mwss_ext = np.zeros( - (f_mw_ext.shape[0], 2 * L, 2 * L - 1), dtype=np.complex128 - ) + fmp_mwss_ext = np.zeros((f_mw_ext.shape[0], 2 * L, 2 * L - 1), dtype=np.complex128) fmp_mwss_ext[:, 1:, :] = np.fft.fftshift( np.fft.fft(f_mw_ext, axis=-2, norm="forward"), axes=-2 @@ -353,9 +333,7 @@ def mw_to_mwss(f_mw: np.ndarray, L: int, spin: int = 0) -> np.ndarray: """ if f_mw.ndim == 2: return np.squeeze( - mw_to_mwss_phi( - mw_to_mwss_theta(np.expand_dims(f_mw, 0), L, spin), L - ) + mw_to_mwss_phi(mw_to_mwss_theta(np.expand_dims(f_mw, 0), L, spin), L) ) else: return mw_to_mwss_phi(mw_to_mwss_theta(f_mw, L, spin), L) diff --git a/s2fft/utils/resampling_jax.py b/s2fft/utils/resampling_jax.py index edcc87b5..976f6d20 100644 --- a/s2fft/utils/resampling_jax.py +++ b/s2fft/utils/resampling_jax.py @@ -29,9 +29,7 @@ def mw_to_mwss(f_mw: jnp.ndarray, L: int, spin: int = 0) -> jnp.ndarray: """ if f_mw.ndim == 2: return jnp.squeeze( - mw_to_mwss_phi( - mw_to_mwss_theta(jnp.expand_dims(f_mw, 0), L, spin), L - ) + mw_to_mwss_phi(mw_to_mwss_theta(jnp.expand_dims(f_mw, 0), L, spin), L) ) else: return mw_to_mwss_phi(mw_to_mwss_theta(f_mw, L, spin), L) @@ -70,9 +68,7 @@ def mw_to_mwss_theta(f_mw: jnp.ndarray, L: int, spin: int = 0) -> jnp.ndarray: ) fmp_mwss_ext = fmp_mwss_ext.at[:, 1:, :].set( - jnp.fft.fftshift( - jnp.fft.fft(f_mw_ext, axis=-2, norm="forward"), axes=-2 - ) + jnp.fft.fftshift(jnp.fft.fft(f_mw_ext, axis=-2, norm="forward"), axes=-2) ) fmp_mwss_ext = fmp_mwss_ext.at[:, 1:, :].set( @@ -171,9 +167,7 @@ def periodic_extension( f_ext = jnp.zeros((f.shape[0], ntheta_ext, nphi), dtype=jnp.complex128) f_ext = f_ext.at[:, 0:ntheta, 0:nphi].set(f[:, 0:ntheta, 0:nphi]) - f_ext = jnp.fft.fftshift( - jnp.fft.fft(f_ext, axis=-1, norm="backward"), axes=-1 - ) + f_ext = jnp.fft.fftshift(jnp.fft.fft(f_ext, axis=-1, norm="backward"), axes=-1) f_ext = f_ext.at[ :, @@ -318,14 +312,10 @@ def upsample_by_two_mwss_ext(f_ext: jnp.ndarray, L: int) -> jnp.ndarray: nphi = 2 * L ntheta_ext = 2 * L - f_ext = jnp.fft.fftshift( - jnp.fft.fft(f_ext, axis=-2, norm="forward"), axes=-2 - ) + f_ext = jnp.fft.fftshift(jnp.fft.fft(f_ext, axis=-2, norm="forward"), axes=-2) ntheta_ext_up = 2 * ntheta_ext - f_ext_up = jnp.zeros( - (f_ext.shape[0], ntheta_ext_up, nphi), dtype=jnp.complex128 - ) + f_ext_up = jnp.zeros((f_ext.shape[0], ntheta_ext_up, nphi), dtype=jnp.complex128) f_ext_up = f_ext_up.at[:, L : ntheta_ext + L, :nphi].set( f_ext[:, 0:ntheta_ext, :nphi] ) diff --git a/s2fft/utils/signal_generator.py b/s2fft/utils/signal_generator.py index 745e412a..f500daa9 100644 --- a/s2fft/utils/signal_generator.py +++ b/s2fft/utils/signal_generator.py @@ -33,7 +33,6 @@ def generate_flm( flm = np.zeros(samples.flm_shape(L), dtype=np.complex128) for el in range(max(L_lower, abs(spin)), L): - if reality: flm[el, 0 + L - 1] = rng.uniform() else: @@ -78,9 +77,7 @@ def generate_flmn( flmn = np.zeros(wigner_samples.flmn_shape(L, N), dtype=np.complex128) for n in range(-N + 1, N): - for el in range(max(L_lower, abs(n)), L): - if reality: flmn[N - 1 + n, el, 0 + L - 1] = rng.uniform() flmn[N - 1 - n, el, 0 + L - 1] = (-1) ** n * flmn[ @@ -89,21 +86,15 @@ def generate_flmn( 0 + L - 1, ] else: - flmn[N - 1 + n, el, 0 + L - 1] = ( - rng.uniform() + 1j * rng.uniform() - ) + flmn[N - 1 + n, el, 0 + L - 1] = rng.uniform() + 1j * rng.uniform() for m in range(1, el + 1): - flmn[N - 1 + n, el, m + L - 1] = ( - rng.uniform() + 1j * rng.uniform() - ) + flmn[N - 1 + n, el, m + L - 1] = rng.uniform() + 1j * rng.uniform() if reality: flmn[N - 1 - n, el, -m + L - 1] = (-1) ** (m + n) * np.conj( flmn[N - 1 + n, el, m + L - 1] ) else: - flmn[N - 1 + n, el, -m + L - 1] = ( - rng.uniform() + 1j * rng.uniform() - ) + flmn[N - 1 + n, el, -m + L - 1] = rng.uniform() + 1j * rng.uniform() return flmn diff --git a/setup.py b/setup.py index 6c1d2d35..baee83a5 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ name="s2fft", version="0.0.1", url="https://github.com/astro-informatics/s2fft", - author="Authors & Contributors", + author="Matthew A. Price, Jason D. McEwen & Contributors", license="GNU General Public License v3 (GPLv3)", python_requires=">=3.8", install_requires=requirements, @@ -33,5 +33,5 @@ long_description=long_description, packages=find_packages(), include_package_data=True, - pacakge_data={"s2fft": ["default-logging-config.yaml"]} + pacakge_data={"s2fft": ["default-logging-config.yaml"]}, ) diff --git a/tests/test_logs.py b/tests/test_logs.py index 9c1c9fc7..e742d0a3 100644 --- a/tests/test_logs.py +++ b/tests/test_logs.py @@ -3,7 +3,6 @@ def test_incorrect_log_yaml_path(): - dir_name = "random/incorrect/filepath/" # Check cannot add samples with different ndim. @@ -12,7 +11,6 @@ def test_incorrect_log_yaml_path(): def test_general_logging(): - lg.setup_logging() lg.critical_log("A random critical message") lg.debug_log("A random debug message") diff --git a/tests/test_quadrature.py b/tests/test_quadrature.py index ff6c524c..dfe22aa3 100644 --- a/tests/test_quadrature.py +++ b/tests/test_quadrature.py @@ -8,7 +8,6 @@ @pytest.mark.parametrize("L", [5, 6]) @pytest.mark.parametrize("sampling", ["mw", "mwss"]) def test_quadrature_mw_weights(flm_generator, L: int, sampling: str): - spin = 0 q = quadrature.quad_weights(L, sampling, spin) @@ -29,7 +28,6 @@ def test_quadrature_mw_weights(flm_generator, L: int, sampling: str): def test_quadrature_exceptions(): - L = 10 with pytest.raises(ValueError) as e: diff --git a/tests/test_resampling.py b/tests/test_resampling.py index a878c423..87ecaf99 100644 --- a/tests/test_resampling.py +++ b/tests/test_resampling.py @@ -5,7 +5,6 @@ def test_periodic_extension_invalid_sampling(): - f_dummy = np.zeros((2, 2), dtype=np.complex128) with pytest.raises(ValueError) as e: @@ -20,7 +19,6 @@ def test_periodic_extension_invalid_sampling(): "spin_reality", [(0, True), (0, False), (1, False), (2, False)] ) def test_periodic_extension_mwss(flm_generator, L: int, spin_reality): - (spin, reality) = spin_reality flm = flm_generator(L=L, spin=spin, reality=reality) f = spherical.inverse(flm, L, spin, sampling="mwss") @@ -38,7 +36,6 @@ def test_periodic_extension_mwss(flm_generator, L: int, spin_reality): "spin_reality", [(0, True), (0, False), (1, False), (2, False)] ) def test_mwss_upsample_downsample(flm_generator, L: int, spin_reality): - (spin, reality) = spin_reality flm = flm_generator(L=L, spin=spin, reality=reality) f = spherical.inverse(flm, L, spin, sampling="mwss") @@ -59,7 +56,6 @@ def test_mwss_upsample_downsample(flm_generator, L: int, spin_reality): "spin_reality", [(0, True), (0, False), (1, False), (2, False)] ) def test_unextend(flm_generator, L: int, sampling: str, spin_reality): - (spin, reality) = spin_reality flm = flm_generator(L=L, spin=spin, reality=reality) f = spherical.inverse(flm, L, spin, sampling=sampling) @@ -72,7 +68,6 @@ def test_unextend(flm_generator, L: int, sampling: str, spin_reality): def test_resampling_exceptions(): - f_dummy = np.zeros((2, 2), dtype=np.complex128) with pytest.raises(ValueError) as e: @@ -98,7 +93,6 @@ def test_resampling_exceptions(): "spin_reality", [(0, True), (0, False), (1, False), (2, False)] ) def test_mw_to_mwss_theta(flm_generator, L: int, spin_reality): - (spin, reality) = spin_reality flm = flm_generator(L=L, spin=spin, reality=reality) f_mw = spherical.inverse(flm, L, spin, sampling="mw") diff --git a/tests/test_samples.py b/tests/test_samples.py index c5a8292b..44f18687 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -11,7 +11,6 @@ @pytest.mark.parametrize("L", [15, 16]) @pytest.mark.parametrize("sampling", ["mw", "mwss", "dh"]) def test_samples_n_and_angles(L: int, sampling: str): - # Test ntheta and nphi ntheta = samples.ntheta(L, sampling) nphi = samples.nphi_equiang(L, sampling) @@ -28,17 +27,12 @@ def test_samples_n_and_angles(L: int, sampling: str): np.testing.assert_allclose(phis, phis_ssht, atol=1e-14) # Test direct thetas and phis - np.testing.assert_allclose( - samples.thetas(L, sampling), thetas_ssht, atol=1e-14 - ) - np.testing.assert_allclose( - samples.phis_equiang(L, sampling), phis_ssht, atol=1e-14 - ) + np.testing.assert_allclose(samples.thetas(L, sampling), thetas_ssht, atol=1e-14) + np.testing.assert_allclose(samples.phis_equiang(L, sampling), phis_ssht, atol=1e-14) @pytest.mark.parametrize("ind", [15, 16]) def test_samples_index_conversion(ind: int): - (el, m) = samples.ind2elm(ind) ind_check = samples.elm2ind(el, m) @@ -48,7 +42,6 @@ def test_samples_index_conversion(ind: int): @pytest.mark.parametrize("L", [15, 16]) def test_samples_ncoeff(L: int): - n = 0 for el in range(0, L): for m in range(-el, el + 1): @@ -59,7 +52,6 @@ def test_samples_ncoeff(L: int): @pytest.mark.parametrize("nside", nside_to_test) def test_samples_n_and_angles_hp(nside: int): - ntheta = samples.ntheta(L=0, sampling="healpix", nside=nside) assert ntheta == 4 * nside - 1 @@ -82,7 +74,6 @@ def test_samples_n_and_angles_hp(nside: int): @pytest.mark.parametrize("nside", nside_to_test) def test_hp_ang2pix(nside: int): - for i in range(12 * nside**2): theta, phi = hp.pix2ang(nside, i) j = samples.hp_ang2pix(nside, theta, phi) @@ -90,7 +81,6 @@ def test_hp_ang2pix(nside: int): def test_samples_exceptions(): - L = 10 with pytest.raises(ValueError) as e: diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index ac9ccce5..0f6b19c0 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -58,9 +58,7 @@ def test_transform_inverse_healpix( flm = flm_generator(L=L, reality=True) f_check = base.inverse(flm, L, 0, sampling, nside, reality) - kernel = spin_spherical_kernel( - L, 0, reality, sampling, nside=nside, forward=False - ) + kernel = spin_spherical_kernel(L, 0, reality, sampling, nside=nside, forward=False) f = inverse(flm, L, 0, kernel, sampling, reality, method, nside) np.testing.assert_allclose(f, f_check, atol=1e-12, rtol=1e-12) @@ -84,9 +82,7 @@ def test_transform_forward( f = base.inverse(flm, L, spin, sampling, reality=reality) flm_check = base.forward(f, L, spin, sampling, reality=reality) - kernel = spin_spherical_kernel( - L, spin, reality, sampling, nside=None, forward=True - ) + kernel = spin_spherical_kernel(L, spin, reality, sampling, nside=None, forward=True) flm_recov = forward(f, L, spin, kernel, sampling, reality, method) for i in range(L): for j in range(2 * L - 1): @@ -111,9 +107,7 @@ def test_transform_forward_healpix( f = base.inverse(flm, L, 0, sampling, nside, reality) flm_check = base.forward(f, L, 0, sampling, nside, reality) - kernel = spin_spherical_kernel( - L, 0, reality, sampling, nside=nside, forward=True - ) + kernel = spin_spherical_kernel(L, 0, reality, sampling, nside=nside, forward=True) flm_recov = forward(f, L, 0, kernel, sampling, reality, method, nside) np.testing.assert_allclose(flm_recov, flm_check, atol=1e-12, rtol=1e-12) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6f012171..d942f329 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,6 @@ def test_flm_reindexing_functions(flm_generator): - L = 16 flm_2d = flm_generator(L=L, spin=0, reality=False) @@ -19,7 +18,6 @@ def test_flm_reindexing_functions(flm_generator): def test_flm_reindexing_functions_healpix(flm_generator): - L = 16 flm_2d = flm_generator(L=L, spin=0, reality=True) flm_hp = samples.flm_2d_to_hp(flm_2d, L) diff --git a/tests/test_wigner_custom_grads.py b/tests/test_wigner_custom_grads.py index c365c11c..df3704f8 100644 --- a/tests/test_wigner_custom_grads.py +++ b/tests/test_wigner_custom_grads.py @@ -68,9 +68,7 @@ def test_forward_wigner_custom_gradients( flmn_target = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality) flmn = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality) - f = wigner.inverse_jax( - flmn, L, N, None, sampling, reality, None, False, L_lower - ) + f = wigner.inverse_jax(flmn, L, N, None, sampling, reality, None, False, L_lower) def func(f): flmn = wigner.forward_jax( diff --git a/tests/test_wigner_precompute.py b/tests/test_wigner_precompute.py index 982301e7..84bef0ea 100644 --- a/tests/test_wigner_precompute.py +++ b/tests/test_wigner_precompute.py @@ -82,9 +82,7 @@ def test_inverse_wigner_transform_healpix( kernel = wigner_kernel(L, N, reality, sampling, nside=nside, forward=False) f_check = inverse(flmn, L, N, kernel, sampling, reality, method, nside) - np.testing.assert_allclose( - np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5 - ) + np.testing.assert_allclose(np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("nside", nside_to_test) diff --git a/tests/test_wigner_recursions.py b/tests/test_wigner_recursions.py index d376d8dc..6c249c7c 100644 --- a/tests/test_wigner_recursions.py +++ b/tests/test_wigner_recursions.py @@ -101,15 +101,11 @@ def test_trapani_interfaces(): dl_jax = recursions.trapani.init(dl_jax, L, implementation="jax") for el in range(1, L): - dl_loop = recursions.trapani.compute_full( - dl_loop, L, el, implementation="loop" - ) + dl_loop = recursions.trapani.compute_full(dl_loop, L, el, implementation="loop") dl_vect = recursions.trapani.compute_full( dl_vect, L, el, implementation="vectorized" ) - dl_jax = recursions.trapani.compute_full( - dl_jax, L, el, implementation="jax" - ) + dl_jax = recursions.trapani.compute_full(dl_jax, L, el, implementation="jax") np.testing.assert_allclose( dl_loop[ -el + (L - 1) : el + (L - 1) + 1, @@ -137,13 +133,10 @@ def test_trapani_interfaces(): recursions.trapani.init(dl_loop, L, implementation="unexpected") with pytest.raises(ValueError) as e: - recursions.trapani.compute_full( - dl_jax, L, el, implementation="unexpected" - ) + recursions.trapani.compute_full(dl_jax, L, el, implementation="unexpected") def test_trapani_checks(): - # TODO # Check throws exception if arguments wrong @@ -184,7 +177,6 @@ def test_turok_with_ssht(L: int, sampling: str): dl_array = ssht.generate_dl(beta, L) for el in range(L): - dl_turok = recursions.turok.compute_full(beta, el, L) np.testing.assert_allclose(dl_turok, dl_array[el], atol=1e-14) @@ -205,7 +197,6 @@ def test_turok_slice_with_ssht(L: int, spin: int, sampling: str): for el in range(L): if el >= np.abs(spin): - dl_turok = recursions.turok.compute_slice(beta, el, L, -spin) np.testing.assert_allclose( @@ -232,9 +223,7 @@ def test_turok_slice_jax_with_ssht(L: int, spin: int, sampling: str): for el in range(L): if el >= np.abs(spin): print("beta {}, el {}, spin {}".format(beta, el, spin)) - dl_turok = recursions.turok_jax.compute_slice( - beta, el, L, -spin - ) + dl_turok = recursions.turok_jax.compute_slice(beta, el, L, -spin) np.testing.assert_allclose( dl_turok[L - 1 - el : L - 1 + el + 1],