diff --git a/CHANGELOG.md b/CHANGELOG.md index a38061a06..2ec99de6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Methods for pretty-printing `Pattern`: `to_ascii`, `to_unicode`, + `to_latex`. + ### Fixed +- The result of `repr()` for `Pattern`, `Circuit`, `Command`, + `Instruction`, `Plane`, `Axis` and `Sign` is now a valid Python + expression and is more readable. + ### Changed +- The method `Pattern.print_pattern` is now deprecated. + ## [0.3.1] - 2025-04-21 ### Added diff --git a/docs/source/modifier.rst b/docs/source/modifier.rst index 47a1e62ea..f10b7fe6a 100644 --- a/docs/source/modifier.rst +++ b/docs/source/modifier.rst @@ -36,7 +36,11 @@ Pattern Manipulation .. automethod:: perform_pauli_measurements - .. automethod:: print_pattern + .. automethod:: to_ascii + + .. automethod:: to_unicode + + .. automethod:: to_latex .. automethod:: standardize diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index fd27108d9..1d65c47ed 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -30,17 +30,14 @@ For any gate network, we can use the :class:`~graphix.transpiler.Circuit` class the :class:`~graphix.pattern.Pattern` object contains the sequence of commands according to the measurement calculus framework [#Danos2007]_. Let us print the pattern (command sequence) that we generated, ->>> pattern.print_pattern() # show the command sequence (pattern) -N, node = 1 -E, nodes = (0, 1) -M, node = 0, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] -X byproduct, node = 1, domain = [0] +>>> pattern +Pattern(input_nodes=[0], cmds=[N(1), E((0, 1)), M(0), X(1, {0})], output_nodes=[1]) The command sequence represents the following sequence: - * starting with an input qubit :math:`|\psi_{in}\rangle_0`, we first prepare an ancilla qubit :math:`|+\rangle_1` with ['N', 1] command - * We then apply CZ-gate by ['E', (0, 1)] command to create entanglement. - * We measure the qubit 0 in Pauli X basis, by ['M'] command. + * starting with an input qubit :math:`|\psi_{in}\rangle_0`, we first prepare an ancilla qubit :math:`|+\rangle_1` with N(1) command + * We then apply CZ-gate by E((0, 1)) command to create entanglement. + * We measure the qubit 0 in Pauli X basis, by M(0) command. * If the measurement outcome is :math:`s_0 = 1` (i.e. if the qubit is projected to :math:`|-\rangle`, the Pauli X eigenstate with eigenvalue of :math:`(-1)^{s_0} = -1`), the 'X' command is applied to qubit 1 to 'correct' the measurement byproduct (see :doc:`intro`) that ensure deterministic computation. * Tracing out the qubit 0 (since the measurement is destructive), we have :math:`H|\psi_{in}\rangle_1` - the input qubit has teleported to qubit 1, while being transformed by Hadamard gate. @@ -86,19 +83,9 @@ As a more complex example than above, we show measurement patterns and graph sta | | | control: input=0, output=0; target: input=1, output=3 | +------------------------------------------------------------------------------+ -| >>> cnot_pattern.print_pattern() | -| N, node = 0 | -| N, node = 1 | -| N, node = 2 | -| N, node = 3 | -| E, nodes = (1, 2) | -| E, nodes = (0, 2) | -| E, nodes = (2, 3) | -| M, node = 1, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] | -| M, node = 2, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] | -| X byproduct, node = 3, domain = [2] | -| Z byproduct, node = 3, domain = [1] | -| Z byproduct, node = 0, domain = [1] | +| >>> cnot_pattern | +| Pattern(cmds=[N(0), N(1), N(2), N(3), E((1, 2)), E((0, 2)), E((2, 3)), M(1), | +| M(2), X(3, {2}), Z(3, {1}), Z(0, {1})], output_nodes=[0, 3]) | +------------------------------------------------------------------------------+ | **general rotation (an example with Euler angles 0.2pi, 0.15pi and 0.1 pi)** | +------------------------------------------------------------------------------+ @@ -108,18 +95,10 @@ As a more complex example than above, we show measurement patterns and graph sta | | | input = 0, output = 4 | +------------------------------------------------------------------------------+ -|>>> euler_rot_pattern.print_pattern() | -| N, node = 0 | -| N, node = 1 | -| N, node = 2 | -| N, node = 3 | -| N, node = 4 | -| M, node = 0, plane = XY, angle(pi) = -0.2, s-domain = [], t_domain = [] | -| M, node = 1, plane = XY, angle(pi) = -0.15, s-domain = [0], t_domain = [] | -| M, node = 2, plane = XY, angle(pi) = -0.1, s-domain = [1], t_domain = [] | -| M, node = 3, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] | -| Z byproduct, node = 4, domain = [0,2] | -| X byproduct, node = 4, domain = [1,3] | +|>>> euler_rot_pattern | +| Pattern(cmds=[N(0), N(1), N(2), N(3), N(4), M(0, angle=-0.2), | +| M(1, angle=-0.15, s_domain={0}), M(2, angle=-0.1, s_domain={1}), | +| M(3), Z(4, domain={0, 2}), X(4, domain={1, 3})], output_nodes=[4]) | +------------------------------------------------------------------------------+ @@ -144,33 +123,8 @@ As an example, let us prepare a pattern to rotate two qubits in :math:`|+\rangle This produces a rather long and complicated command sequence. ->>> pattern.print_pattern() # show the command sequence (pattern) -N, node = 2 -N, node = 3 -E, nodes = (0, 2) -E, nodes = (2, 3) -M, node = 0, plane = XY, angle(pi) = -0.2975038024267561, s-domain = [], t_domain = [] -M, node = 2, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] -X byproduct, node = 3, domain = [2] -Z byproduct, node = 3, domain = [0] -N, node = 4 -N, node = 5 -E, nodes = (1, 4) -E, nodes = (4, 5) -M, node = 1, plane = XY, angle(pi) = -0.14788446865973076, s-domain = [], t_domain = [] -M, node = 4, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] -X byproduct, node = 5, domain = [4] -Z byproduct, node = 5, domain = [1] -N, node = 6 -N, node = 7 -E, nodes = (5, 6) -E, nodes = (3, 6) -E, nodes = (6, 7) -M, node = 5, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] -M, node = 6, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] -X byproduct, node = 7, domain = [6] -Z byproduct, node = 7, domain = [5] -Z byproduct, node = 3, domain = [5] +>>> pattern +Pattern(input_nodes=[0, 1], cmds=[N(2), N(3), E((0, 2)), E((2, 3)), M(0, angle=-0.08131311068764493), M(2), X(3, {2}), Z(3, {0}), N(4), N(5), E((1, 4)), E((4, 5)), M(1, angle=-0.2242107876075538), M(4), X(5, {4}), Z(5, {1}), N(6), N(7), E((5, 6)), E((3, 6)), E((6, 7)), M(5), M(6), X(7, {6}), Z(7, {5}), Z(3, {5})], output_nodes=[3, 7]) .. figure:: ./../imgs/pattern_visualization_2.png :scale: 60 % @@ -190,30 +144,8 @@ These can be called with :meth:`~graphix.pattern.Pattern.standardize` and :meth: >>> pattern.standardize() >>> pattern.shift_signals() ->>> pattern.print_pattern() -N, node = 2 -N, node = 3 -N, node = 4 -N, node = 5 -N, node = 6 -N, node = 7 -E, nodes = (0, 2) -E, nodes = (2, 3) -E, nodes = (1, 4) -E, nodes = (4, 5) -E, nodes = (5, 6) -E, nodes = (6, 3) -E, nodes = (6, 7) -M, node = 0, plane = XY, angle(pi) = -0.2975038024267561, s-domain = [], t_domain = [] -M, node = 2, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] -M, node = 1, plane = XY, angle(pi) = -0.14788446865973076, s-domain = [], t_domain = [] -M, node = 4, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] -M, node = 5, plane = XY, angle(pi) = 0, s-domain = [4], t_domain = [] -M, node = 6, plane = XY, angle(pi) = 0, s-domain = [], t_domain = [] -X byproduct, node = 3, domain = [2] -X byproduct, node = 7, domain = [2, 4, 6] -Z byproduct, node = 3, domain = [0, 1, 5] -Z byproduct, node = 7, domain = [1, 5] +>>> pattern +Pattern(input_nodes=[0, 1], cmds=[N(2), N(3), N(4), N(5), N(6), N(7), E((0, 2)), E((2, 3)), E((1, 4)), E((4, 5)), E((5, 6)), E((3, 6)), E((6, 7)), M(0, angle=-0.22152331776994327), M(2), M(1, angle=-0.18577010991028864), M(4), M(5, s_domain={4}), M(6), Z(3, {0, 1, 5}), Z(7, {1, 5}), X(3, {2}), X(7, {2, 4, 6})], output_nodes=[3, 7]) .. figure:: ./../imgs/pattern_visualization_3.png :scale: 60 % @@ -250,18 +182,8 @@ We can call this in a line by calling :meth:`~graphix.pattern.Pattern.perform_pa We get an updated measurement pattern without Pauli measurements as follows: >>> pattern.perform_pauli_measurements() ->>> pattern.print_pattern() -N, node = 3 -N, node = 7 -E, nodes = (0, 3) -E, nodes = (1, 3) -E, nodes = (1, 7) -M, node = 0, plane = XY, angle(pi) = -0.2975038024267561, s-domain = [], t_domain = [], Clifford index = 6 -M, node = 1, plane = XY, angle(pi) = -0.14788446865973076, s-domain = [], t_domain = [], Clifford index = 6 -X byproduct, node = 3, domain = [2] -X byproduct, node = 7, domain = [2, 4, 6] -Z byproduct, node = 3, domain = [0, 1, 5] -Z byproduct, node = 7, domain = [1, 5] +>>> pattern +Pattern(input_nodes=[0, 1], cmds=[N(3), N(7), E((0, 3)), E((1, 3)), E((1, 7)), M(0, Plane.YZ, 0.2907266109187514), M(1, Plane.YZ, 0.01258854060311348), C(3, Clifford.I), C(7, Clifford.I), Z(3, {0, 1, 5}), Z(7, {1, 5}), X(3, {2}), X(7, {2, 4, 6})], output_nodes=[3, 7]) Notice that all measurements with angle=0 (Pauli X measurements) disappeared - this means that a part of quantum computation was `classically` (and efficiently) preprocessed such that we only need much smaller quantum resource. @@ -290,18 +212,8 @@ We exploit this fact to minimize the `space` of the pattern, which is crucial fo We can simply call :meth:`~graphix.pattern.Pattern.minimize_space()` to reduce the `space`: >>> pattern.minimize_space() ->>> pattern.print_pattern(lim=20) -N, node = 3 -E, nodes = (0, 3) -M, node = 0, plane = XY, angle(pi) = -0.2975038024267561, s-domain = [], t_domain = [], Clifford index = 6 -E, nodes = (1, 3) -N, node = 7 -E, nodes = (1, 7) -M, node = 1, plane = XY, angle(pi) = -0.14788446865973076, s-domain = [], t_domain = [], Clifford index = 6 -X byproduct, node = 3, domain = [2] -X byproduct, node = 7, domain = [2, 4, 6] -Z byproduct, node = 3, domain = [0, 1, 5] -Z byproduct, node = 7, domain = [1, 5] +>>> pattern +Pattern(input_nodes=[0, 1], cmds=[N(3), E((0, 3)), M(0, Plane.YZ, 0.11120090987081546), E((1, 3)), N(7), E((1, 7)), M(1, Plane.YZ, 0.230565199664617), C(3, Clifford.I), C(7, Clifford.I), Z(3, {0, 1, 5}), Z(7, {1, 5}), X(3, {2}), X(7, {2, 4, 6})], output_nodes=[3, 7]) With the original measurement pattern, the simulation should have proceeded as follows, with maximum of four qubits on the memory. diff --git a/examples/deutsch_jozsa.py b/examples/deutsch_jozsa.py index 72b12938d..1f2d5a8fa 100644 --- a/examples/deutsch_jozsa.py +++ b/examples/deutsch_jozsa.py @@ -58,7 +58,7 @@ # Now let us transpile into MBQC measurement pattern and inspect the pattern sequence and graph state pattern = circuit.transpile().pattern -pattern.print_pattern(lim=15) +print(pattern.to_ascii(left_to_right=True, limit=15)) pattern.draw_graph(flow_from_pattern=False) # %% @@ -68,13 +68,19 @@ pattern.standardize() pattern.shift_signals() -pattern.print_pattern(lim=15) +print(pattern.to_ascii(left_to_right=True, limit=15)) # %% # Now we preprocess all Pauli measurements pattern.perform_pauli_measurements() -pattern.print_pattern(lim=16, target=[CommandKind.N, CommandKind.M, CommandKind.C]) +print( + pattern.to_ascii( + left_to_right=True, + limit=16, + target=[CommandKind.N, CommandKind.M, CommandKind.C], + ) +) pattern.draw_graph(flow_from_pattern=True) # %% diff --git a/examples/rotation.py b/examples/rotation.py index fe62436b4..07b127f1f 100644 --- a/examples/rotation.py +++ b/examples/rotation.py @@ -46,7 +46,7 @@ # This returns :class:`~graphix.pattern.Pattern` object containing measurement pattern: pattern = circuit.transpile().pattern -pattern.print_pattern(lim=10) +print(pattern.to_ascii(left_to_right=True, limit=10)) # %% # We can plot the graph state to run the above pattern. diff --git a/graphix/command.py b/graphix/command.py index b5f3c32fd..8bc1b01b9 100644 --- a/graphix/command.py +++ b/graphix/command.py @@ -18,6 +18,7 @@ # Ruff suggests to move this import to a type-checking block, but dataclass requires it here from graphix.parameter import ExpressionOrFloat # noqa: TC001 from graphix.pauli import Pauli +from graphix.pretty_print import DataclassPrettyPrintMixin from graphix.states import BasicStates, State Node = int @@ -44,8 +45,8 @@ def __init_subclass__(cls) -> None: utils.check_kind(cls, {"CommandKind": CommandKind, "Clifford": Clifford}) -@dataclasses.dataclass -class N(_KindChecker): +@dataclasses.dataclass(repr=False) +class N(_KindChecker, DataclassPrettyPrintMixin): """Preparation command.""" node: Node @@ -53,8 +54,8 @@ class N(_KindChecker): kind: ClassVar[Literal[CommandKind.N]] = dataclasses.field(default=CommandKind.N, init=False) -@dataclasses.dataclass -class M(_KindChecker): +@dataclasses.dataclass(repr=False) +class M(_KindChecker, DataclassPrettyPrintMixin): """Measurement command. By default the plane is set to 'XY', the angle to 0, empty domains and identity vop.""" node: Node @@ -80,16 +81,16 @@ def clifford(self, clifford_gate: Clifford) -> M: ) -@dataclasses.dataclass -class E(_KindChecker): +@dataclasses.dataclass(repr=False) +class E(_KindChecker, DataclassPrettyPrintMixin): """Entanglement command.""" nodes: tuple[Node, Node] kind: ClassVar[Literal[CommandKind.E]] = dataclasses.field(default=CommandKind.E, init=False) -@dataclasses.dataclass -class C(_KindChecker): +@dataclasses.dataclass(repr=False) +class C(_KindChecker, DataclassPrettyPrintMixin): """Clifford command.""" node: Node @@ -97,8 +98,8 @@ class C(_KindChecker): kind: ClassVar[Literal[CommandKind.C]] = dataclasses.field(default=CommandKind.C, init=False) -@dataclasses.dataclass -class X(_KindChecker): +@dataclasses.dataclass(repr=False) +class X(_KindChecker, DataclassPrettyPrintMixin): """X correction command.""" node: Node @@ -106,8 +107,8 @@ class X(_KindChecker): kind: ClassVar[Literal[CommandKind.X]] = dataclasses.field(default=CommandKind.X, init=False) -@dataclasses.dataclass -class Z(_KindChecker): +@dataclasses.dataclass(repr=False) +class Z(_KindChecker, DataclassPrettyPrintMixin): """Z correction command.""" node: Node @@ -115,8 +116,8 @@ class Z(_KindChecker): kind: ClassVar[Literal[CommandKind.Z]] = dataclasses.field(default=CommandKind.Z, init=False) -@dataclasses.dataclass -class S(_KindChecker): +@dataclasses.dataclass(repr=False) +class S(_KindChecker, DataclassPrettyPrintMixin): """S command.""" node: Node @@ -124,7 +125,7 @@ class S(_KindChecker): kind: ClassVar[Literal[CommandKind.S]] = dataclasses.field(default=CommandKind.S, init=False) -@dataclasses.dataclass +@dataclasses.dataclass(repr=False) class T(_KindChecker): """T command.""" diff --git a/graphix/fundamentals.py b/graphix/fundamentals.py index 55c4711cb..e6d7ea986 100644 --- a/graphix/fundamentals.py +++ b/graphix/fundamentals.py @@ -12,6 +12,7 @@ from graphix.ops import Ops from graphix.parameter import cos_sin +from graphix.pretty_print import EnumPrettyPrintMixin if TYPE_CHECKING: import numpy as np @@ -28,7 +29,7 @@ SupportsComplexCtor = Union[SupportsComplex, SupportsFloat, SupportsIndex, complex] -class Sign(Enum): +class Sign(EnumPrettyPrintMixin, Enum): """Sign, plus or minus.""" PLUS = 1 @@ -111,7 +112,7 @@ def __complex__(self) -> complex: return complex(self.value) -class ComplexUnit(Enum): +class ComplexUnit(EnumPrettyPrintMixin, Enum): """ Complex unit: 1, -1, j, -j. @@ -165,7 +166,7 @@ def __complex__(self) -> complex: return ret def __str__(self) -> str: - """Return a string representation of the unit.""" + """Return a human-readable representation of the unit.""" result = "1j" if self.is_imag else "1" if self.sign == Sign.MINUS: result = "-" + result @@ -213,7 +214,7 @@ def matrix(self) -> npt.NDArray[np.complex128]: typing_extensions.assert_never(self) -class Axis(Enum): +class Axis(EnumPrettyPrintMixin, Enum): """Axis: `X`, `Y` or `Z`.""" X = enum.auto() @@ -232,7 +233,7 @@ def matrix(self) -> npt.NDArray[np.complex128]: typing_extensions.assert_never(self) -class Plane(Enum): +class Plane(EnumPrettyPrintMixin, Enum): # TODO: Refactor using match """Plane: `XY`, `YZ` or `XZ`.""" diff --git a/graphix/instruction.py b/graphix/instruction.py index 645e89f23..b7c987b5f 100644 --- a/graphix/instruction.py +++ b/graphix/instruction.py @@ -2,17 +2,35 @@ from __future__ import annotations -import dataclasses import enum +import math import sys +from dataclasses import dataclass, field from enum import Enum -from typing import ClassVar, Literal, Union +from typing import ClassVar, Literal, SupportsFloat, Union from graphix import utils from graphix.fundamentals import Plane # Ruff suggests to move this import to a type-checking block, but dataclass requires it here from graphix.parameter import ExpressionOrFloat # noqa: TC001 +from graphix.pretty_print import DataclassPrettyPrintMixin, OutputFormat, angle_to_str + + +def repr_angle(angle: ExpressionOrFloat) -> str: + """ + Return the representation string of an angle in radians. + + This is used for pretty-printing instructions with `angle` parameters. + Delegates to :func:`pretty_print.angle_to_str`. + """ + # Non-float-supporting objects are returned as-is + if not isinstance(angle, SupportsFloat): + return str(angle) + + # Convert to float, express in π units, and format in ASCII/plain mode + pi_units = float(angle) / math.pi + return angle_to_str(pi_units, OutputFormat.ASCII) class InstructionKind(Enum): @@ -45,149 +63,149 @@ def __init_subclass__(cls) -> None: utils.check_kind(cls, {"InstructionKind": InstructionKind, "Plane": Plane}) -@dataclasses.dataclass -class CCX(_KindChecker): +@dataclass(repr=False) +class CCX(_KindChecker, DataclassPrettyPrintMixin): """Toffoli circuit instruction.""" target: int controls: tuple[int, int] - kind: ClassVar[Literal[InstructionKind.CCX]] = dataclasses.field(default=InstructionKind.CCX, init=False) + kind: ClassVar[Literal[InstructionKind.CCX]] = field(default=InstructionKind.CCX, init=False) -@dataclasses.dataclass -class RZZ(_KindChecker): +@dataclass(repr=False) +class RZZ(_KindChecker, DataclassPrettyPrintMixin): """RZZ circuit instruction.""" target: int control: int - angle: ExpressionOrFloat + angle: ExpressionOrFloat = field(metadata={"repr": repr_angle}) # FIXME: Remove `| None` from `meas_index` # - `None` makes codes messy/type-unsafe meas_index: int | None = None - kind: ClassVar[Literal[InstructionKind.RZZ]] = dataclasses.field(default=InstructionKind.RZZ, init=False) + kind: ClassVar[Literal[InstructionKind.RZZ]] = field(default=InstructionKind.RZZ, init=False) -@dataclasses.dataclass -class CNOT(_KindChecker): +@dataclass(repr=False) +class CNOT(_KindChecker, DataclassPrettyPrintMixin): """CNOT circuit instruction.""" target: int control: int - kind: ClassVar[Literal[InstructionKind.CNOT]] = dataclasses.field(default=InstructionKind.CNOT, init=False) + kind: ClassVar[Literal[InstructionKind.CNOT]] = field(default=InstructionKind.CNOT, init=False) -@dataclasses.dataclass -class SWAP(_KindChecker): +@dataclass(repr=False) +class SWAP(_KindChecker, DataclassPrettyPrintMixin): """SWAP circuit instruction.""" targets: tuple[int, int] - kind: ClassVar[Literal[InstructionKind.SWAP]] = dataclasses.field(default=InstructionKind.SWAP, init=False) + kind: ClassVar[Literal[InstructionKind.SWAP]] = field(default=InstructionKind.SWAP, init=False) -@dataclasses.dataclass -class H(_KindChecker): +@dataclass(repr=False) +class H(_KindChecker, DataclassPrettyPrintMixin): """H circuit instruction.""" target: int - kind: ClassVar[Literal[InstructionKind.H]] = dataclasses.field(default=InstructionKind.H, init=False) + kind: ClassVar[Literal[InstructionKind.H]] = field(default=InstructionKind.H, init=False) -@dataclasses.dataclass -class S(_KindChecker): +@dataclass(repr=False) +class S(_KindChecker, DataclassPrettyPrintMixin): """S circuit instruction.""" target: int - kind: ClassVar[Literal[InstructionKind.S]] = dataclasses.field(default=InstructionKind.S, init=False) + kind: ClassVar[Literal[InstructionKind.S]] = field(default=InstructionKind.S, init=False) -@dataclasses.dataclass -class X(_KindChecker): +@dataclass(repr=False) +class X(_KindChecker, DataclassPrettyPrintMixin): """X circuit instruction.""" target: int - kind: ClassVar[Literal[InstructionKind.X]] = dataclasses.field(default=InstructionKind.X, init=False) + kind: ClassVar[Literal[InstructionKind.X]] = field(default=InstructionKind.X, init=False) -@dataclasses.dataclass -class Y(_KindChecker): +@dataclass(repr=False) +class Y(_KindChecker, DataclassPrettyPrintMixin): """Y circuit instruction.""" target: int - kind: ClassVar[Literal[InstructionKind.Y]] = dataclasses.field(default=InstructionKind.Y, init=False) + kind: ClassVar[Literal[InstructionKind.Y]] = field(default=InstructionKind.Y, init=False) -@dataclasses.dataclass -class Z(_KindChecker): +@dataclass(repr=False) +class Z(_KindChecker, DataclassPrettyPrintMixin): """Z circuit instruction.""" target: int - kind: ClassVar[Literal[InstructionKind.Z]] = dataclasses.field(default=InstructionKind.Z, init=False) + kind: ClassVar[Literal[InstructionKind.Z]] = field(default=InstructionKind.Z, init=False) -@dataclasses.dataclass -class I(_KindChecker): +@dataclass(repr=False) +class I(_KindChecker, DataclassPrettyPrintMixin): """I circuit instruction.""" target: int - kind: ClassVar[Literal[InstructionKind.I]] = dataclasses.field(default=InstructionKind.I, init=False) + kind: ClassVar[Literal[InstructionKind.I]] = field(default=InstructionKind.I, init=False) -@dataclasses.dataclass -class M(_KindChecker): +@dataclass(repr=False) +class M(_KindChecker, DataclassPrettyPrintMixin): """M circuit instruction.""" target: int plane: Plane - angle: ExpressionOrFloat - kind: ClassVar[Literal[InstructionKind.M]] = dataclasses.field(default=InstructionKind.M, init=False) + angle: ExpressionOrFloat = field(metadata={"repr": repr_angle}) + kind: ClassVar[Literal[InstructionKind.M]] = field(default=InstructionKind.M, init=False) -@dataclasses.dataclass -class RX(_KindChecker): +@dataclass(repr=False) +class RX(_KindChecker, DataclassPrettyPrintMixin): """X rotation circuit instruction.""" target: int - angle: ExpressionOrFloat + angle: ExpressionOrFloat = field(metadata={"repr": repr_angle}) meas_index: int | None = None - kind: ClassVar[Literal[InstructionKind.RX]] = dataclasses.field(default=InstructionKind.RX, init=False) + kind: ClassVar[Literal[InstructionKind.RX]] = field(default=InstructionKind.RX, init=False) -@dataclasses.dataclass -class RY(_KindChecker): +@dataclass(repr=False) +class RY(_KindChecker, DataclassPrettyPrintMixin): """Y rotation circuit instruction.""" target: int - angle: ExpressionOrFloat + angle: ExpressionOrFloat = field(metadata={"repr": repr_angle}) meas_index: int | None = None - kind: ClassVar[Literal[InstructionKind.RY]] = dataclasses.field(default=InstructionKind.RY, init=False) + kind: ClassVar[Literal[InstructionKind.RY]] = field(default=InstructionKind.RY, init=False) -@dataclasses.dataclass -class RZ(_KindChecker): +@dataclass(repr=False) +class RZ(_KindChecker, DataclassPrettyPrintMixin): """Z rotation circuit instruction.""" target: int - angle: ExpressionOrFloat + angle: ExpressionOrFloat = field(metadata={"repr": repr_angle}) meas_index: int | None = None - kind: ClassVar[Literal[InstructionKind.RZ]] = dataclasses.field(default=InstructionKind.RZ, init=False) + kind: ClassVar[Literal[InstructionKind.RZ]] = field(default=InstructionKind.RZ, init=False) -@dataclasses.dataclass +@dataclass class _XC(_KindChecker): """X correction circuit instruction. Used internally by the transpiler.""" target: int domain: set[int] - kind: ClassVar[Literal[InstructionKind._XC]] = dataclasses.field(default=InstructionKind._XC, init=False) + kind: ClassVar[Literal[InstructionKind._XC]] = field(default=InstructionKind._XC, init=False) -@dataclasses.dataclass +@dataclass class _ZC(_KindChecker): """Z correction circuit instruction. Used internally by the transpiler.""" target: int domain: set[int] - kind: ClassVar[Literal[InstructionKind._ZC]] = dataclasses.field(default=InstructionKind._ZC, init=False) + kind: ClassVar[Literal[InstructionKind._ZC]] = field(default=InstructionKind._ZC, init=False) if sys.version_info >= (3, 10): diff --git a/graphix/pattern.py b/graphix/pattern.py index bfe4081f6..654f5112f 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -7,6 +7,7 @@ import copy import dataclasses +import warnings from collections.abc import Iterator from copy import deepcopy from dataclasses import dataclass @@ -24,12 +25,13 @@ from graphix.gflow import find_flow, find_gflow, get_layers from graphix.graphsim import GraphState from graphix.measurements import Domains, PauliMeasurement +from graphix.pretty_print import OutputFormat, pattern_to_str from graphix.simulator import PatternSimulator from graphix.states import BasicStates from graphix.visualization import GraphVisualizer if TYPE_CHECKING: - from collections.abc import Iterator, Mapping + from collections.abc import Container, Iterable, Iterator, Mapping from graphix.parameter import ExpressionOrSupportsFloat, Parameter from graphix.sim.base_backend import State @@ -81,22 +83,41 @@ class Pattern: total number of nodes in the resource state """ - def __init__(self, input_nodes: list[int] | None = None) -> None: + def __init__( + self, + input_nodes: Iterable[int] | None = None, + cmds: Iterable[Command] | None = None, + output_nodes: Iterable[int] | None = None, + ) -> None: """ Construct a pattern. - :param input_nodes: optional, list of input qubits + Parameters + ---------- + input_nodes : Iterable[int] | None + Optional. List of input qubits. + cmds : Iterable[Command] | None + Optional. List of initial commands. + output_nodes : Iterable[int] | None + Optional. List of output qubits. """ - if input_nodes is None: - input_nodes = [] self.results = {} # measurement results from the graph state simulator - self.__input_nodes = list(input_nodes) # input nodes (list() makes our own copy of the list) - self.__n_node = len(input_nodes) # total number of nodes in the graph state + if input_nodes is None: + self.__input_nodes = [] + else: + self.__input_nodes = list(input_nodes) # input nodes (list() makes our own copy of the list) + self.__n_node = len(self.__input_nodes) # total number of nodes in the graph state self._pauli_preprocessed = False # flag for `measure_pauli` preprocessing completion self.__seq: list[Command] = [] - # output nodes are initially input nodes, since none are measured yet - self.__output_nodes = list(input_nodes) + # output nodes are initially a copy input nodes, since none are measured yet + self.__output_nodes = list(self.__input_nodes) + + if cmds is not None: + self.extend(cmds) + + if output_nodes is not None: + self.reorder_output_nodes(output_nodes) def add(self, cmd: Command) -> None: """Add command to the end of the pattern. @@ -117,7 +138,7 @@ def add(self, cmd: Command) -> None: self.__output_nodes.remove(cmd.node) self.__seq.append(cmd) - def extend(self, cmds: list[Command]) -> None: + def extend(self, cmds: Iterable[Command]) -> None: """Add a list of commands. :param cmds: list of commands @@ -170,47 +191,77 @@ def n_node(self): """Count of nodes that are either `input_nodes` or prepared with `N` commands.""" return self.__n_node - def reorder_output_nodes(self, output_nodes: list[int]): + def reorder_output_nodes(self, output_nodes: Iterable[int]) -> None: """Arrange the order of output_nodes. Parameters ---------- - output_nodes: list of int + output_nodes: iterable of int output nodes order determined by user. each index corresponds to that of logical qubits. """ output_nodes = list(output_nodes) # make our own copy (allow iterators to be passed) assert_permutation(self.__output_nodes, output_nodes) self.__output_nodes = output_nodes - def reorder_input_nodes(self, input_nodes: list[int]): + def reorder_input_nodes(self, input_nodes: Iterable[int]): """Arrange the order of input_nodes. Parameters ---------- - input_nodes: list of int + input_nodes: iterable of int input nodes order determined by user. each index corresponds to that of logical qubits. """ + input_nodes = list(input_nodes) # make our own copy (allow iterators to be passed) assert_permutation(self.__input_nodes, input_nodes) - self.__input_nodes = list(input_nodes) + self.__input_nodes = input_nodes - # TODO: This is not an evaluable representation. Should be __str__? def __repr__(self) -> str: """Return a representation string of the pattern.""" - return ( - f"graphix.pattern.Pattern object with {len(self.__seq)} commands and {len(self.output_nodes)} output qubits" - ) + arguments = [] + if self.__input_nodes: + arguments.append(f"input_nodes={self.__input_nodes}") + if self.__seq: + arguments.append(f"cmds={self.__seq}") + if self.__output_nodes: + arguments.append(f"output_nodes={self.__output_nodes}") + return f"Pattern({', '.join(arguments)})" + + def __str__(self) -> str: + """Return a human-readable string of the pattern.""" + return self.to_ascii() def __eq__(self, other: Pattern) -> bool: """Return `True` if the two patterns are equal, `False` otherwise.""" return ( self.__seq == other.__seq - and self.input_nodes == other.input_nodes - and self.output_nodes == other.output_nodes + and self.__input_nodes == other.__input_nodes + and self.__output_nodes == other.__output_nodes ) - def print_pattern(self, lim=40, target: list[CommandKind] | None = None) -> None: + def to_ascii( + self, left_to_right: bool = False, limit: int = 40, target: Container[command.CommandKind] | None = None + ) -> str: + """Return the ASCII string representation of the pattern.""" + return pattern_to_str(self, OutputFormat.ASCII, left_to_right, limit, target) + + def to_latex( + self, left_to_right: bool = False, limit: int = 40, target: Container[command.CommandKind] | None = None + ) -> str: + """Return a string containing the LaTeX representation of the pattern.""" + return pattern_to_str(self, OutputFormat.LaTeX, left_to_right, limit, target) + + def to_unicode( + self, left_to_right: bool = False, limit: int = 40, target: Container[command.CommandKind] | None = None + ) -> str: + """Return the Unicode string representation of the pattern.""" + return pattern_to_str(self, OutputFormat.Unicode, left_to_right, limit, target) + + def print_pattern(self, lim: int = 40, target: Container[CommandKind] | None = None) -> None: """Print the pattern sequence (Pattern.seq). + This method is deprecated. + See :meth:`to_ascii`, :meth:`to_latex`, :meth:`to_unicode` and :func:`graphix.pretty_print.pattern_to_str`. + Parameters ---------- lim: int, optional @@ -218,49 +269,12 @@ def print_pattern(self, lim=40, target: list[CommandKind] | None = None) -> None target : list of CommandKind, optional show only specified commands, e.g. [CommandKind.M, CommandKind.X, CommandKind.Z] """ - nmax = min(lim, len(self.__seq)) - if target is None: - target = [ - CommandKind.N, - CommandKind.E, - CommandKind.M, - CommandKind.X, - CommandKind.Z, - CommandKind.C, - ] - count = 0 - i = -1 - while count < nmax: - i += 1 - if i == len(self.__seq): - break - cmd = self.__seq[i] - if cmd.kind == CommandKind.N and (CommandKind.N in target): - count += 1 - print(f"N, node = {cmd.node}") - elif cmd.kind == CommandKind.E and (CommandKind.E in target): - count += 1 - print(f"E, nodes = {cmd.nodes}") - elif cmd.kind == CommandKind.M and (CommandKind.M in target): - count += 1 - print( - f"M, node = {cmd.node}, plane = {cmd.plane}, angle(pi) = {cmd.angle}, " - f"s_domain = {cmd.s_domain}, t_domain = {cmd.t_domain}" - ) - elif cmd.kind == CommandKind.X and (CommandKind.X in target): - count += 1 - print(f"X byproduct, node = {cmd.node}, domain = {cmd.domain}") - elif cmd.kind == CommandKind.Z and (CommandKind.Z in target): - count += 1 - print(f"Z byproduct, node = {cmd.node}, domain = {cmd.domain}") - elif cmd.kind == CommandKind.C and (CommandKind.C in target): - count += 1 - print(f"Clifford, node = {cmd.node}, Clifford = {cmd.clifford}") - - if len(self.__seq) > i + 1: - print( - f"{len(self.__seq) - lim} more commands truncated. Change lim argument of print_pattern() to show more" - ) + warnings.warn( + "Method `print_pattern` is deprecated. Use one of the methods `to_ascii`, `to_latex`, `to_unicode`, or the function `graphix.pretty_print.pattern_to_str`.", + DeprecationWarning, + stacklevel=1, + ) + print(pattern_to_str(self, OutputFormat.ASCII, left_to_right=True, limit=lim, target=target)) def standardize(self, method="direct") -> None: """Execute standardization of the pattern. @@ -1771,7 +1785,8 @@ def cmd_to_qasm3(cmd): def assert_permutation(original: list[int], user: list[int]) -> None: """Check that the provided `user` node list is a permutation from `original`.""" node_set = set(user) - assert node_set == set(original), f"{node_set} != {set(original)}" + if node_set != set(original): + raise ValueError(f"{node_set} != {set(original)}") for node in user: if node in node_set: node_set.remove(node) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py new file mode 100644 index 000000000..b28fe6038 --- /dev/null +++ b/graphix/pretty_print.py @@ -0,0 +1,286 @@ +"""Pretty-printing utilities.""" + +from __future__ import annotations + +import dataclasses +import enum +import math +import string +from dataclasses import MISSING +from enum import Enum +from fractions import Fraction +from typing import TYPE_CHECKING, SupportsFloat + +from graphix import command + +if TYPE_CHECKING: + from collections.abc import Container + + # these live only in the stub package, not at runtime + from _typeshed import DataclassInstance + + from graphix.command import Node + from graphix.pattern import Pattern + + +class OutputFormat(Enum): + """Enumeration of the output format for pretty-printing.""" + + ASCII = enum.auto() + LaTeX = enum.auto() + Unicode = enum.auto() + + +def angle_to_str(angle: float, output: OutputFormat, max_denominator: int = 1000) -> str: + r""" + Return a string representation of an angle given in units of π. + + - If the angle is a "simple" fraction of π (within the given max_denominator and a small tolerance), + it returns a fractional string, e.g. "π/2", "2π", or "-3π/4". + - Otherwise, it returns the angle in radians (angle * π) formatted to two decimal places. + + Parameters + ---------- + angle : float + The angle in multiples of π (e.g., 0.5 means π/2). + output : OutputFormat + Desired formatting style: Unicode (π symbol), LaTeX (\pi), or ASCII ("pi"). + max_denominator : int, optional + Maximum denominator for detecting a simple fraction (default: 1000). + + Returns + ------- + str + The formatted angle. + """ + frac = Fraction(angle).limit_denominator(max_denominator) + + if not math.isclose(angle, float(frac)): + rad = angle * math.pi + + return f"{rad:.2f}" + + num, den = frac.numerator, frac.denominator + sign = "-" if num < 0 else "" + num = abs(num) + + if output == OutputFormat.LaTeX: + pi = r"\pi" + + def mkfrac(num: str, den: str) -> str: + return rf"\frac{{{num}}}{{{den}}}" + else: + pi = "π" if output == OutputFormat.Unicode else "pi" + + def mkfrac(num: str, den: str) -> str: + return f"{num}/{den}" + + if den == 1: + if num == 0: + return "0" + if num == 1: + return f"{sign}{pi}" + return f"{sign}{num}{pi}" + + den_str = f"{den}" + num_str = pi if num == 1 else f"{num}{pi}" + return f"{sign}{mkfrac(num_str, den_str)}" + + +def domain_to_str(domain: set[Node]) -> str: + """Return the string representation of a domain.""" + return f"{{{','.join(str(node) for node in domain)}}}" + + +SUBSCRIPTS = str.maketrans(string.digits, "₀₁₂₃₄₅₆₇₈₉") +SUPERSCRIPTS = str.maketrans(string.digits, "⁰¹²³⁴⁵⁶⁷⁸⁹") + + +def command_to_str(cmd: command.Command, output: OutputFormat) -> str: + """Return the string representation of a command according to the given format. + + Parameters + ---------- + cmd: Command + The command to pretty print. + output: OutputFormat + The expected format. + """ + # Circumvent circular import + from graphix.fundamentals import Plane + + out = [cmd.kind.name] + + if cmd.kind == command.CommandKind.E: + u, v = cmd.nodes + if output == OutputFormat.LaTeX: + out.append(f"_{{{u},{v}}}") + elif output == OutputFormat.Unicode: + u_subscripts = str(u).translate(SUBSCRIPTS) + v_subscripts = str(v).translate(SUBSCRIPTS) + out.append(f"{u_subscripts}₋{v_subscripts}") + else: + out.append(f"({u},{v})") + elif cmd.kind == command.CommandKind.T: + pass + else: + # All other commands have a field `node` to print, together + # with some other arguments and/or domains. + arguments = [] + if cmd.kind == command.CommandKind.M: + if cmd.plane != Plane.XY: + arguments.append(cmd.plane.name) + # We use `SupportsFloat` since `isinstance(cmd.angle, float)` + # is `False` if `cmd.angle` is an integer. + if isinstance(cmd.angle, SupportsFloat): + angle = float(cmd.angle) + if not math.isclose(angle, 0.0): + arguments.append(angle_to_str(angle, output)) + else: + # If the angle is a symbolic expression, we can only delegate the printing + # TODO: We should have a mean to specify the format + arguments.append(str(cmd.angle * math.pi)) + elif cmd.kind == command.CommandKind.C: + arguments.append(str(cmd.clifford)) + # Use of `==` here for mypy + command_domain = ( + cmd.domain + if cmd.kind == command.CommandKind.X # noqa: PLR1714 + or cmd.kind == command.CommandKind.Z + or cmd.kind == command.CommandKind.S + else None + ) + if output == OutputFormat.LaTeX: + out.append(f"_{{{cmd.node}}}") + if arguments: + out.append(f"^{{{','.join(arguments)}}}") + elif output == OutputFormat.Unicode: + node_subscripts = str(cmd.node).translate(SUBSCRIPTS) + out.append(f"{node_subscripts}") + if arguments: + out.append(f"({','.join(arguments)})") + else: + arguments = [str(cmd.node), *arguments] + if command_domain: + arguments.append(domain_to_str(command_domain)) + command_domain = None + out.append(f"({','.join(arguments)})") + if cmd.kind == command.CommandKind.M and (cmd.s_domain or cmd.t_domain): + out = ["[", *out, "]"] + if cmd.t_domain: + if output == OutputFormat.LaTeX: + t_domain_str = f"{{}}_{{{','.join(str(node) for node in cmd.t_domain)}}}" + elif output == OutputFormat.Unicode: + t_domain_subscripts = [str(node).translate(SUBSCRIPTS) for node in cmd.t_domain] + t_domain_str = "₊".join(t_domain_subscripts) + else: + t_domain_str = f"{{{','.join(str(node) for node in cmd.t_domain)}}}" + out = [t_domain_str, *out] + command_domain = cmd.s_domain + if command_domain: + if output == OutputFormat.LaTeX: + domain_str = f"^{{{','.join(str(node) for node in command_domain)}}}" + elif output == OutputFormat.Unicode: + domain_superscripts = [str(node).translate(SUPERSCRIPTS) for node in command_domain] + domain_str = "⁺".join(domain_superscripts) + else: + domain_str = f"{{{','.join(str(node) for node in command_domain)}}}" + out.append(domain_str) + return f"{''.join(out)}" + + +def pattern_to_str( + pattern: Pattern, + output: OutputFormat, + left_to_right: bool = False, + limit: int = 40, + target: Container[command.CommandKind] | None = None, +) -> str: + """Return the string representation of a pattern according to the given format. + + Parameters + ---------- + pattern: Pattern + The pattern to pretty print. + output: OutputFormat + The expected format. + left_to_right: bool + Optional. If `True`, the first command will appear on the beginning of + the resulting string. If `False` (the default), the first command will + appear at the end of the string. + """ + separator = r"\," if output == OutputFormat.LaTeX else " " + command_list = list(pattern) + if target is not None: + command_list = [command for command in command_list if command.kind in target] + if not left_to_right: + command_list.reverse() + truncated = len(command_list) > limit + short_command_list = command_list[: limit - 1] if truncated else command_list + result = separator.join(command_to_str(command, output) for command in short_command_list) + if output == OutputFormat.LaTeX: + result = f"\\({result}\\)" + if truncated: + return f"{result}...({len(command_list) - limit + 1} more commands)" + return result + + +class DataclassPrettyPrintMixin: + """ + Mixin for a concise, eval-friendly `repr` of dataclasses. + + Compared to the default dataclass `repr`: + - Class variables are omitted (dataclasses.fields only returns actual fields). + - Fields whose values equal their defaults are omitted. + - Field names are only shown when preceding fields have been omitted, ensuring positional listings when possible. + + Use with `@dataclass(repr=False)` on the target class. + """ + + def __repr__(self: DataclassInstance) -> str: + """Return a representation string for a dataclass.""" + cls_name = type(self).__name__ + arguments = [] + saw_omitted = False + for field in dataclasses.fields(self): + value = getattr(self, field.name) + if field.default is not MISSING or field.default_factory is not MISSING: + default = field.default_factory() if field.default_factory is not MISSING else field.default + if value == default: + saw_omitted = True + continue + custom_repr = field.metadata.get("repr") + value_str = custom_repr(value) if custom_repr else repr(value) + if saw_omitted: + arguments.append(f"{field.name}={value_str}") + else: + arguments.append(value_str) + arguments_str = ", ".join(arguments) + return f"{cls_name}({arguments_str})" + + +class EnumPrettyPrintMixin: + """ + Mixin to provide a concise, eval-friendly repr for Enum members. + + Compared to the default ``, this mixin's `__repr__` + returns `ClassName.MEMBER_NAME`, which can be evaluated in Python (assuming the + enum class is in scope) to retrieve the same member. + """ + + def __repr__(self) -> str: + """ + Return a representation string of an Enum member. + + Returns + ------- + str + A string in the form `ClassName.MEMBER_NAME`. + """ + # Equivalently (as of Python 3.12), `str(value)` also produces + # "ClassName.MEMBER_NAME", but we build it explicitly here for + # clarity. + if not isinstance(self, Enum): + msg = "EnumMixin can only be used with Enum classes." + raise TypeError(msg) + return f"{self.__class__.__name__}.{self.name}" diff --git a/graphix/transpiler.py b/graphix/transpiler.py index cb4546869..3817e160f 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -11,10 +11,12 @@ from typing import TYPE_CHECKING, Callable import numpy as np +from typing_extensions import assert_never from graphix import command, instruction, parameter from graphix.command import CommandKind, E, M, N, X, Z from graphix.fundamentals import Plane +from graphix.instruction import Instruction, InstructionKind from graphix.ops import Ops from graphix.parameter import ExpressionOrSupportsFloat, Parameter from graphix.pattern import Pattern @@ -22,7 +24,7 @@ from graphix.sim.statevec import Data, Statevec if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Iterable, Mapping, Sequence @dataclasses.dataclass @@ -67,7 +69,7 @@ class Circuit: List containing the gate sequence applied. """ - def __init__(self, width: int): + def __init__(self, width: int, instr: Iterable[Instruction] | None = None) -> None: """ Construct a circuit. @@ -75,10 +77,59 @@ def __init__(self, width: int): ---------- width : int number of logical qubits for the gate network + instr : list[instruction.Instruction] | None + Optional. List of initial instructions. """ self.width = width - self.instruction: list[instruction.Instruction] = [] + self.instruction: list[Instruction] = [] self.active_qubits = set(range(width)) + if instr is not None: + self.extend(instr) + + def add(self, instr: Instruction) -> None: + """Add an instruction to the circuit.""" + if instr.kind == InstructionKind.CCX: + self.ccx(instr.controls[0], instr.controls[1], instr.target) + elif instr.kind == InstructionKind.RZZ: + self.rzz(instr.control, instr.target, instr.angle) + elif instr.kind == InstructionKind.CNOT: + self.cnot(instr.control, instr.target) + elif instr.kind == InstructionKind.SWAP: + self.swap(instr.targets[0], instr.targets[1]) + elif instr.kind == InstructionKind.H: + self.h(instr.target) + elif instr.kind == InstructionKind.S: + self.s(instr.target) + elif instr.kind == InstructionKind.X: + self.x(instr.target) + elif instr.kind == InstructionKind.Y: + self.y(instr.target) + elif instr.kind == InstructionKind.Z: + self.z(instr.target) + elif instr.kind == InstructionKind.I: + self.i(instr.target) + elif instr.kind == InstructionKind.M: + self.m(instr.target, instr.plane, instr.angle) + elif instr.kind == InstructionKind.RX: + self.rx(instr.target, instr.angle) + elif instr.kind == InstructionKind.RY: + self.ry(instr.target, instr.angle) + elif instr.kind == InstructionKind.RZ: + self.rz(instr.target, instr.angle) + # Use of `==` here for mypy + elif instr.kind == InstructionKind._XC or instr.kind == InstructionKind._ZC: # noqa: PLR1714 + raise ValueError(f"Unsupported instruction: {instr}") + else: + assert_never(instr.kind) + + def extend(self, instrs: Iterable[Instruction]) -> None: + """Add instructions to the circuit.""" + for instr in instrs: + self.add(instr) + + def __repr__(self) -> str: + """Return a representation of the Circuit.""" + return f"Circuit(width={self.width}, instr={self.instruction})" def cnot(self, control: int, target: int): """Apply a CNOT gate. diff --git a/pyproject.toml b/pyproject.toml index 07ff37fad..c341dd3fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,6 +144,7 @@ files = [ "graphix/ops.py", "graphix/parameter.py", "graphix/pauli.py", + "graphix/pretty_print.py", "graphix/pyzx.py", "graphix/rng.py", "graphix/states.py", @@ -162,6 +163,7 @@ files = [ "tests/test_kraus.py", "tests/test_parameter.py", "tests/test_pauli.py", + "tests/test_pretty_print.py", "tests/test_pyzx.py", "tests/test_rng.py", ] @@ -186,6 +188,7 @@ include = [ "graphix/ops.py", "graphix/parameter.py", "graphix/pauli.py", + "graphix/pretty_print.py", "graphix/pyzx.py", "graphix/rng.py", "graphix/states.py", @@ -204,6 +207,7 @@ include = [ "tests/test_kraus.py", "tests/test_parameter.py", "tests/test_pauli.py", + "tests/test_pretty_print.py", "tests/test_pyzx.py", "tests/test_rng.py", ] diff --git a/tests/test_pattern.py b/tests/test_pattern.py index bc8976e84..58650824a 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -49,13 +49,27 @@ def choice(self, _outcomes: list[Outcome]) -> Outcome: class TestPattern: - # this fails without behaviour modification def test_manual_generation(self) -> None: pattern = Pattern() pattern.add(N(node=0)) pattern.add(N(node=1)) pattern.add(M(node=0)) + def test_init(self) -> None: + pattern = Pattern(input_nodes=[1, 0], cmds=[N(node=2), M(node=1)], output_nodes=[2, 0]) + assert pattern.input_nodes == [1, 0] + assert pattern.output_nodes == [2, 0] + with pytest.raises(ValueError): + Pattern(input_nodes=[1, 0], cmds=[N(node=2), M(node=1)], output_nodes=[0, 1, 2]) + + def test_eq(self) -> None: + pattern1 = Pattern(input_nodes=[1, 0], cmds=[N(node=2), M(node=1)], output_nodes=[2, 0]) + pattern2 = Pattern(input_nodes=[1, 0], cmds=[N(node=2), M(node=1)], output_nodes=[2, 0]) + assert pattern1 == pattern2 + pattern1 = Pattern(input_nodes=[1, 0], cmds=[N(node=2), M(node=1)]) + pattern2 = Pattern(input_nodes=[1, 0], cmds=[N(node=2), M(node=1)], output_nodes=[2, 0]) + assert pattern1 != pattern2 + def test_standardize(self, fx_rng: Generator) -> None: nqubits = 2 depth = 1 diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py new file mode 100644 index 000000000..afef03dc2 --- /dev/null +++ b/tests/test_pretty_print.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import math + +import pytest +from numpy.random import PCG64, Generator + +from graphix import command, instruction +from graphix.clifford import Clifford +from graphix.fundamentals import Plane +from graphix.pattern import Pattern +from graphix.pretty_print import OutputFormat, pattern_to_str +from graphix.random_objects import rand_circuit +from graphix.transpiler import Circuit + + +def test_circuit_repr() -> None: + circuit = Circuit(width=3, instr=[instruction.H(0), instruction.RX(1, math.pi), instruction.CCX(0, (1, 2))]) + assert repr(circuit) == "Circuit(width=3, instr=[H(0), RX(1, pi), CCX(0, (1, 2))])" + + +def j_alpha() -> Pattern: + return Pattern(input_nodes=[1], cmds=[command.N(2), command.E((1, 2)), command.M(1), command.X(2, domain={1})]) + + +def test_pattern_repr_j_alpha() -> None: + p = j_alpha() + assert repr(p) == "Pattern(input_nodes=[1], cmds=[N(2), E((1, 2)), M(1), X(2, {1})], output_nodes=[2])" + + +def test_pattern_pretty_print_j_alpha() -> None: + p = j_alpha() + assert str(p) == "X(2,{1}) M(1) E(1,2) N(2)" + assert p.to_unicode() == "X₂¹ M₁ E₁₋₂ N₂" + assert p.to_latex() == r"\(X_{2}^{1}\,M_{1}\,E_{1,2}\,N_{2}\)" + + +def example_pattern() -> Pattern: + return Pattern( + cmds=[ + command.N(1), + command.N(2), + command.N(3), + command.N(10), + command.N(4), + command.E((1, 2)), + command.C(1, Clifford.H), + command.M(1, Plane.YZ, 0.5), + command.M(2, Plane.XZ, -0.25), + command.M(10, Plane.XZ, -0.25), + command.M(3, Plane.XY, 0.1, s_domain={1, 10}, t_domain={2}), + command.M(4, s_domain={1}, t_domain={2, 3}), + ] + ) + + +def test_pattern_repr_example() -> None: + p = example_pattern() + assert ( + repr(p) + == "Pattern(cmds=[N(1), N(2), N(3), N(10), N(4), E((1, 2)), C(1, Clifford.H), M(1, Plane.YZ, 0.5), M(2, Plane.XZ, -0.25), M(10, Plane.XZ, -0.25), M(3, angle=0.1, s_domain={1, 10}, t_domain={2}), M(4, s_domain={1}, t_domain={2, 3})])" + ) + + +def test_pattern_pretty_print_example() -> None: + p = example_pattern() + assert ( + str(p) + == "{2,3}[M(4)]{1} {2}[M(3,pi/10)]{1,10} M(10,XZ,-pi/4) M(2,XZ,-pi/4) M(1,YZ,pi/2) C(1,H) E(1,2) N(4) N(10) N(3) N(2) N(1)" + ) + assert p.to_unicode() == "₂₊₃[M₄]¹ ₂[M₃(π/10)]¹⁺¹⁰ M₁₀(XZ,-π/4) M₂(XZ,-π/4) M₁(YZ,π/2) C₁(H) E₁₋₂ N₄ N₁₀ N₃ N₂ N₁" + assert ( + p.to_latex() + == r"\({}_{2,3}[M_{4}]^{1}\,{}_{2}[M_{3}^{\frac{\pi}{10}}]^{1,10}\,M_{10}^{XZ,-\frac{\pi}{4}}\,M_{2}^{XZ,-\frac{\pi}{4}}\,M_{1}^{YZ,\frac{\pi}{2}}\,C_{1}^{H}\,E_{1,2}\,N_{4}\,N_{10}\,N_{3}\,N_{2}\,N_{1}\)" + ) + assert ( + pattern_to_str(p, output=OutputFormat.ASCII, limit=9, left_to_right=True) + == "N(1) N(2) N(3) N(10) N(4) E(1,2) C(1,H) M(1,YZ,pi/2)...(4 more commands)" + ) + + +@pytest.mark.parametrize("jumps", range(1, 11)) +@pytest.mark.parametrize("output", list(OutputFormat)) +def test_pattern_pretty_print_random(fx_bg: PCG64, jumps: int, output: OutputFormat) -> None: + rng = Generator(fx_bg.jumped(jumps)) + rand_pat = rand_circuit(5, 5, rng=rng).transpile().pattern + pattern_to_str(rand_pat, output) diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py index ca94e7e7a..8168a7b3a 100644 --- a/tests/test_transpiler.py +++ b/tests/test_transpiler.py @@ -130,3 +130,22 @@ def simulate_and_measure() -> int: nb_shots = 10000 count = sum(1 for _ in range(nb_shots) if simulate_and_measure()) assert abs(count - nb_shots / 2) < nb_shots / 20 + + def test_add_extend(self) -> None: + circuit = Circuit(3) + circuit.ccx(0, 1, 2) + circuit.rzz(0, 1, 2) + circuit.cnot(0, 1) + circuit.swap(0, 1) + circuit.h(0) + circuit.s(0) + circuit.x(0) + circuit.y(0) + circuit.z(0) + circuit.i(0) + circuit.m(0, Plane.XY, 0.5) + circuit.rx(1, 0.5) + circuit.ry(2, 0.5) + circuit.rz(1, 0.5) + circuit2 = Circuit(3, instr=circuit.instruction) + assert circuit.instruction == circuit2.instruction