diff --git a/README.md b/README.md index 96d1a131d3..29a7676d4d 100644 --- a/README.md +++ b/README.md @@ -206,9 +206,7 @@ Requires `M*V/8 + Size(BRWT)` bytes of RAM, where `M` is the number of rows in t ./metagraph assemble -v /graph.dbg \ --unitigs \ -a /annotation.column.annodbg \ - --label-mask-in LABEL_1 \ - --label-mask-in LABEL_2 \ - --label-mask-out LABEL_3 \ + --label-mask-file diff_assembly_experiment_file.txt \ -o diff_assembled.fa ``` diff --git a/metagraph/benchmarks/benchmark_matrix.cpp b/metagraph/benchmarks/benchmark_matrix.cpp index ef43b6ba1f..657dd15a0b 100644 --- a/metagraph/benchmarks/benchmark_matrix.cpp +++ b/metagraph/benchmarks/benchmark_matrix.cpp @@ -115,4 +115,98 @@ BENCHMARK_TEMPLATE(BM_BRWTQueryRows, 3000000, 100, 30, 2, false, 0) ->Unit(benchmark::kMillisecond) ->DenseRange(0, 10, 1); + +template +static void BM_BRWTQueryColumns(benchmark::State& state) { + DataGenerator generator; + generator.set_seed(42); + + auto density_arg = std::vector(unique_arg, state.range(0) / 100.); + auto generated_columns = generator.generate_random_columns( + rows_arg, + unique_arg, + get_densities(unique_arg, density_arg), + std::vector(unique_arg, cols_arg / unique_arg) + ); + + std::unique_ptr matrix = experiments::generate_brwt_from_rows( + std::move(generated_columns), + arity_arg, + greedy_arg, + relax_arg + ); + + std::vector indexes; + call_ones(generator.generate_random_column(matrix->num_columns(), 1. / 10), + [&](uint64_t i) { indexes.push_back(i); } + ); + + for (auto _ : state) { + uint64_t j = 0; + #pragma omp parallel for num_threads(3) + for (size_t i = 0; i < indexes.size(); ++i) { + j += i; + for (auto t : matrix->get_column(indexes[i])) { + j += t; + } + } + } +} + +BENCHMARK_TEMPLATE(BM_BRWTQueryColumns, 3000000, 100, 30, 2, false, 0) + ->Unit(benchmark::kMillisecond) + ->DenseRange(0, 10, 1); + +template +static void BM_BRWTSliceColumns(benchmark::State& state) { + DataGenerator generator; + generator.set_seed(42); + + auto density_arg = std::vector(unique_arg, state.range(0) / 100.); + auto generated_columns = generator.generate_random_columns( + rows_arg, + unique_arg, + get_densities(unique_arg, density_arg), + std::vector(unique_arg, cols_arg / unique_arg) + ); + + std::unique_ptr matrix = experiments::generate_brwt_from_rows( + std::move(generated_columns), + arity_arg, + greedy_arg, + relax_arg + ); + + std::vector indexes; + call_ones(generator.generate_random_column(matrix->num_columns(), 1. / 10), + [&](uint64_t i) { indexes.push_back(i); } + ); + + for (auto _ : state) { + uint64_t j = 0; + #pragma omp parallel num_threads(3) + #pragma omp single + { + matrix->slice_columns(indexes, [&](auto i, auto&& bitmap) { + j += i; + bitmap.call_ones([&](auto t) { j += t; }); + }); + } + } +} + +BENCHMARK_TEMPLATE(BM_BRWTSliceColumns, 3000000, 100, 30, 2, false, 0) + ->Unit(benchmark::kMillisecond) + ->DenseRange(0, 10, 1); + } // namespace diff --git a/metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp b/metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp index d06d312463..bfad5bfc2d 100644 --- a/metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp +++ b/metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp @@ -1,5 +1,7 @@ #include "binary_matrix.hpp" +#include "common/vectors/bitmap.hpp" +#include "common/vectors/bit_vector_adaptive.hpp" #include "common/serialization.hpp" @@ -32,6 +34,15 @@ BinaryMatrix::slice_rows(const std::vector &row_ids) const { return slice; } +void BinaryMatrix::slice_columns(const std::vector &column_ids, + const ColumnCallback &callback) const { + size_t nrows = num_rows(); + for (size_t k = 0; k < column_ids.size(); ++k) { + Column j = column_ids[k]; + callback(j, bitmap_generator(get_column(j), nrows)); + } +} + template StreamRows::StreamRows(const std::string &filename, size_t offset) { std::ifstream instream(filename, std::ios::binary); diff --git a/metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp b/metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp index 085a968192..c1ea903700 100644 --- a/metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp +++ b/metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp @@ -9,6 +9,8 @@ #include "common/vector.hpp" +class bitmap; + namespace mtg { namespace annot { namespace binmat { @@ -21,6 +23,7 @@ class BinaryMatrix { typedef Vector SetBitPositions; typedef std::function RowCallback; typedef std::function ValueCallback; + typedef std::function ColumnCallback; virtual ~BinaryMatrix() {} @@ -32,9 +35,13 @@ class BinaryMatrix { virtual SetBitPositions get_row(Row row) const = 0; virtual std::vector get_rows(const std::vector &rows) const; virtual std::vector get_column(Column column) const = 0; + // get all selected rows appended with -1 and concatenated virtual std::vector slice_rows(const std::vector &rows) const; + virtual void slice_columns(const std::vector &columns, + const ColumnCallback &callback) const; + virtual bool load(std::istream &in) = 0; virtual void serialize(std::ostream &out) const = 0; diff --git a/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp b/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp index f71ae11ed9..9a7a524597 100644 --- a/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp +++ b/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp @@ -3,6 +3,10 @@ #include #include +#include + +#include + #include "common/algorithms.hpp" #include "common/serialization.hpp" @@ -189,6 +193,99 @@ std::vector BRWT::slice_rows(const std::vector &row_ids) cons return slice; } +void BRWT::slice_columns(const std::vector &column_ids, + const ColumnCallback &callback) const { + if (column_ids.empty()) + return; + + auto num_nonzero_rows = nonzero_rows_->num_set_bits(); + + // check if the column is empty + if (!num_nonzero_rows) + return; + + // check whether it is a leaf + if (!child_nodes_.size()) { + // return the index column + for (size_t k = 0; k < column_ids.size(); ++k) { + callback(column_ids[k], std::move(*nonzero_rows_->copy())); + } + + return; + } + + tsl::hopscotch_map> child_columns_map; + for (size_t i = 0; i < column_ids.size(); ++i) { + assert(column_ids[i] < num_columns()); + auto child_node = assignments_.group(column_ids[i]); + auto child_column = assignments_.rank(column_ids[i]); + + auto it = child_columns_map.find(child_node); + if (it == child_columns_map.end()) + it = child_columns_map.emplace(child_node, std::vector{}).first; + + it.value().push_back(child_column); + } + + auto process = [&](auto child_node, auto *child_columns_ptr) { + if (num_nonzero_rows == nonzero_rows_->size()) { + child_nodes_[child_node]->slice_columns(*child_columns_ptr, + [&](Column j, bitmap&& rows) { + callback(assignments_.get(child_node, j), std::move(rows)); + } + ); + } else { + const BRWT *child_node_brwt = dynamic_cast( + child_nodes_[child_node].get() + ); + if (child_node_brwt + && child_columns_ptr->size() > 1 + && !child_node_brwt->child_nodes_.size()) { + // if there are multiple column ids corresponding to the same leaf + // node, then this branch avoids doing redundant select1 calls + const auto *nonzero_rows = child_node_brwt->nonzero_rows_.get(); + size_t num_nonzero_rows = nonzero_rows->num_set_bits(); + if (num_nonzero_rows) { + std::vector set_bits; + set_bits.reserve(num_nonzero_rows); + nonzero_rows->call_ones([&](auto i) { + set_bits.push_back(nonzero_rows->select1(i + 1)); + }); + + for (size_t k = 0; k < child_columns_ptr->size() - 1; ++k) { + callback(assignments_.get(child_node, (*child_columns_ptr)[k]), + bitmap_generator(std::move(set_bits), num_rows())); + } + + callback(assignments_.get(child_node, child_columns_ptr->back()), + bitmap_generator(std::move(set_bits), num_rows())); + } + } else { + child_nodes_[child_node]->slice_columns(*child_columns_ptr, + [&](Column j, bitmap&& rows) { + size_t num_set_bits = rows.num_set_bits(); + callback(assignments_.get(child_node, j), + bitmap_generator(std::move(rows), [&](uint64_t i) { + return nonzero_rows_->select1(i + 1); + }, num_rows(), num_set_bits)); + } + ); + } + } + }; + + for (auto it = ++child_columns_map.begin(); it != child_columns_map.end(); ++it) { + auto child_node = it->first; + auto *child_columns_ptr = &it->second; + #pragma omp task firstprivate(child_node, child_columns_ptr) + process(child_node, child_columns_ptr); + } + + process(child_columns_map.begin()->first, &child_columns_map.begin()->second); + + #pragma omp taskwait +} + std::vector BRWT::get_column(Column column) const { assert(column < num_columns()); diff --git a/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp b/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp index 829341cf76..641e0a5cf7 100644 --- a/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp +++ b/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp @@ -35,6 +35,9 @@ class BRWT : public BinaryMatrix { // get all selected rows appended with -1 and concatenated std::vector slice_rows(const std::vector &rows) const override; + void slice_columns(const std::vector &columns, + const ColumnCallback &callback) const override; + bool load(std::istream &in) override; void serialize(std::ostream &out) const override; diff --git a/metagraph/src/annotation/binary_matrix/rainbowfish/rainbow.cpp b/metagraph/src/annotation/binary_matrix/rainbowfish/rainbow.cpp index 20161658fa..8e16517f10 100644 --- a/metagraph/src/annotation/binary_matrix/rainbowfish/rainbow.cpp +++ b/metagraph/src/annotation/binary_matrix/rainbowfish/rainbow.cpp @@ -131,6 +131,25 @@ Rainbow::get_column(Column column) const { return row_indices; } +template +void +Rainbow::slice_columns(const std::vector &columns, + const ColumnCallback &callback) const { + uint64_t nrows = num_rows(); + sdsl::bit_vector code_column(reduced_matrix_.num_rows()); + reduced_matrix_.slice_columns(columns, [&](Column j, bitmap&& rows) { + sdsl::util::set_to_value(code_column, false); + rows.add_to(&code_column); + + callback(j, bitmap_generator([&](const auto &index_callback) { + for (uint64_t i = 0; i < nrows; ++i) { + if (code_column[get_code(i)]) + index_callback(i); + } + }, nrows)); + }); +} + template bool Rainbow::load(std::istream &in) { try { diff --git a/metagraph/src/annotation/binary_matrix/rainbowfish/rainbow.hpp b/metagraph/src/annotation/binary_matrix/rainbowfish/rainbow.hpp index b5ef7c187b..fc7ced91aa 100644 --- a/metagraph/src/annotation/binary_matrix/rainbowfish/rainbow.hpp +++ b/metagraph/src/annotation/binary_matrix/rainbowfish/rainbow.hpp @@ -40,6 +40,9 @@ class Rainbow : public RainbowMatrix { size_t num_threads = 1) const override; std::vector get_column(Column column) const override; + void slice_columns(const std::vector &columns, + const ColumnCallback &callback) const override; + bool load(std::istream &in) override; void serialize(std::ostream &out) const override; diff --git a/metagraph/src/cli/assemble.cpp b/metagraph/src/cli/assemble.cpp index 541864ac50..75c29d02ce 100644 --- a/metagraph/src/cli/assemble.cpp +++ b/metagraph/src/cli/assemble.cpp @@ -2,8 +2,10 @@ #include +#include "common/algorithms.hpp" #include "common/logger.hpp" #include "common/unix_tools.hpp" +#include "common/threads/threading.hpp" #include "seq_io/sequence_io.hpp" #include "config/config.hpp" #include "load/load_graph.hpp" @@ -19,91 +21,188 @@ using mtg::common::logger; using mtg::graph::DeBruijnGraph; using mtg::graph::MaskedDeBruijnGraph; using mtg::graph::AnnotatedDBG; +using mtg::graph::DifferentialAssemblyConfig; + + +void clean_label_set(const AnnotatedDBG &anno_graph, + std::vector &label_set) { + label_set.erase(std::remove_if(label_set.begin(), label_set.end(), + [&](const std::string &label) { + bool exists = anno_graph.label_exists(label); + if (!exists) + logger->trace("Removing label {}", label); + + return !exists; + } + ), label_set.end()); + + std::sort(label_set.begin(), label_set.end()); + auto end = std::unique(label_set.begin(), label_set.end()); + for (auto it = end; it != label_set.end(); ++it) { + logger->trace("Removing duplicate label {}", *it); + } + label_set.erase(end, label_set.end()); +} std::unique_ptr -mask_graph(const AnnotatedDBG &anno_graph, Config *config) { +mask_graph_from_labels(const AnnotatedDBG &anno_graph, + const std::vector &label_mask_in, + const std::vector &label_mask_out, + const std::vector &label_mask_in_post, + const std::vector &label_mask_out_post, + const DifferentialAssemblyConfig &diff_config, + size_t num_threads) { auto graph = std::dynamic_pointer_cast(anno_graph.get_graph_ptr()); if (!graph.get()) throw std::runtime_error("Masking only supported for DeBruijnGraph"); - // Remove non-present labels - config->label_mask_in.erase( - std::remove_if(config->label_mask_in.begin(), - config->label_mask_in.end(), - [&](const auto &label) { - bool exists = anno_graph.label_exists(label); - if (!exists) - logger->trace("Removing mask-in label {}", label); - - return !exists; - }), - config->label_mask_in.end() - ); - - config->label_mask_out.erase( - std::remove_if(config->label_mask_out.begin(), - config->label_mask_out.end(), - [&](const auto &label) { - bool exists = anno_graph.label_exists(label); - if (!exists) - logger->trace("Removing mask-out label {}", label); - - return !exists; - }), - config->label_mask_out.end() - ); - - logger->trace("Masked in: {}", fmt::join(config->label_mask_in, " ")); - logger->trace("Masked out: {}", fmt::join(config->label_mask_out, " ")); - - if (!config->filter_by_kmer) { - return std::make_unique( - graph, - mask_nodes_by_unitig_labels( - anno_graph, - config->label_mask_in, - config->label_mask_out, - std::max(1u, get_num_threads()), - config->label_mask_in_fraction, - config->label_mask_out_fraction, - config->label_other_fraction - ) - ); + std::vector*> label_sets { + &label_mask_in, &label_mask_out, + &label_mask_in_post, &label_mask_out_post + }; + + for (const auto *label_set : label_sets) { + for (const auto *other_label_set : label_sets) { + if (label_set == other_label_set) + continue; + + if (utils::count_intersection(label_set->begin(), label_set->end(), + other_label_set->begin(), other_label_set->end())) + logger->warn("Overlapping label sets"); + } } - return std::make_unique( - graph, - mask_nodes_by_node_label( - anno_graph, - config->label_mask_in, - config->label_mask_out, - [config,&anno_graph](auto index, - auto get_num_in_labels, - auto get_num_out_labels) { - assert(index != DeBruijnGraph::npos); + logger->trace("Masked in: {}", fmt::join(label_mask_in, " ")); + logger->trace("Masked in (post-processing): {}", fmt::join(label_mask_in_post, " ")); + logger->trace("Masked out: {}", fmt::join(label_mask_out, " ")); + logger->trace("Masked out (post-processing): {}", fmt::join(label_mask_out_post, " ")); + + return std::make_unique(mask_nodes_by_label( + anno_graph, + label_mask_in, label_mask_out, + label_mask_in_post, label_mask_out_post, + diff_config, num_threads + )); +} - size_t num_in_labels = get_num_in_labels(); +DifferentialAssemblyConfig parse_diff_config(const std::string &config_str, + bool canonical) { + DifferentialAssemblyConfig diff_config; + diff_config.add_complement = canonical; - if (num_in_labels < config->label_mask_in_fraction - * config->label_mask_in.size()) - return false; + auto vals = utils::split_string(config_str, ","); + auto it = vals.begin(); + if (it != vals.end()) { + diff_config.label_mask_in_kmer_fraction = std::stof(*it); + ++it; + } + if (it != vals.end()) { + diff_config.label_mask_in_unitig_fraction = std::stof(*it); + ++it; + } + if (it != vals.end()) { + diff_config.label_mask_out_kmer_fraction = std::stof(*it); + ++it; + } + if (it != vals.end()) { + diff_config.label_mask_out_unitig_fraction = std::stof(*it); + ++it; + } + if (it != vals.end()) { + diff_config.label_mask_other_unitig_fraction = std::stof(*it); + ++it; + } - size_t num_out_labels = get_num_out_labels(); + assert(it == vals.end()); - if (num_out_labels < config->label_mask_out_fraction - * config->label_mask_out.size()) - return false; + logger->trace("Per-kmer mask in fraction: {}", diff_config.label_mask_in_kmer_fraction); + logger->trace("Per-unitig mask in fraction: {}", diff_config.label_mask_in_unitig_fraction); + logger->trace("Per-kmer mask out fraction: {}", diff_config.label_mask_out_kmer_fraction); + logger->trace("Per-unitig mask out fraction: {}", diff_config.label_mask_out_unitig_fraction); + logger->trace("Per-unitig other label fraction: {}", diff_config.label_mask_other_unitig_fraction); + logger->trace("Include reverse complements: {}", diff_config.add_complement); - size_t num_total_labels = anno_graph.get_labels(index).size(); + return diff_config; +} - return num_total_labels - num_in_labels - num_out_labels - <= config->label_other_fraction * num_total_labels; - }, - std::max(1u, get_num_threads()) - ) - ); +typedef std::function CallMaskedGraphHeader; + +void call_masked_graphs(const AnnotatedDBG &anno_graph, Config *config, + const CallMaskedGraphHeader &callback, + size_t num_parallel_graphs_masked = 1, + size_t num_threads_per_graph = 1) { + assert(!config->label_mask_file.empty()); + + std::ifstream fin(config->label_mask_file); + if (!fin.good()) { + throw std::iostream::failure("Failed to read label mask file"); + exit(1); + } + + ThreadPool thread_pool(num_parallel_graphs_masked); + std::vector shared_foreground_labels; + std::vector shared_background_labels; + + std::string line; + while (std::getline(fin, line)) { + if (line.empty() || line[0] == '#') + continue; + + if (line[0] == '@') { + logger->trace("Counting shared k-mers"); + + // shared in and out labels + auto line_split = utils::split_string(line, "\t", false); + if (line_split.size() <= 1 || line_split.size() > 3) + throw std::iostream::failure("Each line in mask file must have 2-3 fields."); + + // sync all assembly jobs before clearing shared labels + thread_pool.join(); + + shared_foreground_labels = utils::split_string(line_split[1], ","); + shared_background_labels = utils::split_string( + line_split.size() == 3 ? line_split[2] : "", + "," + ); + + clean_label_set(anno_graph, shared_foreground_labels); + clean_label_set(anno_graph, shared_background_labels); + + continue; + } + + thread_pool.enqueue([&](std::string line) { + auto line_split = utils::split_string(line, "\t", false); + if (line_split.size() <= 2 || line_split.size() > 4) + throw std::iostream::failure("Each line in mask file must have 3-4 fields."); + + auto diff_config = parse_diff_config(line_split[1], config->canonical); + + if (config->enumerate_out_sequences) + line_split[0] += "."; + + auto foreground_labels = utils::split_string(line_split[2], ","); + auto background_labels = utils::split_string( + line_split.size() == 4 ? line_split[3] : "", + "," + ); + + clean_label_set(anno_graph, foreground_labels); + clean_label_set(anno_graph, background_labels); + + callback(*mask_graph_from_labels(anno_graph, + foreground_labels, background_labels, + shared_foreground_labels, + shared_background_labels, + diff_config, num_threads_per_graph), + line_split[0]); + }, std::move(line)); + } + + thread_pool.join(); } @@ -122,15 +221,56 @@ int assemble(Config *config) { logger->trace("Graph loaded in {} sec", timer.elapsed()); - std::unique_ptr anno_graph; if (config->infbase_annotators.size()) { - anno_graph = initialize_annotated_dbg(graph, *config); + assert(config->label_mask_file.size()); + auto anno_graph = initialize_annotated_dbg(graph, *config); + + logger->trace("Generating masked graphs..."); + + std::filesystem::remove( + utils::remove_suffix(config->outfbase, ".gz", ".fasta") + ".fasta.gz" + ); + + std::mutex file_open_mutex; + std::mutex write_mutex; - logger->trace("Masking graph..."); + size_t num_parallel_assemblies = std::max( + 1u, std::min(config->parallel_assemblies, get_num_threads()) + ); + size_t num_threads_per_traversal = std::max( + size_t(1), get_num_threads() / num_parallel_assemblies + ); - graph = mask_graph(*anno_graph, config); + call_masked_graphs(*anno_graph, config, + [&](const graph::MaskedDeBruijnGraph &graph, const std::string &header) { + std::lock_guard file_lock(file_open_mutex); + seq_io::FastaWriter writer(config->outfbase, header, + config->enumerate_out_sequences, + get_num_threads() > 1, /* async write */ + "a" /* append mode */); + + if (config->unitigs || config->min_tip_size > 1) { + graph.call_unitigs([&](const auto &unitig, auto&&) { + std::lock_guard lock(write_mutex); + writer.write(unitig); + }, + num_threads_per_traversal, + config->min_tip_size, + config->kmers_in_single_form); + } else { + graph.call_sequences([&](const auto &seq, auto&&) { + std::lock_guard lock(write_mutex); + writer.write(seq); + }, + num_threads_per_traversal, + config->kmers_in_single_form); + } + }, + num_parallel_assemblies, + num_threads_per_traversal + ); - logger->trace("Masked in {} sec", timer.elapsed()); + return 0; } logger->trace("Extracting sequences from graph..."); @@ -179,6 +319,8 @@ int assemble(Config *config) { get_num_threads(), config->min_tip_size ); + + return 0; } seq_io::FastaWriter writer(config->outfbase, config->header, diff --git a/metagraph/src/cli/config/config.cpp b/metagraph/src/cli/config/config.cpp index 19dac3e31a..ec28012047 100644 --- a/metagraph/src/cli/config/config.cpp +++ b/metagraph/src/cli/config/config.cpp @@ -142,6 +142,8 @@ Config::Config(int argc, char *argv[]) { set_num_threads(atoi(get_value(i++))); } else if (!strcmp(argv[i], "--parallel-nodes")) { parallel_nodes = atoi(get_value(i++)); + } else if (!strcmp(argv[i], "--parallel-assemblies")) { + parallel_assemblies = atoi(get_value(i++)); } else if (!strcmp(argv[i], "--max-path-length")) { max_path_length = atoi(get_value(i++)); } else if (!strcmp(argv[i], "--parts-total")) { @@ -250,6 +252,8 @@ Config::Config(int argc, char *argv[]) { host_address = get_value(i++); }else if (!strcmp(argv[i], "--suffix")) { suffix = get_value(i++); + } else if (!strcmp(argv[i], "--label-mask-file")) { + label_mask_file = get_value(i++); } else if (!strcmp(argv[i], "--initialize-bloom")) { initialize_bloom = true; } else if (!strcmp(argv[i], "--bloom-fpp")) { @@ -330,18 +334,6 @@ Config::Config(int argc, char *argv[]) { print_welcome_message(); print_usage(argv[0], identity); exit(0); - } else if (!strcmp(argv[i], "--label-mask-in")) { - label_mask_in.emplace_back(get_value(i++)); - } else if (!strcmp(argv[i], "--label-mask-out")) { - label_mask_out.emplace_back(get_value(i++)); - } else if (!strcmp(argv[i], "--label-mask-in-fraction")) { - label_mask_in_fraction = std::stof(get_value(i++)); - } else if (!strcmp(argv[i], "--label-mask-out-fraction")) { - label_mask_out_fraction = std::stof(get_value(i++)); - } else if (!strcmp(argv[i], "--label-other-fraction")) { - label_other_fraction = std::stof(get_value(i++)); - } else if (!strcmp(argv[i], "--filter-by-kmer")) { - filter_by_kmer = true; } else if (!strcmp(argv[i], "--disk-swap")) { tmp_dir = get_value(i++); } else if (!strcmp(argv[i], "--disk-cap-gb")) { @@ -358,6 +350,9 @@ Config::Config(int argc, char *argv[]) { if (parallel_nodes == static_cast(-1)) parallel_nodes = get_num_threads(); + if (parallel_assemblies == static_cast(-1)) + parallel_assemblies = get_num_threads(); + if (identity == TRANSFORM && to_fasta) identity = ASSEMBLE; @@ -500,6 +495,12 @@ Config::Config(int argc, char *argv[]) { print_usage_and_exit = true; } + if ((identity == ASSEMBLE || identity == TRANSFORM) + && (infbase_annotators.size() && label_mask_file.empty())) { + std::cerr << "Error: annotator passed, but no label mask file provided" << std::endl; + print_usage_and_exit = true; + } + if (identity == ANNOTATE_COORDINATES && outfbase.empty()) outfbase = utils::remove_suffix(infbase, ".dbg", ".orhashdbg", @@ -926,13 +927,16 @@ void Config::print_usage(const std::string &prog_name, IdentityType identity) { fprintf(stderr, "\t --header [STR] \theader for sequences in FASTA output []\n"); fprintf(stderr, "\t-p --parallel [INT] \tuse multiple threads for computation [1]\n"); fprintf(stderr, "\n"); - fprintf(stderr, "\t-a --annotator [STR] \t\t\tannotator to load []\n"); - fprintf(stderr, "\t --label-mask-in [STR] \t\tlabel to include in masked graph\n"); - fprintf(stderr, "\t --label-mask-out [STR] \t\tlabel to exclude from masked graph\n"); - fprintf(stderr, "\t --label-mask-in-fraction [FLOAT] \tminimum fraction of mask-in labels among the set of masked labels [1.0]\n"); - fprintf(stderr, "\t --label-mask-out-fraction [FLOAT] \tmaximum fraction of mask-out labels among the set of masked labels [0.0]\n"); - fprintf(stderr, "\t --label-other-fraction [FLOAT] \tmaximum fraction of other labels allowed [1.0]\n"); - fprintf(stderr, "\t --filter-by-kmer \t\t\tmask out graph k-mers individually [off]\n"); + fprintf(stderr, "\t-a --annotator [STR] \t\tannotator to load []\n"); + fprintf(stderr, "\t --parallel-assemblies [INT] \tnumber of assembly experiments to run in parallel [n_threads]\n"); + fprintf(stderr, "\t --label-mask-file [STR] \tfile describing labels to mask in and out and their relative fractions []\n"); + fprintf(stderr, "\t \t\tA k-mer is an in-k-mer if it has at least in_kmer_frac in-labels.\n"); + fprintf(stderr, "\t \t\tA k-mer is an out-k-mer if it has more than out_kmer_frac out-labels.\n"); + fprintf(stderr, "\t \t\tA unitig is included if it has at least in_unitig_frac in-k-mers and at most out_unitig_frac out-k-mers.\n"); + fprintf(stderr, "\t \t\tA unitig is excluded if more than other_unitig_frac of its k-mers are not in the in- or out-label sets.\n"); + fprintf(stderr, "\t \t\texample: '\\t,,,,\\t,,...\\t,,...'\n"); + fprintf(stderr, "\t \t\tIf exp_label == @, then the second and third fields are lists of in- and out-labels, respectively, that apply to all subsequent lines.\n"); + fprintf(stderr, "\t \t\texample: '@\\t,,...\\t,,...'\n"); } break; case STATS: { fprintf(stderr, "Usage: %s stats [options] GRAPH1 [[GRAPH2] ...]\n\n", prog_name.c_str()); diff --git a/metagraph/src/cli/config/config.hpp b/metagraph/src/cli/config/config.hpp index eb4458265a..aabbbd078b 100644 --- a/metagraph/src/cli/config/config.hpp +++ b/metagraph/src/cli/config/config.hpp @@ -56,7 +56,6 @@ class Config { bool map_sequences = false; bool align_sequences = false; bool align_both_strands = false; - bool filter_by_kmer = false; bool output_json = false; bool optimize = false; @@ -86,6 +85,7 @@ class Config { unsigned int bloom_max_num_hash_functions = 10; unsigned int num_columns_cached = 10; unsigned int max_hull_forks = 4; + unsigned int parallel_assemblies = -1; // by default, run |parallel| assemblies in parallel, one per thread unsigned long long int query_batch_size_in_bytes = 100'000'000; unsigned long long int num_rows_subsampled = 1'000'000; @@ -115,9 +115,6 @@ class Config { size_t alignment_max_num_seeds_per_locus = std::numeric_limits::max(); double discovery_fraction = 0.7; - double label_mask_in_fraction = 1.0; - double label_mask_out_fraction = 0.0; - double label_other_fraction = 1.0; double min_count_quantile = 0.; double max_count_quantile = 1.; double bloom_fpp = 1.0; @@ -129,8 +126,6 @@ class Config { std::vector fnames; std::vector anno_labels; std::vector infbase_annotators; - std::vector label_mask_in; - std::vector label_mask_out; std::string outfbase; std::string infbase; std::string rename_instructions_file; @@ -141,6 +136,7 @@ class Config { std::string fasta_anno_comment_delim = UNINITIALIZED_STR; std::string header = ""; std::string host_address; + std::string label_mask_file; uint32_t max_path_length = 50; std::filesystem::path tmp_dir; diff --git a/metagraph/src/cli/load/load_annotated_graph.cpp b/metagraph/src/cli/load/load_annotated_graph.cpp index aa4d0c31bc..9a16f451fc 100644 --- a/metagraph/src/cli/load/load_annotated_graph.cpp +++ b/metagraph/src/cli/load/load_annotated_graph.cpp @@ -22,7 +22,7 @@ using mtg::common::logger; std::unique_ptr initialize_annotated_dbg(std::shared_ptr graph, const Config &config) { // TODO: check and wrap into canonical only if the graph is primary - if (config.canonical && !graph->is_canonical_mode()) + if (config.canonical && !graph->is_canonical_mode() && config.identity != Config::ASSEMBLE) graph = std::make_shared(graph, true); uint64_t max_index = graph->max_index(); @@ -85,6 +85,5 @@ std::unique_ptr initialize_annotated_dbg(const Config &config) { return initialize_annotated_dbg(load_critical_dbg(config.infbase), config); } - } // namespace cli } // namespace mtg diff --git a/metagraph/src/common/vectors/bitmap.cpp b/metagraph/src/common/vectors/bitmap.cpp index 507582c3af..2e998f7ac0 100644 --- a/metagraph/src/common/vectors/bitmap.cpp +++ b/metagraph/src/common/vectors/bitmap.cpp @@ -369,3 +369,52 @@ void bitmap_lazy::call_ones_in_range(uint64_t begin, uint64_t end, callback(i); } } + + +//////////////////////////////////////////////////////////////// +// bitmap_generator // +//////////////////////////////////////////////////////////////// + +bitmap_generator::bitmap_generator(size_t size, bool value) + : bitmap_generator([value,size](const auto &index_callback) { + if (value) { + for (size_t i = 0; i < size; ++i) { + index_callback(i); + } + } + }, size, size * value) {} + +bitmap_generator +::bitmap_generator(std::function&)>&& generator, + size_t size, + size_t num_set_bits) noexcept + : size_(size), num_set_bits_(num_set_bits), generator_(std::move(generator)) {} + +bitmap_generator +::bitmap_generator(bitmap&& base, std::function&& index_transformer, + size_t size, size_t num_set_bits) noexcept + : size_(size == static_cast(-1) ? base.size() : size), + num_set_bits_(num_set_bits == static_cast(-1) ? base.num_set_bits() : num_set_bits) { + auto get_generator = [&](bitmap&& v, auto&& transformer) { + return [&](const auto &callback) { + v.call_ones([&](auto i) { callback(transformer(i)); }); + }; + }; + + generator_ = get_generator(std::move(base), std::move(index_transformer)); +} + +void bitmap_generator::add_to(sdsl::bit_vector *other) const { + assert(other); + assert(other->size() == size()); + call_ones([other](auto i) { (*other)[i] = true; }); +} + +void bitmap_generator::call_ones_in_range(uint64_t begin, uint64_t end, + const VoidCall &callback) const { + generator_([&](uint64_t i) { + assert(i < size_); + if (i >= begin && i < end) + callback(i); + }); +} diff --git a/metagraph/src/common/vectors/bitmap.hpp b/metagraph/src/common/vectors/bitmap.hpp index 3f7dfa98c6..2a0869b12a 100644 --- a/metagraph/src/common/vectors/bitmap.hpp +++ b/metagraph/src/common/vectors/bitmap.hpp @@ -187,4 +187,40 @@ class bitmap_lazy : public bitmap { size_t num_set_bits_; }; +class bitmap_generator : public bitmap { + public: + explicit bitmap_generator(size_t size = 0, bool value = false); + + bitmap_generator(std::vector&& set_bits, size_t size) noexcept + : size_(size), num_set_bits_(set_bits.size()), + generator_([s=std::move(set_bits)](const VoidCall &callback) { + std::for_each(s.begin(), s.end(), callback); + }) {} + + bitmap_generator(std::function&)>&& generator, + size_t size, + size_t num_set_bits = -1) noexcept; + + bitmap_generator(bitmap&& base, + std::function&& index_transformer + = [](auto i) { return i; }, + size_t size = -1, + size_t num_set_bits = -1) noexcept; + + bool operator[](uint64_t) const { throw std::runtime_error("Not implemented"); } + uint64_t get_int(uint64_t, uint32_t) const { throw std::runtime_error("Not implemented"); } + + uint64_t size() const { return size_; } + uint64_t num_set_bits() const { return num_set_bits_; } + + void add_to(sdsl::bit_vector *other) const; + void call_ones_in_range(uint64_t begin, uint64_t end, + const VoidCall &callback) const; + + private: + size_t size_; + size_t num_set_bits_; + std::function&)> generator_; +}; + #endif // __BITMAP_HPP__ diff --git a/metagraph/src/graph/annotated_graph_algorithm.cpp b/metagraph/src/graph/annotated_graph_algorithm.cpp index e888219f87..908565dd44 100644 --- a/metagraph/src/graph/annotated_graph_algorithm.cpp +++ b/metagraph/src/graph/annotated_graph_algorithm.cpp @@ -1,253 +1,520 @@ #include "annotated_graph_algorithm.hpp" -#include +#include +#include "common/logger.hpp" +#include "common/seq_tools/reverse_complement.hpp" #include "common/vectors/vector_algorithm.hpp" -#include "common/vector_map.hpp" -#include "kmer/alphabets.hpp" -#include "annotation/representation/column_compressed/annotate_column_compressed.hpp" -#include "graph/alignment/aligner_helper.hpp" +#include "common/threads/threading.hpp" +#include "common/vectors/bitmap.hpp" #include "graph/representation/masked_graph.hpp" +#include "graph/representation/succinct/dbg_succinct.hpp" +#include "graph/representation/succinct/boss_construct.hpp" namespace mtg { namespace graph { +using mtg::graph::boss::BOSS; +using mtg::graph::boss::BOSSConstructor; +using mtg::common::logger; + typedef AnnotatedDBG::node_index node_index; typedef AnnotatedDBG::row_index row_index; typedef AnnotatedDBG::Annotator::Label Label; -typedef Alignment DBGAlignment; +typedef std::function LabelCountCallback; + +constexpr bool MAKE_BOSS = true; + + +/** + * Return an int_vector<>, bitmap pair, each of length anno_graph.get_graph().max_index(). + * For an index i, the int_vector will contain a packed integer representing the + * number of labels in labels_in and labels_out which the k-mer of index i is + * annotated with. The least significant half of each integer represents the count + * from labels_in, while the most significant half represents the count from + * labels_out. + * The returned bitmap is a binarization of the int_vector + */ +std::pair, std::unique_ptr> +construct_diff_label_count_vector(const AnnotatedDBG &anno_graph, + const std::vector