Skip to content

Commit

Permalink
approx_spectral: add info tracking and plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 15, 2024
1 parent e582714 commit d4aa80d
Show file tree
Hide file tree
Showing 3 changed files with 3,055 additions and 40 deletions.
2,939 changes: 2,919 additions & 20 deletions docs/calculating quantities.ipynb

Large diffs are not rendered by default.

152 changes: 132 additions & 20 deletions quimb/linalg/approx_spectral.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
"""Use stochastic Lanczos quadrature to approximate spectral function sums of
any operator which has an efficient representation of action on a vector.
"""

import functools
from math import sqrt, log2, exp, inf, nan
import random
import warnings
from math import exp, inf, log2, nan, sqrt

import numpy as np
import scipy.linalg as scla
from scipy.ndimage import uniform_filter1d

from ..core import ptr, prod, vdot, njit, dot, subtract_update_, divide_update_
from ..utils import int2tup, find_library, raise_cant_find_library_function
from ..gen.rand import randn, rand_rademacher, rand_phase, seed_rand
from ..core import divide_update_, dot, njit, prod, ptr, subtract_update_, vdot
from ..gen.rand import rand_phase, rand_rademacher, randn, seed_rand
from ..linalg.mpi_launcher import get_mpi_pool
from ..utils import (
default_to_neutral_style,
find_library,
format_number_with_error,
int2tup,
raise_cant_find_library_function,
)
from ..utils import progbar as Progbar

if find_library("cotengra") and find_library("autoray"):
from ..tensor.tensor_core import Tensor
from ..tensor.tensor_1d import MatrixProductOperator
from ..tensor.tensor_approx_spectral import construct_lanczos_tridiag_MPO
from ..tensor.tensor_core import Tensor
else:
reqs = "[cotengra,autoray]"
Tensor = raise_cant_find_library_function(reqs)
Expand Down Expand Up @@ -216,9 +224,7 @@ def random_rect(
# already normalized

elif dist == "gaussian":
V = randn(
shape, scale=1 / (prod(shape) ** 0.5 * 2**0.5), dtype=dtype
)
V = randn(shape, scale=1 / (prod(shape) ** 0.5 * 2**0.5), dtype=dtype)
if norm:
V /= norm_fro(V)

Expand Down Expand Up @@ -308,7 +314,6 @@ def construct_lanczos_tridiag(
Q = np.copy(q).reshape(-1, 1)

for j in range(1, K + 1):

r = dot(A, q)
subtract_update_(r, beta[j], v)
alpha[j] = inner(q, r)
Expand Down Expand Up @@ -470,17 +475,15 @@ def calc_est_fit(estimates, conv_n, tau):
return est, err


def calc_est_window(estimates, mean_ests, conv_n):
def calc_est_window(estimates, conv_n):
"""Make estimate from mean of last ``m`` samples, following:
1. Take between ``conv_n`` and 12 estimates.
2. Pair the estimates as they are alternate upper/lower bounds
3. Compute the standard error on the paired estimates.
"""
m_est = min(max(conv_n, len(estimates) // 8), 12)

est = sum(estimates[-m_est:]) / len(estimates[-m_est:])
mean_ests.append(est)

if len(estimates) > conv_n:
# check for convergence using variance of paired last m estimates
Expand Down Expand Up @@ -511,6 +514,7 @@ def single_random_estimate(
*,
seed=None,
v0_opts=None,
info=None,
**lanczos_opts,
):
# choose normal (any LinearOperator) or MPO lanczos tridiag construction
Expand All @@ -520,8 +524,10 @@ def single_random_estimate(
lanc_fn = construct_lanczos_tridiag
lanczos_opts["bsz"] = bsz

estimates_raw = []
estimates_window = []
estimates_fit = []
estimates = []
mean_ests = []

# the number of samples to check standard deviation convergence with
conv_n = 6 # 3 pairs
Expand All @@ -537,45 +543,47 @@ def single_random_estimate(
v0_opts=v0_opts,
**lanczos_opts,
):

try:
Tl, Tv = lanczos_tridiag_eig(alpha, beta, check_finite=False)
Gf = scaling * calc_trace_fn_tridiag(Tl, Tv, f=f, pos=pos)
except scla.LinAlgError: # pragma: no cover
warnings.warn("Approx Spectral Gf tri-eig didn't converge.")
estimates.append(np.nan)
estimates_raw.append(np.nan)
continue

k = alpha.size
estimates.append(Gf)
estimates_raw.append(Gf)

# check for break-down convergence (e.g. found entire subspace)
# in which case latest estimate should be accurate
if abs(beta[-1]) < beta_tol:
if verbosity >= 2:
print(f"k={k}: Beta breadown, returning {Gf}.")
est = Gf
estimates.append(est)
break

# compute an estimate and error using a window of the last few results
win_est, win_err = calc_est_window(estimates, mean_ests, conv_n)
win_est, win_err = calc_est_window(estimates_raw, conv_n)
estimates_window.append(win_est)

# try and compute an estimate and error using exponential fit
fit_est, fit_err = calc_est_fit(mean_ests, conv_n, tau)
fit_est, fit_err = calc_est_fit(estimates_window, conv_n, tau)
estimates_fit.append(fit_est)

# take whichever has lowest error
est, err = min(
(win_est, win_err),
(fit_est, fit_err),
key=lambda est_err: est_err[1],
)
estimates.append(est)
converged = err < tau * (abs(win_est) + tol_scale)

if verbosity >= 2:
if verbosity >= 3:
print(f"est_win={win_est}, err_win={win_err}")
print(f"est_fit={fit_est}, err_fit={fit_err}")

print(f"k={k}: Gf={Gf}, Est={est}, Err={err}")
if converged:
print(f"k={k}: Converged to tau {tau}.")
Expand All @@ -586,6 +594,16 @@ def single_random_estimate(
if verbosity >= 1:
print(f"k={k}: Returning estimate {est}.")

if info is not None:
if "estimates_raw" in info:
info["estimates_raw"].append(estimates_raw)
if "estimates_window" in info:
info["estimates_window"].append(estimates_window)
if "estimates_fit" in info:
info["estimates_fit"].append(estimates_fit)
if "estimates" in info:
info["estimates"].append(estimates)

return est


Expand Down Expand Up @@ -626,13 +644,68 @@ def get_equivalent_real_dtype(dtype):
raise ValueError(f"dtype {dtype} not understood.")


@default_to_neutral_style
def plot_approx_spectral_info(info):
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator

fig, axs = plt.subplots(
ncols=2,
figsize=(8, 4),
sharey=True,
gridspec_kw={"width_ratios": [3, 1]},
)
plt.subplots_adjust(wspace=0.0)

Z = info["estimate"]

alpha = len(info["estimates_raw"])**-(1 / 6)

# plot the raw kyrlov runs
for x in info["estimates_raw"]:
axs[0].plot(x, ".-", alpha=alpha, lw=1 / 2, zorder=-10, markersize=1)
axs[0].axhline(Z - info["error"], color="grey", linestyle="--")
axs[0].axhline(Z + info["error"], color="grey", linestyle="--")
axs[0].axhline(Z, color="black", linestyle="--")
axs[0].set_rasterization_zorder(-5)
axs[0].set_xlabel("krylov iteration (offset)")
axs[0].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[0].set_ylabel("$Tr[f(x)]$ approximation")

# plot the overall final samples
axs[1].hist(
info["samples"],
bins=round(len(info["samples"])**0.5),
orientation="horizontal",
color=(0.2, 0.6, 1.0),
)
axs[1].axhline(Z - info["error"], color="grey", linestyle="--")
axs[1].axhline(Z + info["error"], color="grey", linestyle="--")
axs[1].axhline(Z, color="black", linestyle="--")
axs[1].set_xlabel("sample count")
axs[1].set_title(
"estimate ≈ " + format_number_with_error(Z, info["error"]),
ha="right",
)

# plot the correlation between raw and fitted estimates
iax = axs[0].inset_axes((0.03, 0.6, 0.3, 0.3))
iax.set_aspect("equal")
x = [es[-1] for es in info["estimates"]]
y = [es[-1] for es in info["estimates_raw"]]
iax.scatter(x, y, marker=".", alpha=alpha, color=(0.3, 0.7, 0.3), s=1)

return fig, axs


def approx_spectral_function(
A,
f,
tol=1e-2,
*,
bsz=1,
R=1024,
R_min=3,
tol_scale=1,
tau=1e-4,
k_min=10,
Expand All @@ -646,6 +719,8 @@ def approx_spectral_function(
verbosity=0,
single_precision="AUTO",
info=None,
progbar=False,
plot=False,
**lanczos_opts,
):
"""Approximate a spectral function, that is, the quantity ``Tr(f(A))``.
Expand All @@ -671,6 +746,8 @@ def approx_spectral_function(
Increasing this should increase accuracy as ``sqrt(R)``. Cost of
algorithm thus scales linearly with ``R``. If ``tol`` is non-zero, this
is the maximum number of repeats.
R_min : int, optional
The minimum number of repeats to perform. Default: 3.
tau : float, optional
The relative tolerance required for a single lanczos run to converge.
This needs to be small enough that each estimate with a single random
Expand Down Expand Up @@ -742,6 +819,18 @@ def approx_spectral_function(
if verbosity:
print(f"LANCZOS f(A) CALC: tol={tol}, tau={tau}, R={R}, bsz={bsz}")

if plot:
# need to store all the info
if info is None:
info = {}
info.setdefault('estimate', None)
info.setdefault('error', None)
info.setdefault('samples', None)
info.setdefault('estimates_raw', [])
info.setdefault('estimates_window', [])
info.setdefault('estimates_fit', [])
info.setdefault('estimates', [])

# generate repeat estimates
kwargs = {
"A": A,
Expand All @@ -755,6 +844,7 @@ def approx_spectral_function(
"k_min": k_min,
"tol_scale": tol_scale,
"verbosity": verbosity,
"info": info,
**lanczos_opts,
}

Expand All @@ -773,6 +863,11 @@ def gen_results():
for f in fs:
yield f.result()

if progbar:
pbar = Progbar(total=R)
else:
pbar = None

# iterate through estimates, waiting for convergence
results = gen_results()
estimate = None
Expand All @@ -784,7 +879,7 @@ def gen_results():
print(f"Repeat {len(samples)}: estimate is {samples[-1]}")

# wait a few iterations before checking error on mean breakout
if len(samples) >= 3:
if len(samples) >= R_min:
estimate, err, converged = calc_stats(
samples, mean_p, mean_s, tol, tol_scale
)
Expand All @@ -795,6 +890,18 @@ def gen_results():
print(f"Repeat {len(samples)}: converged to tol {tol}")
break

if pbar:
if len(samples) < R_min:
estimate, err, _ = calc_stats(
samples, mean_p, mean_s, tol, tol_scale
)
pbar.set_description(format_number_with_error(estimate, err))

if pbar:
pbar.update()
if pbar:
pbar.close()

if mpi:
# deal with remaining futures
extra_futures = []
Expand Down Expand Up @@ -822,6 +929,11 @@ def gen_results():
info["samples"] = samples
if "error" in info:
info["error"] = err
if "estimate" in info:
info["estimate"] = estimate

if plot:
info["fig"], info["axs"] = plot_approx_spectral_info(info)

return estimate

Expand Down
4 changes: 4 additions & 0 deletions tests/test_linalg/test_approx_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ def test_approx_spectral_subspaces_with_heis_partition(self, bsz):
approx_Z = tr_exp_approx(-beta * h, bsz=bsz)
assert_allclose(actual_Z, approx_Z, rtol=3e-2)

def test_approx_spectral_plot(self):
X = rand_herm(1000, sparse=True)
approx_spectral_function(X, lambda x: abs(x), plot=True)


# ------------------------ Test specific quantities ------------------------- #

Expand Down

0 comments on commit d4aa80d

Please sign in to comment.