diff --git a/metagraph/src/common/algorithms.hpp b/metagraph/src/common/algorithms.hpp index c203883888..c3879c91b7 100644 --- a/metagraph/src/common/algorithms.hpp +++ b/metagraph/src/common/algorithms.hpp @@ -53,8 +53,7 @@ namespace utils { size_t segment_length) { std::vector mask(array.size(), false); size_t last_occurrence - = std::find(array.data(), array.data() + array.size(), label) - - array.data(); + = std::find(array.begin(), array.end(), label) - array.begin(); for (size_t i = last_occurrence; i < array.size(); ++i) { if (array[i] == label) diff --git a/metagraph/src/graph/representation/canonical_dbg.cpp b/metagraph/src/graph/representation/canonical_dbg.cpp index fdd4bd683e..39b3798001 100644 --- a/metagraph/src/graph/representation/canonical_dbg.cpp +++ b/metagraph/src/graph/representation/canonical_dbg.cpp @@ -64,7 +64,7 @@ ::map_to_nodes_sequentially(std::string_view sequence, path.reserve(sequence.size() - get_k() + 1); if (const auto sshash = std::dynamic_pointer_cast(graph_)) { - sshash->map_to_nodes_with_rc<>(sequence, [&](node_index node, bool orientation) { + sshash->map_to_nodes_with_rc(sequence, [&](node_index node, bool orientation) { callback(node && orientation ? reverse_complement(node) : node); }, terminate); return; @@ -180,7 +180,7 @@ void CanonicalDBG::call_outgoing_kmers(node_index node, } if (const auto sshash = std::dynamic_pointer_cast(graph_)) { - sshash->call_outgoing_kmers_with_rc<>(node, [&](node_index next, char c, bool orientation) { + sshash->call_outgoing_kmers_with_rc(node, [&](node_index next, char c, bool orientation) { callback(orientation ? reverse_complement(next) : next, c); }); return; @@ -273,7 +273,7 @@ void CanonicalDBG::call_incoming_kmers(node_index node, } if (const auto sshash = std::dynamic_pointer_cast(graph_)) { - sshash->call_incoming_kmers_with_rc<>(node, [&](node_index prev, char c, bool orientation) { + sshash->call_incoming_kmers_with_rc(node, [&](node_index prev, char c, bool orientation) { callback(orientation ? reverse_complement(prev) : prev, c); }); return; diff --git a/metagraph/src/graph/representation/hash/dbg_sshash.cpp b/metagraph/src/graph/representation/hash/dbg_sshash.cpp index 9e47d1fe3d..07120a28e2 100644 --- a/metagraph/src/graph/representation/hash/dbg_sshash.cpp +++ b/metagraph/src/graph/representation/hash/dbg_sshash.cpp @@ -5,6 +5,7 @@ #include "common/seq_tools/reverse_complement.hpp" #include "common/threads/threading.hpp" #include "common/logger.hpp" +#include "common/algorithms.hpp" #include "kmer/kmer_extractor.hpp" @@ -99,32 +100,53 @@ void DBGSSHash::add_sequence(std::string_view sequence, throw std::logic_error("adding sequences not supported"); } -template -void DBGSSHash::map_to_nodes_with_rc(std::string_view sequence, - const std::function& callback, - const std::function& terminate) const { - if (terminate() || sequence.size() < k_) +template +void map_to_nodes_with_rc_impl(size_t k, + const Dict &dict, + std::string_view sequence, + const std::function& callback, + const std::function& terminate) { + size_t n = sequence.size(); + if (terminate() || n < k) return; - if (!num_nodes()) { - for (size_t i = 0; i < sequence.size() - k_ + 1 && !terminate(); ++i) { - callback(npos, false); + if (!dict.size()) { + for (size_t i = 0; i + k <= sequence.size() && !terminate(); ++i) { + callback(sshash::lookup_result()); } return; } + using kmer_t = get_kmer_t; + + std::vector invalid_char(n); + for (size_t i = 0; i < n; ++i) { + invalid_char[i] = !kmer_t::is_valid(sequence[i]); + } + + auto invalid_kmer = utils::drag_and_mark_segments(invalid_char, true, k); + + kmer_t uint_kmer = sshash::util::string_to_uint_kmer(sequence.data(), k - 1); + uint_kmer.pad_char(); + for (size_t i = k - 1; i < n && !terminate(); ++i) { + uint_kmer.drop_char(); + uint_kmer.kth_char_or(k - 1, kmer_t::char_to_uint(sequence[i])); + callback(invalid_kmer[i] ? sshash::lookup_result() + : dict.lookup_advanced_uint(uint_kmer, with_rc)); + } +} + +template +void DBGSSHash::map_to_nodes_with_rc(std::string_view sequence, + const std::function& callback, + const std::function& terminate) const { std::visit([&](const auto &dict) { - using kmer_t = get_kmer_t; - kmer_t uint_kmer = sshash::util::string_to_uint_kmer(sequence.data(), k_ - 1); - uint_kmer.pad_char(); - for (size_t i = k_ - 1; i < sequence.size() && !terminate(); ++i) { - uint_kmer.drop_char(); - uint_kmer.kth_char_or(k_ - 1, kmer_t::char_to_uint(sequence[i])); - auto res = dict.lookup_advanced_uint(uint_kmer, with_rc); + map_to_nodes_with_rc_impl(k_, dict, sequence, [&](sshash::lookup_result res) { callback(sshash_to_graph_index(res.kmer_id), res.kmer_orientation); - } + }, terminate); }, dict_); } + template void DBGSSHash::map_to_nodes_with_rc(std::string_view, const std::function&, diff --git a/metagraph/tests/annotation/test_aligner_labeled.cpp b/metagraph/tests/annotation/test_aligner_labeled.cpp index 6462cdfc73..bdeddccf7f 100644 --- a/metagraph/tests/annotation/test_aligner_labeled.cpp +++ b/metagraph/tests/annotation/test_aligner_labeled.cpp @@ -52,6 +52,7 @@ class LabeledAlignerTest : public ::testing::Test {}; typedef ::testing::Types>, std::pair>, + std::pair>, std::pair, std::pair, std::pair> FewGraphAnnotationPairTypes; diff --git a/metagraph/tests/annotation/test_annotated_dbg.cpp b/metagraph/tests/annotation/test_annotated_dbg.cpp index 278f01f1f0..d1437aa725 100644 --- a/metagraph/tests/annotation/test_annotated_dbg.cpp +++ b/metagraph/tests/annotation/test_annotated_dbg.cpp @@ -4,15 +4,12 @@ #include "gtest/gtest.h" #include "../test_helpers.hpp" +#include "../graph/all/test_dbg_helpers.hpp" #include "common/threads/threading.hpp" #include "common/vectors/bit_vector_dyn.hpp" #include "common/vectors/vector_algorithm.hpp" #include "annotation/representation/column_compressed/annotate_column_compressed.hpp" -#include "graph/representation/bitmap/dbg_bitmap.hpp" -#include "graph/representation/hash/dbg_hash_string.hpp" -#include "graph/representation/hash/dbg_hash_ordered.hpp" -#include "graph/representation/hash/dbg_hash_fast.hpp" #define protected public #define private public @@ -987,6 +984,7 @@ typedef ::testing::Types>, std::pair>, std::pair>, std::pair>, + std::pair>, std::pair, std::pair, std::pair, @@ -1016,6 +1014,7 @@ class AnnotatedDBGNoNTest : public ::testing::Test {}; typedef ::testing::Types>, std::pair>, std::pair>, + std::pair>, std::pair, std::pair, std::pair, diff --git a/metagraph/tests/annotation/test_annotated_dbg_helpers.cpp b/metagraph/tests/annotation/test_annotated_dbg_helpers.cpp index 2c6b1f5735..39e387335e 100644 --- a/metagraph/tests/annotation/test_annotated_dbg_helpers.cpp +++ b/metagraph/tests/annotation/test_annotated_dbg_helpers.cpp @@ -235,6 +235,7 @@ template std::unique_ptr build_anno_graph build_anno_graph>(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); template std::unique_ptr build_anno_graph>(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); template std::unique_ptr build_anno_graph>(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); +template std::unique_ptr build_anno_graph>(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); template std::unique_ptr build_anno_graph(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); template std::unique_ptr build_anno_graph(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); diff --git a/metagraph/tests/graph/all/test_dbg_helpers.cpp b/metagraph/tests/graph/all/test_dbg_helpers.cpp index 82c6878024..3af60c0cfb 100644 --- a/metagraph/tests/graph/all/test_dbg_helpers.cpp +++ b/metagraph/tests/graph/all/test_dbg_helpers.cpp @@ -146,6 +146,7 @@ void writeFastaFile(const std::vector& sequences, const std::string fastaFile.close(); } + template <> std::shared_ptr build_graph(uint64_t k, @@ -154,8 +155,8 @@ build_graph(uint64_t k, if (sequences.empty()) return std::make_shared(k, mode); - // use DBGHashString to get contigs for SSHash - auto string_graph = build_graph(k, sequences, mode); + // use DBGHashFast to get contigs for SSHash + auto string_graph = build_graph(k, sequences, mode); std::vector contigs; size_t num_kmers = 0;