diff --git a/docs/calculating quantities.ipynb b/docs/calculating quantities.ipynb index 76890ee6..b1efbd9f 100644 --- a/docs/calculating quantities.ipynb +++ b/docs/calculating quantities.ipynb @@ -52,7 +52,7 @@ { "data": { "text/plain": [ - "54.321684719154966" + "54.32566715771646" ] }, "execution_count": 1, @@ -76,7 +76,7 @@ { "data": { "text/plain": [ - "54.00283979240904" + "54.45773513531054" ] }, "execution_count": 2, @@ -111,7 +111,7 @@ { "data": { "text/plain": [ - "5.743478523097124" + "5.7437225898123465" ] }, "execution_count": 3, @@ -127,11 +127,9 @@ }, { "cell_type": "markdown", - "metadata": { - "raw_mimetype": "text/restructuredtext" - }, + "metadata": {}, "source": [ - "Under the hood, this method makes use of {mod}`quimb.tensor` functionality, which allows [various tensor contraction backends](http://optimized-einsum.readthedocs.io/en/latest/backends.html) to be used (see {func}`~quimb.tensor.tensor_core.set_tensor_linop_backend`). These types of computation are particularly suited to the GPU and therefore if [cupy](https://cupy.chainer.org) is installed it will be used automatically:" + "To inspect the process you can supply an `info` dict, and optionally also have the trials plotted like so:" ] }, { @@ -140,15 +138,2913 @@ "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "2.21 s ± 194 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "1.217(12)e+11: 40%|############################4 | 411/1024 [00:34<00:50, 12.03it/s]\n" ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-05-15T16:38:59.856614\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" } ], "source": [ - "%timeit qu.logneg_subsys_approx(psi, dims, sysa=0, sysb=2, backend='numpy')" + "%config InlineBackend.figure_formats = ['svg']\n", + "\n", + "# estimate a 2D partition function\n", + "beta = 2.0\n", + "H = qu.ham_heis_2D(4, 4, bz=1.7, sparse=True)\n", + "\n", + "# if plot=False, the desired keys should be filled before supplying info\n", + "info = {}\n", + "\n", + "Z = qu.approx_spectral_function(\n", + " H, f=lambda x: qu.exp(-beta * x), \n", + " tol=1e-2, \n", + " info=info, \n", + " progbar=True,\n", + " plot=True,\n", + ")" ] }, { @@ -157,24 +3053,27 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "111 ms ± 6.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] + "data": { + "text/plain": [ + "dict_keys(['estimate', 'error', 'samples', 'estimates_raw', 'estimates_window', 'estimates_fit', 'estimates', 'fig', 'axs'])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "%timeit qu.logneg_subsys_approx(psi, dims, sysa=0, sysb=2, backend='cupy')" + "info.keys()" ] } ], "metadata": { "celltoolbar": "Raw Cell Format", "kernelspec": { - "display_name": "Python 3.10.8 ('numpy')", + "display_name": "Python [conda env:cupy]", "language": "python", - "name": "python3" + "name": "conda-env-cupy-py" }, "language_info": { "codemirror_mode": { @@ -186,7 +3085,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8 (main, Nov 4 2022, 13:48:29) [GCC 11.2.0]" + "version": "3.11.9" }, "vscode": { "interpreter": { @@ -195,5 +3094,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/quimb/linalg/approx_spectral.py b/quimb/linalg/approx_spectral.py index 44d6a857..ddb13f12 100644 --- a/quimb/linalg/approx_spectral.py +++ b/quimb/linalg/approx_spectral.py @@ -1,24 +1,32 @@ """Use stochastic Lanczos quadrature to approximate spectral function sums of any operator which has an efficient representation of action on a vector. """ + import functools -from math import sqrt, log2, exp, inf, nan import random import warnings +from math import exp, inf, log2, nan, sqrt import numpy as np import scipy.linalg as scla from scipy.ndimage import uniform_filter1d -from ..core import ptr, prod, vdot, njit, dot, subtract_update_, divide_update_ -from ..utils import int2tup, find_library, raise_cant_find_library_function -from ..gen.rand import randn, rand_rademacher, rand_phase, seed_rand +from ..core import divide_update_, dot, njit, prod, ptr, subtract_update_, vdot +from ..gen.rand import rand_phase, rand_rademacher, randn, seed_rand from ..linalg.mpi_launcher import get_mpi_pool +from ..utils import ( + default_to_neutral_style, + find_library, + format_number_with_error, + int2tup, + raise_cant_find_library_function, +) +from ..utils import progbar as Progbar if find_library("cotengra") and find_library("autoray"): - from ..tensor.tensor_core import Tensor from ..tensor.tensor_1d import MatrixProductOperator from ..tensor.tensor_approx_spectral import construct_lanczos_tridiag_MPO + from ..tensor.tensor_core import Tensor else: reqs = "[cotengra,autoray]" Tensor = raise_cant_find_library_function(reqs) @@ -216,9 +224,7 @@ def random_rect( # already normalized elif dist == "gaussian": - V = randn( - shape, scale=1 / (prod(shape) ** 0.5 * 2**0.5), dtype=dtype - ) + V = randn(shape, scale=1 / (prod(shape) ** 0.5 * 2**0.5), dtype=dtype) if norm: V /= norm_fro(V) @@ -308,7 +314,6 @@ def construct_lanczos_tridiag( Q = np.copy(q).reshape(-1, 1) for j in range(1, K + 1): - r = dot(A, q) subtract_update_(r, beta[j], v) alpha[j] = inner(q, r) @@ -470,7 +475,7 @@ def calc_est_fit(estimates, conv_n, tau): return est, err -def calc_est_window(estimates, mean_ests, conv_n): +def calc_est_window(estimates, conv_n): """Make estimate from mean of last ``m`` samples, following: 1. Take between ``conv_n`` and 12 estimates. @@ -478,9 +483,7 @@ def calc_est_window(estimates, mean_ests, conv_n): 3. Compute the standard error on the paired estimates. """ m_est = min(max(conv_n, len(estimates) // 8), 12) - est = sum(estimates[-m_est:]) / len(estimates[-m_est:]) - mean_ests.append(est) if len(estimates) > conv_n: # check for convergence using variance of paired last m estimates @@ -511,6 +514,7 @@ def single_random_estimate( *, seed=None, v0_opts=None, + info=None, **lanczos_opts, ): # choose normal (any LinearOperator) or MPO lanczos tridiag construction @@ -520,8 +524,10 @@ def single_random_estimate( lanc_fn = construct_lanczos_tridiag lanczos_opts["bsz"] = bsz + estimates_raw = [] + estimates_window = [] + estimates_fit = [] estimates = [] - mean_ests = [] # the number of samples to check standard deviation convergence with conv_n = 6 # 3 pairs @@ -537,17 +543,16 @@ def single_random_estimate( v0_opts=v0_opts, **lanczos_opts, ): - try: Tl, Tv = lanczos_tridiag_eig(alpha, beta, check_finite=False) Gf = scaling * calc_trace_fn_tridiag(Tl, Tv, f=f, pos=pos) except scla.LinAlgError: # pragma: no cover warnings.warn("Approx Spectral Gf tri-eig didn't converge.") - estimates.append(np.nan) + estimates_raw.append(np.nan) continue k = alpha.size - estimates.append(Gf) + estimates_raw.append(Gf) # check for break-down convergence (e.g. found entire subspace) # in which case latest estimate should be accurate @@ -555,13 +560,16 @@ def single_random_estimate( if verbosity >= 2: print(f"k={k}: Beta breadown, returning {Gf}.") est = Gf + estimates.append(est) break # compute an estimate and error using a window of the last few results - win_est, win_err = calc_est_window(estimates, mean_ests, conv_n) + win_est, win_err = calc_est_window(estimates_raw, conv_n) + estimates_window.append(win_est) # try and compute an estimate and error using exponential fit - fit_est, fit_err = calc_est_fit(mean_ests, conv_n, tau) + fit_est, fit_err = calc_est_fit(estimates_window, conv_n, tau) + estimates_fit.append(fit_est) # take whichever has lowest error est, err = min( @@ -569,13 +577,13 @@ def single_random_estimate( (fit_est, fit_err), key=lambda est_err: est_err[1], ) + estimates.append(est) converged = err < tau * (abs(win_est) + tol_scale) if verbosity >= 2: if verbosity >= 3: print(f"est_win={win_est}, err_win={win_err}") print(f"est_fit={fit_est}, err_fit={fit_err}") - print(f"k={k}: Gf={Gf}, Est={est}, Err={err}") if converged: print(f"k={k}: Converged to tau {tau}.") @@ -586,6 +594,16 @@ def single_random_estimate( if verbosity >= 1: print(f"k={k}: Returning estimate {est}.") + if info is not None: + if "estimates_raw" in info: + info["estimates_raw"].append(estimates_raw) + if "estimates_window" in info: + info["estimates_window"].append(estimates_window) + if "estimates_fit" in info: + info["estimates_fit"].append(estimates_fit) + if "estimates" in info: + info["estimates"].append(estimates) + return est @@ -626,6 +644,60 @@ def get_equivalent_real_dtype(dtype): raise ValueError(f"dtype {dtype} not understood.") +@default_to_neutral_style +def plot_approx_spectral_info(info): + from matplotlib import pyplot as plt + from matplotlib.ticker import MaxNLocator + + fig, axs = plt.subplots( + ncols=2, + figsize=(8, 4), + sharey=True, + gridspec_kw={"width_ratios": [3, 1]}, + ) + plt.subplots_adjust(wspace=0.0) + + Z = info["estimate"] + + alpha = len(info["estimates_raw"])**-(1 / 6) + + # plot the raw kyrlov runs + for x in info["estimates_raw"]: + axs[0].plot(x, ".-", alpha=alpha, lw=1 / 2, zorder=-10, markersize=1) + axs[0].axhline(Z - info["error"], color="grey", linestyle="--") + axs[0].axhline(Z + info["error"], color="grey", linestyle="--") + axs[0].axhline(Z, color="black", linestyle="--") + axs[0].set_rasterization_zorder(-5) + axs[0].set_xlabel("krylov iteration (offset)") + axs[0].xaxis.set_major_locator(MaxNLocator(integer=True)) + axs[0].set_ylabel("$Tr[f(x)]$ approximation") + + # plot the overall final samples + axs[1].hist( + info["samples"], + bins=round(len(info["samples"])**0.5), + orientation="horizontal", + color=(0.2, 0.6, 1.0), + ) + axs[1].axhline(Z - info["error"], color="grey", linestyle="--") + axs[1].axhline(Z + info["error"], color="grey", linestyle="--") + axs[1].axhline(Z, color="black", linestyle="--") + axs[1].set_xlabel("sample count") + axs[1].set_title( + "estimate ≈ " + format_number_with_error(Z, info["error"]), + ha="right", + ) + + # plot the correlation between raw and fitted estimates + iax = axs[0].inset_axes((0.03, 0.6, 0.3, 0.3)) + iax.set_aspect("equal") + x = [es[-1] for es in info["estimates"]] + y = [es[-1] for es in info["estimates_raw"]] + iax.scatter(x, y, marker=".", alpha=alpha, color=(0.3, 0.7, 0.3), s=1) + + return fig, axs + + def approx_spectral_function( A, f, @@ -633,6 +705,7 @@ def approx_spectral_function( *, bsz=1, R=1024, + R_min=3, tol_scale=1, tau=1e-4, k_min=10, @@ -646,6 +719,8 @@ def approx_spectral_function( verbosity=0, single_precision="AUTO", info=None, + progbar=False, + plot=False, **lanczos_opts, ): """Approximate a spectral function, that is, the quantity ``Tr(f(A))``. @@ -671,6 +746,8 @@ def approx_spectral_function( Increasing this should increase accuracy as ``sqrt(R)``. Cost of algorithm thus scales linearly with ``R``. If ``tol`` is non-zero, this is the maximum number of repeats. + R_min : int, optional + The minimum number of repeats to perform. Default: 3. tau : float, optional The relative tolerance required for a single lanczos run to converge. This needs to be small enough that each estimate with a single random @@ -742,6 +819,18 @@ def approx_spectral_function( if verbosity: print(f"LANCZOS f(A) CALC: tol={tol}, tau={tau}, R={R}, bsz={bsz}") + if plot: + # need to store all the info + if info is None: + info = {} + info.setdefault('estimate', None) + info.setdefault('error', None) + info.setdefault('samples', None) + info.setdefault('estimates_raw', []) + info.setdefault('estimates_window', []) + info.setdefault('estimates_fit', []) + info.setdefault('estimates', []) + # generate repeat estimates kwargs = { "A": A, @@ -755,6 +844,7 @@ def approx_spectral_function( "k_min": k_min, "tol_scale": tol_scale, "verbosity": verbosity, + "info": info, **lanczos_opts, } @@ -773,6 +863,11 @@ def gen_results(): for f in fs: yield f.result() + if progbar: + pbar = Progbar(total=R) + else: + pbar = None + # iterate through estimates, waiting for convergence results = gen_results() estimate = None @@ -784,7 +879,7 @@ def gen_results(): print(f"Repeat {len(samples)}: estimate is {samples[-1]}") # wait a few iterations before checking error on mean breakout - if len(samples) >= 3: + if len(samples) >= R_min: estimate, err, converged = calc_stats( samples, mean_p, mean_s, tol, tol_scale ) @@ -795,6 +890,18 @@ def gen_results(): print(f"Repeat {len(samples)}: converged to tol {tol}") break + if pbar: + if len(samples) < R_min: + estimate, err, _ = calc_stats( + samples, mean_p, mean_s, tol, tol_scale + ) + pbar.set_description(format_number_with_error(estimate, err)) + + if pbar: + pbar.update() + if pbar: + pbar.close() + if mpi: # deal with remaining futures extra_futures = [] @@ -822,6 +929,11 @@ def gen_results(): info["samples"] = samples if "error" in info: info["error"] = err + if "estimate" in info: + info["estimate"] = estimate + + if plot: + info["fig"], info["axs"] = plot_approx_spectral_info(info) return estimate diff --git a/tests/test_linalg/test_approx_spectral.py b/tests/test_linalg/test_approx_spectral.py index b90d396c..4320a470 100644 --- a/tests/test_linalg/test_approx_spectral.py +++ b/tests/test_linalg/test_approx_spectral.py @@ -306,6 +306,10 @@ def test_approx_spectral_subspaces_with_heis_partition(self, bsz): approx_Z = tr_exp_approx(-beta * h, bsz=bsz) assert_allclose(actual_Z, approx_Z, rtol=3e-2) + def test_approx_spectral_plot(self): + X = rand_herm(1000, sparse=True) + approx_spectral_function(X, lambda x: abs(x), plot=True) + # ------------------------ Test specific quantities ------------------------- #