From b2e83dc31a884c3cb6c246c1e9970d3cb27215fb Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Tue, 1 Oct 2024 15:23:24 -0700 Subject: [PATCH] TN: tweaks for simple update, gauging and symmray compat --- quimb/tensor/tensor_arbgeom.py | 28 ++++++++++++++++++---------- quimb/tensor/tensor_arbgeom_tebd.py | 2 +- quimb/tensor/tensor_core.py | 16 ++++++++++++---- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/quimb/tensor/tensor_arbgeom.py b/quimb/tensor/tensor_arbgeom.py index 9ed9b1fc..48ddaf85 100644 --- a/quimb/tensor/tensor_arbgeom.py +++ b/quimb/tensor/tensor_arbgeom.py @@ -645,6 +645,8 @@ def normalize_simple(self, gauges, **contract_opts): for ix, g in gauges.items(): gauges[ix] = g / do("linalg.norm", g) + nfactor = 1.0 + # normalize sites for site in self.sites: tn_site = self.select(site) @@ -654,6 +656,9 @@ def normalize_simple(self, gauges, **contract_opts): all, **contract_opts ) ** 0.5 tn_site /= lnorm + nfactor *= lnorm + + return nfactor def gauge_product_boundary_vector( @@ -1030,7 +1035,7 @@ def gate_simple_( renorm=True, smudge=1e-12, power=1.0, - **gate_opts + **gate_opts, ): """Apply a gate to this vector tensor network at sites ``where``, using simple update style gauging of the tensors first, as supplied in @@ -1063,15 +1068,16 @@ def gate_simple_( if isinstance(where, int): where = (where,) - if len(where) == 1: - # single site gate + site_tags = tuple(map(self.site_tag, where)) + tids = self._get_tids_from_tags(site_tags, "any") + + if len(tids) == 1: + # gate acts on a single tensor return self.gate_(G, where, contract=True) gate_opts.setdefault("absorb", None) gate_opts.setdefault("contract", "reduce-split") - - site_tags = tuple(map(self.site_tag, where)) - tn_where = self.select_any(site_tags) + tn_where = self._select_tids(tids) with tn_where.gauge_simple_temp( gauges, @@ -1085,7 +1091,7 @@ def gate_simple_( # inner ungauging is performed by tracking the new singular values (((_, ix), s),) = info.items() if renorm: - s = s / do("max", s) + s = s / do("linalg.norm", s) gauges[ix] = s return self @@ -1104,7 +1110,7 @@ def gate_fit_local_( tuple(map(self.site_tag, where)), "any" ) if len(tids) == 2: - tids = self._get_string_between_tids(*tids) + tids = self.get_path_between_tids(*tids).tids k = self._select_local_tids( tids, @@ -1133,6 +1139,8 @@ def local_expectation_cluster( max_distance=0, fillin=False, gauges=None, + smudge=0.0, + power=1.0, optimize="auto", max_bond=None, rehearse=False, @@ -1208,7 +1216,7 @@ def local_expectation_cluster( ) if len(tids) == 2: - tids = self._get_string_between_tids(*tids) + tids = self.get_path_between_tids(*tids).tids k = self._select_local_tids( tids, @@ -1219,7 +1227,7 @@ def local_expectation_cluster( if gauges is not None: # gauge the region with simple update style bond gauges - k.gauge_simple_insert(gauges) + k.gauge_simple_insert(gauges, smudge=smudge, power=power) if max_bond is not None: return k.local_expectation( diff --git a/quimb/tensor/tensor_arbgeom_tebd.py b/quimb/tensor/tensor_arbgeom_tebd.py index f6092d50..7bec8643 100644 --- a/quimb/tensor/tensor_arbgeom_tebd.py +++ b/quimb/tensor/tensor_arbgeom_tebd.py @@ -582,7 +582,7 @@ def evolve(self, steps, tau=None, progbar=None): if progbar is None: progbar = self.progbar - pbar = Progbar(total=steps, disable=progbar is not True) + pbar = Progbar(total=steps, disable=not progbar) try: for i in range(steps): diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index 30ec273e..31f17a6c 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -3277,8 +3277,9 @@ def _tensor_network_gate_inds_basic( tl, tr = tn._inds_get(ixl, ixr) bnds_l, (bix,), bnds_r = group_inds(tl, tr) - if len(bnds_l) <= 2 and len(bnds_r) <= 2: - # reduce split is likely redundant + if (len(bnds_l) <= 2) or (len(bnds_r) <= 2): + # reduce split is likely redundant (i.e. contracting pair and splitting + # just as cheap as performing QR reductions) contract = "split" if contract == "split": @@ -6801,6 +6802,13 @@ def gauge_all_simple( if not gauges_supplied: gauges = {} + _sval_mapper = { + (True, True): lambda s: s, + (True, False): lambda s: s + smudge, + (False, True): lambda s: s**power, + (False, False): lambda s: (s + smudge) ** power, + }[(power == 1.0, smudge == 0.0)] + # for retrieving singular values info = {} @@ -6837,7 +6845,7 @@ def gauge_all_simple( for t, ixs in ((t1, lix), (t2, rix)): for ix in ixs: try: - s = (gauges[ix] + smudge)**power + s = _sval_mapper(gauges[ix]) except KeyError: continue t.multiply_index_diagonal_(ix, s) @@ -6854,7 +6862,7 @@ def gauge_all_simple( ) s = info["singular_values"] - smax = do("max", s) + smax = do("linalg.norm", s) new_gauge = s / smax nfact = do("log10", smax) + nfact