Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DetectorDatabase is now implemented with a translation-invariant key system #376

Merged
merged 8 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 39 additions & 45 deletions src/tqec/compile/detectors/database.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import hashlib
from dataclasses import dataclass, field
from functools import cached_property
from typing import Sequence

import numpy
import numpy.typing as npt

from tqec.circuit.generation import generate_circuit_from_instantiation
from tqec.circuit.measurement_map import MeasurementRecordsMap
from tqec.circuit.moment import Moment
Expand All @@ -20,12 +20,8 @@
from tqec.templates.subtemplates import SubTemplateType


def _NUMPY_ARRAY_HASHER(arr: npt.NDArray[numpy.int_]) -> int:
return int(hashlib.md5(arr.data.tobytes(), usedforsecurity=False).hexdigest(), 16)


@dataclass(frozen=True)
class DetectorDatabaseKey:
class _DetectorDatabaseKey:
"""Immutable type used as a key in the database of detectors.

This class represents a "situation" for which we might be able to compute
Expand All @@ -43,25 +39,19 @@ class DetectorDatabaseKey:

## Implementation details

This class stores data types that are not efficiently hashable (i.e., not in
constant time) when considering their values:

- `self.subtemplates` is a raw array of integers without any guarantee on the
stored values except that they are positive.
- `self.plaquettes_by_timestep` contains :class:`Plaquette` instances, each
containing a class:`ScheduledCircuit` instance, ultimately containing a
`stim.Circuit`. Hashing a quantum circuit cannot be performed in constant
time.

For `self.subtemplates`, we hash the `shape` of the array as well as the
hash of the array's data. This is a constant time operation, because
we only consider spatially local detectors at the moment and that
restriction makes sub-templates that are of constant size (w.r.t the number
of qubits).

For `self.plaquettes_by_timestep`, we rely on the hash implementation of
:class:`Plaquettes`. It is up to :class:`Plaquettes` to implement hash
efficiently.
This class uses a surjective representation to compare (`__eq__`) and hash
(`__hash__`) its instances. This representation is computed and cached using
the :meth:`_DetectorDatabaseKey.plaquette_names` property that basically
uses the provided subtemplates to build a nested tuple data-structure with
the same shape as `self.subtemplates` (3 dimensions, the first one being the
number of time steps, the next 2 ones being of odd and equal size and
depending on the radius used to build subtemplates) storing in each of its
entries the corresponding plaquette name.

This intermediate data-structure is not the most memory efficient one, but
it has the advantage of being easy to construct, trivially invariant to
plaquette re-indexing and easy to hash (with some care to NOT use Python's
default `hash` due to its absence of stability across different runs).
"""

subtemplates: Sequence[SubTemplateType]
Expand All @@ -79,24 +69,28 @@ def __post_init__(self) -> None:
def num_timeslices(self) -> int:
return len(self.subtemplates)

def __hash__(self) -> int:
return hash(
(
tuple(st.shape for st in self.subtemplates),
tuple(_NUMPY_ARRAY_HASHER(st) for st in self.subtemplates),
tuple(self.plaquettes_by_timestep),
)
@cached_property
def plaquette_names(self) -> tuple[tuple[tuple[str, ...], ...], ...]:
return tuple(
tuple(tuple(plaquettes[pi].name for pi in row) for row in st)
for st, plaquettes in zip(self.subtemplates, self.plaquettes_by_timestep)
)

def reliable_hash(self) -> int:
hasher = hashlib.md5()
for timeslice in self.plaquette_names:
for row in timeslice:
for name in row:
hasher.update(name.encode())
return int(hasher.hexdigest(), 16)

def __hash__(self) -> int:
return self.reliable_hash()

def __eq__(self, rhs: object) -> bool:
return (
isinstance(rhs, DetectorDatabaseKey)
and len(self.subtemplates) == len(rhs.subtemplates)
and all(
bool(numpy.all(self_st == rhs_st))
for self_st, rhs_st in zip(self.subtemplates, rhs.subtemplates)
)
and self.plaquettes_by_timestep == rhs.plaquettes_by_timestep
isinstance(rhs, _DetectorDatabaseKey)
and self.plaquette_names == rhs.plaquette_names
)

def circuit(self, plaquette_increments: Displacement) -> ScheduledCircuit:
Expand Down Expand Up @@ -138,7 +132,7 @@ class DetectorDatabase:
computation.
"""

mapping: dict[DetectorDatabaseKey, frozenset[Detector]] = field(
mapping: dict[_DetectorDatabaseKey, frozenset[Detector]] = field(
default_factory=dict
)
frozen: bool = False
Expand Down Expand Up @@ -169,7 +163,7 @@ def add_situation(
"""
if self.frozen:
raise TQECException("Cannot add a situation to a frozen database.")
key = DetectorDatabaseKey(subtemplates, plaquettes_by_timestep)
key = _DetectorDatabaseKey(subtemplates, plaquettes_by_timestep)
self.mapping[key] = (
frozenset([detectors]) if isinstance(detectors, Detector) else detectors
)
Expand All @@ -195,7 +189,7 @@ def remove_situation(
"""
if self.frozen:
raise TQECException("Cannot remove a situation to a frozen database.")
key = DetectorDatabaseKey(subtemplates, plaquettes_by_timestep)
key = _DetectorDatabaseKey(subtemplates, plaquettes_by_timestep)
del self.mapping[key]

def get_detectors(
Expand All @@ -220,7 +214,7 @@ def get_detectors(
detectors associated with the provided situation or `None` if the
situation is not in the database.
"""
key = DetectorDatabaseKey(subtemplates, plaquettes_by_timestep)
key = _DetectorDatabaseKey(subtemplates, plaquettes_by_timestep)
return self.mapping.get(key)

def freeze(self) -> None:
Expand Down
31 changes: 21 additions & 10 deletions src/tqec/compile/detectors/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tqec.circuit.coordinates import StimCoordinates
from tqec.circuit.measurement import Measurement
from tqec.circuit.qubit import GridQubit
from tqec.compile.detectors.database import DetectorDatabase, DetectorDatabaseKey
from tqec.compile.detectors.database import DetectorDatabase, _DetectorDatabaseKey
from tqec.compile.detectors.detector import Detector
from tqec.compile.specs.library._utils import _build_plaquettes_for_rotated_surface_code
from tqec.exceptions import TQECException
Expand Down Expand Up @@ -71,7 +71,6 @@
QubitTemplate().instantiate(k=10), manhattan_radius=2
).subtemplates.values()
),
axis=0,
)
)

Expand Down Expand Up @@ -118,40 +117,40 @@


def test_detector_database_key_creation() -> None:
DetectorDatabaseKey((SUBTEMPLATES[0],), (PLAQUETTE_COLLECTIONS[0],))
DetectorDatabaseKey(SUBTEMPLATES[1:5], PLAQUETTE_COLLECTIONS[1:5])
_DetectorDatabaseKey((SUBTEMPLATES[0],), (PLAQUETTE_COLLECTIONS[0],))
_DetectorDatabaseKey(SUBTEMPLATES[1:5], PLAQUETTE_COLLECTIONS[1:5])
with pytest.raises(
TQECException,
match="^DetectorDatabaseKey can only store an equal number of "
"subtemplates and plaquettes. Got 4 subtemplates and 3 plaquettes.$",
):
DetectorDatabaseKey(SUBTEMPLATES[1:5], PLAQUETTE_COLLECTIONS[1:4])
_DetectorDatabaseKey(SUBTEMPLATES[1:5], PLAQUETTE_COLLECTIONS[1:4])


def test_detector_database_key_num_timeslices() -> None:
for i in range(min(len(PLAQUETTE_COLLECTIONS), len(SUBTEMPLATES))):
assert (
DetectorDatabaseKey(
_DetectorDatabaseKey(
SUBTEMPLATES[:i], PLAQUETTE_COLLECTIONS[:i]
).num_timeslices
== i
)


def test_detector_database_key_hash() -> None:
dbkey = DetectorDatabaseKey(SUBTEMPLATES[1:5], PLAQUETTE_COLLECTIONS[1:5])
dbkey = _DetectorDatabaseKey(SUBTEMPLATES[1:5], PLAQUETTE_COLLECTIONS[1:5])
assert hash(dbkey) == hash(dbkey)
# This is a value that has been pre-computed locally. It is hard-coded here
# to check that the hash of a dbkey is reliable and does not change depending
# on the Python interpreter, Python version, host OS, process ID, ...
assert hash(dbkey) == 6635855037027289589
assert hash(dbkey) == 1085786788918911944

dbkey = DetectorDatabaseKey(SUBTEMPLATES[:1], PLAQUETTE_COLLECTIONS[:1])
dbkey = _DetectorDatabaseKey(SUBTEMPLATES[:1], PLAQUETTE_COLLECTIONS[:1])
assert hash(dbkey) == hash(dbkey)
# This is a value that has been pre-computed locally. It is hard-coded here
# to check that the hash of a dbkey is reliable and does not change depending
# on the Python interpreter, Python version, host OS, process ID, ...
assert hash(dbkey) == -8009786746945676048
assert hash(dbkey) == 1699471538780763110


def test_detector_database_creation() -> None:
Expand Down Expand Up @@ -210,3 +209,15 @@ def test_detector_database_freeze() -> None:
detectors2 = db.get_detectors(SUBTEMPLATES[:4], PLAQUETTE_COLLECTIONS[:4])
assert detectors2 is not None
assert detectors2 == DETECTORS[1]


def test_detector_database_translation_invariance() -> None:
db = DetectorDatabase()
db.add_situation(SUBTEMPLATES[:1], PLAQUETTE_COLLECTIONS[:1], DETECTORS[0])

offset = 36
translated_subtemplate = SUBTEMPLATES[0] + offset
translated_plaquettes = PLAQUETTE_COLLECTIONS[0].map_indices(lambda i: i + offset)
detectors = db.get_detectors((translated_subtemplate,), (translated_plaquettes,))
assert detectors is not None
assert detectors == DETECTORS[0]
18 changes: 11 additions & 7 deletions src/tqec/plaquette/frozendefaultdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,21 @@ def __eq__(self, other: object) -> bool:
return (
isinstance(other, FrozenDefaultDict)
and (
not operator.xor(
self._default_factory is None, other._default_factory is None
(self._default_factory is None and other._default_factory is None)
or (
self._default_factory is not None
and other._default_factory is not None
and (self._default_factory() == other._default_factory())
)
)
and (
self._default_factory is not None
and other._default_factory is not None
and (self._default_factory() == other._default_factory())
)
and self._dict == other._dict
)

def has_default_factory(self) -> bool:
return self._default_factory is not None

def map_keys(self, callable: Callable[[K], K]) -> FrozenDefaultDict[K, V]:
return FrozenDefaultDict(
{callable(k): v for k, v in self.items()},
default_factory=self._default_factory,
)
5 changes: 4 additions & 1 deletion src/tqec/plaquette/plaquette.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Mapping
from typing import Callable, Mapping

from typing_extensions import override

Expand Down Expand Up @@ -152,6 +152,9 @@ def with_updated_plaquettes(
) -> Plaquettes:
return Plaquettes(self.collection | plaquettes_to_update)

def map_indices(self, callable: Callable[[int], int]) -> Plaquettes:
return Plaquettes(self.collection.map_keys(callable))

def __eq__(self, rhs: object) -> bool:
return isinstance(rhs, Plaquettes) and self.collection == rhs.collection

Expand Down
Loading