Skip to content

Commit

Permalink
TN: tweaks for simple update, gauging and symmray compat
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 1, 2024
1 parent 8dec023 commit b2e83dc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
28 changes: 18 additions & 10 deletions quimb/tensor/tensor_arbgeom.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,8 @@ def normalize_simple(self, gauges, **contract_opts):
for ix, g in gauges.items():
gauges[ix] = g / do("linalg.norm", g)

nfactor = 1.0

# normalize sites
for site in self.sites:
tn_site = self.select(site)
Expand All @@ -654,6 +656,9 @@ def normalize_simple(self, gauges, **contract_opts):
all, **contract_opts
) ** 0.5
tn_site /= lnorm
nfactor *= lnorm

return nfactor


def gauge_product_boundary_vector(
Expand Down Expand Up @@ -1030,7 +1035,7 @@ def gate_simple_(
renorm=True,
smudge=1e-12,
power=1.0,
**gate_opts
**gate_opts,
):
"""Apply a gate to this vector tensor network at sites ``where``, using
simple update style gauging of the tensors first, as supplied in
Expand Down Expand Up @@ -1063,15 +1068,16 @@ def gate_simple_(
if isinstance(where, int):
where = (where,)

if len(where) == 1:
# single site gate
site_tags = tuple(map(self.site_tag, where))
tids = self._get_tids_from_tags(site_tags, "any")

if len(tids) == 1:
# gate acts on a single tensor
return self.gate_(G, where, contract=True)

gate_opts.setdefault("absorb", None)
gate_opts.setdefault("contract", "reduce-split")

site_tags = tuple(map(self.site_tag, where))
tn_where = self.select_any(site_tags)
tn_where = self._select_tids(tids)

with tn_where.gauge_simple_temp(
gauges,
Expand All @@ -1085,7 +1091,7 @@ def gate_simple_(
# inner ungauging is performed by tracking the new singular values
(((_, ix), s),) = info.items()
if renorm:
s = s / do("max", s)
s = s / do("linalg.norm", s)
gauges[ix] = s

return self
Expand All @@ -1104,7 +1110,7 @@ def gate_fit_local_(
tuple(map(self.site_tag, where)), "any"
)
if len(tids) == 2:
tids = self._get_string_between_tids(*tids)
tids = self.get_path_between_tids(*tids).tids

k = self._select_local_tids(
tids,
Expand Down Expand Up @@ -1133,6 +1139,8 @@ def local_expectation_cluster(
max_distance=0,
fillin=False,
gauges=None,
smudge=0.0,
power=1.0,
optimize="auto",
max_bond=None,
rehearse=False,
Expand Down Expand Up @@ -1208,7 +1216,7 @@ def local_expectation_cluster(
)

if len(tids) == 2:
tids = self._get_string_between_tids(*tids)
tids = self.get_path_between_tids(*tids).tids

k = self._select_local_tids(
tids,
Expand All @@ -1219,7 +1227,7 @@ def local_expectation_cluster(

if gauges is not None:
# gauge the region with simple update style bond gauges
k.gauge_simple_insert(gauges)
k.gauge_simple_insert(gauges, smudge=smudge, power=power)

if max_bond is not None:
return k.local_expectation(
Expand Down
2 changes: 1 addition & 1 deletion quimb/tensor/tensor_arbgeom_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def evolve(self, steps, tau=None, progbar=None):
if progbar is None:
progbar = self.progbar

pbar = Progbar(total=steps, disable=progbar is not True)
pbar = Progbar(total=steps, disable=not progbar)

try:
for i in range(steps):
Expand Down
16 changes: 12 additions & 4 deletions quimb/tensor/tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3277,8 +3277,9 @@ def _tensor_network_gate_inds_basic(
tl, tr = tn._inds_get(ixl, ixr)
bnds_l, (bix,), bnds_r = group_inds(tl, tr)

if len(bnds_l) <= 2 and len(bnds_r) <= 2:
# reduce split is likely redundant
if (len(bnds_l) <= 2) or (len(bnds_r) <= 2):
# reduce split is likely redundant (i.e. contracting pair and splitting
# just as cheap as performing QR reductions)
contract = "split"

if contract == "split":
Expand Down Expand Up @@ -6801,6 +6802,13 @@ def gauge_all_simple(
if not gauges_supplied:
gauges = {}

_sval_mapper = {
(True, True): lambda s: s,
(True, False): lambda s: s + smudge,
(False, True): lambda s: s**power,
(False, False): lambda s: (s + smudge) ** power,
}[(power == 1.0, smudge == 0.0)]

# for retrieving singular values
info = {}

Expand Down Expand Up @@ -6837,7 +6845,7 @@ def gauge_all_simple(
for t, ixs in ((t1, lix), (t2, rix)):
for ix in ixs:
try:
s = (gauges[ix] + smudge)**power
s = _sval_mapper(gauges[ix])
except KeyError:
continue
t.multiply_index_diagonal_(ix, s)
Expand All @@ -6854,7 +6862,7 @@ def gauge_all_simple(
)

s = info["singular_values"]
smax = do("max", s)
smax = do("linalg.norm", s)
new_gauge = s / smax
nfact = do("log10", smax) + nfact

Expand Down

0 comments on commit b2e83dc

Please sign in to comment.