diff --git a/docs/notebooks/bose_hubbard.ipynb b/docs/notebooks/bose_hubbard.ipynb new file mode 100644 index 0000000..808ec29 --- /dev/null +++ b/docs/notebooks/bose_hubbard.ipynb @@ -0,0 +1,3353 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6c56b043", + "metadata": {}, + "source": [ + "# Bose-Hubbard model" + ] + }, + { + "cell_type": "markdown", + "id": "22cefce0", + "metadata": {}, + "source": [ + "This notebook builds a **Bose–Hubbard** Hamiltonian over a graph and minimizes its energy with a **photonic variational circuit** in Optyx. We’ll:\n", + " \n", + "1) prepare a photonic ansatz, \n", + "2) define creation/annihilation and number operators on a single photonic mode, \n", + "3) assemble the Hamiltonian \\(H\\) from a NetworkX graph using *function syntax*, \n", + "4) evaluate $E(\\boldsymbol\\theta)=\\langle\\psi|H|\\psi\\rangle$ and its gradients, \n", + "5) run an optimiser with a decaying learning rate, and \n", + "6) plot convergence" + ] + }, + { + "cell_type": "markdown", + "id": "11e739e8", + "metadata": {}, + "source": [ + "## Ansatz\n", + "Define the variational ansatz:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "12d61180-27de-4d3a-bad6-73a0e94c8b16", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2025-10-14T17:04:54.650833\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.10.5, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from optyx import photonic, classical\n", + "\n", + "circuit = photonic.Create(1, 1, 1) >> photonic.ansatz(3, 4) >> photonic.Id(2) @ classical.Select(1)\n", + "circuit.draw()" + ] + }, + { + "cell_type": "markdown", + "id": "685b16c4", + "metadata": {}, + "source": [ + "## The model" + ] + }, + { + "cell_type": "markdown", + "id": "a49ddabf", + "metadata": {}, + "source": [ + "Define the creation/annihilation channels for a single photonic mode; we’ll reuse these to place operators on specific lattice sites.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5736ad24", + "metadata": {}, + "outputs": [], + "source": [ + "from optyx import Channel\n", + "from optyx.core.diagram import mode\n", + "from optyx.core.zw import W, Create, Select\n", + "\n", + "creation_op = Channel(\n", + " \"a†\",\n", + " Create(1) @ mode >> W(2).dagger()\n", + ")\n", + "\n", + "annihilation_op = Channel(\n", + " \"a\",\n", + " W(2) >> Select(1) @ mode\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9e720322", + "metadata": {}, + "source": [ + "We’ll construct the Bose–Hubbard Hamiltonian $H(t,U,\\mu)$ on an arbitrary graph using function syntax so each term edits only its intended wire(s).\n" + ] + }, + { + "cell_type": "markdown", + "id": "77676e33", + "metadata": {}, + "source": [ + "![Bose Hubbard model](./bose_hubbard.png \"Bose Hubbard model\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fbf6443", + "metadata": {}, + "outputs": [], + "source": [ + "import networkx as nx\n", + "from optyx import Diagram, qmode\n", + "from optyx.photonic import NumOp, Scalar\n", + "\n", + "def bose_hubbard_from_graph(\n", + " graph: nx.Graph,\n", + " t: float,\n", + " mu: float,\n", + " U: float\n", + "):\n", + " nodes = sorted(graph.nodes())\n", + " idx = {u: i for i, u in enumerate(nodes)}\n", + " N = len(nodes)\n", + "\n", + " H = None\n", + "\n", + " # hopping: -t (a_i^dagger a_j + a_j^dagger a_i)\n", + " for u, v in graph.edges():\n", + " i, j = idx[u], idx[v]\n", + "\n", + " @Diagram.from_callable(dom=qmode**N, cod=qmode**N)\n", + " def hop_ij(*in_wires):\n", + " # a_i^dagger a_j\n", + " out = list(in_wires)\n", + " out[i] = creation_op(out[i])\n", + " out[j] = annihilation_op(out[j])\n", + " Scalar(-t)()\n", + " return tuple(out)\n", + "\n", + " @Diagram.from_callable(dom=qmode**N, cod=qmode**N)\n", + " def hop_ji(*in_wires):\n", + " # a_j^dagger a_i\n", + " out = list(in_wires)\n", + " out[j] = creation_op(out[j])\n", + " out[i] = annihilation_op(out[i])\n", + " Scalar(-t)()\n", + " return tuple(out)\n", + "\n", + " H = (hop_ij + hop_ji) if H is None else (H + hop_ij + hop_ji)\n", + "\n", + " # onsite interaction: (U/2) a_i^dagger a_i^dagger a_i a_i\n", + " for u in nodes:\n", + " i = idx[u]\n", + "\n", + " @Diagram.from_callable(dom=qmode**N, cod=qmode**N)\n", + " def quartic_i(*in_wires, i=i):\n", + " out = list(in_wires)\n", + " w = out[i]\n", + " w = creation_op(w)\n", + " w = creation_op(w)\n", + " w = annihilation_op(w)\n", + " w = annihilation_op(w)\n", + " out[i] = w\n", + " Scalar(U/2)()\n", + " return tuple(out)\n", + "\n", + " H = quartic_i if H is None else (H + quartic_i)\n", + "\n", + " # -mu n_i\n", + " for u in nodes:\n", + " i = idx[u]\n", + "\n", + " @Diagram.from_callable(dom=qmode**N, cod=qmode**N)\n", + " def n_i(*in_wires, i=i):\n", + " out = list(in_wires)\n", + " out[i] = NumOp()(out[i])\n", + " Scalar(-mu)()\n", + " return tuple(out)\n", + "\n", + " H = n_i if H is None else (H + n_i)\n", + "\n", + " return H\n" + ] + }, + { + "cell_type": "markdown", + "id": "770d29fb", + "metadata": {}, + "source": [ + "Assemble hopping, on-site interaction, and chemical-potential terms into one Diagram representing the full Hamiltonian on $qmode^{\\otimes N}$. Start with a 2-site chain and some parameters:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "900d0c48", + "metadata": {}, + "outputs": [], + "source": [ + "import networkx as nx\n", + "\n", + "graph = nx.path_graph(2) # 2 sites\n", + "\n", + "t, U, mu = 0.10, 4.0, 2.0\n", + "\n", + "hamiltonian = bose_hubbard_from_graph(graph, t, mu, U)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bb26b7d4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2025-10-14T17:04:54.764399\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.10.5, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "hamiltonian.draw()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8a02308c-aba4-4a39-9325-faee4c1f0c2d", + "metadata": {}, + "outputs": [], + "source": [ + "expectation = circuit >> hamiltonian >> circuit.dagger()" + ] + }, + { + "cell_type": "markdown", + "id": "c76edfb7", + "metadata": {}, + "source": [ + "## Optimisation" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f4e5462d-464d-4b09-8574-6867522b61ea", + "metadata": {}, + "outputs": [], + "source": [ + "from optyx.core.backends import PermanentBackend\n", + "\n", + "\n", + "def to_float(x):\n", + " if isinstance(x, complex):\n", + " assert x.imag < 1e-8, x\n", + " return x.real\n", + " return x\n", + "\n", + "free_syms = list(expectation.free_symbols)\n", + "\n", + "f_exp = lambda xs: to_float(\n", + " expectation.lambdify(*free_syms)(*xs)\n", + " .eval(PermanentBackend())\n", + " .tensor\n", + " .array\n", + ")\n", + "\n", + "def d_f_exp(xs):\n", + " return [\n", + " expectation.grad(s).lambdify(*free_syms)(*xs)\n", + " .eval(PermanentBackend())\n", + " .tensor\n", + " .array\n", + " for s in free_syms\n", + " ]" + ] + }, + { + "cell_type": "markdown", + "id": "5adb53ee", + "metadata": {}, + "source": [ + "Run a short gradient-descent loop with a decaying learning rate to drive the energy down." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "24355282-4c9c-40b3-b643-461981f9475f", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "xs = []\n", + "fxs = []\n", + "dfxs = []\n", + "\n", + "def optimize(x0):\n", + " x = x0\n", + " lr = 5\n", + " steps = 10\n", + " for _ in tqdm(range(steps)):\n", + " fx = f_exp(x)\n", + " dfx = d_f_exp(x)\n", + "\n", + " xs.append(x[::])\n", + " fxs.append(fx)\n", + " dfxs.append(dfx)\n", + " for i, dfxx in enumerate(dfx):\n", + " x[i] = to_float(x[i] - lr * dfxx)\n", + "\n", + " lr *= 0.2*(i**(1/6)) # make lr smaller with each step\n", + "\n", + " xs.append(x[::])\n", + " fxs.append(f_exp(x))\n", + " dfxs.append(d_f_exp(x))" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "cde44850", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10/10 [05:31<00:00, 33.16s/it]\n" + ] + } + ], + "source": [ + "optimize([2]*len(free_syms))" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "088c8fa8-76b3-45bd-a5dd-b337be0d33b1", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2025-10-14T17:53:51.051493\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.10.5, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "sns.set_theme()\n", + "\n", + "fig, axs = plt.subplots(1, 1, figsize=(10, 7))\n", + "\n", + "axs.plot(range(len(xs)), fxs, c=\"#0072B2\", marker='o')\n", + "axs.set_xlabel('Iteration', fontsize=18)\n", + "axs.set_ylabel('Expected Energy', fontsize=18)\n", + "axs.grid(True)\n", + "axs.tick_params(axis='both', which='major', labelsize=16)\n", + "axs.tick_params(axis='both', which='minor', labelsize=16)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92dfc684", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/bose_hubbard.png b/docs/notebooks/bose_hubbard.png new file mode 100644 index 0000000..ea43a2e Binary files /dev/null and b/docs/notebooks/bose_hubbard.png differ diff --git a/docs/notebooks/bosonic-vqe-2.ipynb b/docs/notebooks/bosonic-vqe-2.ipynb index d028b85..dc860e5 100644 --- a/docs/notebooks/bosonic-vqe-2.ipynb +++ b/docs/notebooks/bosonic-vqe-2.ipynb @@ -17,7 +17,7 @@ " \n", " \n", " \n", - " 2025-09-05T20:09:47.097250\n", + " 2025-10-09T17:20:26.561086\n", " image/svg+xml\n", " \n", " \n", @@ -46,62 +46,62 @@ "L 223.2 7.2 \n", "L 7.2 7.2 \n", "z\n", - "\" clip-path=\"url(#pd1650474ac)\" style=\"fill: #ffffff; stroke: #ffffff; stroke-linejoin: miter\"/>\n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: #ffffff; stroke: #ffffff; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: none; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p605fbb3927)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", @@ -878,7 +878,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -911,12 +911,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2025-09-05T20:09:47.853222\n", + " 2025-10-09T17:20:27.356367\n", " image/svg+xml\n", " \n", " \n", @@ -931,161 +931,209 @@ " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", + "\" clip-path=\"url(#pe741f2f76b)\" style=\"fill: #ffffff; stroke: #ffffff; stroke-linejoin: miter\"/>\n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", + "\" clip-path=\"url(#pe741f2f76b)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", - " \n", + "\" clip-path=\"url(#pe741f2f76b)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", - " \n", + "\" clip-path=\"url(#pe741f2f76b)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\" clip-path=\"url(#pe741f2f76b)\" style=\"fill: #ffffff; stroke: #000000; stroke-linejoin: miter\"/>\n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1235,7 +1283,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1245,7 +1293,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1255,7 +1303,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1265,7 +1313,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1275,7 +1323,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1285,7 +1333,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1295,7 +1343,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1305,7 +1353,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1315,7 +1363,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1325,7 +1373,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1335,7 +1383,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1345,7 +1393,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1355,7 +1403,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1365,7 +1413,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1374,70 +1422,69 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1717,14 +1776,14 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" ], "text/plain": [ - "
" + "
" ] }, "metadata": {}, @@ -1770,12 +1829,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "f4e5462d-464d-4b09-8574-6867522b61ea", "metadata": {}, "outputs": [], "source": [ - "from optyx.core.backends import DiscopyBackend\n", + "from optyx.core.backends import PermanentBackend\n", "\n", "def to_float(x):\n", " if isinstance(x, complex):\n", @@ -1785,18 +1844,26 @@ "\n", "free_syms = list(expectation.free_symbols)\n", "\n", - "f_exp = lambda xs: to_float(expectation.lambdify(*free_syms)(*xs).eval().tensor.array)\n", + "f_exp = lambda xs: to_float(\n", + " expectation.lambdify(*free_syms)(*xs)\n", + " .eval(PermanentBackend())\n", + " .tensor\n", + " .array\n", + ")\n", "\n", "def d_f_exp(xs):\n", " return [\n", - " expectation.grad(s).lambdify(*free_syms)(*xs).eval().tensor.array\n", + " expectation.grad(s).lambdify(*free_syms)(*xs)\n", + " .eval(PermanentBackend())\n", + " .tensor\n", + " .array\n", " for s in free_syms\n", " ]" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "24355282-4c9c-40b3-b643-461981f9475f", "metadata": {}, "outputs": [], @@ -1827,7 +1894,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "e96227b0-db8e-4871-8c12-b1d034a4bc86", "metadata": {}, "outputs": [ @@ -1835,7 +1902,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 10/10 [04:56<00:00, 29.62s/it]\n" + "100%|██████████| 10/10 [03:00<00:00, 18.01s/it]\n" ] } ], diff --git a/optyx/classical.py b/optyx/classical.py index 920f102..8955d15 100644 --- a/optyx/classical.py +++ b/optyx/classical.py @@ -141,7 +141,7 @@ diagram, path ) -from optyx import ( +from optyx.core.channel import ( bit, mode, qmode, diff --git a/optyx/core/backends.py b/optyx/core/backends.py index 552933f..f3bd54d 100644 --- a/optyx/core/backends.py +++ b/optyx/core/backends.py @@ -43,7 +43,7 @@ >>> diag = Create(1, 1) >> BS >>> backend = QuimbBackend() >>> result = diag.eval(backend) ->>> np.round(result.prob((2, 0)), 1) +>>> np.round(result.single_prob((2, 0)), 1) 0.5 **Compressed contraction (hyper-optimiser reused across calls)** @@ -52,21 +52,22 @@ >>> opt = ReusableHyperCompressedOptimizer(max_repeats=32) >>> backend = QuimbBackend(hyperoptimiser=opt) >>> result = diag.eval(backend) ->>> np.round(result.prob((2, 0)), 1) +>>> np.round(result.single_prob((2, 0)), 1) 0.5 **Unitary circuit simulation with Perceval** >>> from optyx.core.backends import PercevalBackend >>> backend = PercevalBackend() ->>> result = BS.eval(backend) ->>> np.round(result.prob((2, 0)), 1) +>>> result = (Create(1, 1) >> BS).eval(backend) +>>> np.round(result.single_prob((2, 0)), 1) 0.5 """ +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Literal, Sequence from collections import defaultdict from enum import Enum from cotengra import ( @@ -79,18 +80,49 @@ import numpy as np import perceval as pcvl from quimb.tensor import TensorNetwork -from optyx.core.channel import Diagram -from optyx.core.channel import Ty, mode, bit +from optyx.core.channel import Diagram, Ty, mode, bit +from optyx.core.path import Matrix from optyx.utils.utils import preprocess_quimb_tensors_safe +@dataclass +class PercevalEvalConfig: + """ + Configuration for Perceval backend evaluation. + - task: The Perceval task to perform. Allowed values are + "probs" (default), "amps", "single_amp", "single_prob". + - state: A `perceval.BasicState` or a sequence of + non-negative integers (occupation numbers). Defaults + to a bosonic product state |11...1> if the + diagram does not include any photon creations. + Either the creations for all input ports are specified + by the diagram (`Create(...)`) or the user must provide + the `state` argument covering all input ports. + - effect: Required if task is "single_amp" or "single_prob". + A sequence of non-negative integers (occupation numbers) + specifying the output configuration for which to compute + the amplitude or probability. + """ + task: Literal[ + "probs", "amps", "single_amp", "single_prob" + ] = "probs" + state: pcvl.BasicState | Sequence[int] | None = None + effect: pcvl.BasicState | Sequence[int] | None = None + + +ProbDist = dict[tuple[int, ...], float] +Amps = dict[tuple[int, ...], complex] + + class StateType(Enum): """ Enum to represent the type of state represented by the result tensor. """ - AMP = "amp" # pure-state amplitudes - DM = "dm" # density matrix - PROB = "prob" # classical probability distribution + AMP = "amp" # pure-state amplitudes + DM = "dm" # density matrix + PROB = "prob" # classical probability distribution + SINGLE_PROB = "SINGLE_PROB" # single ||^2 probability + SINGLE_AMP = "SINGLE_AMP" # single amplitude @dataclass(frozen=True) @@ -99,7 +131,7 @@ class EvalResult: Class to encapsulate the result of an evaluation of a diagram. """ _tensor: discopy_tensor.Box - output_types: Ty + output_types: Ty | None state_type: StateType @property @@ -113,27 +145,31 @@ def tensor(self) -> discopy_tensor.Box: return self._tensor @property - def density_matrix(self) -> discopy_tensor.Box: + def density_matrix(self) -> np.ndarray: """ Get the density matrix from the result tensor. Returns: - tensor.Box: The density matrix. + np.ndarray: The density matrix. """ if len(self.tensor.dom) != 0: raise ValueError( - "Result tensor must represent a state with inputs." + "Result tensor must represent a state with no inputs. " + + f"Current domain: {self.tensor.dom}" ) if self.state_type not in {StateType.AMP, StateType.DM}: raise TypeError( - "Cannot get density matrix from probability distribution." + f"Cannot get density matrix from {self.state_type}." ) if self.state_type is StateType.AMP: density_matrix = self.tensor.dagger() >> self.tensor - return density_matrix + return density_matrix.array return self.tensor.array - def amplitudes(self, normalise=True) -> dict[tuple[int, ...], float]: + def amplitudes( + self, + normalise=True + ) -> Amps: """ Get the amplitudes from the result tensor. Returns: @@ -142,12 +178,12 @@ def amplitudes(self, normalise=True) -> dict[tuple[int, ...], float]: """ if self.state_type != StateType.AMP: raise TypeError( - ("Cannot get amplitudes from density " + - "matrix or probability distribution.") + (f"Cannot get amplitudes from {self.state_type}.") ) if len(self.tensor.dom) != 0: raise ValueError( - "Result tensor must represent a state with inputs." + "Result tensor must represent a state with no inputs. " + + f"Current domain: {self.tensor.dom}" ) dic = self._convert_array_to_dict(self.tensor.array) @@ -157,7 +193,10 @@ def amplitudes(self, normalise=True) -> dict[tuple[int, ...], float]: for key, value in dic.items()} return dic - def prob_dist(self, round_digits: int = None) -> dict: + def prob_dist( + self, + round_digits: int = None + ) -> ProbDist: """ Get the probability distribution from the result tensor. @@ -167,22 +206,40 @@ def prob_dist(self, round_digits: int = None) -> dict: """ if len(self.tensor.dom) != 0: raise ValueError( - "Result tensor must represent a state with inputs." + "Result tensor must represent a state with no inputs. " + + f"Current domain: {self.tensor.dom}" ) if self.state_type is StateType.AMP: - return self._prob_dist_pure(round_digits) - if self.state_type is StateType.DM: - return self._prob_dist_mixed(round_digits) - if self.state_type is StateType.PROB: - return self._convert_array_to_dict( - self.tensor.array, - round_digits=round_digits + values = self._prob_dist_pure() + elif self.state_type is StateType.DM: + values = self._prob_dist_mixed() + elif self.state_type is StateType.PROB: + values = self._convert_array_to_dict( + self.tensor.array + ) + else: + raise ValueError( + f"Unsupported state_type type: {self.state_type}. " + + "Must be StateType.AMP, StateType.DM, " + + "or StateType.PROB." ) - raise ValueError("Unsupported state_type type. " + - "Must be StateType.AMP, StateType.DM, " + - "or StateType.PROB.") - def prob(self, occupation: tuple) -> float: + if not np.allclose( + np.sum(list(values.values())), 1, atol=1e-12 + ): + total = float(np.sum(list(values.values()))) + if total == 0: + raise ValueError("The probability distribution sums to zero.") + + prob = {k: v / total for k, v in values.items()} + else: + prob = values + + if round_digits is None: + return prob + return {k: round(v, round_digits) for k, v in prob.items()} + + def single_prob(self, occupation: tuple) -> float: """ Get the probability of a specific occupation configuration. @@ -192,13 +249,31 @@ def prob(self, occupation: tuple) -> float: Returns: float: The probability of the specified occupation configuration. """ + if self.state_type == StateType.SINGLE_PROB: + return float(self.tensor.array) + prob_dist = self.prob_dist() return prob_dist.get(occupation, 0.0) + def single_amplitude(self, occupation: tuple) -> complex: + """ + Get the amplitude of a specific occupation configuration. + + Args: + occupation: The occupation configuration to query. + + Returns: + complex: The amplitude of the specified occupation configuration. + """ + if self.state_type == StateType.SINGLE_AMP: + return complex(self.tensor.array) + + dic = self.amplitudes(normalise=False) + return dic.get(occupation, 0.0) + def _convert_array_to_dict( self, - array: np.ndarray, - round_digits: int = None) -> dict: + array: np.ndarray) -> dict: """ Return a dict that maps multi-indices - values for all non-zero entries of an array. @@ -211,13 +286,9 @@ def _convert_array_to_dict( nz_vals = array.flat[nz_flat] nz_multi = np.vstack(np.unravel_index(nz_flat, array.shape)).T - if round_digits is not None: - return {tuple(idx): np.round(val, round_digits) for - idx, val in zip(nz_multi, nz_vals)} - return {tuple(idx): val for idx, val in zip(nz_multi, nz_vals)} - def _prob_dist_pure(self, round_digits: int = None) -> dict: + def _prob_dist_pure(self) -> ProbDist: """ Get the probability distribution for a pure state. @@ -227,20 +298,21 @@ def _prob_dist_pure(self, round_digits: int = None) -> dict: """ values = self._convert_array_to_dict( - self.tensor.array, - round_digits=round_digits + self.tensor.array ) - sum_ = np.sum(np.abs(list(values.values())) ** 2) - return {key: (abs(value) ** 2)/sum_ for key, value in values.items()} - def _prob_dist_mixed( - self, - round_digits: int | None = None) -> dict[tuple[int, ...], float]: + return {k: abs(v) ** 2 for k, v in values.items()} + + def _prob_dist_mixed(self) -> ProbDist: """ Get the probability distribution from a mixed state. This method computes the probability distribution by aggregating occupation configurations based on the output types of the tensor. + Assumes the output types contain at least one 'bit' or 'mode' + These will be treated as measured registers while 'qubit' and 'qmode' + are treated as unmeasured and traced out. + Args: round_digits: Optional number of digits to round the probabilities. @@ -251,10 +323,12 @@ def _prob_dist_mixed( if not any(t in {bit, mode} for t in self.output_types): raise ValueError( - "Output types must contain at least one 'bit' or 'mode'." + "Output types must contain at least one 'bit' or 'mode'." + + "These will be treated as measured registers. " + + f"Current output types: {self.output_types}" ) - values = self._convert_array_to_dict(self.tensor.array, round_digits) + values = self._convert_array_to_dict(self.tensor.array) mask_flat = np.concatenate( [[1] if t in {bit, mode} else [0, 0] for t in self.output_types] ) @@ -270,16 +344,16 @@ def _prob_dist_mixed( zip(key, mask_flat) if not m) if all(occs_unmeasured[i] == occs_unmeasured[i + 1] for i in range(0, len(occs_unmeasured) - 1, 2)): - probs[occ_measured] += amp + + val = float(np.real_if_close(amp)) + if val < 0 and abs(val) < 1e-12: + val = 0.0 + probs[occ_measured] += val for occ in all_measured: probs.setdefault(occ, 0.0) - sum_ = np.sum(list(probs.values())) - prob = { - key: value / sum_ - for key, value in probs.items() - } - return prob + + return probs # pylint: disable=too-few-public-methods @@ -292,7 +366,7 @@ class AbstractBackend(ABC): def _get_matrix( self, diagram: Diagram - ) -> np.ndarray: + ) -> Matrix: """ Get the matrix representation of the diagram. @@ -303,7 +377,7 @@ def _get_matrix( np.ndarray: The matrix representation of the diagram. """ try: - return diagram.to_path().array + return diagram.to_path() except NotImplementedError as error: raise NotImplementedError( "The diagram cannot be converted to a matrix. " + @@ -361,12 +435,13 @@ class QuimbBackend(AbstractBackend): def __init__( self, - hyperoptimiser: Union[ - HyperOptimizer, - ReusableHyperOptimizer, - HyperCompressedOptimizer, - ReusableHyperCompressedOptimizer - ] = None, + hyperoptimiser: ( + HyperOptimizer | + ReusableHyperOptimizer | + HyperCompressedOptimizer | + ReusableHyperCompressedOptimizer | + None + ) = None, contraction_params: dict = None): """ Initialize the Quimb backend. @@ -395,7 +470,7 @@ def eval( tensor_diagram = self._get_discopy_tensor(diagram) - if hasattr(diagram, 'terms'): + if hasattr(tensor_diagram, 'terms'): results = sum( self._process_term(term) for term in tensor_diagram.terms ) @@ -418,12 +493,12 @@ def eval( state_type=state_type ) - def _process_term(self, term: Diagram) -> np.ndarray: + def _process_term(self, term: discopy_tensor.Diagram) -> np.ndarray: """ Process a term in a sum of diagrams. Args: - term (Diagram): The term to process. + term (discopy.tensor.Diagram): The term to process. Returns: np.ndarray: The processed term as a numpy array. @@ -453,7 +528,9 @@ def _process_term(self, term: Diagram) -> np.ndarray: "Unsupported hyperoptimiser type. " + "Use ReusableHyperOptimizer, HyperOptimizer, " + "ReusableHyperCompressedOptimizer, or " + - "HyperCompressedOptimizer." + "HyperCompressedOptimizer. " + + f"Got: {type(self.hyperoptimiser)}" + ) if is_approx: @@ -528,51 +605,327 @@ def eval( """ Evaluate the diagram using Perceval. Works only for unitary operations. - If no `perceval_state` is provided in `extra`, - it defaults to a bosonic product state. Args: diagram (Diagram): The diagram to evaluate. - **extra: Additional arguments for the evaluation, - including 'perceval_state'. + **extra: Additional arguments for the evaluation: + - config: Configuration for Perceval evaluation, + see :class:`PercevalEvalConfig`. Returns: - The result of the evaluation. + The result of the evaluation (EvalResult). """ - if extra: + if hasattr( + diagram, + "terms" + ): + array = 0 + for term in diagram.terms: + arr, output_types, return_type = \ + self._process_term( + term, extra + ) + array += arr + else: + array, output_types, return_type = \ + self._process_term( + diagram, extra + ) + + if array.shape == (1,): + cod = discopy_tensor.Dim(1) + else: + cod = discopy_tensor.Dim(*array.shape) + + return EvalResult( + discopy_tensor.Box( + "Result", + discopy_tensor.Dim(1), + cod, + array + ), + output_types=output_types, + state_type=return_type + ) + + def _get_state_from_creations( + self, + creations, + external_perceval_state + ): + return ( + external_perceval_state * + pcvl.BasicState(creations) + ) + + def _get_effect_from_selections( + self, + selections, + external_perceval_effect + ): + return ( + external_perceval_effect * + pcvl.BasicState(selections) + ) + + def _post_select_vacuum( + self, + dist, + m_orig, + k_extra, + task + ): + """Keep only entries where extra (ancilla) + modes are all 0, then drop them.""" + if task in ("probs", "amps"): + if k_extra <= 0: + return dist + return { + k[:m_orig]: v + for k, v in dist.items() + if all(x == 0 for x in k[m_orig:]) + } + return dist + + def _process_state(self, perceval_state): + if not isinstance(perceval_state, pcvl.BasicState): try: - perceval_state: pcvl.StateVector = extra["perceval_state"] - except KeyError as error: + perceval_state = pcvl.BasicState(list(perceval_state)) + except Exception as e: raise TypeError( - "PercevalBackend.eval requires " + - "a 'perceval_state=' keyword." - ) from error - else: - perceval_state = pcvl.StateVector( - [1] * len(diagram.dom) + "perceval_state must be a perceval.BasicState" + " or a sequence of non-negative " + + "integers (occupation numbers). " + + f"Got: {type(perceval_state)}" + ) from e + return perceval_state + + def _process_effect(self, perceval_effect): + if perceval_effect is None: + return None + if not isinstance(perceval_effect, pcvl.BasicState): + try: + perceval_effect = pcvl.BasicState(list(perceval_effect)) + except ValueError as e: + raise ValueError( + "perceval_effect must be a perceval.BasicState" + " or a sequence of non-negative " + + "integers (occupation numbers). " + + f"Got: {type(perceval_effect)}" + ) from e + return perceval_effect + + def _dilate( + self, + matrix, + perceval_state + ): + warnings.warn( + "The provided matrix is not unitary. " + "PercevalBackend expects a unitary matrix. " + "Dilation will be used. " + "This can impact performance.", + UserWarning, + stacklevel=2 + ) + current_n_create = len(matrix.creations) + matrix = matrix.dilate() + pad_zeros = len(matrix.creations) - current_n_create + + perceval_state = perceval_state * pcvl.BasicState( + [0] * pad_zeros + ) + + return matrix, perceval_state + + def _process_io( + self, + term, + matrix, + extra + ): + """ + Process the input and output states/effects for the diagram. + + Args: + term (discopy.tensor.Diagram): The term to process. + matrix (Matrix): The matrix representation of the diagram. + extra (dict): Additional arguments for the evaluation. + Returns: + perceval_state (pcvl.BasicState): The processed input state. + perceval_effect (pcvl.BasicState | None): + The processed output effect. + task (str): The Perceval task to perform. + """ + cfg: PercevalEvalConfig = extra.get("config", PercevalEvalConfig()) + task = cfg.task + state_provided = cfg.state is not None + effect_provided = cfg.effect is not None + is_dom_closed = len(term.dom) == 0 + + if not state_provided and not is_dom_closed: + raise ValueError( + "External 'state' not provided but " + + "the diagram has open input modes. " + + "Provide a 'state' or close all input modes with a state. " + + f"Open input modes: {term.dom}" ) - tensor_diagram = self._get_discopy_tensor(diagram) - sim = pcvl.Simulator(self.perceval_backend) - matrix = self._get_matrix(diagram) + external_perceval_state = pcvl.BasicState([]) + if state_provided: + external_perceval_state = self._process_state(cfg.state) - if not np.allclose( - np.eye(matrix.shape[0]), - matrix.dot(matrix.conj().T) + if external_perceval_state.m != matrix.dom: + raise ValueError( + "The provided 'state' does not match the " + + "number of input modes of the diagram. " + + f"Provided state has {external_perceval_state.m} " + + f"modes but diagram has {matrix.dom} input modes." + ) + + perceval_state = self._process_state( + self._get_state_from_creations( + matrix.creations, + external_perceval_state + ) + ) + + perceval_effect = None + if effect_provided: + perceval_effect = self._process_effect(cfg.effect) + + if perceval_effect.m != matrix.cod: + raise ValueError( + "The provided 'effect' does not match the number " + + "of output modes of the diagram. " + + f"Provided effect has {perceval_effect.m} " + + f"modes but diagram has {matrix.cod} output modes." + ) + + # convert post-selections to pcvl effects and post-selections + if matrix.cod == 0: + sel0 = matrix.selections[0] + m = Matrix( + matrix.array.copy(), + matrix.dom, + 1, + creations=list(matrix.creations), + selections=list(matrix.selections[1:]) + ) + perceval_effect = pcvl.BasicState([sel0]) + matrix = m + + if ( + perceval_effect is not None and + task in ("amps", "probs") ): raise ValueError( - "The provided diagram does not represent a unitary operation." + f"An 'effect' was provided but task='{task}'. " + "Use task='single_amp' or task='single_prob' " + + "when conditioning on an effect." ) - perceval_circuit = self._umatrix_to_perceval_circuit(matrix) + if task in ("single_amp", "single_prob"): + if perceval_effect is None: + raise ValueError( + "The 'perceval_effect' argument must be provided for " + + "task 'single_amp' or 'single_prob'." + ) + + return perceval_state, perceval_effect, task + + def _prepare_simulation( + self, + matrix + ): + """ + Prepare the Perceval simulator with the given matrix. + """ + selections = matrix.selections + + sim = pcvl.Simulator(self.perceval_backend) + postselect_conditions = [ + f"{str([i + matrix.cod])} == {s}" + for i, s in enumerate(selections) + ] + sim.set_postselection( + pcvl.PostSelect(str.join(" & ", postselect_conditions)) + ) + perceval_circuit = self._umatrix_to_perceval_circuit(matrix.array) sim.set_circuit(perceval_circuit) - result = sim.probs(perceval_state) - result = {tuple(k): v for k, v in result.items()} + return sim + + def _simulate( + self, + sim, + perceval_state, + perceval_effect, + task + ): + + if task in ("single_amp", "single_prob"): + if task == "single_prob": + result = float( + sim.probability(perceval_state, perceval_effect) + ) + + else: + result = complex( + sim.prob_amplitude(perceval_state, perceval_effect) + ) + else: + if task == "probs": + result = sim.probs(perceval_state) + result = {tuple(k): float(v) for k, v in result.items()} + else: + sv = sim.evolve(perceval_state) + result = {tuple(k): complex(v) for k, v in sv} + + return result + + def _get_output_params( + self, + term, + task + ): + """ + Get the output parameters for the diagram. - array = np.zeros(tensor_diagram.cod.inside) + Args: + term (discopy.tensor.Diagram): The term to process. + task (str): The Perceval task to perform. + Returns: + output_shape (tuple): The shape of the output array. + output_types (Ty | None): The output types of the diagram. + """ + + if task in ("single_amp", "single_prob"): + return (1,), None + return self._get_discopy_tensor(term).cod.inside, term.cod + + def _get_array_from_result( + self, + result, + output_shape, + task + ): + """ + Get the output array from the simulation result. + + Args: + result (dict | float | complex): The simulation result. + output_shape (tuple): The shape of the output array. + task (str): The Perceval task to perform. + Returns: + np.ndarray: The output array. + """ + array = np.zeros( + output_shape, + dtype=float if task in ("single_prob", "probs") else complex + ) - if result: + if task in ("probs", "amps"): configs = np.fromiter( (i for key in result for i in key), dtype=int, @@ -581,19 +934,165 @@ def eval( coeffs = np.fromiter( result.values(), - dtype=float, + dtype=float if task == "probs" else complex, count=len(result) ) array[tuple(configs.T)] = coeffs + else: + array[0] = result + return array + + def _process_term( + self, + term, + extra + ): + """ + Process a term in a sum of diagrams. + + Args: + term (discopy.tensor.Diagram): The term to process. + """ + matrix = self._get_matrix(term) + + perceval_state, perceval_effect, task = self._process_io( + term, + matrix, + extra + ) + + if task == "amps": + return_type = StateType.AMP + elif task == "probs": + return_type = StateType.PROB + elif task == "single_amp": + return_type = StateType.SINGLE_AMP + elif task == "single_prob": + return_type = StateType.SINGLE_PROB + else: + raise ValueError( + "Invalid task. Allowed values are" + + " 'probs', 'amps', 'single_amp', 'single_prob'. " + + f"Got: {task}" + ) + + # pylint: disable=protected-access + if not matrix._umatrix_is_unitary(): + matrix, perceval_state = self._dilate( + matrix, perceval_state + ) + + sim = self._prepare_simulation( + matrix + ) + + result = self._simulate( + sim, + perceval_state, + perceval_effect, + task + ) + + result = self._post_select_vacuum( + result, + len(term.dom), + matrix.dom - len(term.dom), + task + ) + + output_shape, output_types = self._get_output_params( + term, + task + ) + + array = self._get_array_from_result( + result, + output_shape, + task + ) + + return array, output_types, return_type + + +class PermanentBackend(AbstractBackend): + """ + Backend implementation using optyx' Path module to compute matrix + permanents. + """ + + def eval( + self, + diagram: Diagram, + **extra: Any + ): + """ + Evaluate the diagram using the Permanent/Path backend. + Works only for LO circuits. + + Args: + diagram (Diagram): The diagram to evaluate. + **extra: Additional arguments for the evaluation, + including 'n_photons' for optyx.core.path.Matrix.eval. + """ + + n_photons = extra.get( + "n_photons", + 0 + ) + + def check_creations(matrix): + if ( + len(matrix.creations) == 0 and + n_photons == 0 + ): + raise ValueError( + "The diagram does not include any photon creations. " + + "n_photons must be greater than 0." + ) + + if hasattr( + diagram, + "terms" + ): + result = 0 + dims = [] + for term in diagram.terms: + matrix = self._get_matrix(term) + check_creations(matrix) + result_matrix = matrix.eval( + n_photons=n_photons, + as_tensor=True + ) + result += result_matrix.array + dims.append( + ( + result_matrix.dom.inside, # list + result_matrix.cod.inside # list + ) + ) + else: + matrix = self._get_matrix(diagram) + check_creations(matrix) + result_matrix = matrix.eval(n_photons=n_photons, as_tensor=True) + result = result_matrix.array + dims = [(result_matrix.dom.inside, result_matrix.cod.inside)] + + norm = lambda x: list(x) if isinstance( # noqa: E731 + x, (list, tuple) + ) else [x] + dom_lists, cod_lists = zip(*((norm(d), norm(c)) for d, c in dims)) + max_dom_dims = [max(vals) for vals in zip(*dom_lists)] + max_cod_dims = [max(vals) for vals in zip(*cod_lists)] + return EvalResult( discopy_tensor.Box( "Result", - tensor_diagram.dom**0, - tensor_diagram.cod, - array + discopy_tensor.Dim(*tuple(max_dom_dims)), + discopy_tensor.Dim(*tuple(max_cod_dims)), + result ), output_types=diagram.cod, - state_type=StateType.PROB + state_type=StateType.AMP ) diff --git a/optyx/core/path.py b/optyx/core/path.py index 2fdb68b..4f27061 100644 --- a/optyx/core/path.py +++ b/optyx/core/path.py @@ -434,7 +434,7 @@ def prob_with_perceval( ... (n_photons=1).round(1)== Probabilities[complex]( ... [0.9+0.j, 0.1+0.j, 0.1+0.j, 0.9+0.j], dom=2, cod=2) """ - if not self._umatrix_is_is_unitary(): + if not self._umatrix_is_unitary(): self = self.dilate() circ = self._umatrix_to_perceval_circuit() @@ -491,7 +491,7 @@ def _to_perceval_post_select(self) -> pcvl.PostSelect: ] return pcvl.PostSelect(" & ".join(post_str)) - def _umatrix_is_is_unitary(self) -> bool: + def _umatrix_is_unitary(self) -> bool: m = self.umatrix.array return np.allclose(np.eye(m.shape[0]), m.dot(m.conj().T)) diff --git a/optyx/core/zx.py b/optyx/core/zx.py index 4834aed..00686b0 100644 --- a/optyx/core/zx.py +++ b/optyx/core/zx.py @@ -128,15 +128,15 @@ def move(scan, source, target): swaps = Id(diagram.Bit(len(scan))) return scan, swaps - def make_wires_adjacent(scan, diagram, inputs): + def make_wires_adjacent(scan, dgrm, inputs): if not inputs: - return scan, diagram, len(scan) + return scan, dgrm, len(scan) offset = scan.index(inputs[0]) for i, _ in enumerate(inputs[1:]): source, target = scan.index(inputs[i + 1]), offset + i + 1 scan, swaps = move(scan, source, target) - diagram = diagram >> swaps - return scan, diagram, offset + dgrm = dgrm >> swaps + return scan, dgrm, offset missing_boundary = any( graph.type(node) == VertexType.BOUNDARY # noqa: E721 @@ -238,11 +238,17 @@ def to_pyzx(self): graph.set_inputs(graph.inputs() + (node,)) graph.set_position(node, i, 0) for row, (box, offset) in enumerate(zip(self.boxes, self.offsets)): - if isinstance(box, Spider): - node = graph.add_vertex( - (VertexType.Z if isinstance(box, Z) else VertexType.X), - phase=box.phase * 2 if box.phase else None, - ) + if isinstance(box, diagram.Spider): + if isinstance(box, Spider): + node = graph.add_vertex( + (VertexType.Z if isinstance(box, Z) else VertexType.X), + phase=box.phase * 2 if box.phase else None, + ) + else: + node = graph.add_vertex( + VertexType.Z, + phase=box.phase * 2 if box.phase else None, + ) graph.set_position(node, offset, row + 1) for i, _ in enumerate(box.dom): source, hadamard = scan[offset + i] @@ -264,21 +270,6 @@ def to_pyzx(self): elif box == H: node, hadamard = scan[offset] scan[offset] = (node, not hadamard) - elif isinstance(box, diagram.Spider): - node = graph.add_vertex( - VertexType.Z, - phase=box.phase * 2 if box.phase else None, - ) - graph.set_position(node, offset, row + 1) - for i, _ in enumerate(box.dom): - source, hadamard = scan[offset + i] - etype = EdgeType.HADAMARD if hadamard else EdgeType.SIMPLE - graph.add_edge((source, node), etype) - scan = ( - scan[:offset] - + len(box.cod) * [(node, False)] - + scan[offset + len(box.dom):] - ) else: raise NotImplementedError for i, _ in enumerate(self.cod): diff --git a/optyx/photonic.py b/optyx/photonic.py index 8f3e17c..375900f 100644 --- a/optyx/photonic.py +++ b/optyx/photonic.py @@ -262,7 +262,7 @@ from optyx.classical import ClassicalFunction, DiscardMode from optyx.utils.utils import matrix_to_zw -from optyx import ( +from optyx.core.channel import ( bit, mode, qmode, diff --git a/optyx/qubits.py b/optyx/qubits.py index 1700f68..1a7deef 100644 --- a/optyx/qubits.py +++ b/optyx/qubits.py @@ -311,7 +311,7 @@ diagram, zx ) -from optyx import ( +from optyx.core.channel import ( bit, qubit, Measure as MeasureChannel, diff --git a/test/test_backends.py b/test/test_backends.py index 02ef527..040abe5 100644 --- a/test/test_backends.py +++ b/test/test_backends.py @@ -1,11 +1,33 @@ import pytest -from optyx import photonic, qubits, classical +from optyx import photonic, qubits, classical, mode, qubit, qmode, bit from cotengra import ReusableHyperCompressedOptimizer -from optyx.core.backends import QuimbBackend, PercevalBackend, DiscopyBackend +from optyx.core.backends import ( + QuimbBackend, + PercevalBackend, + DiscopyBackend, + EvalResult, + StateType, + PermanentBackend, + PercevalEvalConfig +) import numpy as np import math from itertools import chain import perceval as pcvl +import discopy.tensor as discopy_tensor + +unitary_circuit = photonic.BS +non_unitary_circuit = ( + photonic.MZI(0.23, 0.51) >> + photonic.NumOp() @ photonic.NumOp() +) + +bs_state = lambda nmodes: tuple(1 for _ in range(nmodes)) + +def _compare_prob_for_outcome(diagram, outcome): + res_q = diagram.eval() + d = res_q.prob_dist() + return float(d.get(outcome, 0.0)) @pytest.mark.skip(reason="Helper function for testing") def chip_mzi(w, l): @@ -114,7 +136,8 @@ class TestPercevalBackend: def test_pure_circuit(self, circuit): backend = PercevalBackend() perceval_state = pcvl.BasicState(get_state(circuit)) - result_perceval = circuit.eval(backend, perceval_state=perceval_state) + config = PercevalEvalConfig(state=perceval_state) + result_perceval = circuit.eval(backend, config=config) state = photonic.Create(*get_state(circuit)) diagram = state >> circuit @@ -125,6 +148,14 @@ def test_pure_circuit(self, circuit): result_perceval.prob_dist(), ) + config = PercevalEvalConfig(state=perceval_state, task="amps") + result_perceval = circuit.eval(backend, config=config) + + assert dict_allclose( + result_quimb.amplitudes(), + result_perceval.amplitudes(), + ) + class TestDiscopyBackend: # compare with matrix.probs @pytest.mark.parametrize("circuit", PURE_CIRCUITS_TO_TEST) @@ -221,6 +252,159 @@ def test_probs_sum_to_one(self, circuit): total_prob = sum(prob_dist.values()) assert math.isclose(total_prob, 1.0) + def test_prob_lookup_and_missing_key(self): + v = np.array([1/np.sqrt(2), 1j/np.sqrt(2)], dtype=complex) + box = discopy_tensor.Tensor(v, discopy_tensor.Dim(1), discopy_tensor.Dim(2)) + ev = EvalResult(_tensor=box, output_types=(mode,), state_type=StateType.AMP) + + p = ev.prob_dist() + assert ev.single_prob((0,)) == pytest.approx(p[(0,)], rel=1e-12) + assert ev.single_prob((1,)) == pytest.approx(p[(1,)], rel=1e-12) + assert ev.single_prob((2,)) == 0.0 + + def test_density_matrix_from_amp_is_outer_product(self): + v = np.array([1/np.sqrt(3), np.sqrt(2/3)], dtype=complex) + box = discopy_tensor.Tensor(v, discopy_tensor.Dim(1), discopy_tensor.Dim(2)) + ev = EvalResult(_tensor=box, output_types=(mode,), state_type=StateType.AMP) + dm = ev.density_matrix + assert dm.shape == (2, 2) + assert np.allclose(dm, np.outer(v.conj(), v), atol=1e-12) + + def test_mixed_interleaved_partial_trace_numeric_hygiene(self): + dm = np.zeros((2, 2, 2), dtype=complex) + dm[0, 0, 0] = 0.499999999999 + 1e-14j + dm[0, 1, 1] = 0.100000000001 - 1e-14j + dm[1, 0, 0] = 0.2000000000005 + 1e-14j + dm[1, 1, 1] = 0.2000000000005 - 1e-14j + #off-diagonals that should be ignored by the diagonal check + dm[0, 0, 1] = -1e-13 + dm[1, 1, 0] = -1e-13 + + box = discopy_tensor.Tensor(dm, discopy_tensor.Dim(1), discopy_tensor.Dim(*dm.shape)) + #1st axis measured (mode), 2ns is quantum (unmeasured, bra/ket) + ev = EvalResult(_tensor=box, output_types=(mode, object()), state_type=StateType.DM) + probs = ev.prob_dist() + + assert pytest.approx(probs[(0,)], rel=1e-12) == 0.6 + assert pytest.approx(probs[(1,)], rel=1e-12) == 0.4 + assert pytest.approx(sum(probs.values()), rel=1e-6) == 1.0 + + def test_prob_branch_rounding_and_sum(self): + P = np.array([[1/3, 1/6], + [1/6, 1/3]], dtype=float) + box = discopy_tensor.Tensor(P, discopy_tensor.Dim(1), discopy_tensor.Dim(*P.shape)) + ev = EvalResult(_tensor=box, output_types=(mode, mode), state_type=StateType.PROB) + + d = ev.prob_dist(round_digits=4) + assert pytest.approx(sum(d.values()), rel=1e-12) == 1.0 + print(d) + assert d[(0, 0)] == round(float(P[0, 0]), 4) + assert d[(1, 1)] == round(float(P[1, 1]), 4) + + def test_density_matrix_from_amp_complex_phase(self): + v = np.array([np.exp(1j*0.2)/np.sqrt(3), + np.exp(-1j*0.3)*np.sqrt(2/3)], dtype=complex) + box = discopy_tensor.Tensor(v, discopy_tensor.Dim(1), discopy_tensor.Dim(2)) + ev = EvalResult(_tensor=box, output_types=(mode,), state_type=StateType.AMP) + dm = ev.density_matrix + assert dm.shape == (2, 2) + assert np.allclose(dm, dm.conj().T, atol=1e-12) + assert np.allclose(dm, np.outer(v.conj(), v), atol=1e-12) + eig = np.linalg.eigvalsh(dm) + assert np.all(eig >= -1e-12) + + def test_density_matrix_passthrough_for_DM(self): + dm = np.array([[0.7, 0.1j], + [-0.1j, 0.3]], dtype=complex) + box = discopy_tensor.Tensor(dm, discopy_tensor.Dim(1), discopy_tensor.Dim(*dm.shape)) + ev = EvalResult(_tensor=box, output_types=(mode,), state_type=StateType.DM) + out = ev.density_matrix + assert out.shape == (2, 2) + assert np.allclose(out, dm, atol=1e-12) + + def test_mixed_simple_single_measured_single_unmeasured(self): + dm = np.zeros((2, 2, 2), dtype=complex) + dm[0, 0, 0] = 0.6 + 1e-14j + dm[1, 1, 1] = 0.4 - 1e-14j + dm[0, 1, 0] = -1e-13 + dm[1, 0, 1] = -1e-13 + + box = discopy_tensor.Tensor(dm, discopy_tensor.Dim(1), discopy_tensor.Dim(*dm.shape)) + ev = EvalResult(_tensor=box, output_types=(mode, qubit), state_type=StateType.DM) + probs = ev.prob_dist() + assert pytest.approx(probs[(0,)], rel=1e-12) == 0.6 + assert pytest.approx(probs[(1,)], rel=1e-12) == 0.4 + assert pytest.approx(sum(probs.values()), rel=1e-12) == 1.0 + + def test_mixed_two_measured_one_unmeasured_matches_manual_trace(self): + P = np.array([[0.1, 0.2], + [0.3, 0.4]], dtype=float) + W = np.array([0.6, 0.4], dtype=float) + dm = np.zeros((2, 2, 2, 2), dtype=complex) + + for k in range(2): + for l in range(2): + dm[k, l, 0, 0] = P[k, l] * W[0] + dm[k, l, 1, 1] = P[k, l] * W[1] + dm[0, 0, 0, 1] = 1e-13 + dm[1, 1, 1, 0] = -1e-13 + + box = discopy_tensor.Tensor(dm, discopy_tensor.Dim(1), discopy_tensor.Dim(*dm.shape)) + ev = EvalResult(_tensor=box, output_types=(mode, bit, qubit), state_type=StateType.DM) + + probs = ev.prob_dist() + manual = {(k, l): (dm[k, l, 0, 0].real + dm[k, l, 1, 1].real) for k in range(2) for l in range(2)} + Z = sum(manual.values()) + manual = {k: v / Z for k, v in manual.items()} + + for k in range(2): + for l in range(2): + assert pytest.approx(probs[(k, l)], rel=1e-12) == manual[(k, l)] + assert pytest.approx(sum(probs.values()), rel=1e-12) == 1.0 + + def test_mixed_round_digits_applied_before_normalization(self): + dm = np.zeros((2, 2, 2), dtype=complex) + dm[0, 0, 0] = 0.333333333333 + dm[1, 1, 1] = 0.666666666667 + box = discopy_tensor.Tensor(dm, discopy_tensor.Dim(1), discopy_tensor.Dim(*dm.shape)) + ev = EvalResult(_tensor=box, output_types=(mode, qubit), state_type=StateType.DM) + d = ev.prob_dist(round_digits=6) + assert pytest.approx(sum(d.values()), rel=1e-12) == 1.0 + assert d[(0,)] == pytest.approx(round(0.333333333333, 6) / (round(0.333333333333, 6) + round(0.666666666667, 6)), rel=1e-12) + assert d[(1,)] == pytest.approx(round(0.666666666667, 6) / (round(0.333333333333, 6) + round(0.666666666667, 6)), rel=1e-12) + + def test_mixed_requires_at_least_one_measured_type(self): + dm = np.eye(4, dtype=complex).reshape(2, 2, 2, 2) / 4.0 + box = discopy_tensor.Tensor(dm, discopy_tensor.Dim(1), discopy_tensor.Dim(*dm.shape)) + ev = EvalResult(_tensor=box, output_types=(qubit, qmode), state_type=StateType.DM) + with pytest.raises(ValueError): + _ = ev.prob_dist() + + def test_mixed_two_unmeasured_pairs_partial_trace(self): + + dm = np.zeros((2, 2, 2, 2, 2), dtype=float) + + w_m = {0: 0.55, 1: 0.45} + w_u1 = np.array([0.6, 0.4]) + w_u2 = np.array([0.7, 0.3]) + + for m in (0, 1): + for r1 in (0, 1): + for r2 in (0, 1): + dm[m, r1, r1, r2, r2] = w_m[m] * w_u1[r1] * w_u2[r2] + + dm[0, 0, 1, 0, 0] = 1e-13 + dm[1, 1, 0, 1, 0] = -1e-13 + + box = discopy_tensor.Tensor(dm, discopy_tensor.Dim(1), discopy_tensor.Dim(*dm.shape)) + ev = EvalResult(_tensor=box, output_types=(mode, object(), object()), state_type=StateType.DM) + probs = ev.prob_dist() + + assert pytest.approx(sum(probs.values()), rel=1e-12) == 1.0 + assert pytest.approx(probs[(0,)], rel=1e-12) == 0.55 + assert pytest.approx(probs[(1,)], rel=1e-12) == 0.45 + + class TestExceptions: # EvalResult # density matrix/amps/prob from not a state @@ -240,7 +424,8 @@ def test_eval_result_not_density_matrix(self, circuit): with pytest.raises(TypeError): backend = PercevalBackend() perceval_state = pcvl.BasicState(get_state(circuit)) - result_perceval = circuit.eval(backend, perceval_state=perceval_state) + config = PercevalEvalConfig(state=perceval_state) + result_perceval = circuit.eval(backend, config=config) dm = result_perceval.density_matrix @@ -265,4 +450,289 @@ def test_abstract_backend_not_lo(self, circuit): with pytest.raises(AssertionError): backend = PercevalBackend() perceval_state = pcvl.BasicState(get_state(circuit)) - result_perceval = circuit.eval(backend, perceval_state=perceval_state) \ No newline at end of file + config = PercevalEvalConfig(state=perceval_state) + result_perceval = circuit.eval(backend, config=config) + + def test_mixed_all_off_diagonal_mass_raises_zero_total(self): + dm = np.zeros((2, 2, 2), dtype=float) + dm[0, 0, 1] = 0.6 + dm[1, 1, 0] = 0.4 + + box = discopy_tensor.Tensor(dm, discopy_tensor.Dim(1), discopy_tensor.Dim(*dm.shape)) + ev = EvalResult(_tensor=box, output_types=(mode, object()), state_type=StateType.DM) + + with pytest.raises(ValueError): + _ = ev.prob_dist() + + @pytest.mark.parametrize( + "circuit", + MIXED_CIRCUITS_TO_TEST + CIRCUITS_WITH_DISCARDS_TO_TEST, + ) + def test_non_lo_circuits_raise_notimplemented(self, circuit): + backend = PermanentBackend() + with pytest.raises(AssertionError): + _ = circuit.eval(backend) + + @pytest.mark.parametrize("circuit", PURE_CIRCUITS_TO_TEST) + def test_zero_photons_no_creations_raises_valueerror(self, circuit): + backend = PermanentBackend() + with pytest.raises(ValueError): + _ = circuit.eval(backend, n_photons=0) + +class TestPermanentBackendVsQuimb: + @pytest.mark.parametrize("circuit", PURE_CIRCUITS_TO_TEST) + def test_permanent_amp_matches_quimb_with_create(self, circuit): + state_occ = get_state(circuit) + state = photonic.Create(*state_occ) + diagram = state >> circuit + + result_quimb = diagram.eval() + + backend = PermanentBackend() + result_perm = diagram.eval(backend) + + assert dict_allclose(result_perm.amplitudes(), result_quimb.amplitudes()) + assert dict_allclose(result_perm.prob_dist(), result_quimb.prob_dist()) + assert np.allclose(result_perm.tensor.array, result_quimb.tensor.array, atol=1e-12) + + @pytest.mark.parametrize("circuit", PURE_CIRCUITS_TO_TEST) + def test_permanent_prob_matches_quimb_with_create(self, circuit): + state_occ = get_state(circuit) + state = photonic.Create(*state_occ) + diagram = state >> circuit + + result_quimb = diagram.eval() + + backend = PermanentBackend() + result_perm = diagram.eval(backend, return_kind="prob") + + assert dict_allclose(result_perm.prob_dist(), result_quimb.prob_dist()) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_prob_nothing_raises_no_state(circuit): + backend = PercevalBackend() + with pytest.raises(ValueError): + _ = circuit.eval(backend, task="probs") + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_prob_state_compare_with_quimb(circuit): + print(circuit) + nmodes = len(circuit.dom) + perceval_state = pcvl.BasicState(bs_state(nmodes)) + backend = PercevalBackend() + config = PercevalEvalConfig(state=perceval_state) + res_pcvl = circuit.eval(backend, config=config) + + ref = photonic.Create(*bs_state(nmodes)) >> circuit + d_ref = ref.eval().prob_dist() + d_pcvl = res_pcvl.prob_dist() + + keys = set(d_ref) | set(d_pcvl) + for k in keys: + assert math.isclose(d_ref.get(k, 0.0), d_pcvl.get(k, 0.0), rel_tol=1e-9, abs_tol=1e-12) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_prob_effect_only_errors_no_state(circuit): + nmodes = len(circuit.dom) + backend = PercevalBackend() + config = PercevalEvalConfig(effect=pcvl.BasicState([2] + [0]*(nmodes-1))) + with pytest.raises(ValueError): + _ = circuit.eval( + backend, + config=config + ) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_prob_diagram_only_effect_errors_no_state(circuit): + nmodes = len(circuit.cod) + diagram = circuit >> classical.Select(*([0]*(nmodes-1) + [2])) + backend = PercevalBackend() + with pytest.raises(ValueError): + _ = diagram.eval(backend) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_single_prob_errors_without_state_and_effect(circuit): + backend = PercevalBackend() + config = PercevalEvalConfig(task="single_prob") + with pytest.raises(ValueError): + _ = circuit.eval(backend, config=config) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_single_prob_errors_with_state_only(circuit): + backend = PercevalBackend() + nmodes = len(circuit.dom) + config = PercevalEvalConfig(state=pcvl.BasicState(bs_state(nmodes)), task="single_prob") + with pytest.raises(ValueError): + _ = circuit.eval(backend, config=config) + + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_single_prob_errors_with_effect_only(circuit): + backend = PercevalBackend() + nmodes = len(circuit.dom) + config = PercevalEvalConfig(effect=pcvl.BasicState([2] + [0]*(nmodes-1)), task="single_prob") + with pytest.raises(ValueError): + _ = circuit.eval(backend, config=config) + + backend = PercevalBackend() + nmodes = len(circuit.dom) + config = PercevalEvalConfig(effect=pcvl.BasicState([2] + [0]*(nmodes-1)), task="single_prob") + with pytest.raises(ValueError): + _ = circuit.eval(backend, config=config + ) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_single_prob_matches_quimb_selected_outcome(circuit): + nmodes = len(circuit.dom) + state_occ = bs_state(nmodes) + outcome = (2,) + (0,)*(nmodes-1) if nmodes >= 1 else tuple() + + backend = PercevalBackend() + config = PercevalEvalConfig( + state=pcvl.BasicState(state_occ), + effect=pcvl.BasicState(outcome), + task="single_prob" + ) + res = circuit.eval(backend, config=config) + + assert res.state_type is StateType.SINGLE_PROB + p_pcvl = float(res.single_prob(outcome)) + + ref = photonic.Create(*state_occ) >> circuit + p_ref = _compare_prob_for_outcome(ref, outcome) + assert math.isclose(p_pcvl, p_ref, rel_tol=1e-9, abs_tol=1e-12) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_amp_nothing_raises_no_state(circuit): + backend = PercevalBackend() + config = PercevalEvalConfig(task="amps") + with pytest.raises(ValueError): + _ = circuit.eval(backend, config=config) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_amp_state_compare_with_quimb(circuit): + + nmodes = len(circuit.dom) + state_occ = bs_state(nmodes) + + backend = PercevalBackend() + config = PercevalEvalConfig( + state=pcvl.BasicState(state_occ), + task="amps" + ) + res_pcvl = circuit.eval( + backend, + config=config + ) + + # reference + ref = photonic.Create(*state_occ) >> circuit + d_ref = ref.eval().amplitudes() + d_pcvl = res_pcvl.amplitudes() + + keys = set(d_ref) | set(d_pcvl) + for k in keys: + assert complex(d_ref.get(k, 0.0)) == pytest.approx(d_pcvl.get(k, 0.0), rel=1e-9, abs=1e-12) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_amp_effect_only_errors_no_state(circuit): + backend = PercevalBackend() + nmodes = len(circuit.dom) + config = PercevalEvalConfig( + effect=pcvl.BasicState([2] + [0]*(nmodes-1)), + task="amps" + ) + with pytest.raises(ValueError): + _ = circuit.eval( + backend, + config=config + ) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_single_amp_errors_without_state_and_effect(circuit): + backend = PercevalBackend() + config = PercevalEvalConfig(task="single_amp") + with pytest.raises(ValueError): + _ = circuit.eval(backend, config=config) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_single_amp_errors_with_state_only(circuit): + backend = PercevalBackend() + nmodes = len(circuit.dom) + config = PercevalEvalConfig( + state=pcvl.BasicState(bs_state(nmodes)), + task="single_amp" + ) + with pytest.raises(ValueError): + _ = circuit.eval(backend, config=config) + + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_single_amp_errors_with_effect_only(circuit): + backend = PercevalBackend() + nmodes = len(circuit.dom) + config = PercevalEvalConfig( + effect=pcvl.BasicState([2] + [0]*(nmodes-1)), + task="single_amp" + ) + with pytest.raises(ValueError): + _ = circuit.eval(backend, config=config + ) + +@pytest.mark.parametrize("circuit", [unitary_circuit, non_unitary_circuit]) +def test_single_amp_matches_quimb_selected_outcome(circuit): + + nmodes = len(circuit.dom) + state_occ = bs_state(nmodes) + outcome = (2,) + (0,)*(nmodes-1) if nmodes >= 1 else tuple() + + backend = PercevalBackend() + config = PercevalEvalConfig( + state=pcvl.BasicState(state_occ), + effect=pcvl.BasicState(outcome), + task="single_amp" + ) + res = circuit.eval( + backend, + config=config + ) + assert res.state_type is StateType.SINGLE_AMP + a_pcvl = res.single_amplitude(outcome) + + ref = photonic.Create(*state_occ) >> circuit + a_ref = ref.eval().amplitudes().get(outcome, 0.0) + assert a_pcvl == pytest.approx(a_ref, rel=1e-9, abs=1e-12) + +def test_sums(): + diagram = ( + photonic.Create(1, 1, 1, 1) >> + photonic.MZI(0.5, 0.5) @ photonic.MZI(0.5, 0.5) >> + photonic.qmode @ photonic.MZI(0.5, 0.5) @ photonic.qmode + ) + ( + photonic.Create(0, 2, 0, 2) >> + photonic.MZI(0.5, 0.5) @ photonic.MZI(0.5, 0.5) >> + photonic.qmode @ photonic.MZI(0.5, 0.5) @ photonic.qmode + ) + + res_perceval = diagram.eval(PercevalBackend()).prob_dist() + res_permanent = diagram.eval(PermanentBackend()).prob_dist() + res_quimb = diagram.eval().prob_dist() + print(res_perceval) + print(res_quimb) + keys = set(res_perceval) | set(res_quimb) + for k in keys: + assert math.isclose( + res_perceval.get(k, 0.0), + res_quimb.get(k, 0.0), + rel_tol=1e-9, + abs_tol=1e-12 + ) + + keys = set(res_permanent) | set(res_quimb) + for k in keys: + assert math.isclose( + res_permanent.get(k, 0.0), + res_quimb.get(k, 0.0), + rel_tol=1e-9, + abs_tol=1e-12 + )