From c9bb0a21f5e7094d978a2b63467928c3aa92be37 Mon Sep 17 00:00:00 2001 From: Harun Mustafa Date: Wed, 21 Oct 2020 18:35:10 +0200 Subject: [PATCH 1/2] Differential assembly support for canonical and primary graphs DBGMode used in build_anno_graph --- README.md | 4 +- .../binary_matrix/base/binary_matrix.cpp | 11 + .../binary_matrix/base/binary_matrix.hpp | 7 + metagraph/src/cli/assemble.cpp | 290 ++++++-- metagraph/src/cli/config/config.cpp | 42 +- metagraph/src/cli/config/config.hpp | 8 +- .../src/cli/load/load_annotated_graph.cpp | 3 +- metagraph/src/common/vectors/bitmap.cpp | 49 ++ metagraph/src/common/vectors/bitmap.hpp | 36 + .../src/graph/annotated_graph_algorithm.cpp | 639 +++++++++++++----- .../src/graph/annotated_graph_algorithm.hpp | 95 ++- .../src/graph/representation/masked_graph.hpp | 1 + .../annotation/test_annotated_dbg_helpers.cpp | 37 +- .../annotation/test_annotated_dbg_helpers.hpp | 15 +- .../test_annotated_graph_algorithm.cpp | 395 ++++++----- .../tests/annotation/test_matrix_helpers.cpp | 43 ++ 16 files changed, 1148 insertions(+), 527 deletions(-) diff --git a/README.md b/README.md index 96d1a131d3..29a7676d4d 100644 --- a/README.md +++ b/README.md @@ -206,9 +206,7 @@ Requires `M*V/8 + Size(BRWT)` bytes of RAM, where `M` is the number of rows in t ./metagraph assemble -v /graph.dbg \ --unitigs \ -a /annotation.column.annodbg \ - --label-mask-in LABEL_1 \ - --label-mask-in LABEL_2 \ - --label-mask-out LABEL_3 \ + --label-mask-file diff_assembly_experiment_file.txt \ -o diff_assembled.fa ``` diff --git a/metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp b/metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp index d06d312463..bfad5bfc2d 100644 --- a/metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp +++ b/metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp @@ -1,5 +1,7 @@ #include "binary_matrix.hpp" +#include "common/vectors/bitmap.hpp" +#include "common/vectors/bit_vector_adaptive.hpp" #include "common/serialization.hpp" @@ -32,6 +34,15 @@ BinaryMatrix::slice_rows(const std::vector &row_ids) const { return slice; } +void BinaryMatrix::slice_columns(const std::vector &column_ids, + const ColumnCallback &callback) const { + size_t nrows = num_rows(); + for (size_t k = 0; k < column_ids.size(); ++k) { + Column j = column_ids[k]; + callback(j, bitmap_generator(get_column(j), nrows)); + } +} + template StreamRows::StreamRows(const std::string &filename, size_t offset) { std::ifstream instream(filename, std::ios::binary); diff --git a/metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp b/metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp index 085a968192..c1ea903700 100644 --- a/metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp +++ b/metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp @@ -9,6 +9,8 @@ #include "common/vector.hpp" +class bitmap; + namespace mtg { namespace annot { namespace binmat { @@ -21,6 +23,7 @@ class BinaryMatrix { typedef Vector SetBitPositions; typedef std::function RowCallback; typedef std::function ValueCallback; + typedef std::function ColumnCallback; virtual ~BinaryMatrix() {} @@ -32,9 +35,13 @@ class BinaryMatrix { virtual SetBitPositions get_row(Row row) const = 0; virtual std::vector get_rows(const std::vector &rows) const; virtual std::vector get_column(Column column) const = 0; + // get all selected rows appended with -1 and concatenated virtual std::vector slice_rows(const std::vector &rows) const; + virtual void slice_columns(const std::vector &columns, + const ColumnCallback &callback) const; + virtual bool load(std::istream &in) = 0; virtual void serialize(std::ostream &out) const = 0; diff --git a/metagraph/src/cli/assemble.cpp b/metagraph/src/cli/assemble.cpp index 541864ac50..75c29d02ce 100644 --- a/metagraph/src/cli/assemble.cpp +++ b/metagraph/src/cli/assemble.cpp @@ -2,8 +2,10 @@ #include +#include "common/algorithms.hpp" #include "common/logger.hpp" #include "common/unix_tools.hpp" +#include "common/threads/threading.hpp" #include "seq_io/sequence_io.hpp" #include "config/config.hpp" #include "load/load_graph.hpp" @@ -19,91 +21,188 @@ using mtg::common::logger; using mtg::graph::DeBruijnGraph; using mtg::graph::MaskedDeBruijnGraph; using mtg::graph::AnnotatedDBG; +using mtg::graph::DifferentialAssemblyConfig; + + +void clean_label_set(const AnnotatedDBG &anno_graph, + std::vector &label_set) { + label_set.erase(std::remove_if(label_set.begin(), label_set.end(), + [&](const std::string &label) { + bool exists = anno_graph.label_exists(label); + if (!exists) + logger->trace("Removing label {}", label); + + return !exists; + } + ), label_set.end()); + + std::sort(label_set.begin(), label_set.end()); + auto end = std::unique(label_set.begin(), label_set.end()); + for (auto it = end; it != label_set.end(); ++it) { + logger->trace("Removing duplicate label {}", *it); + } + label_set.erase(end, label_set.end()); +} std::unique_ptr -mask_graph(const AnnotatedDBG &anno_graph, Config *config) { +mask_graph_from_labels(const AnnotatedDBG &anno_graph, + const std::vector &label_mask_in, + const std::vector &label_mask_out, + const std::vector &label_mask_in_post, + const std::vector &label_mask_out_post, + const DifferentialAssemblyConfig &diff_config, + size_t num_threads) { auto graph = std::dynamic_pointer_cast(anno_graph.get_graph_ptr()); if (!graph.get()) throw std::runtime_error("Masking only supported for DeBruijnGraph"); - // Remove non-present labels - config->label_mask_in.erase( - std::remove_if(config->label_mask_in.begin(), - config->label_mask_in.end(), - [&](const auto &label) { - bool exists = anno_graph.label_exists(label); - if (!exists) - logger->trace("Removing mask-in label {}", label); - - return !exists; - }), - config->label_mask_in.end() - ); - - config->label_mask_out.erase( - std::remove_if(config->label_mask_out.begin(), - config->label_mask_out.end(), - [&](const auto &label) { - bool exists = anno_graph.label_exists(label); - if (!exists) - logger->trace("Removing mask-out label {}", label); - - return !exists; - }), - config->label_mask_out.end() - ); - - logger->trace("Masked in: {}", fmt::join(config->label_mask_in, " ")); - logger->trace("Masked out: {}", fmt::join(config->label_mask_out, " ")); - - if (!config->filter_by_kmer) { - return std::make_unique( - graph, - mask_nodes_by_unitig_labels( - anno_graph, - config->label_mask_in, - config->label_mask_out, - std::max(1u, get_num_threads()), - config->label_mask_in_fraction, - config->label_mask_out_fraction, - config->label_other_fraction - ) - ); + std::vector*> label_sets { + &label_mask_in, &label_mask_out, + &label_mask_in_post, &label_mask_out_post + }; + + for (const auto *label_set : label_sets) { + for (const auto *other_label_set : label_sets) { + if (label_set == other_label_set) + continue; + + if (utils::count_intersection(label_set->begin(), label_set->end(), + other_label_set->begin(), other_label_set->end())) + logger->warn("Overlapping label sets"); + } } - return std::make_unique( - graph, - mask_nodes_by_node_label( - anno_graph, - config->label_mask_in, - config->label_mask_out, - [config,&anno_graph](auto index, - auto get_num_in_labels, - auto get_num_out_labels) { - assert(index != DeBruijnGraph::npos); + logger->trace("Masked in: {}", fmt::join(label_mask_in, " ")); + logger->trace("Masked in (post-processing): {}", fmt::join(label_mask_in_post, " ")); + logger->trace("Masked out: {}", fmt::join(label_mask_out, " ")); + logger->trace("Masked out (post-processing): {}", fmt::join(label_mask_out_post, " ")); + + return std::make_unique(mask_nodes_by_label( + anno_graph, + label_mask_in, label_mask_out, + label_mask_in_post, label_mask_out_post, + diff_config, num_threads + )); +} - size_t num_in_labels = get_num_in_labels(); +DifferentialAssemblyConfig parse_diff_config(const std::string &config_str, + bool canonical) { + DifferentialAssemblyConfig diff_config; + diff_config.add_complement = canonical; - if (num_in_labels < config->label_mask_in_fraction - * config->label_mask_in.size()) - return false; + auto vals = utils::split_string(config_str, ","); + auto it = vals.begin(); + if (it != vals.end()) { + diff_config.label_mask_in_kmer_fraction = std::stof(*it); + ++it; + } + if (it != vals.end()) { + diff_config.label_mask_in_unitig_fraction = std::stof(*it); + ++it; + } + if (it != vals.end()) { + diff_config.label_mask_out_kmer_fraction = std::stof(*it); + ++it; + } + if (it != vals.end()) { + diff_config.label_mask_out_unitig_fraction = std::stof(*it); + ++it; + } + if (it != vals.end()) { + diff_config.label_mask_other_unitig_fraction = std::stof(*it); + ++it; + } - size_t num_out_labels = get_num_out_labels(); + assert(it == vals.end()); - if (num_out_labels < config->label_mask_out_fraction - * config->label_mask_out.size()) - return false; + logger->trace("Per-kmer mask in fraction: {}", diff_config.label_mask_in_kmer_fraction); + logger->trace("Per-unitig mask in fraction: {}", diff_config.label_mask_in_unitig_fraction); + logger->trace("Per-kmer mask out fraction: {}", diff_config.label_mask_out_kmer_fraction); + logger->trace("Per-unitig mask out fraction: {}", diff_config.label_mask_out_unitig_fraction); + logger->trace("Per-unitig other label fraction: {}", diff_config.label_mask_other_unitig_fraction); + logger->trace("Include reverse complements: {}", diff_config.add_complement); - size_t num_total_labels = anno_graph.get_labels(index).size(); + return diff_config; +} - return num_total_labels - num_in_labels - num_out_labels - <= config->label_other_fraction * num_total_labels; - }, - std::max(1u, get_num_threads()) - ) - ); +typedef std::function CallMaskedGraphHeader; + +void call_masked_graphs(const AnnotatedDBG &anno_graph, Config *config, + const CallMaskedGraphHeader &callback, + size_t num_parallel_graphs_masked = 1, + size_t num_threads_per_graph = 1) { + assert(!config->label_mask_file.empty()); + + std::ifstream fin(config->label_mask_file); + if (!fin.good()) { + throw std::iostream::failure("Failed to read label mask file"); + exit(1); + } + + ThreadPool thread_pool(num_parallel_graphs_masked); + std::vector shared_foreground_labels; + std::vector shared_background_labels; + + std::string line; + while (std::getline(fin, line)) { + if (line.empty() || line[0] == '#') + continue; + + if (line[0] == '@') { + logger->trace("Counting shared k-mers"); + + // shared in and out labels + auto line_split = utils::split_string(line, "\t", false); + if (line_split.size() <= 1 || line_split.size() > 3) + throw std::iostream::failure("Each line in mask file must have 2-3 fields."); + + // sync all assembly jobs before clearing shared labels + thread_pool.join(); + + shared_foreground_labels = utils::split_string(line_split[1], ","); + shared_background_labels = utils::split_string( + line_split.size() == 3 ? line_split[2] : "", + "," + ); + + clean_label_set(anno_graph, shared_foreground_labels); + clean_label_set(anno_graph, shared_background_labels); + + continue; + } + + thread_pool.enqueue([&](std::string line) { + auto line_split = utils::split_string(line, "\t", false); + if (line_split.size() <= 2 || line_split.size() > 4) + throw std::iostream::failure("Each line in mask file must have 3-4 fields."); + + auto diff_config = parse_diff_config(line_split[1], config->canonical); + + if (config->enumerate_out_sequences) + line_split[0] += "."; + + auto foreground_labels = utils::split_string(line_split[2], ","); + auto background_labels = utils::split_string( + line_split.size() == 4 ? line_split[3] : "", + "," + ); + + clean_label_set(anno_graph, foreground_labels); + clean_label_set(anno_graph, background_labels); + + callback(*mask_graph_from_labels(anno_graph, + foreground_labels, background_labels, + shared_foreground_labels, + shared_background_labels, + diff_config, num_threads_per_graph), + line_split[0]); + }, std::move(line)); + } + + thread_pool.join(); } @@ -122,15 +221,56 @@ int assemble(Config *config) { logger->trace("Graph loaded in {} sec", timer.elapsed()); - std::unique_ptr anno_graph; if (config->infbase_annotators.size()) { - anno_graph = initialize_annotated_dbg(graph, *config); + assert(config->label_mask_file.size()); + auto anno_graph = initialize_annotated_dbg(graph, *config); + + logger->trace("Generating masked graphs..."); + + std::filesystem::remove( + utils::remove_suffix(config->outfbase, ".gz", ".fasta") + ".fasta.gz" + ); + + std::mutex file_open_mutex; + std::mutex write_mutex; - logger->trace("Masking graph..."); + size_t num_parallel_assemblies = std::max( + 1u, std::min(config->parallel_assemblies, get_num_threads()) + ); + size_t num_threads_per_traversal = std::max( + size_t(1), get_num_threads() / num_parallel_assemblies + ); - graph = mask_graph(*anno_graph, config); + call_masked_graphs(*anno_graph, config, + [&](const graph::MaskedDeBruijnGraph &graph, const std::string &header) { + std::lock_guard file_lock(file_open_mutex); + seq_io::FastaWriter writer(config->outfbase, header, + config->enumerate_out_sequences, + get_num_threads() > 1, /* async write */ + "a" /* append mode */); + + if (config->unitigs || config->min_tip_size > 1) { + graph.call_unitigs([&](const auto &unitig, auto&&) { + std::lock_guard lock(write_mutex); + writer.write(unitig); + }, + num_threads_per_traversal, + config->min_tip_size, + config->kmers_in_single_form); + } else { + graph.call_sequences([&](const auto &seq, auto&&) { + std::lock_guard lock(write_mutex); + writer.write(seq); + }, + num_threads_per_traversal, + config->kmers_in_single_form); + } + }, + num_parallel_assemblies, + num_threads_per_traversal + ); - logger->trace("Masked in {} sec", timer.elapsed()); + return 0; } logger->trace("Extracting sequences from graph..."); @@ -179,6 +319,8 @@ int assemble(Config *config) { get_num_threads(), config->min_tip_size ); + + return 0; } seq_io::FastaWriter writer(config->outfbase, config->header, diff --git a/metagraph/src/cli/config/config.cpp b/metagraph/src/cli/config/config.cpp index 19dac3e31a..ec28012047 100644 --- a/metagraph/src/cli/config/config.cpp +++ b/metagraph/src/cli/config/config.cpp @@ -142,6 +142,8 @@ Config::Config(int argc, char *argv[]) { set_num_threads(atoi(get_value(i++))); } else if (!strcmp(argv[i], "--parallel-nodes")) { parallel_nodes = atoi(get_value(i++)); + } else if (!strcmp(argv[i], "--parallel-assemblies")) { + parallel_assemblies = atoi(get_value(i++)); } else if (!strcmp(argv[i], "--max-path-length")) { max_path_length = atoi(get_value(i++)); } else if (!strcmp(argv[i], "--parts-total")) { @@ -250,6 +252,8 @@ Config::Config(int argc, char *argv[]) { host_address = get_value(i++); }else if (!strcmp(argv[i], "--suffix")) { suffix = get_value(i++); + } else if (!strcmp(argv[i], "--label-mask-file")) { + label_mask_file = get_value(i++); } else if (!strcmp(argv[i], "--initialize-bloom")) { initialize_bloom = true; } else if (!strcmp(argv[i], "--bloom-fpp")) { @@ -330,18 +334,6 @@ Config::Config(int argc, char *argv[]) { print_welcome_message(); print_usage(argv[0], identity); exit(0); - } else if (!strcmp(argv[i], "--label-mask-in")) { - label_mask_in.emplace_back(get_value(i++)); - } else if (!strcmp(argv[i], "--label-mask-out")) { - label_mask_out.emplace_back(get_value(i++)); - } else if (!strcmp(argv[i], "--label-mask-in-fraction")) { - label_mask_in_fraction = std::stof(get_value(i++)); - } else if (!strcmp(argv[i], "--label-mask-out-fraction")) { - label_mask_out_fraction = std::stof(get_value(i++)); - } else if (!strcmp(argv[i], "--label-other-fraction")) { - label_other_fraction = std::stof(get_value(i++)); - } else if (!strcmp(argv[i], "--filter-by-kmer")) { - filter_by_kmer = true; } else if (!strcmp(argv[i], "--disk-swap")) { tmp_dir = get_value(i++); } else if (!strcmp(argv[i], "--disk-cap-gb")) { @@ -358,6 +350,9 @@ Config::Config(int argc, char *argv[]) { if (parallel_nodes == static_cast(-1)) parallel_nodes = get_num_threads(); + if (parallel_assemblies == static_cast(-1)) + parallel_assemblies = get_num_threads(); + if (identity == TRANSFORM && to_fasta) identity = ASSEMBLE; @@ -500,6 +495,12 @@ Config::Config(int argc, char *argv[]) { print_usage_and_exit = true; } + if ((identity == ASSEMBLE || identity == TRANSFORM) + && (infbase_annotators.size() && label_mask_file.empty())) { + std::cerr << "Error: annotator passed, but no label mask file provided" << std::endl; + print_usage_and_exit = true; + } + if (identity == ANNOTATE_COORDINATES && outfbase.empty()) outfbase = utils::remove_suffix(infbase, ".dbg", ".orhashdbg", @@ -926,13 +927,16 @@ void Config::print_usage(const std::string &prog_name, IdentityType identity) { fprintf(stderr, "\t --header [STR] \theader for sequences in FASTA output []\n"); fprintf(stderr, "\t-p --parallel [INT] \tuse multiple threads for computation [1]\n"); fprintf(stderr, "\n"); - fprintf(stderr, "\t-a --annotator [STR] \t\t\tannotator to load []\n"); - fprintf(stderr, "\t --label-mask-in [STR] \t\tlabel to include in masked graph\n"); - fprintf(stderr, "\t --label-mask-out [STR] \t\tlabel to exclude from masked graph\n"); - fprintf(stderr, "\t --label-mask-in-fraction [FLOAT] \tminimum fraction of mask-in labels among the set of masked labels [1.0]\n"); - fprintf(stderr, "\t --label-mask-out-fraction [FLOAT] \tmaximum fraction of mask-out labels among the set of masked labels [0.0]\n"); - fprintf(stderr, "\t --label-other-fraction [FLOAT] \tmaximum fraction of other labels allowed [1.0]\n"); - fprintf(stderr, "\t --filter-by-kmer \t\t\tmask out graph k-mers individually [off]\n"); + fprintf(stderr, "\t-a --annotator [STR] \t\tannotator to load []\n"); + fprintf(stderr, "\t --parallel-assemblies [INT] \tnumber of assembly experiments to run in parallel [n_threads]\n"); + fprintf(stderr, "\t --label-mask-file [STR] \tfile describing labels to mask in and out and their relative fractions []\n"); + fprintf(stderr, "\t \t\tA k-mer is an in-k-mer if it has at least in_kmer_frac in-labels.\n"); + fprintf(stderr, "\t \t\tA k-mer is an out-k-mer if it has more than out_kmer_frac out-labels.\n"); + fprintf(stderr, "\t \t\tA unitig is included if it has at least in_unitig_frac in-k-mers and at most out_unitig_frac out-k-mers.\n"); + fprintf(stderr, "\t \t\tA unitig is excluded if more than other_unitig_frac of its k-mers are not in the in- or out-label sets.\n"); + fprintf(stderr, "\t \t\texample: '\\t,,,,\\t,,...\\t,,...'\n"); + fprintf(stderr, "\t \t\tIf exp_label == @, then the second and third fields are lists of in- and out-labels, respectively, that apply to all subsequent lines.\n"); + fprintf(stderr, "\t \t\texample: '@\\t,,...\\t,,...'\n"); } break; case STATS: { fprintf(stderr, "Usage: %s stats [options] GRAPH1 [[GRAPH2] ...]\n\n", prog_name.c_str()); diff --git a/metagraph/src/cli/config/config.hpp b/metagraph/src/cli/config/config.hpp index eb4458265a..aabbbd078b 100644 --- a/metagraph/src/cli/config/config.hpp +++ b/metagraph/src/cli/config/config.hpp @@ -56,7 +56,6 @@ class Config { bool map_sequences = false; bool align_sequences = false; bool align_both_strands = false; - bool filter_by_kmer = false; bool output_json = false; bool optimize = false; @@ -86,6 +85,7 @@ class Config { unsigned int bloom_max_num_hash_functions = 10; unsigned int num_columns_cached = 10; unsigned int max_hull_forks = 4; + unsigned int parallel_assemblies = -1; // by default, run |parallel| assemblies in parallel, one per thread unsigned long long int query_batch_size_in_bytes = 100'000'000; unsigned long long int num_rows_subsampled = 1'000'000; @@ -115,9 +115,6 @@ class Config { size_t alignment_max_num_seeds_per_locus = std::numeric_limits::max(); double discovery_fraction = 0.7; - double label_mask_in_fraction = 1.0; - double label_mask_out_fraction = 0.0; - double label_other_fraction = 1.0; double min_count_quantile = 0.; double max_count_quantile = 1.; double bloom_fpp = 1.0; @@ -129,8 +126,6 @@ class Config { std::vector fnames; std::vector anno_labels; std::vector infbase_annotators; - std::vector label_mask_in; - std::vector label_mask_out; std::string outfbase; std::string infbase; std::string rename_instructions_file; @@ -141,6 +136,7 @@ class Config { std::string fasta_anno_comment_delim = UNINITIALIZED_STR; std::string header = ""; std::string host_address; + std::string label_mask_file; uint32_t max_path_length = 50; std::filesystem::path tmp_dir; diff --git a/metagraph/src/cli/load/load_annotated_graph.cpp b/metagraph/src/cli/load/load_annotated_graph.cpp index aa4d0c31bc..9a16f451fc 100644 --- a/metagraph/src/cli/load/load_annotated_graph.cpp +++ b/metagraph/src/cli/load/load_annotated_graph.cpp @@ -22,7 +22,7 @@ using mtg::common::logger; std::unique_ptr initialize_annotated_dbg(std::shared_ptr graph, const Config &config) { // TODO: check and wrap into canonical only if the graph is primary - if (config.canonical && !graph->is_canonical_mode()) + if (config.canonical && !graph->is_canonical_mode() && config.identity != Config::ASSEMBLE) graph = std::make_shared(graph, true); uint64_t max_index = graph->max_index(); @@ -85,6 +85,5 @@ std::unique_ptr initialize_annotated_dbg(const Config &config) { return initialize_annotated_dbg(load_critical_dbg(config.infbase), config); } - } // namespace cli } // namespace mtg diff --git a/metagraph/src/common/vectors/bitmap.cpp b/metagraph/src/common/vectors/bitmap.cpp index 507582c3af..2e998f7ac0 100644 --- a/metagraph/src/common/vectors/bitmap.cpp +++ b/metagraph/src/common/vectors/bitmap.cpp @@ -369,3 +369,52 @@ void bitmap_lazy::call_ones_in_range(uint64_t begin, uint64_t end, callback(i); } } + + +//////////////////////////////////////////////////////////////// +// bitmap_generator // +//////////////////////////////////////////////////////////////// + +bitmap_generator::bitmap_generator(size_t size, bool value) + : bitmap_generator([value,size](const auto &index_callback) { + if (value) { + for (size_t i = 0; i < size; ++i) { + index_callback(i); + } + } + }, size, size * value) {} + +bitmap_generator +::bitmap_generator(std::function&)>&& generator, + size_t size, + size_t num_set_bits) noexcept + : size_(size), num_set_bits_(num_set_bits), generator_(std::move(generator)) {} + +bitmap_generator +::bitmap_generator(bitmap&& base, std::function&& index_transformer, + size_t size, size_t num_set_bits) noexcept + : size_(size == static_cast(-1) ? base.size() : size), + num_set_bits_(num_set_bits == static_cast(-1) ? base.num_set_bits() : num_set_bits) { + auto get_generator = [&](bitmap&& v, auto&& transformer) { + return [&](const auto &callback) { + v.call_ones([&](auto i) { callback(transformer(i)); }); + }; + }; + + generator_ = get_generator(std::move(base), std::move(index_transformer)); +} + +void bitmap_generator::add_to(sdsl::bit_vector *other) const { + assert(other); + assert(other->size() == size()); + call_ones([other](auto i) { (*other)[i] = true; }); +} + +void bitmap_generator::call_ones_in_range(uint64_t begin, uint64_t end, + const VoidCall &callback) const { + generator_([&](uint64_t i) { + assert(i < size_); + if (i >= begin && i < end) + callback(i); + }); +} diff --git a/metagraph/src/common/vectors/bitmap.hpp b/metagraph/src/common/vectors/bitmap.hpp index 3f7dfa98c6..2a0869b12a 100644 --- a/metagraph/src/common/vectors/bitmap.hpp +++ b/metagraph/src/common/vectors/bitmap.hpp @@ -187,4 +187,40 @@ class bitmap_lazy : public bitmap { size_t num_set_bits_; }; +class bitmap_generator : public bitmap { + public: + explicit bitmap_generator(size_t size = 0, bool value = false); + + bitmap_generator(std::vector&& set_bits, size_t size) noexcept + : size_(size), num_set_bits_(set_bits.size()), + generator_([s=std::move(set_bits)](const VoidCall &callback) { + std::for_each(s.begin(), s.end(), callback); + }) {} + + bitmap_generator(std::function&)>&& generator, + size_t size, + size_t num_set_bits = -1) noexcept; + + bitmap_generator(bitmap&& base, + std::function&& index_transformer + = [](auto i) { return i; }, + size_t size = -1, + size_t num_set_bits = -1) noexcept; + + bool operator[](uint64_t) const { throw std::runtime_error("Not implemented"); } + uint64_t get_int(uint64_t, uint32_t) const { throw std::runtime_error("Not implemented"); } + + uint64_t size() const { return size_; } + uint64_t num_set_bits() const { return num_set_bits_; } + + void add_to(sdsl::bit_vector *other) const; + void call_ones_in_range(uint64_t begin, uint64_t end, + const VoidCall &callback) const; + + private: + size_t size_; + size_t num_set_bits_; + std::function&)> generator_; +}; + #endif // __BITMAP_HPP__ diff --git a/metagraph/src/graph/annotated_graph_algorithm.cpp b/metagraph/src/graph/annotated_graph_algorithm.cpp index e888219f87..908565dd44 100644 --- a/metagraph/src/graph/annotated_graph_algorithm.cpp +++ b/metagraph/src/graph/annotated_graph_algorithm.cpp @@ -1,253 +1,520 @@ #include "annotated_graph_algorithm.hpp" -#include +#include +#include "common/logger.hpp" +#include "common/seq_tools/reverse_complement.hpp" #include "common/vectors/vector_algorithm.hpp" -#include "common/vector_map.hpp" -#include "kmer/alphabets.hpp" -#include "annotation/representation/column_compressed/annotate_column_compressed.hpp" -#include "graph/alignment/aligner_helper.hpp" +#include "common/threads/threading.hpp" +#include "common/vectors/bitmap.hpp" #include "graph/representation/masked_graph.hpp" +#include "graph/representation/succinct/dbg_succinct.hpp" +#include "graph/representation/succinct/boss_construct.hpp" namespace mtg { namespace graph { +using mtg::graph::boss::BOSS; +using mtg::graph::boss::BOSSConstructor; +using mtg::common::logger; + typedef AnnotatedDBG::node_index node_index; typedef AnnotatedDBG::row_index row_index; typedef AnnotatedDBG::Annotator::Label Label; -typedef Alignment DBGAlignment; +typedef std::function LabelCountCallback; + +constexpr bool MAKE_BOSS = true; + + +/** + * Return an int_vector<>, bitmap pair, each of length anno_graph.get_graph().max_index(). + * For an index i, the int_vector will contain a packed integer representing the + * number of labels in labels_in and labels_out which the k-mer of index i is + * annotated with. The least significant half of each integer represents the count + * from labels_in, while the most significant half represents the count from + * labels_out. + * The returned bitmap is a binarization of the int_vector + */ +std::pair, std::unique_ptr> +construct_diff_label_count_vector(const AnnotatedDBG &anno_graph, + const std::vector