Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/qec/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_library(${LIBRARY_NAME} SHARED
)

add_subdirectory(decoders/plugins/example)
add_subdirectory(decoders/plugins/pymatching)

if(CUDAQ_QEC_BUILD_TRT_DECODER)
add_subdirectory(decoders/plugins/trt_decoder)
Expand Down
86 changes: 86 additions & 0 deletions libs/qec/lib/decoders/plugins/pymatching/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# ============================================================================ #
# Copyright (c) 2025 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

cmake_minimum_required(VERSION 3.28 FATAL_ERROR)

set(MODULE_NAME "cudaq-qec-pymatching")

# External Dependencies
# ==============================================================================

FetchContent_Declare(
pymatching
GIT_REPOSITORY https://github.com/oscarhiggott/PyMatching
GIT_TAG e27b8c6a5f5ba10fb74d4ebb29822f2df5e12bcd # v2.3.1
PATCH_COMMAND sed -i "s/gtest_discover_tests/#gtest_discover_tests/g" CMakeLists.txt
)
FetchContent_MakeAvailable(pymatching)

project(${MODULE_NAME})

# Specify the source file for the plugin
set(PLUGIN_SRC
pymatching.cpp
)

# Create the shared library
add_library(${MODULE_NAME} SHARED ${PLUGIN_SRC})

# Don't export any symbols (specifically from static libs like PyMatching)
# to avoid conflicts with other Python packages.
target_link_options(${MODULE_NAME} PRIVATE "-Wl,--exclude-libs,ALL")

# Set the include directories for dependencies
target_include_directories(${MODULE_NAME}
PUBLIC
${CMAKE_SOURCE_DIR}/libs/qec/include
${CMAKE_SOURCE_DIR}/libs/core/include
)

# Link with required libraries
target_link_libraries(${MODULE_NAME}
PUBLIC
cudaqx-core
cudaq::cudaq-operator
PRIVATE
cudaq::cudaq-common
cudaq-qec
libpymatching
)

set_target_properties(${MODULE_NAME} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib/decoder-plugins
)

# RPATH configuration
# ==============================================================================

if (NOT SKBUILD)
set_target_properties(${MODULE_NAME} PROPERTIES
BUILD_RPATH "$ORIGIN"
INSTALL_RPATH "$ORIGIN:$ORIGIN/.."
)

# Let CMake automatically add paths of linked libraries to the RPATH:
set_target_properties(${MODULE_NAME} PROPERTIES
INSTALL_RPATH_USE_LINK_PATH TRUE)
else()
# CUDA-Q install its libraries in site-packages/lib (or dist-packages/lib)
# Thus, we need the $ORIGIN/../lib
set_target_properties(${MODULE_NAME} PROPERTIES
INSTALL_RPATH "$ORIGIN/../../lib"
)
endif()

# Install
# ==============================================================================

install(TARGETS ${MODULE_NAME}
COMPONENT qec-lib-plugins
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}/decoder-plugins
)
143 changes: 143 additions & 0 deletions libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*******************************************************************************
* Copyright (c) 2022 - 2026 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "pymatching/sparse_blossom/driver/mwpm_decoding.h"
#include "pymatching/sparse_blossom/driver/user_graph.h"
#include "cudaq/qec/decoder.h"
#include "cudaq/qec/pcm_utils.h"
#include <algorithm>
#include <cassert>
#include <map>
#include <vector>

namespace cudaq::qec {

/// @brief This is a wrapper around the PyMatching library that implements the
/// MWPM decoder.
class pymatching : public decoder {
private:
pm::UserGraph user_graph;

// Input parameters
std::vector<double> error_rate_vec;
pm::MERGE_STRATEGY merge_strategy_enum = pm::MERGE_STRATEGY::DISALLOW;

// Map of edge pairs to column indices. This does not seem particularly
// efficient.
std::map<std::pair<int64_t, int64_t>, size_t> edge2col_idx;

// Helper function to make a canonical edge from two nodes.
std::pair<int64_t, int64_t> make_canonical_edge(int64_t node1,
int64_t node2) {
return std::make_pair(std::min(node1, node2), std::max(node1, node2));
}

public:
pymatching(const cudaqx::tensor<uint8_t> &H,
const cudaqx::heterogeneous_map &params)
: decoder(H) {

if (params.contains("error_rate_vec")) {
error_rate_vec = params.get<std::vector<double>>("error_rate_vec");
if (error_rate_vec.size() != block_size) {
throw std::runtime_error("error_rate_vec must be of size block_size");
}
// Validate that the values in the error_rate_vec are between 0 and 0.5.
// Values > 0.5 would have negative LLR, which is not supported by
// PyMatching.
for (auto error_rate : error_rate_vec) {
if (error_rate <= 0.0 || error_rate > 0.5) {
throw std::runtime_error(
"error_rate_vec value is out of range (0, 0.5]");
}
}
}

if (params.contains("merge_strategy")) {
std::string merge_strategy = params.get<std::string>("merge_strategy");
if (merge_strategy == "disallow") {
merge_strategy_enum = pm::MERGE_STRATEGY::DISALLOW;
} else if (merge_strategy == "independent") {
merge_strategy_enum = pm::MERGE_STRATEGY::INDEPENDENT;
} else if (merge_strategy == "smallest_weight") {
merge_strategy_enum = pm::MERGE_STRATEGY::SMALLEST_WEIGHT;
} else if (merge_strategy == "keep_original") {
merge_strategy_enum = pm::MERGE_STRATEGY::KEEP_ORIGINAL;
} else if (merge_strategy == "replace") {
merge_strategy_enum = pm::MERGE_STRATEGY::REPLACE;
} else {
throw std::runtime_error(
"merge_strategy must be one of: disallow, independent, "
"smallest_weight, keep_original, replace");
}
}

user_graph = pm::UserGraph(H.shape()[0]);

auto sparse = cudaq::qec::dense_to_sparse(H);
std::vector<size_t> observables;
std::size_t col_idx = 0;
for (auto &col : sparse) {
double weight = 1.0;
if (col_idx < error_rate_vec.size()) {
weight = -std::log(error_rate_vec[col_idx] /
(1.0 - error_rate_vec[col_idx]));
}
if (col.size() == 2) {
edge2col_idx[make_canonical_edge(col[0], col[1])] = col_idx;
user_graph.add_or_merge_edge(col[0], col[1], observables, weight, 0.0,
merge_strategy_enum);
} else if (col.size() == 1) {
edge2col_idx[make_canonical_edge(col[0], -1)] = col_idx;
user_graph.add_or_merge_boundary_edge(col[0], observables, weight, 0.0,
merge_strategy_enum);
} else {
throw std::runtime_error(
"Invalid column in H: " + std::to_string(col_idx) + " has " +
std::to_string(col.size()) + " ones. Must have 1 or 2 ones.");
}
col_idx++;
}
}

virtual decoder_result decode(const std::vector<float_t> &syndrome) {
decoder_result result{false, std::vector<float_t>(block_size, 0.0)};
auto &mwpm = user_graph.get_mwpm_with_search_graph();
std::vector<int64_t> edges;
std::vector<uint64_t> detection_events;
detection_events.reserve(syndrome.size());
for (size_t i = 0; i < syndrome.size(); i++)
if (syndrome[i] > 0.5)
detection_events.push_back(i);
pm::decode_detection_events_to_edges(mwpm, detection_events, edges);
// Loop over the edge pairs
assert(edges.size() % 2 == 0);
for (size_t i = 0; i < edges.size(); i += 2) {
auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1));
auto col_idx = edge2col_idx.at(edge);
result.result[col_idx] = 1.0;
}
// An exception is thrown if no matching solution is found, so we can just
// set converged to true.
result.converged = true;
return result;
}

virtual ~pymatching() {}

CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
pymatching, static std::unique_ptr<decoder> create(
const cudaqx::tensor<uint8_t> &H,
const cudaqx::heterogeneous_map &params) {
return std::make_unique<pymatching>(H, params);
})
};

CUDAQ_REGISTER_TYPE(pymatching)

} // namespace cudaq::qec
24 changes: 23 additions & 1 deletion libs/qec/python/tests/test_decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ============================================================================ #
# Copyright (c) 2024 - 2025 NVIDIA Corporation & Affiliates. #
# Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
Expand Down Expand Up @@ -339,5 +339,27 @@ def test_single_error_lut_opt_results():
assert "decoding_time" not in result.opt_results # Was set to False


def test_decoder_pymatching_results():
pcm = qec.generate_random_pcm(n_rounds=2,
n_errs_per_round=10,
n_syndromes_per_round=5,
weight=2,
seed=7)
pcm, _ = qec.simplify_pcm(pcm, np.ones(pcm.shape[1]), 10)
# Pick 3 random columns from the PCM and XOR them together to get the
# syndrome.
columns = np.random.choice(pcm.shape[1], 3, replace=False)
syndrome = np.sum(pcm[:, columns], axis=1) % 2
decoder = qec.get_decoder('pymatching', pcm)
print(syndrome)
result = decoder.decode(syndrome)
assert result.converged is True
assert all(isinstance(x, float) for x in result.result)
assert all(0 <= x <= 1 for x in result.result)
actual_errors = np.zeros(pcm.shape[1], dtype=np.uint8)
actual_errors[columns] = 1
assert np.array_equal(result.result, actual_errors)


if __name__ == "__main__":
pytest.main()
9 changes: 6 additions & 3 deletions libs/qec/python/tests/test_sliding_window.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ============================================================================ #
# Copyright (c) 2025 NVIDIA Corporation & Affiliates. #
# Copyright (c) 2025 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
Expand All @@ -20,7 +20,7 @@ def setTarget():
cudaq.set_target(old_target)


@pytest.mark.parametrize("decoder_name", ["single_error_lut"])
@pytest.mark.parametrize("decoder_name", ["single_error_lut", "pymatching"])
@pytest.mark.parametrize("batched", [True, False])
@pytest.mark.parametrize("num_rounds", [5, 10])
@pytest.mark.parametrize("num_windows", [1, 2, 3])
Expand Down Expand Up @@ -62,7 +62,10 @@ def test_sliding_window_1(decoder_name, batched, num_rounds, num_windows):
straddle_end_round=True,
error_rate_vec=np.array(dem.error_rates),
inner_decoder_name=decoder_name,
inner_decoder_params={'dummy_param': 1})
inner_decoder_params={
'dummy_param': 1,
'merge_strategy': 'smallest_weight'
})

if batched:
full_results = full_decoder.decode_batch(syndromes)
Expand Down
1 change: 1 addition & 0 deletions libs/qec/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@ endif()

add_subdirectory(backend-specific)
add_subdirectory(realtime)
add_subdirectory(decoders/pymatching)

29 changes: 29 additions & 0 deletions libs/qec/unittests/decoders/pymatching/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# ============================================================================ #
# Copyright (c) 2025 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

# External Dependencies
# ==============================================================================

set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)

# Bug in GCC 12 leads to spurious warnings (-Wrestrict)
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105329
if (CMAKE_COMPILER_IS_GNUCXX
AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0.0
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 13.0.0)
target_compile_options(gtest PUBLIC --param=evrp-mode=legacy)
endif()
include(GoogleTest)

# ==============================================================================
add_compile_options(-Wno-attributes)

add_executable(test_pymatching test_pymatching.cpp)
target_link_libraries(test_pymatching PRIVATE GTest::gtest_main cudaq-qec cudaq::cudaq)
add_dependencies(CUDAQXQECUnitTests test_pymatching)
gtest_discover_tests(test_pymatching)
Loading