diff --git a/libs/core/include/cuda-qx/core/kwargs_utils.h b/libs/core/include/cuda-qx/core/kwargs_utils.h index ce50f3cc..f0b0048b 100644 --- a/libs/core/include/cuda-qx/core/kwargs_utils.h +++ b/libs/core/include/cuda-qx/core/kwargs_utils.h @@ -75,22 +75,54 @@ inline heterogeneous_map hetMapFromKwargs(const py::kwargs &kwargs) { } else if (py::isinstance(value)) { py::array np_array = value.cast(); py::buffer_info info = np_array.request(); - auto insert_vector = [&](auto type_tag) { - using T = decltype(type_tag); - std::vector vec(static_cast(info.ptr), - static_cast(info.ptr) + info.size); - result.insert(key, std::move(vec)); - }; - if (info.format == py::format_descriptor::format()) { - insert_vector(double{}); - } else if (info.format == py::format_descriptor::format()) { - insert_vector(float{}); - } else if (info.format == py::format_descriptor::format()) { - insert_vector(int{}); - } else if (info.format == py::format_descriptor::format()) { - insert_vector(uint8_t{}); + if (info.ndim >= 2) { + if (info.strides[0] == static_cast(info.itemsize)) { + throw std::runtime_error( + "Array in kwargs must be in row-major order, but " + "column-major order was detected."); + } + std::vector shape(static_cast(info.ndim), + std::size_t(0)); + for (py::ssize_t d = 0; d < info.ndim; d++) + shape[d] = static_cast(info.shape[d]); + + auto insert_tensor = [&](auto type_tag) { + using T = decltype(type_tag); + cudaqx::tensor ten(shape); + ten.borrow(static_cast(info.ptr), shape); + result.insert(key, std::move(ten)); + }; + if (info.format == py::format_descriptor::format()) { + insert_tensor(double{}); + } else if (info.format == py::format_descriptor::format()) { + insert_tensor(float{}); + } else if (info.format == py::format_descriptor::format()) { + insert_tensor(int{}); + } else if (info.format == py::format_descriptor::format()) { + insert_tensor(uint8_t{}); + } else { + throw std::runtime_error("Unsupported array data type in kwargs."); + } } else { - throw std::runtime_error("Unsupported array data type in kwargs."); + // 1D array: keep as flattened vector for backward compatibility + // (e.g. error_rate_vec used by decoders). + auto insert_vector = [&](auto type_tag) { + using T = decltype(type_tag); + std::vector vec(static_cast(info.ptr), + static_cast(info.ptr) + info.size); + result.insert(key, std::move(vec)); + }; + if (info.format == py::format_descriptor::format()) { + insert_vector(double{}); + } else if (info.format == py::format_descriptor::format()) { + insert_vector(float{}); + } else if (info.format == py::format_descriptor::format()) { + insert_vector(int{}); + } else if (info.format == py::format_descriptor::format()) { + insert_vector(uint8_t{}); + } else { + throw std::runtime_error("Unsupported array data type in kwargs."); + } } } else { throw std::runtime_error( diff --git a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp index 514adfb6..e976a9ba 100644 --- a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp +++ b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp @@ -31,6 +31,8 @@ class pymatching : public decoder { // efficient. std::map, size_t> edge2col_idx; + bool decode_to_observables = false; + // Helper function to make a canonical edge from two nodes. std::pair make_canonical_edge(int64_t node1, int64_t node2) { @@ -77,10 +79,39 @@ class pymatching : public decoder { } } + std::vector> errs2observables(block_size); + if (params.contains("O")) { + auto O = params.get>("O"); + if (O.rank() != 2) { + throw std::runtime_error( + "O must be a 2-dimensional tensor (num_observables x block_size)"); + } + const size_t num_observables = O.shape()[0]; + if (O.shape()[1] != block_size) { + throw std::runtime_error( + "O must be of shape (num_observables, block_size); got second " + "dimension " + + std::to_string(O.shape()[1]) + ", block_size " + + std::to_string(block_size)); + } + std::vector> O_sparse; + for (size_t i = 0; i < num_observables; i++) { + O_sparse.emplace_back(); + auto *row = &O.at({i, 0}); + for (size_t j = 0; j < block_size; j++) { + if (row[j] > 0) { + O_sparse.back().push_back(static_cast(j)); + errs2observables[j].push_back(static_cast(i)); + } + } + } + this->set_O_sparse(O_sparse); + decode_to_observables = true; + } + user_graph = pm::UserGraph(H.shape()[0]); auto sparse = cudaq::qec::dense_to_sparse(H); - std::vector observables; std::size_t col_idx = 0; for (auto &col : sparse) { double weight = 1.0; @@ -90,12 +121,14 @@ class pymatching : public decoder { } if (col.size() == 2) { edge2col_idx[make_canonical_edge(col[0], col[1])] = col_idx; - user_graph.add_or_merge_edge(col[0], col[1], observables, weight, 0.0, + user_graph.add_or_merge_edge(col[0], col[1], + errs2observables.at(col_idx), weight, 0.0, merge_strategy_enum); } else if (col.size() == 1) { edge2col_idx[make_canonical_edge(col[0], -1)] = col_idx; - user_graph.add_or_merge_boundary_edge(col[0], observables, weight, 0.0, - merge_strategy_enum); + user_graph.add_or_merge_boundary_edge(col[0], + errs2observables.at(col_idx), + weight, 0.0, merge_strategy_enum); } else { throw std::runtime_error( "Invalid column in H: " + std::to_string(col_idx) + " has " + @@ -119,13 +152,26 @@ class pymatching : public decoder { for (size_t i = 0; i < syndrome.size(); i++) if (syndrome[i] > 0.5) detection_events.push_back(i); - pm::decode_detection_events_to_edges(mwpm, detection_events, edges); - // Loop over the edge pairs - assert(edges.size() % 2 == 0); - for (size_t i = 0; i < edges.size(); i += 2) { - auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1)); - auto col_idx = edge2col_idx.at(edge); - result.result[col_idx] = 1.0; + if (decode_to_observables) { + assert(O_sparse.size() == mwpm.flooder.graph.num_observables); + pm::total_weight_int weight = 0; + std::vector obs(mwpm.flooder.graph.num_observables, 0); + obs.resize(mwpm.flooder.graph.num_observables); + pm::decode_detection_events(mwpm, detection_events, obs.data(), weight, + /*edge_correlations=*/false); + result.result.resize(mwpm.flooder.graph.num_observables); + for (size_t i = 0; i < mwpm.flooder.graph.num_observables; i++) { + result.result[i] = static_cast(obs[i]); + } + } else { + pm::decode_detection_events_to_edges(mwpm, detection_events, edges); + // Loop over the edge pairs to reconstruct errors. + assert(edges.size() % 2 == 0); + for (size_t i = 0; i < edges.size(); i += 2) { + auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1)); + auto col_idx = edge2col_idx.at(edge); + result.result[col_idx] = 1.0; + } } // An exception is thrown if no matching solution is found, so we can just // set converged to true. diff --git a/libs/qec/python/tests/test_dem.py b/libs/qec/python/tests/test_dem.py index 48c3ee6d..69f57a77 100644 --- a/libs/qec/python/tests/test_dem.py +++ b/libs/qec/python/tests/test_dem.py @@ -1,5 +1,5 @@ # ============================================================================ # -# Copyright (c) 2024 - 2025 NVIDIA Corporation & Affiliates. # +# Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. # # All rights reserved. # # # # This source code and the accompanying materials are made available under # @@ -275,6 +275,49 @@ def test_decoding_from_surface_code_dem_from_memory_circuit( assert nLogicalErrorsWithDecoding < nLogicalErrorsWithoutDecoding +def test_pymatching_decode_to_observable_surface_code_dem(): + """Test PyMatching with O (observables) matrix: decoder returns observable + flips directly.cpp).""" + cudaq.set_random_seed(13) + code = qec.get_code('surface_code', distance=5) + Lz = code.get_observables_z() + p = 0.003 + noise = cudaq.NoiseModel() + noise.add_all_qubit_channel("x", cudaq.Depolarization2(p), 1) + statePrep = qec.operation.prep0 + nRounds = 5 + nShots = 2000 + + syndromes, data = qec.sample_memory_circuit(code, statePrep, nShots, + nRounds, noise) + + logical_measurements = (Lz @ data.transpose()) % 2 + logical_measurements = logical_measurements.flatten() + + syndromes = syndromes.reshape((nShots, nRounds, -1)) + syndromes = syndromes[:, :, :syndromes.shape[2] // 2] + syndromes = syndromes.reshape((nShots, -1)) + + dem = qec.z_dem_from_memory_circuit(code, statePrep, nRounds, noise) + + decoder = qec.get_decoder( + 'pymatching', + dem.detector_error_matrix, + O=dem.observables_flips_matrix, + error_rate_vec=np.array(dem.error_rates), + ) + + dr = decoder.decode_batch(syndromes) + # With decode_to_observables=True, each e.result is observable flips + # (length num_observables), not error predictions. + obs_per_shot = np.array([e.result for e in dr], dtype=np.float64) + data_predictions = np.round(obs_per_shot).astype(np.uint8).T + + nLogicalErrorsWithoutDecoding = np.sum(logical_measurements) + nLogicalErrorsWithDecoding = np.sum(data_predictions ^ logical_measurements) + assert nLogicalErrorsWithDecoding < nLogicalErrorsWithoutDecoding + + def test_pcm_extend_to_n_rounds(): # This test independently compares the functionality of dem_from_memory_circuit # (of two different numbers of rounds) to pcm_extend_to_n_rounds.