From f4ff318401848949d9cee25daabfeac5d4f92ec4 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Tue, 17 Sep 2024 16:22:58 -0700 Subject: [PATCH] add some symmray support --- quimb/tensor/tensor_core.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index cb8d2548..9e14e174 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -903,6 +903,13 @@ def tensor_multifuse(ts, inds, gauges=None): # contract into a single gauge gauges[inds[0]] = functools.reduce(lambda x, y: do("kron", x, y), gs) + if hasattr(ts[0].data, "align_axes"): + arrays = [t.data for t in ts] + axes = [tuple(map(t.inds.index, inds)) for t in ts] + arrays = do("align_axes", *arrays, axes) + for t, a in zip(ts, arrays): + t.modify(data=a) + # index fusing for t in ts: t.fuse_({inds[0]: inds}) @@ -4178,6 +4185,19 @@ def conj(self, mangle_inner=False, inplace=False): append = None if mangle_inner is True else str(mangle_inner) tn.mangle_inner_(append) + if hasattr(next(iter(tn.tensor_map.values())), "phase_flip"): + # need to phase dual outer indices + outer_inds = tn.outer_inds() + for t in tn: + data = t.data + dual_outer_axs = tuple( + ax + for ax, ix in enumerate(t.inds) + if (ix in outer_inds) and not data.indices[ax].dual + ) + if dual_outer_axs: + t.modify(data=data.phase_flip(*dual_outer_axs)) + return tn conj_ = functools.partialmethod(conj, inplace=True)