Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions libs/core/include/cuda-qx/core/kwargs_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,54 @@ inline heterogeneous_map hetMapFromKwargs(const py::kwargs &kwargs) {
} else if (py::isinstance<py::array>(value)) {
py::array np_array = value.cast<py::array>();
py::buffer_info info = np_array.request();
auto insert_vector = [&](auto type_tag) {
using T = decltype(type_tag);
std::vector<T> vec(static_cast<T *>(info.ptr),
static_cast<T *>(info.ptr) + info.size);
result.insert(key, std::move(vec));
};
if (info.format == py::format_descriptor<double>::format()) {
insert_vector(double{});
} else if (info.format == py::format_descriptor<float>::format()) {
insert_vector(float{});
} else if (info.format == py::format_descriptor<int>::format()) {
insert_vector(int{});
} else if (info.format == py::format_descriptor<uint8_t>::format()) {
insert_vector(uint8_t{});
if (info.ndim >= 2) {
if (info.strides[0] == static_cast<py::ssize_t>(info.itemsize)) {
throw std::runtime_error(
"Array in kwargs must be in row-major order, but "
"column-major order was detected.");
}
std::vector<std::size_t> shape(static_cast<std::size_t>(info.ndim),
std::size_t(0));
for (py::ssize_t d = 0; d < info.ndim; d++)
shape[d] = static_cast<std::size_t>(info.shape[d]);

auto insert_tensor = [&](auto type_tag) {
using T = decltype(type_tag);
cudaqx::tensor<T> ten(shape);
ten.borrow(static_cast<T *>(info.ptr), shape);
result.insert(key, std::move(ten));
};
if (info.format == py::format_descriptor<double>::format()) {
insert_tensor(double{});
} else if (info.format == py::format_descriptor<float>::format()) {
insert_tensor(float{});
} else if (info.format == py::format_descriptor<int>::format()) {
insert_tensor(int{});
} else if (info.format == py::format_descriptor<uint8_t>::format()) {
insert_tensor(uint8_t{});
} else {
throw std::runtime_error("Unsupported array data type in kwargs.");
}
} else {
throw std::runtime_error("Unsupported array data type in kwargs.");
// 1D array: keep as flattened vector for backward compatibility
// (e.g. error_rate_vec used by decoders).
auto insert_vector = [&](auto type_tag) {
using T = decltype(type_tag);
std::vector<T> vec(static_cast<T *>(info.ptr),
static_cast<T *>(info.ptr) + info.size);
result.insert(key, std::move(vec));
};
if (info.format == py::format_descriptor<double>::format()) {
insert_vector(double{});
} else if (info.format == py::format_descriptor<float>::format()) {
insert_vector(float{});
} else if (info.format == py::format_descriptor<int>::format()) {
insert_vector(int{});
} else if (info.format == py::format_descriptor<uint8_t>::format()) {
insert_vector(uint8_t{});
} else {
throw std::runtime_error("Unsupported array data type in kwargs.");
}
}
} else {
throw std::runtime_error(
Expand Down
68 changes: 57 additions & 11 deletions libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class pymatching : public decoder {
// efficient.
std::map<std::pair<int64_t, int64_t>, size_t> edge2col_idx;

bool decode_to_observables = false;

// Helper function to make a canonical edge from two nodes.
std::pair<int64_t, int64_t> make_canonical_edge(int64_t node1,
int64_t node2) {
Expand Down Expand Up @@ -77,10 +79,39 @@ class pymatching : public decoder {
}
}

std::vector<std::vector<size_t>> errs2observables(block_size);
if (params.contains("O")) {
auto O = params.get<cudaqx::tensor<uint8_t>>("O");
if (O.rank() != 2) {
throw std::runtime_error(
"O must be a 2-dimensional tensor (num_observables x block_size)");
}
const size_t num_observables = O.shape()[0];
if (O.shape()[1] != block_size) {
throw std::runtime_error(
"O must be of shape (num_observables, block_size); got second "
"dimension " +
std::to_string(O.shape()[1]) + ", block_size " +
std::to_string(block_size));
}
std::vector<std::vector<uint32_t>> O_sparse;
for (size_t i = 0; i < num_observables; i++) {
O_sparse.emplace_back();
auto *row = &O.at({i, 0});
for (size_t j = 0; j < block_size; j++) {
if (row[j] > 0) {
O_sparse.back().push_back(static_cast<uint32_t>(j));
errs2observables[j].push_back(static_cast<uint32_t>(i));
}
}
}
this->set_O_sparse(O_sparse);
decode_to_observables = true;
}

user_graph = pm::UserGraph(H.shape()[0]);

auto sparse = cudaq::qec::dense_to_sparse(H);
std::vector<size_t> observables;
std::size_t col_idx = 0;
for (auto &col : sparse) {
double weight = 1.0;
Expand All @@ -90,12 +121,14 @@ class pymatching : public decoder {
}
if (col.size() == 2) {
edge2col_idx[make_canonical_edge(col[0], col[1])] = col_idx;
user_graph.add_or_merge_edge(col[0], col[1], observables, weight, 0.0,
user_graph.add_or_merge_edge(col[0], col[1],
errs2observables.at(col_idx), weight, 0.0,
merge_strategy_enum);
} else if (col.size() == 1) {
edge2col_idx[make_canonical_edge(col[0], -1)] = col_idx;
user_graph.add_or_merge_boundary_edge(col[0], observables, weight, 0.0,
merge_strategy_enum);
user_graph.add_or_merge_boundary_edge(col[0],
errs2observables.at(col_idx),
weight, 0.0, merge_strategy_enum);
} else {
throw std::runtime_error(
"Invalid column in H: " + std::to_string(col_idx) + " has " +
Expand All @@ -119,13 +152,26 @@ class pymatching : public decoder {
for (size_t i = 0; i < syndrome.size(); i++)
if (syndrome[i] > 0.5)
detection_events.push_back(i);
pm::decode_detection_events_to_edges(mwpm, detection_events, edges);
// Loop over the edge pairs
assert(edges.size() % 2 == 0);
for (size_t i = 0; i < edges.size(); i += 2) {
auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1));
auto col_idx = edge2col_idx.at(edge);
result.result[col_idx] = 1.0;
if (decode_to_observables) {
assert(O_sparse.size() == mwpm.flooder.graph.num_observables);
pm::total_weight_int weight = 0;
std::vector<uint8_t> obs(mwpm.flooder.graph.num_observables, 0);
obs.resize(mwpm.flooder.graph.num_observables);
pm::decode_detection_events(mwpm, detection_events, obs.data(), weight,
/*edge_correlations=*/false);
result.result.resize(mwpm.flooder.graph.num_observables);
for (size_t i = 0; i < mwpm.flooder.graph.num_observables; i++) {
result.result[i] = static_cast<float_t>(obs[i]);
}
} else {
pm::decode_detection_events_to_edges(mwpm, detection_events, edges);
// Loop over the edge pairs to reconstruct errors.
assert(edges.size() % 2 == 0);
for (size_t i = 0; i < edges.size(); i += 2) {
auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1));
auto col_idx = edge2col_idx.at(edge);
result.result[col_idx] = 1.0;
}
}
// An exception is thrown if no matching solution is found, so we can just
// set converged to true.
Expand Down
45 changes: 44 additions & 1 deletion libs/qec/python/tests/test_dem.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ============================================================================ #
# Copyright (c) 2024 - 2025 NVIDIA Corporation & Affiliates. #
# Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
Expand Down Expand Up @@ -275,6 +275,49 @@ def test_decoding_from_surface_code_dem_from_memory_circuit(
assert nLogicalErrorsWithDecoding < nLogicalErrorsWithoutDecoding


def test_pymatching_decode_to_observable_surface_code_dem():
"""Test PyMatching with O (observables) matrix: decoder returns observable
flips directly.cpp)."""
cudaq.set_random_seed(13)
code = qec.get_code('surface_code', distance=5)
Lz = code.get_observables_z()
p = 0.003
noise = cudaq.NoiseModel()
noise.add_all_qubit_channel("x", cudaq.Depolarization2(p), 1)
statePrep = qec.operation.prep0
nRounds = 5
nShots = 2000

syndromes, data = qec.sample_memory_circuit(code, statePrep, nShots,
nRounds, noise)

logical_measurements = (Lz @ data.transpose()) % 2
logical_measurements = logical_measurements.flatten()

syndromes = syndromes.reshape((nShots, nRounds, -1))
syndromes = syndromes[:, :, :syndromes.shape[2] // 2]
syndromes = syndromes.reshape((nShots, -1))

dem = qec.z_dem_from_memory_circuit(code, statePrep, nRounds, noise)

decoder = qec.get_decoder(
'pymatching',
dem.detector_error_matrix,
O=dem.observables_flips_matrix,
error_rate_vec=np.array(dem.error_rates),
)

dr = decoder.decode_batch(syndromes)
# With decode_to_observables=True, each e.result is observable flips
# (length num_observables), not error predictions.
obs_per_shot = np.array([e.result for e in dr], dtype=np.float64)
data_predictions = np.round(obs_per_shot).astype(np.uint8).T

nLogicalErrorsWithoutDecoding = np.sum(logical_measurements)
nLogicalErrorsWithDecoding = np.sum(data_predictions ^ logical_measurements)
assert nLogicalErrorsWithDecoding < nLogicalErrorsWithoutDecoding


def test_pcm_extend_to_n_rounds():
# This test independently compares the functionality of dem_from_memory_circuit
# (of two different numbers of rounds) to pcm_extend_to_n_rounds.
Expand Down