diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d0bde6190d..7be87f2bc3 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,8 +7,6 @@ on: tags: - 'v*' pull_request: - branches: - - master env: REGISTRY: ghcr.io diff --git a/metagraph/CMakeLists.txt b/metagraph/CMakeLists.txt index 8114dc0052..3e7a526998 100644 --- a/metagraph/CMakeLists.txt +++ b/metagraph/CMakeLists.txt @@ -111,6 +111,7 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") add_compile_options( -Wno-exit-time-destructors -Wno-deprecated-declarations + -Wno-vla-extension ) if (NOT CMAKE_CXX_COMPILER_ID MATCHES "AppleClang") diff --git a/metagraph/integration_tests/base.py b/metagraph/integration_tests/base.py index 015b7dcf41..d77f8cbb59 100644 --- a/metagraph/integration_tests/base.py +++ b/metagraph/integration_tests/base.py @@ -6,7 +6,14 @@ script_path = os.path.dirname(os.path.realpath(__file__)) -METAGRAPH = f'{os.getcwd()}/metagraph' +METAGRAPH_EXE = f'{os.getcwd()}/metagraph' +DNA_MODE = os.readlink(METAGRAPH_EXE).endswith("_DNA") +PROTEIN_MODE = os.readlink(METAGRAPH_EXE).endswith("_Protein") +METAGRAPH = METAGRAPH_EXE + +def update_prefix(PREFIX): + global METAGRAPH + METAGRAPH = PREFIX + METAGRAPH_EXE TEST_DATA_DIR = os.path.join(script_path, '..', 'tests', 'data') @@ -37,10 +44,19 @@ def setUpClass(cls): def _get_stats(graph_path): stats_command = METAGRAPH + ' stats ' + graph_path + ' --mmap' res = subprocess.run(stats_command.split(), stdout=PIPE, stderr=PIPE) - assert(res.returncode == 0) + if res.returncode != 0: + raise AssertionError(f"Command '{stats_command}' failed with return code {res.returncode} and error: {res.stderr.decode()}") stats_command = METAGRAPH + ' stats ' + graph_path + MMAP_FLAG res = subprocess.run(stats_command.split(), stdout=PIPE, stderr=PIPE) - return res + parsed = dict() + parsed['returncode'] = res.returncode + res = res.stdout.decode().split('\n')[2:] + for line in res: + if ': ' in line: + x, y = map(str.strip, line.split(':', 1)) + assert(x not in parsed or parsed[x] == y) + parsed[x] = y + return parsed @staticmethod def _build_graph(input, output, k, repr, mode='basic', extra_params=''): diff --git a/metagraph/integration_tests/main.py b/metagraph/integration_tests/main.py index 640e7de6fa..624a3328ec 100644 --- a/metagraph/integration_tests/main.py +++ b/metagraph/integration_tests/main.py @@ -5,6 +5,7 @@ import sys import argparse from helpers import TimeLoggingTestResult +from base import update_prefix """Run all integration tests""" @@ -32,7 +33,11 @@ def create_test_suite(filter_pattern="*"): parser = argparse.ArgumentParser(description='Metagraph integration tests.') parser.add_argument('--test_filter', dest='filter', type=str, default="*", help='filter test cases (default: run all)') + parser.add_argument('--gdb', dest='use_gdb', action='store_true', + help='run metagraph with gdb') args = parser.parse_args() + if args.use_gdb: + update_prefix('gdb -ex run -ex bt -ex quit --args ') result = unittest.TextTestRunner( resultclass=TimeLoggingTestResult diff --git a/metagraph/integration_tests/test_align.py b/metagraph/integration_tests/test_align.py index 2e3bcb0d83..231a0512d5 100644 --- a/metagraph/integration_tests/test_align.py +++ b/metagraph/integration_tests/test_align.py @@ -6,14 +6,11 @@ import glob import os -from base import TestingBase, METAGRAPH, TEST_DATA_DIR, NUM_THREADS +from base import PROTEIN_MODE, DNA_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR, NUM_THREADS """Test graph construction and alignment""" -DNA_MODE = os.readlink(METAGRAPH).endswith("_DNA") -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") - graph_file_extension = {'succinct': '.dbg', 'bitmap': '.bitmapdbg', 'hash': '.orhashdbg', @@ -35,11 +32,10 @@ def test_simple_align_all_graphs(self, representation): k=11, repr=representation, extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16438', params_str[1]) - self.assertEqual('mode: basic', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16438', params['nodes (k)']) + self.assertEqual('basic', params['mode']) stats_command = '{exe} align --align-only-forwards -i {graph} --align-min-exact-match 0.0 {reads}'.format( exe=METAGRAPH, @@ -68,11 +64,10 @@ def test_simple_align_map_all_graphs(self, representation): k=11, repr=representation, extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16438', params_str[1]) - self.assertEqual('mode: basic', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16438', params['nodes (k)']) + self.assertEqual('basic', params['mode']) stats_command = '{exe} align -i {graph} --map --count-kmers {reads}'.format( exe=METAGRAPH, @@ -99,11 +94,10 @@ def test_simple_align_map_all_graphs_subk(self, representation): k=11, repr=representation, extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16438', params_str[1]) - self.assertEqual('mode: basic', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16438', params['nodes (k)']) + self.assertEqual('basic', params['mode']) stats_command = '{exe} align -i {graph} --map --count-kmers --align-length 10 {reads}'.format( exe=METAGRAPH, @@ -134,11 +128,10 @@ def test_simple_align_map_canonical_all_graphs(self, representation): k=11, repr=representation, mode='canonical', extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 32782', params_str[1]) - self.assertEqual('mode: canonical', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('32782', params['nodes (k)']) + self.assertEqual('canonical', params['mode']) stats_command = '{exe} align -i {graph} --map --count-kmers {reads}'.format( exe=METAGRAPH, @@ -165,11 +158,10 @@ def test_simple_align_json_all_graphs(self, representation): k=11, repr=representation, extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16438', params_str[1]) - self.assertEqual('mode: basic', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16438', params['nodes (k)']) + self.assertEqual('basic', params['mode']) stats_command = '{exe} align --align-only-forwards -i {graph} --align-min-exact-match 0.0 {reads}'.format( exe=METAGRAPH, @@ -189,11 +181,10 @@ def test_simple_align_fwd_rev_comp_all_graphs(self, representation): k=11, repr=representation, extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16438', params_str[1]) - self.assertEqual('mode: basic', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16438', params['nodes (k)']) + self.assertEqual('basic', params['mode']) stats_command = '{exe} align -i {graph} --align-min-exact-match 0.0 {reads}'.format( exe=METAGRAPH, @@ -222,11 +213,10 @@ def test_simple_align_canonical_all_graphs(self, representation): k=11, repr=representation, mode='canonical', extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 32782', params_str[1]) - self.assertEqual('mode: canonical', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('32782', params['nodes (k)']) + self.assertEqual('canonical', params['mode']) stats_command = '{exe} align -i {graph} --align-min-exact-match 0.0 {reads}'.format( exe=METAGRAPH, @@ -256,11 +246,10 @@ def test_simple_align_canonical_subk_succinct(self, representation): k=11, repr=representation, mode='canonical', extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 32782', params_str[1]) - self.assertEqual('mode: canonical', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('32782', params['nodes (k)']) + self.assertEqual('canonical', params['mode']) stats_command = '{exe} align -i {graph} --align-min-exact-match 0.0 --align-min-seed-length 10 {reads}'.format( exe=METAGRAPH, @@ -286,11 +275,10 @@ def test_simple_align_primary_all_graphs(self, representation): k=11, repr=representation, mode='primary', extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT.primary' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16391', params_str[1]) - self.assertEqual('mode: primary', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT.primary' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16391', params['nodes (k)']) + self.assertEqual('primary', params['mode']) stats_command = '{exe} align -i {graph} --align-min-exact-match 0.0 {reads}'.format( exe=METAGRAPH, @@ -320,11 +308,10 @@ def test_simple_align_primary_subk_succinct(self, representation): k=11, repr=representation, mode='primary', extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT.primary' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16391', params_str[1]) - self.assertEqual('mode: primary', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT.primary' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16391', params['nodes (k)']) + self.assertEqual('primary', params['mode']) stats_command = '{exe} align -i {graph} --align-min-exact-match 0.0 --align-min-seed-length 10 {reads}'.format( exe=METAGRAPH, @@ -347,13 +334,13 @@ def test_simple_align_fwd_rev_comp_json_all_graphs(self, representation): self._build_graph(input=TEST_DATA_DIR + '/genome.MT.fa', output=self.tempdir.name + '/genome.MT', - k=11, repr=representation) + k=11, repr=representation, + extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16461', params_str[1]) - self.assertEqual('mode: basic', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16438', params['nodes (k)']) + self.assertEqual('basic', params['mode']) stats_command = '{exe} align --json -i {graph} --align-min-exact-match 0.0 {reads}'.format( exe=METAGRAPH, @@ -373,13 +360,13 @@ def test_simple_align_edit_distance_all_graphs(self, representation): self._build_graph(input=TEST_DATA_DIR + '/genome.MT.fa', output=self.tempdir.name + '/genome.MT', - k=11, repr=representation) + k=11, repr=representation, + extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) - params_str = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', params_str[0]) - self.assertEqual('nodes (k): 16461', params_str[1]) - self.assertEqual('mode: basic', params_str[2]) + params = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + self.assertEqual('11', params['k']) + self.assertEqual('16438', params['nodes (k)']) + self.assertEqual('basic', params['mode']) stats_command = '{exe} align --json --align-edit-distance -i {graph} --align-min-exact-match 0.0 {reads}'.format( exe=METAGRAPH, diff --git a/metagraph/integration_tests/test_annotate.py b/metagraph/integration_tests/test_annotate.py index 38b5f90ba4..68a98cc567 100644 --- a/metagraph/integration_tests/test_annotate.py +++ b/metagraph/integration_tests/test_annotate.py @@ -6,13 +6,11 @@ import filecmp import glob import os -from base import TestingBase, METAGRAPH, TEST_DATA_DIR, NUM_THREADS +from base import PROTEIN_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR, NUM_THREADS """Test graph annotation""" -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") - graph_file_extension = {'succinct': '.dbg', 'bitmap': '.bitmapdbg', 'hash': '.orhashdbg', @@ -44,12 +42,11 @@ def test_simple_all_graphs(self, graph_repr): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 46960', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('20', stats_graph['k']) + self.assertEqual('46960', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) for anno_repr in ['row', 'column']: # build annotation @@ -63,13 +60,15 @@ def test_simple_all_graphs(self, graph_repr): self.assertEqual(res.returncode, 0) # check annotation - res = self._get_stats(f'-a {self.tempdir.name}/annotation{anno_file_extension[anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 100', out[0]) - self.assertEqual('objects: 46960', out[1]) - self.assertEqual('density: 0.0185072', out[2]) - self.assertEqual('representation: ' + anno_repr, out[3]) + stats_annotation = self._get_stats('-a ' + self.tempdir.name + '/annotation' + anno_file_extension[anno_repr]) + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual('100', stats_annotation['labels']) + self.assertEqual(stats_graph['max index (k)'], stats_annotation['objects']) + self.assertAlmostEqual( + 0.0185072 * (int(stats_graph['nodes (k)']) / int(stats_graph['max index (k)'])), + float(stats_annotation['density']), + places=6) + self.assertEqual(anno_repr, stats_annotation['representation']) # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand(['succinct', 'bitmap', 'hash']) # , 'hashstr']: @@ -88,12 +87,11 @@ def test_simple_all_graphs_canonical(self, graph_repr): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 91584', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('20', stats_graph['k']) + self.assertEqual('91584', stats_graph['nodes (k)']) + self.assertEqual('canonical', stats_graph['mode']) for anno_repr in ['row', 'column']: # build annotation @@ -106,13 +104,15 @@ def test_simple_all_graphs_canonical(self, graph_repr): self.assertEqual(res.returncode, 0) # check annotation - res = self._get_stats(f'-a {self.tempdir.name}/annotation{anno_file_extension[anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 100', out[0]) - self.assertEqual('objects: 91584', out[1]) - self.assertEqual('density: 0.00948888', out[2]) - self.assertEqual('representation: ' + anno_repr, out[3]) + stats_annotation = self._get_stats('-a ' + self.tempdir.name + '/annotation' + anno_file_extension[anno_repr]) + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual('100', stats_annotation['labels']) + self.assertEqual(stats_graph['max index (k)'], stats_annotation['objects']) + self.assertAlmostEqual( + 0.00948888 * (int(stats_graph['nodes (k)']) / int(stats_graph['max index (k)'])), + float(stats_annotation['density']), + places=6) + self.assertEqual(anno_repr, stats_annotation['representation']) @parameterized.expand(GRAPH_TYPES) def test_simple_all_graphs_from_kmc(self, graph_repr): @@ -128,12 +128,11 @@ def test_simple_all_graphs_from_kmc(self, graph_repr): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 469983', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('469983', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) for anno_repr in ['row', 'column']: # build annotation @@ -146,13 +145,15 @@ def test_simple_all_graphs_from_kmc(self, graph_repr): self.assertEqual(res.returncode, 0) # check annotation - res = self._get_stats(f'-a {self.tempdir.name}/annotation{anno_file_extension[anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 1', out[0]) - self.assertEqual('objects: 469983', out[1]) - self.assertEqual('density: 1', out[2]) - self.assertEqual('representation: ' + anno_repr, out[3]) + stats_annotation = self._get_stats('-a ' + self.tempdir.name + '/annotation' + anno_file_extension[anno_repr]) + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual('1', stats_annotation['labels']) + self.assertEqual(stats_graph['max index (k)'], stats_annotation['objects']) + self.assertAlmostEqual( + 1 * (int(stats_graph['nodes (k)']) / int(stats_graph['max index (k)'])), + float(stats_annotation['density']), + places=6) + self.assertEqual(anno_repr, stats_annotation['representation']) @parameterized.expand(GRAPH_TYPES) def test_simple_all_graphs_from_kmc_both(self, graph_repr): @@ -168,12 +169,11 @@ def test_simple_all_graphs_from_kmc_both(self, graph_repr): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('802920', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) for anno_repr in ['row', 'column']: # build annotation @@ -186,13 +186,15 @@ def test_simple_all_graphs_from_kmc_both(self, graph_repr): self.assertEqual(res.returncode, 0) # check annotation - res = self._get_stats(f'-a {self.tempdir.name}/annotation_single{anno_file_extension[anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 1', out[0]) - self.assertEqual('objects: 802920', out[1]) - self.assertEqual('density: 0.585342', out[2]) - self.assertEqual('representation: ' + anno_repr, out[3]) + stats_annotation = self._get_stats('-a ' + self.tempdir.name + '/annotation_single' + anno_file_extension[anno_repr]) + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual('1', stats_annotation['labels']) + self.assertEqual(stats_graph['max index (k)'], stats_annotation['objects']) + self.assertAlmostEqual( + 0.585342 * (int(stats_graph['nodes (k)']) / int(stats_graph['max index (k)'])), + float(stats_annotation['density']), + places=6) + self.assertEqual(anno_repr, stats_annotation['representation']) # both strands annotate_command = f'{METAGRAPH} annotate --anno-label LabelName -p {NUM_THREADS} \ @@ -204,13 +206,15 @@ def test_simple_all_graphs_from_kmc_both(self, graph_repr): self.assertEqual(res.returncode, 0) # check annotation - res = self._get_stats(f'-a {self.tempdir.name}/annotation_both{anno_file_extension[anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 1', out[0]) - self.assertEqual('objects: 802920', out[1]) - self.assertEqual('density: 1', out[2]) - self.assertEqual('representation: ' + anno_repr, out[3]) + stats_annotation = self._get_stats('-a ' + self.tempdir.name + '/annotation_both' + anno_file_extension[anno_repr]) + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual('1', stats_annotation['labels']) + self.assertEqual(stats_graph['max index (k)'], stats_annotation['objects']) + self.assertAlmostEqual( + 1 * (int(stats_graph['nodes (k)']) / int(stats_graph['max index (k)'])), + float(stats_annotation['density']), + places=6) + self.assertEqual(anno_repr, stats_annotation['representation']) # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand(['succinct', 'bitmap', 'hash']) # , 'hashstr']: @@ -228,12 +232,11 @@ def test_simple_all_graphs_from_kmc_both_canonical(self, graph_repr): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('802920', stats_graph['nodes (k)']) + self.assertEqual('canonical', stats_graph['mode']) for anno_repr in ['row', 'column']: # build annotation @@ -246,13 +249,15 @@ def test_simple_all_graphs_from_kmc_both_canonical(self, graph_repr): self.assertEqual(res.returncode, 0) # check annotation - res = self._get_stats(f'-a {self.tempdir.name}/annotation_single{anno_file_extension[anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 1', out[0]) - self.assertEqual('objects: 802920', out[1]) - self.assertEqual('density: 0.5', out[2]) - self.assertEqual('representation: ' + anno_repr, out[3]) + stats_annotation = self._get_stats('-a ' + self.tempdir.name + '/annotation_single' + anno_file_extension[anno_repr]) + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual('1', stats_annotation['labels']) + self.assertEqual(stats_graph['max index (k)'], stats_annotation['objects']) + self.assertAlmostEqual( + 0.5 * (int(stats_graph['nodes (k)']) / int(stats_graph['max index (k)'])), + float(stats_annotation['density']), + places=6) + self.assertEqual(anno_repr, stats_annotation['representation']) # both strands annotate_command = f'{METAGRAPH} annotate --anno-label LabelName -p {NUM_THREADS} \ @@ -264,13 +269,15 @@ def test_simple_all_graphs_from_kmc_both_canonical(self, graph_repr): self.assertEqual(res.returncode, 0) # check annotation - res = self._get_stats(f'-a {self.tempdir.name}/annotation_both{anno_file_extension[anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 1', out[0]) - self.assertEqual('objects: 802920', out[1]) - self.assertEqual('density: 0.5', out[2]) - self.assertEqual('representation: ' + anno_repr, out[3]) + stats_annotation = self._get_stats('-a ' + self.tempdir.name + '/annotation_both' + anno_file_extension[anno_repr]) + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual('1', stats_annotation['labels']) + self.assertEqual(stats_graph['max index (k)'], stats_annotation['objects']) + self.assertAlmostEqual( + 0.5 * (int(stats_graph['nodes (k)']) / int(stats_graph['max index (k)'])), + float(stats_annotation['density']), + places=6) + self.assertEqual(anno_repr, stats_annotation['representation']) def test_annotate_with_disk_swap(self): graph_repr = 'succinct' @@ -288,12 +295,11 @@ def test_annotate_with_disk_swap(self): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 46960', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[graph_repr]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('20', stats_graph['k']) + self.assertEqual('46960', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) # build annotation annotate_command = f'{METAGRAPH} annotate --anno-header \ @@ -306,13 +312,15 @@ def test_annotate_with_disk_swap(self): self.assertEqual(res.returncode, 0) # check annotation - res = self._get_stats(f'-a {self.tempdir.name}/annotation{anno_file_extension[anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 100', out[0]) - self.assertEqual('objects: 46960', out[1]) - self.assertEqual('density: 0.0185072', out[2]) - self.assertEqual('representation: ' + anno_repr, out[3]) + stats_annotation = self._get_stats('-a ' + f'{self.tempdir.name}/annotation{anno_file_extension[anno_repr]}') + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual('100', stats_annotation['labels']) + self.assertEqual(stats_graph['max index (k)'], stats_annotation['objects']) + self.assertAlmostEqual( + 0.0185072 * (int(stats_graph['nodes (k)']) / int(stats_graph['max index (k)'])), + float(stats_annotation['density']), + places=6) + self.assertEqual(anno_repr, stats_annotation['representation']) @parameterized.expand(GRAPH_TYPES) def test_annotate_coordinates(self, graph_repr): diff --git a/metagraph/integration_tests/test_api.py b/metagraph/integration_tests/test_api.py index 8842f91c86..3565960298 100644 --- a/metagraph/integration_tests/test_api.py +++ b/metagraph/integration_tests/test_api.py @@ -13,10 +13,7 @@ from concurrent.futures import Future from parameterized import parameterized, parameterized_class -from base import TestingBase, METAGRAPH, TEST_DATA_DIR - -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") - +from base import PROTEIN_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR class TestAPIBase(TestingBase): @classmethod diff --git a/metagraph/integration_tests/test_assemble.py b/metagraph/integration_tests/test_assemble.py index 6f3846572b..05f084a36f 100644 --- a/metagraph/integration_tests/test_assemble.py +++ b/metagraph/integration_tests/test_assemble.py @@ -6,14 +6,12 @@ import gzip import itertools from helpers import get_test_class_name -from base import TestingBase, graph_file_extension, METAGRAPH, TEST_DATA_DIR, NUM_THREADS +from base import PROTEIN_MODE, TestingBase, graph_file_extension, METAGRAPH, TEST_DATA_DIR, NUM_THREADS from test_query import anno_file_extension, GRAPH_TYPES, ANNO_TYPES, product """Test graph assemble""" -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") - gfa_tests = { 'compacted': { 'fasta_path': TEST_DATA_DIR + '/transcripts_100.fa', diff --git a/metagraph/integration_tests/test_build.py b/metagraph/integration_tests/test_build.py index 7ded5fb30e..deeeb42da9 100644 --- a/metagraph/integration_tests/test_build.py +++ b/metagraph/integration_tests/test_build.py @@ -5,14 +5,11 @@ from tempfile import TemporaryDirectory import glob import os -from base import TestingBase, METAGRAPH, TEST_DATA_DIR +from base import PROTEIN_MODE, DNA_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR """Test graph construction""" -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") -DNA_MODE = os.readlink(METAGRAPH).endswith("_DNA") - graph_file_extension = {'succinct': '.dbg', 'bitmap': '.bitmapdbg', 'hash': '.orhashdbg', @@ -50,12 +47,11 @@ def test_simple_all_graphs(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 591997', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('20', stats_graph['k']) + self.assertEqual('591997', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) @parameterized.expand(succinct_states) def test_build_succinct_inplace(self, state): @@ -67,13 +63,12 @@ def test_build_succinct_inplace(self, state): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension['succinct']) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 597931', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('state: ' + state, out[8]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension['succinct']) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('20', stats_graph['k']) + self.assertEqual('597931', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) + self.assertEqual(state, stats_graph['state']) @parameterized.expand(['succinct']) def test_simple_bloom_graph(self, build): @@ -90,12 +85,11 @@ def test_simple_bloom_graph(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 591997', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('20', stats_graph['k']) + self.assertEqual('591997', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) convert_command = '{exe} transform -o {outfile} --initialize-bloom {bloom_param} {input}'.format( exe=METAGRAPH, @@ -136,12 +130,11 @@ def test_simple_all_graphs_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 1159851', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('20', stats_graph['k']) + self.assertEqual('1159851', stats_graph['nodes (k)']) + self.assertEqual('canonical', stats_graph['mode']) @parameterized.expand(BUILDS) def test_build_tiny_k(self, build): @@ -157,12 +150,11 @@ def test_build_tiny_k(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 2', out[0]) - self.assertEqual('nodes (k): 16', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('2', stats_graph['k']) + self.assertEqual('16', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand([repr for repr in BUILDS if repr != 'hashstr']) @@ -180,12 +172,11 @@ def test_build_tiny_k_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 2', out[0]) - self.assertEqual('nodes (k): 16', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('2', stats_graph['k']) + self.assertEqual('16', stats_graph['nodes (k)']) + self.assertEqual('canonical', stats_graph['mode']) @parameterized.expand(BUILDS) def test_build_tiny_k_parallel(self, build): @@ -199,12 +190,11 @@ def test_build_tiny_k_parallel(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 2', out[0]) - self.assertEqual('nodes (k): 16', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('2', stats_graph['k']) + self.assertEqual('16', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand([repr for repr in BUILDS if repr != 'hashstr']) @@ -221,12 +211,11 @@ def test_build_tiny_k_parallel_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 2', out[0]) - self.assertEqual('nodes (k): 16', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('2', stats_graph['k']) + self.assertEqual('16', stats_graph['nodes (k)']) + self.assertEqual('canonical', stats_graph['mode']) @parameterized.expand(BUILDS) def test_build_from_kmc(self, build): @@ -243,12 +232,11 @@ def test_build_from_kmc(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 469983', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('469983', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) @parameterized.expand(BUILDS) def test_build_from_kmc_both(self, build): @@ -265,12 +253,11 @@ def test_build_from_kmc_both(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('802920', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) @parameterized.expand([repr for repr in BUILDS if repr != 'hashstr']) @unittest.skipIf(PROTEIN_MODE, "No canonical mode for Protein alphabets") @@ -289,12 +276,11 @@ def test_build_from_kmc_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('802920', stats_graph['nodes (k)']) + self.assertEqual('canonical', stats_graph['mode']) @parameterized.expand([repr for repr in BUILDS if repr != 'hashstr']) @unittest.skipIf(PROTEIN_MODE, "No canonical mode for Protein alphabets") @@ -313,12 +299,11 @@ def test_build_from_kmc_both_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('802920', stats_graph['nodes (k)']) + self.assertEqual('canonical', stats_graph['mode']) @parameterized.expand(['succinct', 'succinct_disk']) @unittest.skipUnless(DNA_MODE, "Need to adapt suffixes for other alphabets") @@ -352,13 +337,12 @@ def test_build_chunks_from_kmc(self, build): self.assertEqual(res.returncode, 0) # Check graph - res = self._get_stats(self.tempdir.name + '/graph_from_chunks' - + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 469983', out[1]) - self.assertEqual('mode: basic', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph_from_chunks' + + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('469983', stats_graph['nodes (k)']) + self.assertEqual('basic', stats_graph['mode']) @parameterized.expand(['succinct', 'succinct_disk']) @unittest.skipUnless(DNA_MODE, "Need to adapt suffixes for other alphabets") @@ -392,13 +376,12 @@ def test_build_chunks_from_kmc_canonical(self, build): self.assertEqual(res.returncode, 0) # Check graph - res = self._get_stats(self.tempdir.name + '/graph_from_chunks' - + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats_graph = self._get_stats(self.tempdir.name + '/graph_from_chunks' + + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual('11', stats_graph['k']) + self.assertEqual('802920', stats_graph['nodes (k)']) + self.assertEqual('canonical', stats_graph['mode']) if __name__ == '__main__': diff --git a/metagraph/integration_tests/test_build_weighted.py b/metagraph/integration_tests/test_build_weighted.py index 6c176cffe7..57123d390a 100644 --- a/metagraph/integration_tests/test_build_weighted.py +++ b/metagraph/integration_tests/test_build_weighted.py @@ -7,13 +7,11 @@ import glob import os import gzip -from base import TestingBase, METAGRAPH, TEST_DATA_DIR +from base import PROTEIN_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR """Test graph construction""" -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") - graph_file_extension = {'succinct': '.dbg', 'bitmap': '.bitmapdbg', 'hash': '.orhashdbg', @@ -50,14 +48,13 @@ def test_simple_all_graphs(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 591997', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 591997', out[3]) - self.assertEqual('avg weight: 2.48587', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '20') + self.assertEqual(stats_graph['nodes (k)'], '591997') + self.assertEqual(stats_graph['mode'], 'basic') + self.assertEqual(stats_graph['nnz weights'], '591997') + self.assertEqual(stats_graph['avg weight'], '2.48587') @parameterized.expand([repr for repr in BUILDS if not (repr == 'bitmap' and PROTEIN_MODE)]) def test_simple_all_graphs_contigs(self, build): @@ -88,14 +85,13 @@ def test_simple_all_graphs_contigs(self, build): res = subprocess.run([command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 591997', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 591997', out[3]) - self.assertEqual('avg weight: 2.48587', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '20') + self.assertEqual(stats_graph['nodes (k)'], '591997') + self.assertEqual(stats_graph['mode'], 'basic') + self.assertEqual(stats_graph['nnz weights'], '591997') + self.assertEqual(stats_graph['avg weight'], '2.48587') # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand([repr for repr in BUILDS if repr != 'hashstr']) @@ -115,14 +111,13 @@ def test_simple_all_graphs_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 1159851', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 1159851', out[3]) - self.assertEqual('avg weight: 2.53761', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '20') + self.assertEqual(stats_graph['nodes (k)'], '1159851') + self.assertEqual(stats_graph['mode'], 'canonical') + self.assertEqual(stats_graph['nnz weights'], '1159851') + self.assertEqual(stats_graph['avg weight'], '2.53761') @parameterized.expand(BUILDS) def test_build_tiny_k(self, build): @@ -138,14 +133,13 @@ def test_build_tiny_k(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 2', out[0]) - self.assertEqual('nodes (k): 16', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 16', out[3]) - self.assertEqual('avg weight: 255', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '2') + self.assertEqual(stats_graph['nodes (k)'], '16') + self.assertEqual(stats_graph['mode'], 'basic') + self.assertEqual(stats_graph['nnz weights'], '16') + self.assertEqual(stats_graph['avg weight'], '255') # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand([repr for repr in BUILDS if repr != 'hashstr']) @@ -164,14 +158,13 @@ def test_build_tiny_k_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 2', out[0]) - self.assertEqual('nodes (k): 16', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 16', out[3]) - self.assertEqual('avg weight: 255', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '2') + self.assertEqual(stats_graph['nodes (k)'], '16') + self.assertEqual(stats_graph['mode'], 'canonical') + self.assertEqual(stats_graph['nnz weights'], '16') + self.assertEqual(stats_graph['avg weight'], '255') @parameterized.expand(BUILDS) def test_build_from_kmc(self, build): @@ -189,14 +182,13 @@ def test_build_from_kmc(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 469983', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 469983', out[3]) - self.assertEqual('avg weight: 3.15029', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '11') + self.assertEqual(stats_graph['nodes (k)'], '469983') + self.assertEqual(stats_graph['mode'], 'basic') + self.assertEqual(stats_graph['nnz weights'], '469983') + self.assertEqual(stats_graph['avg weight'], '3.15029') @parameterized.expand(BUILDS) def test_build_from_kmc_both(self, build): @@ -214,14 +206,13 @@ def test_build_from_kmc_both(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 802920', out[3]) - self.assertEqual('avg weight: 3.68754', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '11') + self.assertEqual(stats_graph['nodes (k)'], '802920') + self.assertEqual(stats_graph['mode'], 'basic') + self.assertEqual(stats_graph['nnz weights'], '802920') + self.assertEqual(stats_graph['avg weight'], '3.68754') # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand([repr for repr in BUILDS if repr != 'hashstr']) @@ -241,14 +232,13 @@ def test_build_from_kmc_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 802920', out[3]) - self.assertEqual('avg weight: 3.68754', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '11') + self.assertEqual(stats_graph['nodes (k)'], '802920') + self.assertEqual(stats_graph['mode'], 'canonical') + self.assertEqual(stats_graph['nnz weights'], '802920') + self.assertEqual(stats_graph['avg weight'], '3.68754') # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand([repr for repr in BUILDS if repr != 'hashstr']) @@ -268,14 +258,13 @@ def test_build_from_kmc_both_canonical(self, build): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 11', out[0]) - self.assertEqual('nodes (k): 802920', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 802920', out[3]) - self.assertEqual('avg weight: 3.68754', out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '11') + self.assertEqual(stats_graph['nodes (k)'], '802920') + self.assertEqual(stats_graph['mode'], 'canonical') + self.assertEqual(stats_graph['nnz weights'], '802920') + self.assertEqual(stats_graph['avg weight'], '3.68754') @parameterized.expand( itertools.product(BUILDS, @@ -306,14 +295,13 @@ def test_kmer_count_width(self, build, width_result): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 4', out[0]) - self.assertEqual('nodes (k): 256', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 256', out[3]) - self.assertEqual('avg weight: {}'.format(avg_count_expected), out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], '4') + self.assertEqual(stats_graph['nodes (k)'], '256') + self.assertEqual(stats_graph['mode'], 'basic') + self.assertEqual(stats_graph['nnz weights'], '256') + self.assertEqual(stats_graph['avg weight'], str(avg_count_expected)) @parameterized.expand(itertools.chain( itertools.product(BUILDS, @@ -366,14 +354,13 @@ def test_kmer_count_width_large(self, build, k_width_result): res = subprocess.run([construct_command], shell=True) self.assertEqual(res.returncode, 0) - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: {}'.format(k), out[0]) - self.assertEqual('nodes (k): 2', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 2', out[3]) - self.assertEqual('avg weight: {}'.format(avg_count_expected), out[4]) + stats_graph = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual(stats_graph['returncode'], 0) + self.assertEqual(stats_graph['k'], str(k)) + self.assertEqual(stats_graph['nodes (k)'], '2') + self.assertEqual(stats_graph['mode'], 'basic') + self.assertEqual(stats_graph['nnz weights'], '2') + self.assertEqual(stats_graph['avg weight'], str(avg_count_expected)) if __name__ == '__main__': diff --git a/metagraph/integration_tests/test_clean.py b/metagraph/integration_tests/test_clean.py index 070b396a17..b5013723ee 100644 --- a/metagraph/integration_tests/test_clean.py +++ b/metagraph/integration_tests/test_clean.py @@ -7,13 +7,11 @@ import glob import os import gzip -from base import TestingBase, METAGRAPH, TEST_DATA_DIR, NUM_THREADS +from base import PROTEIN_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR, NUM_THREADS """Test graph construction""" -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") - graph_file_extension = {'succinct': '.dbg', 'bitmap': '.bitmapdbg', 'hash': '.orhashdbg', @@ -35,13 +33,12 @@ def test_no_cleaning_contigs(self, representation): k=20, repr=representation, extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 591997', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 591997', out[3]) - self.assertEqual('avg weight: 2.48587', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual('20', stats['k']) + self.assertEqual('591997', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) + self.assertEqual('591997', stats['nnz weights']) + self.assertEqual('2.48587', stats['avg weight']) clean_fasta = self.tempdir.name + '/contigs.fasta.gz' self._clean(self.tempdir.name + '/graph' + graph_file_extension[representation], @@ -53,13 +50,12 @@ def test_no_cleaning_contigs(self, representation): k=20, repr=representation, extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 591997', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 591997', out[3]) - self.assertEqual('avg weight: 2.48587', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('20', stats['k']) + self.assertEqual('591997', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) + self.assertEqual('591997', stats['nnz weights']) + self.assertEqual('2.48587', stats['avg weight']) @parameterized.expand([repr for repr in GRAPH_TYPES if not (repr == 'bitmap' and PROTEIN_MODE)]) def test_no_cleaning_contigs_2bit_counts(self, representation): @@ -69,13 +65,12 @@ def test_no_cleaning_contigs_2bit_counts(self, representation): k=20, repr=representation, extra_params="--mask-dummy --count-kmers --count-width 2") - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 591997', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 591997', out[3]) - self.assertEqual('avg weight: 1.73589', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual('20', stats['k']) + self.assertEqual('591997', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) + self.assertEqual('591997', stats['nnz weights']) + self.assertEqual('1.73589', stats['avg weight']) clean_fasta = self.tempdir.name + '/contigs.fasta.gz' self._clean(self.tempdir.name + '/graph' + graph_file_extension[representation], @@ -87,13 +82,12 @@ def test_no_cleaning_contigs_2bit_counts(self, representation): k=20, repr=representation, extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 591997', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 591997', out[3]) - self.assertEqual('avg weight: 1.73589', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('20', stats['k']) + self.assertEqual('591997', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) + self.assertEqual('591997', stats['nnz weights']) + self.assertEqual('1.73589', stats['avg weight']) @parameterized.expand([repr for repr in GRAPH_TYPES if not (repr == 'bitmap' and PROTEIN_MODE)]) def test_clean_prune_tips_no_counts(self, representation): @@ -113,11 +107,10 @@ def test_clean_prune_tips_no_counts(self, representation): k=20, repr=representation, extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 589774', out[1]) - self.assertEqual('mode: basic', out[2]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('20', stats['k']) + self.assertEqual('589774', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) @parameterized.expand([repr for repr in GRAPH_TYPES if not (repr == 'bitmap' and PROTEIN_MODE)]) def test_clean_prune_tips(self, representation): @@ -137,13 +130,12 @@ def test_clean_prune_tips(self, representation): k=20, repr=representation, extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 589774', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 589774', out[3]) - self.assertEqual('avg weight: 2.49001', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('20', stats['k']) + self.assertEqual('589774', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) + self.assertEqual('589774', stats['nnz weights']) + self.assertEqual('2.49001', stats['avg weight']) @parameterized.expand([repr for repr in GRAPH_TYPES if not (repr == 'bitmap' and PROTEIN_MODE)]) def test_cleaning_threshold_fixed(self, representation): @@ -163,14 +155,12 @@ def test_cleaning_threshold_fixed(self, representation): k=20, repr=representation, extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 167395', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 167395', out[3]) - self.assertEqual('avg weight: 5.52732', out[4]) - + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('20', stats['k']) + self.assertEqual('167395', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) + self.assertEqual('167395', stats['nnz weights']) + self.assertEqual('5.52732', stats['avg weight']) @parameterized.expand([repr for repr in GRAPH_TYPES if not (repr == 'bitmap' and PROTEIN_MODE)]) def test_cleaning_prune_tips_threshold_fixed(self, representation): @@ -189,13 +179,12 @@ def test_cleaning_prune_tips_threshold_fixed(self, representation): k=20, repr=representation, extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 20', out[0]) - self.assertEqual('nodes (k): 167224', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 167224', out[3]) - self.assertEqual('avg weight: 5.52757', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('20', stats['k']) + self.assertEqual('167224', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) + self.assertEqual('167224', stats['nnz weights']) + self.assertEqual('5.52757', stats['avg weight']) @unittest.skipIf(PROTEIN_MODE, "No canonical mode for Protein alphabets") @@ -212,13 +201,12 @@ def test_no_cleaning_contigs(self, representation): k=31, repr=representation, mode='canonical', extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 1185814', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 1185814', out[3]) - self.assertEqual('avg weight: 2.4635', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('1185814', stats['nodes (k)']) + self.assertEqual('canonical', stats['mode']) + self.assertEqual('1185814', stats['nnz weights']) + self.assertEqual('2.4635', stats['avg weight']) clean_fasta = self.tempdir.name + '/contigs.fasta.gz' self._clean(self.tempdir.name + '/graph' + graph_file_extension[representation], @@ -230,13 +218,12 @@ def test_no_cleaning_contigs(self, representation): k=31, repr=representation, mode='canonical', extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 1185814', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 1185814', out[3]) - self.assertEqual('avg weight: 2.4635', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('1185814', stats['nodes (k)']) + self.assertEqual('canonical', stats['mode']) + self.assertEqual('1185814', stats['nnz weights']) + self.assertEqual('2.4635', stats['avg weight']) # TODO: add 'hashstr' once the canonical mode is implemented for it @parameterized.expand(['succinct', 'bitmap', 'hash']) # , 'hashstr']: @@ -247,13 +234,12 @@ def test_no_cleaning_contigs_2bit_counts(self, representation): k=31, repr=representation, mode='canonical', extra_params="--mask-dummy --count-kmers --count-width 2") - res = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 1185814', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 1185814', out[3]) - self.assertEqual('avg weight: 1.72792', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('1185814', stats['nodes (k)']) + self.assertEqual('canonical', stats['mode']) + self.assertEqual('1185814', stats['nnz weights']) + self.assertEqual('1.72792', stats['avg weight']) clean_fasta = self.tempdir.name + '/contigs.fasta.gz' self._clean(self.tempdir.name + '/graph' + graph_file_extension[representation], @@ -265,13 +251,12 @@ def test_no_cleaning_contigs_2bit_counts(self, representation): k=31, repr=representation, mode='canonical', extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 1185814', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 1185814', out[3]) - self.assertEqual('avg weight: 1.72792', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('1185814', stats['nodes (k)']) + self.assertEqual('canonical', stats['mode']) + self.assertEqual('1185814', stats['nnz weights']) + self.assertEqual('1.72792', stats['avg weight']) @parameterized.expand(['succinct', 'bitmap', 'hash']) # , 'hashstr']: def test_clean_prune_tips_no_counts(self, representation): @@ -291,11 +276,10 @@ def test_clean_prune_tips_no_counts(self, representation): k=31, repr=representation, mode='canonical', extra_params="--mask-dummy") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 1180802', out[1]) - self.assertEqual('mode: canonical', out[2]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('1180802', stats['nodes (k)']) + self.assertEqual('canonical', stats['mode']) @parameterized.expand(['succinct', 'bitmap', 'hash']) # , 'hashstr']: def test_clean_prune_tips(self, representation): @@ -315,13 +299,12 @@ def test_clean_prune_tips(self, representation): k=31, repr=representation, mode='canonical', extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 1180802', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 1180802', out[3]) - self.assertEqual('avg weight: 2.46882', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('1180802', stats['nodes (k)']) + self.assertEqual('canonical', stats['mode']) + self.assertEqual('1180802', stats['nnz weights']) + self.assertEqual('2.46882', stats['avg weight']) @parameterized.expand(GRAPH_TYPES) def test_cleaning_threshold_fixed_both_strands(self, representation): @@ -342,13 +325,12 @@ def test_cleaning_threshold_fixed_both_strands(self, representation): k=31, repr=representation, extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 331452', out[1]) - self.assertEqual('mode: basic', out[2]) - self.assertEqual('nnz weights: 331452', out[3]) - self.assertEqual('avg weight: 5.52692', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('331452', stats['nodes (k)']) + self.assertEqual('basic', stats['mode']) + self.assertEqual('331452', stats['nnz weights']) + self.assertEqual('5.52692', stats['avg weight']) @parameterized.expand(['succinct', 'bitmap', 'hash']) # , 'hashstr']: def test_cleaning_threshold_fixed(self, representation): @@ -368,13 +350,12 @@ def test_cleaning_threshold_fixed(self, representation): k=31, repr=representation, mode='canonical', extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 331452', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 331452', out[3]) - self.assertEqual('avg weight: 5.52692', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('331452', stats['nodes (k)']) + self.assertEqual('canonical', stats['mode']) + self.assertEqual('331452', stats['nnz weights']) + self.assertEqual('5.52692', stats['avg weight']) @parameterized.expand(['succinct', 'bitmap', 'hash']) # , 'hashstr']: def test_cleaning_prune_tips_threshold_fixed(self, representation): @@ -394,13 +375,12 @@ def test_cleaning_prune_tips_threshold_fixed(self, representation): k=31, repr=representation, mode='canonical', extra_params="--mask-dummy --count-kmers") - res = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('k: 31', out[0]) - self.assertEqual('nodes (k): 331266', out[1]) - self.assertEqual('mode: canonical', out[2]) - self.assertEqual('nnz weights: 331266', out[3]) - self.assertEqual('avg weight: 5.52728', out[4]) + stats = self._get_stats(self.tempdir.name + '/graph_clean' + graph_file_extension[representation]) + self.assertEqual('31', stats['k']) + self.assertEqual('331266', stats['nodes (k)']) + self.assertEqual('canonical', stats['mode']) + self.assertEqual('331266', stats['nnz weights']) + self.assertEqual('5.52728', stats['avg weight']) if __name__ == '__main__': diff --git a/metagraph/integration_tests/test_query.py b/metagraph/integration_tests/test_query.py index 25434a7b22..df3d1e0918 100644 --- a/metagraph/integration_tests/test_query.py +++ b/metagraph/integration_tests/test_query.py @@ -8,15 +8,12 @@ import os import numpy as np from helpers import get_test_class_name -from base import TestingBase, METAGRAPH, TEST_DATA_DIR, graph_file_extension +from base import PROTEIN_MODE, DNA_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR, graph_file_extension import hashlib """Test graph construction""" -DNA_MODE = os.readlink(METAGRAPH).endswith("_DNA") -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") - anno_file_extension = {'column': '.column.annodbg', 'column_coord': '.column_coord.annodbg', 'brwt_coord': '.brwt_coord.annodbg', @@ -86,13 +83,12 @@ def setUpClass(cls): 20, cls.graph_repr, 'basic', '--mask-dummy' if cls.mask_dummy else '') - res = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('k: 20' == out[0]) + stats_graph = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') + assert(stats_graph['returncode'] == 0) + assert('20' == stats_graph['k']) if cls.graph_repr != 'succinct' or cls.mask_dummy: - assert('nodes (k): 46960' == out[1]) - assert('mode: basic' == out[2]) + assert('46960' == stats_graph['nodes (k)']) + assert('basic' == stats_graph['mode']) if cls.with_bloom: convert_command = f'{METAGRAPH} transform -o {cls.tempdir.name}/graph \ @@ -122,17 +118,16 @@ def check_suffix(anno_repr, suffix): ) # check annotation - res = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('labels: 100' == out[0]) + stats_annotation = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') + assert(stats_annotation['returncode'] == 0) + assert('100' == stats_annotation['labels']) if cls.graph_repr != 'hashfast' and (cls.graph_repr != 'succinct' or cls.mask_dummy): - assert('objects: 46960' == out[1]) + assert(stats_graph['max index (k)'] == stats_annotation['objects']) if cls.anno_repr.endswith('_noswap'): cls.anno_repr = cls.anno_repr[:-len('_noswap')] - assert(f'representation: {cls.anno_repr}' == out[3]) + assert(cls.anno_repr == stats_annotation['representation']) def test_query(self): query_command = '{exe} query --batch-size 0 -i {graph} -a {annotation} --min-kmers-fraction-label 1.0 {input}'.format( @@ -574,12 +569,11 @@ def setUpClass(cls): cls._build_graph(cls.fasta_graph, cls.tempdir.name + '/graph', 5, cls.graph_repr, 'basic', '--mask-dummy') - res = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('k: 5' == out[0]) - assert('nodes (k): 12' == out[1]) - assert('mode: basic' == out[2]) + stats_graph = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') + assert(stats_graph['returncode'] == 0) + assert(stats_graph['k'] == '5') + assert(stats_graph['nodes (k)'] == '12') + assert(stats_graph['mode'] == 'basic') def check_suffix(anno_repr, suffix): match = anno_repr.endswith(suffix) @@ -597,16 +591,15 @@ def check_suffix(anno_repr, suffix): separate, no_fork_opt, no_anchor_opt) # check annotation - res = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('labels: 3' == out[0]) - assert('objects: 12' == out[1]) + stats_annotation = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') + assert(stats_annotation['returncode'] == 0) + assert(stats_annotation['labels'] == '3') + assert(stats_annotation['objects'] == stats_graph['max index (k)']) if cls.anno_repr.endswith('_noswap'): cls.anno_repr = cls.anno_repr[:-len('_noswap')] - assert(f'representation: {cls.anno_repr}' == out[3]) + assert(cls.anno_repr == stats_annotation['representation']) def test_query_coordinates(self): if not self.anno_repr.endswith('_coord'): @@ -655,13 +648,12 @@ def setUpClass(cls): 20, cls.graph_repr, 'basic', '--mask-dummy' if cls.mask_dummy else '') - res = cls._get_stats(cls.tempdir.name + '/graph' + graph_file_extension[cls.graph_repr]) - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('k: 20' == out[0]) + stats_graph = cls._get_stats(cls.tempdir.name + '/graph' + graph_file_extension[cls.graph_repr]) + assert(stats_graph['returncode'] == 0) + assert(stats_graph['k'] == '20') if cls.graph_repr != 'succinct' or cls.mask_dummy: - assert('nodes (k): 46960' == out[1]) - assert('mode: basic' == out[2]) + assert(stats_graph['nodes (k)'] == '46960') + assert(stats_graph['mode'] == 'basic') if cls.with_bloom: convert_command = f'{METAGRAPH} transform -o {cls.tempdir.name}/graph \ @@ -692,17 +684,16 @@ def check_suffix(anno_repr, suffix): ) # check annotation - res = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('labels: 1' == out[0]) + stats_annotation = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') + assert(stats_annotation['returncode'] == 0) + assert(stats_annotation['labels'] == '1') if cls.graph_repr != 'hashfast' and (cls.graph_repr != 'succinct' or cls.mask_dummy): - assert('objects: 46960' == out[1]) + assert(stats_annotation['objects'] == stats_graph['max index (k)']) if cls.anno_repr.endswith('_noswap'): cls.anno_repr = cls.anno_repr[:-len('_noswap')] - assert('representation: ' + cls.anno_repr == out[3]) + assert(cls.anno_repr == stats_annotation['representation']) def test_query(self): query_command = f'{METAGRAPH} query --batch-size 0 \ @@ -788,13 +779,12 @@ def setUpClass(cls): cls._build_graph((cls.fasta_file_1, cls.fasta_file_2), cls.tempdir.name + '/graph', cls.k, cls.graph_repr, 'basic', '--mask-dummy' if cls.mask_dummy else '') - res = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('k: 3' == out[0]) + stats_graph = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') + assert(stats_graph['returncode'] == 0) + assert(stats_graph['k'] == '3') if cls.graph_repr != 'succinct' or cls.mask_dummy: - assert('nodes (k): 12' == out[1]) - assert('mode: basic' == out[2]) + assert(stats_graph['nodes (k)'] == '12') + assert(stats_graph['mode'] == 'basic') if cls.with_bloom: convert_command = f'{METAGRAPH} transform -o {cls.tempdir.name}/graph \ @@ -812,13 +802,12 @@ def setUpClass(cls): ) # check annotation - res = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('labels: 2' == out[0]) + stats_annotation = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') + assert(stats_annotation['returncode'] == 0) + assert(stats_annotation['labels'] == '2') if cls.graph_repr != 'hashfast' and (cls.graph_repr != 'succinct' or cls.mask_dummy): - assert('objects: 12' == out[1]) - assert('representation: ' + cls.anno_repr == out[3]) + assert(stats_annotation['objects'] == stats_graph['max index (k)']) + assert(stats_annotation['representation'] == cls.anno_repr) cls.queries = [ 'AAA', @@ -968,13 +957,12 @@ def setUpClass(cls): 20, cls.graph_repr, 'canonical', '--mask-dummy' if cls.mask_dummy else '') - res = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('k: 20' == out[0]) + stats_graph = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') + assert(stats_graph['returncode'] == 0) + assert(stats_graph['k'] == '20') if cls.graph_repr != 'succinct' or cls.mask_dummy: - assert('nodes (k): 91584' == out[1]) - assert('mode: canonical' == out[2]) + assert(stats_graph['nodes (k)'] == '91584') + assert(stats_graph['mode'] == 'canonical') if cls.with_bloom: convert_command = f'{METAGRAPH} transform -o {cls.tempdir.name}/graph \ @@ -991,17 +979,16 @@ def setUpClass(cls): ) # check annotation - res = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('labels: 100' == out[0]) + stats_annotation = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') + assert(stats_annotation['returncode'] == 0) + assert(stats_annotation['labels'] == '100') if cls.graph_repr != 'hashfast' and (cls.graph_repr != 'succinct' or cls.mask_dummy): - assert('objects: 91584' == out[1]) + assert(stats_annotation['objects'] == stats_graph['max index (k)']) if cls.anno_repr.endswith('_noswap'): cls.anno_repr = cls.anno_repr[:-len('_noswap')] - assert('representation: ' + cls.anno_repr == out[3]) + assert(cls.anno_repr == stats_annotation['representation']) def test_query(self): query_command = '{exe} query --batch-size 0 -i {graph} -a {annotation} --min-kmers-fraction-label 1.0 {input}'.format( @@ -1135,13 +1122,12 @@ def setUpClass(cls): 20, cls.graph_repr, 'primary', '--mask-dummy' if cls.mask_dummy else '') - res = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('k: 20' == out[0]) + stats_graph = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') + assert(stats_graph['returncode'] == 0) + assert(stats_graph['k'] == '20') if cls.graph_repr != 'succinct' or cls.mask_dummy: - assert('nodes (k): 45792' == out[1]) - assert('mode: primary' == out[2]) + assert(stats_graph['nodes (k)'] == '45792') + assert(stats_graph['mode'] == 'primary') if cls.with_bloom: convert_command = f'{METAGRAPH} transform -o {cls.tempdir.name}/graph \ @@ -1158,17 +1144,16 @@ def setUpClass(cls): ) # check annotation - res = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('labels: 100' == out[0]) + stats_annotation = cls._get_stats(f'-a {cls.tempdir.name}/annotation{anno_file_extension[cls.anno_repr]}') + assert(stats_annotation['returncode'] == 0) + assert(stats_annotation['labels'] == '100') if cls.graph_repr != 'hashfast' and (cls.graph_repr != 'succinct' or cls.mask_dummy): - assert('objects: 45792' == out[1]) + assert(stats_annotation['objects'] == stats_graph['max index (k)']) if cls.anno_repr.endswith('_noswap'): cls.anno_repr = cls.anno_repr[:-len('_noswap')] - assert('representation: ' + cls.anno_repr == out[3]) + assert(cls.anno_repr == stats_annotation['representation']) def test_query(self): query_command = '{exe} query --batch-size 0 -i {graph} -a {annotation} --min-kmers-fraction-label 1.0 {input}'.format( diff --git a/metagraph/integration_tests/test_transform_anno.py b/metagraph/integration_tests/test_transform_anno.py index a1887db9e0..e78ee6380a 100644 --- a/metagraph/integration_tests/test_transform_anno.py +++ b/metagraph/integration_tests/test_transform_anno.py @@ -12,8 +12,6 @@ """Test operations on annotation columns""" -DNA_MODE = os.readlink(METAGRAPH).endswith("_DNA") -PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") TEST_DATA_DIR = os.path.dirname(os.path.realpath(__file__)) + '/../tests/data' NUM_THREADS = 4 @@ -30,12 +28,14 @@ def setUpClass(cls): cls.tempdir.name + '/graph', 20, cls.graph_repr, 'basic', '--mask-dummy') - res = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') - assert(res.returncode == 0) - out = res.stdout.decode().split('\n')[2:] - assert('k: 20' == out[0]) - assert('nodes (k): 46960' == out[1]) - assert('mode: basic' == out[2]) + stats_graph = cls._get_stats(f'{cls.tempdir.name}/graph{graph_file_extension[cls.graph_repr]}') + assert(stats_graph['returncode'] == 0) + assert(stats_graph['k'] == '20') + assert(stats_graph['nodes (k)'] == '46960') + assert(stats_graph['mode'] == 'basic') + + cls.num_nodes = stats_graph['nodes (k)'] + cls.max_index = stats_graph['max index (k)'] cls._annotate_graph( TEST_DATA_DIR + '/transcripts_100.fa', @@ -52,13 +52,15 @@ def setUp(self): self.annotation = f'annotation{anno_file_extension[self.anno_repr]}'; # check annotation - res = self._get_stats(f'-a {self.annotation}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 100', out[0]) - self.assertEqual('objects: 46960', out[1]) - self.assertEqual('density: 0.0185072', out[2]) - self.assertEqual(f'representation: {self.anno_repr}', out[3]) + stats_annotation = self._get_stats(f'-a {self.annotation}') + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual(stats_annotation['labels'], '100') + self.assertEqual(stats_annotation['objects'], self.max_index) + self.assertAlmostEqual( + float(stats_annotation['density']), + 0.0185072 * int(self.num_nodes) / int(self.max_index), + places=6) + self.assertEqual(stats_annotation['representation'], self.anno_repr) def tearDown(self): os.chdir(self.old_cwd) @@ -78,13 +80,15 @@ def _check_aggregation_min(self, min_count, expected_density): res = subprocess.run(command.split(), stdout=PIPE) self.assertEqual(res.returncode, 0) - res = self._get_stats(f'-a aggregated{anno_file_extension[self.anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 1', out[0]) - self.assertEqual('objects: 46960', out[1]) - self.assertEqual(f'density: {expected_density}', out[2]) - self.assertEqual(f'representation: {self.anno_repr}', out[3]) + stats_annotation = self._get_stats(f'-a aggregated{anno_file_extension[self.anno_repr]}') + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual(stats_annotation['labels'], '1') + self.assertEqual(stats_annotation['objects'], self.max_index) + self.assertAlmostEqual( + float(stats_annotation['density']), + float(expected_density) * int(self.num_nodes) / int(self.max_index), + places=5) + self.assertEqual(stats_annotation['representation'], self.anno_repr) def test_aggregate_columns(self): self._check_aggregation_min(0, 1) @@ -100,13 +104,15 @@ def _check_aggregation_min_max_value(self, min_count, max_value, expected_densit res = subprocess.run(command.split(), stdout=PIPE) self.assertEqual(res.returncode, 0) - res = self._get_stats(f'-a aggregated{anno_file_extension[self.anno_repr]}') - self.assertEqual(res.returncode, 0) - out = res.stdout.decode().split('\n')[2:] - self.assertEqual('labels: 1', out[0]) - self.assertEqual('objects: 46960', out[1]) - self.assertEqual(f'density: {expected_density}', out[2]) - self.assertEqual(f'representation: {self.anno_repr}', out[3]) + stats_annotation = self._get_stats(f'-a aggregated{anno_file_extension[self.anno_repr]}') + self.assertEqual(stats_annotation['returncode'], 0) + self.assertEqual(stats_annotation['labels'], '1') + self.assertEqual(stats_annotation['objects'], self.max_index) + self.assertAlmostEqual( + float(stats_annotation['density']), + float(expected_density) * int(self.num_nodes) / int(self.max_index), + places=5) + self.assertEqual(stats_annotation['representation'], self.anno_repr) def test_aggregate_columns_filtered(self): self._check_aggregation_min_max_value(0, 0, 0) diff --git a/metagraph/src/annotation/annotation_converters.cpp b/metagraph/src/annotation/annotation_converters.cpp index 061b16ee21..f7a567b9e7 100644 --- a/metagraph/src/annotation/annotation_converters.cpp +++ b/metagraph/src/annotation/annotation_converters.cpp @@ -10,6 +10,7 @@ #include #include "row_diff_builder.hpp" +#include "cli/load/load_graph.hpp" #include "common/logger.hpp" #include "common/algorithms.hpp" #include "common/hashers/hash.hpp" @@ -213,7 +214,7 @@ convert_row_diff_to_BRWT(RowDiffColumnAnnotator &&annotator, BRWTBottomUpBuilder::Partitioner partitioning, size_t num_parallel_nodes, size_t num_threads) { - const graph::DBGSuccinct* graph = annotator.get_matrix().graph(); + const graph::DeBruijnGraph* graph = annotator.get_matrix().graph(); auto matrix = std::make_unique( BRWTBottomUpBuilder::build(std::move(annotator.release_matrix()->diffs().data()), @@ -1539,12 +1540,13 @@ void convert_to_row_diff(const std::vector &files, if (out_dir.empty()) out_dir = "./"; + auto graph = cli::load_critical_dbg(graph_fname); if (construction_stage != RowDiffStage::COUNT_LABELS) - build_pred_succ(graph_fname, graph_fname, out_dir, + build_pred_succ(*graph, graph_fname, out_dir, ".row_count", get_num_threads()); if (construction_stage == RowDiffStage::CONVERT) { - assign_anchors(graph_fname, graph_fname, out_dir, max_path_length, + assign_anchors(*graph, graph_fname, out_dir, max_path_length, ".row_reduction", get_num_threads()); const std::string anchors_fname = graph_fname + kRowDiffAnchorExt; diff --git a/metagraph/src/annotation/binary_matrix/row_diff/row_diff.cpp b/metagraph/src/annotation/binary_matrix/row_diff/row_diff.cpp index 81f23f581b..58b83656dd 100644 --- a/metagraph/src/annotation/binary_matrix/row_diff/row_diff.cpp +++ b/metagraph/src/annotation/binary_matrix/row_diff/row_diff.cpp @@ -10,6 +10,31 @@ namespace mtg { namespace annot { + +using node_index = graph::DeBruijnGraph::node_index; + +node_index row_diff_successor(const graph::DeBruijnGraph &graph, + node_index node, + const bit_vector &rd_succ) { + if (auto* dbg_succ = dynamic_cast(&graph)) { + return dbg_succ->get_boss().row_diff_successor( + node, + rd_succ.size() ? rd_succ : dbg_succ->get_boss().get_last() + ); + } else { + assert(rd_succ.size()); + node_index succ = graph::DeBruijnGraph::npos; + graph.adjacent_outgoing_nodes(node, [&](node_index adjacent_node) { + if (rd_succ[adjacent_node]) { + succ = adjacent_node; + } + }); + assert(graph.in_graph(succ) && "a row diff successor must exist"); + return succ; + } +} + + namespace matrix { void IRowDiff::load_anchor(const std::string &filename) { @@ -41,7 +66,7 @@ void IRowDiff::load_fork_succ(const std::string &filename) { std::tuple, std::vector>, std::vector> IRowDiff::get_rd_ids(const std::vector &row_ids) const { assert(graph_ && "graph must be loaded"); - assert(!fork_succ_.size() || fork_succ_.size() == graph_->get_boss().get_last().size()); + assert(!fork_succ_.size() || fork_succ_.size() == graph_->max_index() + 1); using Row = BinaryMatrix::Row; @@ -54,18 +79,14 @@ IRowDiff::get_rd_ids(const std::vector &row_ids) const { // been reached before, and thus, will be reconstructed before this one. std::vector> rd_paths_trunc(row_ids.size()); - const graph::boss::BOSS &boss = graph_->get_boss(); - const bit_vector &rd_succ = fork_succ_.size() ? fork_succ_ : boss.get_last(); - for (size_t i = 0; i < row_ids.size(); ++i) { Row row = row_ids[i]; - graph::boss::BOSS::edge_index boss_edge = graph_->kmer_to_boss_index( - graph::AnnotatedSequenceGraph::anno_to_graph_index(row)); + node_index node = graph::AnnotatedSequenceGraph::anno_to_graph_index(row); while (true) { - row = graph::AnnotatedSequenceGraph::graph_to_anno_index( - graph_->boss_to_kmer_index(boss_edge)); + assert(graph_->in_graph(node)); + row = graph::AnnotatedSequenceGraph::graph_to_anno_index(node); auto [it, is_new] = node_to_rd.try_emplace(row, node_to_rd.size()); rd_paths_trunc[i].push_back(it.value()); @@ -80,7 +101,7 @@ IRowDiff::get_rd_ids(const std::vector &row_ids) const { if (anchor_[row]) break; - boss_edge = boss.row_diff_successor(boss_edge, rd_succ); + node = row_diff_successor(*graph_, node, fork_succ_); } } diff --git a/metagraph/src/annotation/binary_matrix/row_diff/row_diff.hpp b/metagraph/src/annotation/binary_matrix/row_diff/row_diff.hpp index f2842d58f4..67382f5a12 100644 --- a/metagraph/src/annotation/binary_matrix/row_diff/row_diff.hpp +++ b/metagraph/src/annotation/binary_matrix/row_diff/row_diff.hpp @@ -19,6 +19,11 @@ namespace mtg { namespace annot { + +graph::DeBruijnGraph::node_index row_diff_successor(const graph::DeBruijnGraph &graph, + graph::DeBruijnGraph::node_index node, + const bit_vector &rd_succ); + namespace matrix { const std::string kRowDiffAnchorExt = ".anchors"; @@ -34,8 +39,8 @@ class IRowDiff { virtual ~IRowDiff() {} - const graph::DBGSuccinct* graph() const { return graph_; } - void set_graph(const graph::DBGSuccinct *graph) { graph_ = graph; } + const graph::DeBruijnGraph* graph() const { return graph_; } + void set_graph(const graph::DeBruijnGraph *graph) { graph_ = graph; } void load_fork_succ(const std::string &filename); void load_anchor(const std::string &filename); @@ -49,7 +54,7 @@ class IRowDiff { std::tuple, std::vector>, std::vector> get_rd_ids(const std::vector &row_ids) const; - const graph::DBGSuccinct *graph_ = nullptr; + const graph::DeBruijnGraph *graph_ = nullptr; anchor_bv_type anchor_; fork_succ_bv_type fork_succ_; }; @@ -79,7 +84,7 @@ template class RowDiff : public IRowDiff, public BinaryMatrix { public: template - RowDiff(const graph::DBGSuccinct *graph = nullptr, Args&&... args) + RowDiff(const graph::DeBruijnGraph *graph = nullptr, Args&&... args) : diffs_(std::forward(args)...) { graph_ = graph; } /** @@ -117,23 +122,16 @@ std::vector RowDiff::get_column(Column column) co assert(graph_ && "graph must be loaded"); assert(anchor_.size() == diffs_.num_rows() && "anchors must be loaded"); - const graph::boss::BOSS &boss = graph_->get_boss(); - assert(!fork_succ_.size() || fork_succ_.size() == boss.get_last().size()); + assert(!fork_succ_.size() || fork_succ_.size() == graph_->max_index() + 1); std::vector result; // TODO: implement a more efficient algorithm - for (Row row = 0; row < num_rows(); ++row) { - auto edge = graph_->kmer_to_boss_index( - graph::AnnotatedSequenceGraph::anno_to_graph_index(row) - ); - - if (!boss.get_W(edge)) - continue; - + graph_->call_nodes([&](auto node) { + auto row = graph::AnnotatedDBG::graph_to_anno_index(node); SetBitPositions set_bits = get_rows({ row })[0]; if (std::binary_search(set_bits.begin(), set_bits.end(), column)) result.push_back(row); - } + }); return result; } @@ -142,7 +140,7 @@ std::vector RowDiff::get_rows(const std::vector &row_ids) const { assert(graph_ && "graph must be loaded"); assert(anchor_.size() == diffs_.num_rows() && "anchors must be loaded"); - assert(!fork_succ_.size() || fork_succ_.size() == graph_->get_boss().get_last().size()); + assert(!fork_succ_.size() || fork_succ_.size() == graph_->max_index() + 1); // get row-diff paths auto [rd_ids, rd_paths_trunc, times_traversed] = get_rd_ids(row_ids); diff --git a/metagraph/src/annotation/int_matrix/row_diff/int_row_diff.hpp b/metagraph/src/annotation/int_matrix/row_diff/int_row_diff.hpp index 36a04eace6..0bd920f53c 100644 --- a/metagraph/src/annotation/int_matrix/row_diff/int_row_diff.hpp +++ b/metagraph/src/annotation/int_matrix/row_diff/int_row_diff.hpp @@ -47,7 +47,7 @@ class IntRowDiff : public IRowDiff, public BinaryMatrix, public IntMatrix { static_assert(std::is_convertible::value); template - IntRowDiff(const graph::DBGSuccinct *graph = nullptr, Args&&... args) + IntRowDiff(const graph::DeBruijnGraph *graph = nullptr, Args&&... args) : diffs_(std::forward(args)...) { graph_ = graph; } std::vector get_column(Column j) const override; @@ -80,23 +80,16 @@ std::vector IntRowDiff::get_column(Column j) cons assert(graph_ && "graph must be loaded"); assert(anchor_.size() == diffs_.num_rows() && "anchors must be loaded"); - const graph::boss::BOSS &boss = graph_->get_boss(); - assert(!fork_succ_.size() || fork_succ_.size() == boss.get_last().size()); + assert(!fork_succ_.size() || fork_succ_.size() == graph_->max_index() + 1); // TODO: implement a more efficient algorithm std::vector result; - for (Row i = 0; i < num_rows(); ++i) { - auto edge = graph_->kmer_to_boss_index( - graph::AnnotatedSequenceGraph::anno_to_graph_index(i) - ); - - if (!boss.get_W(edge)) - continue; - + graph_->call_nodes([&](auto node) { + auto i = graph::AnnotatedDBG::graph_to_anno_index(node); SetBitPositions set_bits = get_rows({ i })[0]; if (std::binary_search(set_bits.begin(), set_bits.end(), j)) result.push_back(i); - } + }); return result; } @@ -120,7 +113,7 @@ std::vector IntRowDiff::get_row_values(const std::vector &row_ids) const { assert(graph_ && "graph must be loaded"); assert(anchor_.size() == diffs_.num_rows() && "anchors must be loaded"); - assert(!fork_succ_.size() || fork_succ_.size() == graph_->get_boss().get_last().size()); + assert(!fork_succ_.size() || fork_succ_.size() == graph_->max_index() + 1); // get row-diff paths auto [rd_ids, rd_paths_trunc, times_traversed] = get_rd_ids(row_ids); diff --git a/metagraph/src/annotation/int_matrix/row_diff/tuple_row_diff.hpp b/metagraph/src/annotation/int_matrix/row_diff/tuple_row_diff.hpp index 8c9df1cfa5..1bbeed50f3 100644 --- a/metagraph/src/annotation/int_matrix/row_diff/tuple_row_diff.hpp +++ b/metagraph/src/annotation/int_matrix/row_diff/tuple_row_diff.hpp @@ -30,7 +30,7 @@ class TupleRowDiff : public IRowDiff, public BinaryMatrix, public MultiIntMatrix static const int SHIFT = 1; // coordinates increase by 1 at each edge template - TupleRowDiff(const graph::DBGSuccinct *graph = nullptr, Args&&... args) + TupleRowDiff(const graph::DeBruijnGraph *graph = nullptr, Args&&... args) : diffs_(std::forward(args)...) { graph_ = graph; } std::vector get_column(Column j) const override; @@ -63,23 +63,16 @@ std::vector TupleRowDiff::get_column(Column j) co assert(graph_ && "graph must be loaded"); assert(anchor_.size() == diffs_.num_rows() && "anchors must be loaded"); - const graph::boss::BOSS &boss = graph_->get_boss(); - assert(!fork_succ_.size() || fork_succ_.size() == boss.get_last().size()); + assert(!fork_succ_.size() || fork_succ_.size() == graph_->max_index() + 1); // TODO: implement a more efficient algorithm std::vector result; - for (Row i = 0; i < num_rows(); ++i) { - auto edge = graph_->kmer_to_boss_index( - graph::AnnotatedSequenceGraph::anno_to_graph_index(i) - ); - - if (!boss.get_W(edge)) - continue; - + graph_->call_nodes([&](auto node) { + auto i = graph::AnnotatedDBG::graph_to_anno_index(node); SetBitPositions set_bits = get_rows({ i })[0]; if (std::binary_search(set_bits.begin(), set_bits.end(), j)) result.push_back(i); - } + }); return result; } @@ -103,7 +96,7 @@ std::vector TupleRowDiff::get_row_tuples(const std::vector &row_ids) const { assert(graph_ && "graph must be loaded"); assert(anchor_.size() == diffs_.num_rows() && "anchors must be loaded"); - assert(!fork_succ_.size() || fork_succ_.size() == graph_->get_boss().get_last().size()); + assert(!fork_succ_.size() || fork_succ_.size() == graph_->max_index() + 1); // get row-diff paths auto [rd_ids, rd_paths_trunc, times_traversed] = get_rd_ids(row_ids); diff --git a/metagraph/src/annotation/row_diff_builder.cpp b/metagraph/src/annotation/row_diff_builder.cpp index db7d36befa..05ec2a5933 100644 --- a/metagraph/src/annotation/row_diff_builder.cpp +++ b/metagraph/src/annotation/row_diff_builder.cpp @@ -11,6 +11,7 @@ #include "common/elias_fano/elias_fano_merger.hpp" #include "common/utils/file_utils.hpp" #include "common/vectors/bit_vector_sdsl.hpp" +#include "common/vectors/bit_vector_dyn.hpp" #include "graph/annotated_dbg.hpp" const uint64_t BLOCK_SIZE = 1 << 25; @@ -26,6 +27,7 @@ namespace annot { using namespace mtg::annot::matrix; using mtg::common::logger; using mtg::graph::boss::BOSS; +using node_index = graph::DeBruijnGraph::node_index; namespace fs = std::filesystem; using anchor_bv_type = RowDiff::anchor_bv_type; @@ -264,13 +266,35 @@ void sum_and_call_counts(const fs::path &dir, } } -rd_succ_bv_type route_at_forks(const graph::DBGSuccinct &graph, - const std::string &rd_succ_filename, - const std::string &count_vectors_dir, - const std::string &row_count_extension) { +std::shared_ptr get_last(const graph::DeBruijnGraph &graph) { + if (auto* dbg_succ = dynamic_cast(&graph)) { + return std::shared_ptr( + std::shared_ptr{}, &dbg_succ->get_boss().get_last()); + } else { + sdsl::bit_vector last_bv(graph.max_index() + 1); + + __atomic_thread_fence(__ATOMIC_RELEASE); + graph.call_nodes([&](node_index v) { + std::pair last = { '\0', graph::DeBruijnGraph::npos }; + graph.call_outgoing_kmers(v, [&](node_index u, char c) { + last = std::max(last, std::pair{ c, u }); + }); + + if (last.second != graph::DeBruijnGraph::npos) + set_bit(last_bv.data(), last.second, true, __ATOMIC_RELAXED); + }, []() { return false; }, get_num_threads()); + __atomic_thread_fence(__ATOMIC_ACQUIRE); + return std::make_shared(std::move(last_bv)); + } +} + +std::shared_ptr route_at_forks(const graph::DeBruijnGraph &graph, + const std::string &rd_succ_filename, + const std::string &count_vectors_dir, + const std::string &row_count_extension) { logger->trace("Assigning row-diff successors at forks..."); - rd_succ_bv_type rd_succ; + std::shared_ptr rd_succ; bool optimize_forks = false; for (const auto &p : fs::directory_iterator(count_vectors_dir)) { @@ -278,55 +302,270 @@ rd_succ_bv_type route_at_forks(const graph::DBGSuccinct &graph, optimize_forks = true; } + std::ofstream f(rd_succ_filename, ios::binary); if (optimize_forks) { logger->trace("RowDiff successors will be set to the adjacent nodes with" " the largest number of labels"); - const bit_vector &last = graph.get_boss().get_last(); + sdsl::bit_vector rd_succ_bv(graph.max_index() + 1, false); + graph::DeBruijnGraph::node_index graph_idx = to_node(0); + if (const auto *succinct = dynamic_cast(&graph)) { + const auto &boss = succinct->get_boss(); + std::vector outgoing_counts; + sum_and_call_counts(count_vectors_dir, row_count_extension, "row counts", + [&](int32_t count) { + // TODO: skip single outgoing + outgoing_counts.push_back((count + 1) * graph.in_graph(graph_idx)); + if (boss.get_last(graph_idx)) { + // pick the node with the largest count + size_t max_pos = std::max_element(outgoing_counts.rbegin(), + outgoing_counts.rend()) + - outgoing_counts.rbegin(); + if (outgoing_counts[max_pos]) { // Don't mark fake vertices as succ + rd_succ_bv[graph_idx - max_pos] = true; + } + outgoing_counts.resize(0); + } + graph_idx++; + } + ); + } else { + auto get_first_parent = [&](node_index v) { + std::pair first_in_edge { 127, graph::DeBruijnGraph::npos }; + size_t indegree = 0; + graph.call_incoming_kmers(v, [&](node_index prev, char c) { + first_in_edge = std::min(first_in_edge, std::make_pair(c, prev)); + ++indegree; + }); - std::vector outgoing_counts; + auto [c, p] = first_in_edge; - sdsl::bit_vector rd_succ_bv(last.size(), false); + assert((indegree > 0) == (p != graph::DeBruijnGraph::npos)); + size_t outdegree = !indegree ? 0 : graph.outdegree(p); + return std::make_pair(p, outdegree); + }; - sum_and_call_counts(count_vectors_dir, row_count_extension, "row counts", - [&](int32_t count) { - // TODO: skip single outgoing - outgoing_counts.push_back(count); - if (last[graph.kmer_to_boss_index(graph_idx)]) { - // pick the node with the largest count - size_t max_pos = std::max_element(outgoing_counts.rbegin(), - outgoing_counts.rend()) - - outgoing_counts.rbegin(); - rd_succ_bv[graph.kmer_to_boss_index(graph_idx - max_pos)] = true; - outgoing_counts.resize(0); + tsl::hopscotch_map>> outgoing_counts_cache; + sum_and_call_counts(count_vectors_dir, row_count_extension, "row counts", + [&](int32_t count) { + if (graph.in_graph(graph_idx)) { + auto [parent, outdegree] = get_first_parent(graph_idx); + if (outdegree > 1) { + assert(parent != graph::DeBruijnGraph::npos); + auto &bucket = outgoing_counts_cache[parent]; + bucket.emplace_back(count, graph_idx); + if (bucket.size() == outdegree) { + // all siblings visited, mark the max, then clear cache + auto max_it = std::max_element(bucket.begin(), bucket.end()); + rd_succ_bv[max_it->second] = true; + outgoing_counts_cache.erase(parent); + } + } else { + rd_succ_bv[graph_idx] = true; + } + } + graph_idx++; } - graph_idx++; + ); + + if (outgoing_counts_cache.size()) { + logger->error("{} parent nodes unaccounted for", outgoing_counts_cache.size()); + exit(1); } - ); + } - if (graph_idx != graph.num_nodes() + 1) { - logger->error("Size the count vectors is incompatible with the" - " graph: {} != {}", graph_idx - 1, graph.num_nodes()); + if (graph_idx != graph.max_index() + 1) { + logger->error("Size of the count vectors is incompatible with the" + " graph: {} != {}", graph_idx - 1, graph.max_index()); exit(1); } - rd_succ = rd_succ_bv_type(std::move(rd_succ_bv)); + rd_succ = std::make_shared(std::move(rd_succ_bv)); + rd_succ->serialize(f); } else { logger->warn("No count vectors could be found in {}. The last outgoing" " edges will be selected for assigning RowDiff successors", count_vectors_dir); + rd_succ = get_last(graph); + if (dynamic_cast(&graph)) { + rd_succ_bv_type().serialize(f); + } else { + assert(std::dynamic_pointer_cast(rd_succ)); + rd_succ_bv_type(std::move( + *std::static_pointer_cast(rd_succ) + )).serialize(f); + } } - std::ofstream f(rd_succ_filename, ios::binary); - rd_succ.serialize(f); logger->trace("RowDiff successors are assigned for forks and written to {}", rd_succ_filename); return rd_succ; } -void build_pred_succ(const std::string &graph_fname, +void row_diff_traverse(const graph::DeBruijnGraph &graph, + size_t num_threads, + size_t max_length, + const bit_vector &rd_succ, + sdsl::bit_vector *terminal) { + if (auto* dbg_succ = dynamic_cast(&graph)) { + return dbg_succ->get_boss().row_diff_traverse( + num_threads, max_length, rd_succ, terminal); + } else { + std::atomic_thread_fence(std::memory_order_release); + + sdsl::bit_vector visited(graph.max_index() + 1); + assert(terminal->size() == visited.size()); + assert(rd_succ.size() == visited.size()); + + ProgressBar progress_bar(graph.num_nodes(), "Checking nodes", std::cerr, + !common::get_verbose()); + + graph.call_nodes([&](node_index v) { + if (!graph.outdegree(v)) { + assert(rd_succ[v]); + // mark all terminals as anchors + set_bit(terminal->data(), v, true, std::memory_order_relaxed); + std::vector> traversal; + traversal.emplace_back(v, 0); + while (traversal.size()) { + auto [v, d] = traversal.back(); + traversal.pop_back(); + if (fetch_and_set_bit(visited.data(), v, true, std::memory_order_acq_rel)) + continue; + + ++progress_bar; + + if (d == max_length) { + set_bit(terminal->data(), v, true, std::memory_order_release); + d = 0; + } + + if (rd_succ[v]) { + graph.adjacent_incoming_nodes(v, [&](node_index pred) { + traversal.emplace_back(pred, d + 1); + }); + } + } + } + }, []() { return false; }, num_threads); + + graph.call_nodes([&](node_index v) { + if (!graph.indegree(v)) { + // traverse forwards from sources + size_t d = 0; + node_index last_v = graph::DeBruijnGraph::npos; + while (!fetch_and_set_bit(visited.data(), v, true, std::memory_order_acq_rel)) { + ++d; + ++progress_bar; + if (d == max_length) { + set_bit(terminal->data(), v, true, std::memory_order_release); + d = 0; + } + last_v = v; + v = row_diff_successor(graph, v, rd_succ); + } + if (last_v && !fetch_bit(terminal->data(), v, true, std::memory_order_acquire)) + set_bit(terminal->data(), last_v, true, std::memory_order_relaxed); + } + }, []() { return false; }, num_threads); + + graph.call_nodes([&](node_index v) { + if (fetch_bit(terminal->data(), v, true, std::memory_order_acquire)) { + // start at next nodes from termini + graph.adjacent_outgoing_nodes(v, [&](node_index v) { + size_t d = 0; + node_index last_v = graph::DeBruijnGraph::npos; + while (!fetch_and_set_bit(visited.data(), v, true, std::memory_order_acq_rel)) { + ++d; + ++progress_bar; + if (d == max_length) { + set_bit(terminal->data(), v, true, std::memory_order_release); + d = 0; + } + last_v = v; + v = row_diff_successor(graph, v, rd_succ); + } + if (last_v && !fetch_bit(terminal->data(), v, true, std::memory_order_acquire)) + set_bit(terminal->data(), last_v, true, std::memory_order_relaxed); + }); + } + }, []() { return false; }, num_threads); + + // forks + graph.call_nodes([&](node_index v) { + if (graph.has_multiple_outgoing(v)) { + graph.adjacent_outgoing_nodes(v, [&](node_index v) { + size_t d = 0; + node_index last_v = graph::DeBruijnGraph::npos; + while (!fetch_and_set_bit(visited.data(), v, true, std::memory_order_acq_rel)) { + ++d; + ++progress_bar; + if (d == max_length) { + set_bit(terminal->data(), v, true, std::memory_order_release); + d = 0; + } + last_v = v; + v = row_diff_successor(graph, v, rd_succ); + } + if (last_v && !fetch_bit(terminal->data(), v, true, std::memory_order_acquire)) + set_bit(terminal->data(), last_v, true, std::memory_order_relaxed); + }); + } + }, []() { return false; }, num_threads); + + // everything else + graph.call_nodes([&](node_index v) { + size_t d = 0; + node_index last_v = graph::DeBruijnGraph::npos; + while (!fetch_and_set_bit(visited.data(), v, true, std::memory_order_acq_rel)) { + ++d; + ++progress_bar; + if (d == max_length) { + set_bit(terminal->data(), v, true, std::memory_order_release); + d = 0; + } + last_v = v; + v = row_diff_successor(graph, v, rd_succ); + } + if (last_v && !fetch_bit(terminal->data(), v, true, std::memory_order_acquire)) + set_bit(terminal->data(), last_v, true, std::memory_order_relaxed); + }, []() { return false; }, num_threads); + + std::atomic_thread_fence(std::memory_order_acquire); + + // std::atomic_thread_fence(std::memory_order_release); + // auto finalised = visited; + // graph.call_nodes([&](node_index start) { + // node_index v = start; + // std::vector path; + // while (path.size() < max_length + // && !fetch_and_set_bit(visited.data(), v, true, std::memory_order_acq_rel)) { + // path.push_back(v); + // if (!graph.has_no_outgoing(v)) + // v = row_diff_successor(graph, v, rd_succ); + // } + + // if (path.empty()) + // return; + + // progress_bar += path.size(); + + // // Either a sink, or a cyclic dependency + // if (!fetch_and_set_bit(finalised.data(), v, true, std::memory_order_acq_rel)) + // set_bit(terminal->data(), v, true, std::memory_order_relaxed); + + // for (node_index v : path) { + // set_bit(finalised.data(), v, true, std::memory_order_release); + // } + // }, []() { return false; }, num_threads); + + // std::atomic_thread_fence(std::memory_order_acquire); + } +} + +void build_pred_succ(const graph::DeBruijnGraph &graph, const std::string &outfbase, const std::string &count_vectors_dir, const std::string &row_count_extension, @@ -342,70 +581,70 @@ void build_pred_succ(const std::string &graph_fname, logger->trace("Building and writing successor and predecessor files to {}.*", outfbase); - graph::DBGSuccinct graph(2); - logger->trace("Loading graph..."); - if (!graph.load(graph_fname)) { - logger->error("Cannot load graph from {}", graph_fname); - std::exit(1); - } // assign row-diff successors at forks - rd_succ_bv_type rd_succ = route_at_forks(graph, outfbase + kRowDiffForkSuccExt, - count_vectors_dir, row_count_extension); - - const BOSS &boss = graph.get_boss(); - - sdsl::bit_vector dummy = boss.mark_all_dummy_edges(num_threads); + auto rd_succ_ptr = route_at_forks(graph, outfbase + kRowDiffForkSuccExt, + count_vectors_dir, row_count_extension); + auto& rd_succ = *rd_succ_ptr; // create the succ/pred files, indexed using annotation indices - uint32_t width = sdsl::bits::hi(graph.num_nodes()) + 1; + uint32_t width = sdsl::bits::hi(graph.max_index()) + 1; sdsl::int_vector_buffer<> succ(outfbase + ".succ", std::ios::out, BUFFER_SIZE, width); sdsl::int_vector_buffer<1> succ_boundary(outfbase + ".succ_boundary", std::ios::out, BUFFER_SIZE); sdsl::int_vector_buffer<> pred(outfbase + ".pred", std::ios::out, BUFFER_SIZE, width); sdsl::int_vector_buffer<1> pred_boundary(outfbase + ".pred_boundary", std::ios::out, BUFFER_SIZE); - ProgressBar progress_bar(graph.num_nodes(), "Compute succ/pred", std::cerr, + std::optional dummy; + auto* succinct = dynamic_cast(&graph); + if (succinct) { + dummy = succinct->get_boss().mark_all_dummy_edges(num_threads); + } + + ProgressBar progress_bar(graph.max_index(), "Compute succ/pred", std::cerr, !common::get_verbose()); const uint64_t BS = 1'000'000; - // traverse BOSS table in parallel processing blocks of size |BS| + // traverse graph in parallel processing blocks of size |BS| // use static scheduling to make threads process ordered contiguous blocks #pragma omp parallel for ordered num_threads(num_threads) schedule(dynamic) - for (uint64_t start = 1; start <= graph.num_nodes(); start += BS) { - std::vector succ_buf; + for (node_index start = 1; start <= graph.max_index(); start += BS) { + std::vector succ_buf; std::vector succ_boundary_buf; - std::vector pred_buf; + std::vector pred_buf; std::vector pred_boundary_buf; - for (uint64_t i = start; i < std::min(start + BS, graph.num_nodes() + 1); ++i) { - BOSS::edge_index boss_idx = graph.kmer_to_boss_index(i); - if (!dummy[boss_idx]) { - BOSS::edge_index next = boss.fwd(boss_idx); - assert(next); - if (!dummy[next]) { - while (rd_succ.size() && !rd_succ[next]) { - next--; - assert(!boss.get_last(next)); - } - succ_buf.push_back(to_row(graph.boss_to_kmer_index(next))); + for (node_index i = start; i < std::min(start + BS, graph.max_index() + 1); ++i) { + bool skip_succ = false; + bool skip_all = !graph.in_graph(i); + + if (!skip_all && succinct) { // Legacy code for DBGSuccinct + BOSS::edge_index boss_idx = i; + if((*dummy)[boss_idx]) { + skip_all = true; + } else { + skip_succ = (*dummy)[succinct->get_boss().fwd(boss_idx)]; + } + } + + if (!skip_all) { + skip_succ |= graph.has_no_outgoing(i); + if (!skip_succ) { + auto j = row_diff_successor(graph, i, rd_succ); + succ_buf.push_back(to_row(j)); succ_boundary_buf.push_back(0); } - // compute predecessors only for row-diff successors - if (rd_succ.size() ? rd_succ[boss_idx] : boss.get_last(boss_idx)) { - BOSS::TAlphabet d = boss.get_node_last_value(boss_idx); - BOSS::edge_index back_idx = boss.bwd(boss_idx); - boss.call_incoming_to_target(back_idx, d, - [&](BOSS::edge_index pred) { - // dummy predecessors are ignored - if (!dummy[pred]) { - uint64_t node_index = graph.boss_to_kmer_index(pred); - pred_buf.push_back(to_row(node_index)); - pred_boundary_buf.push_back(0); - } + + if (rd_succ[i]) { + graph.adjacent_incoming_nodes(i, [&](auto pred) { + if (dummy && (*dummy)[pred]) { + return; } - ); + pred_buf.push_back(to_row(pred)); + pred_boundary_buf.push_back(0); + }); } } + succ_boundary_buf.push_back(1); pred_boundary_buf.push_back(1); ++progress_bar; @@ -424,7 +663,7 @@ void build_pred_succ(const std::string &graph_fname, logger->trace("Pred/succ nodes written to {}.pred/succ", outfbase); } -void assign_anchors(const std::string &graph_fname, +void assign_anchors(const graph::DeBruijnGraph &graph, const std::string &outfbase, const std::filesystem::path &count_vectors_dir, uint32_t max_length, @@ -436,14 +675,7 @@ void assign_anchors(const std::string &graph_fname, return; } - graph::DBGSuccinct graph(2); - logger->trace("Loading graph..."); - if (!graph.load(graph_fname)) { - logger->error("Cannot load graph from {}", graph_fname); - std::exit(1); - } - const BOSS &boss = graph.get_boss(); - const uint64_t num_rows = graph.num_nodes(); + const uint64_t num_rows = graph.max_index(); bool optimize_anchors = false; for (const auto &p : fs::directory_iterator(count_vectors_dir)) { @@ -451,7 +683,7 @@ void assign_anchors(const std::string &graph_fname, optimize_anchors = true; } - sdsl::bit_vector anchors_bv(boss.get_last().size(), false); + sdsl::bit_vector anchors_bv(graph.max_index() + 1, false); if (optimize_anchors) { logger->trace("Making every row with negative reduction an anchor..."); @@ -460,8 +692,9 @@ void assign_anchors(const std::string &graph_fname, sum_and_call_counts(count_vectors_dir, row_reduction_extension, "row reduction", [&](int32_t count) { // check if the reduction is negative - if (count < 0) - anchors_bv[graph.kmer_to_boss_index(to_node(i))] = true; + if (count < 0) { + anchors_bv[to_node(i)] = true; + } i++; } ); @@ -492,11 +725,12 @@ void assign_anchors(const std::string &graph_fname, if (rd_succ.size()) { logger->trace("Assigning anchors for RowDiff successors {}...", rd_succ_fname); - boss.row_diff_traverse(num_threads, max_length, rd_succ, &anchors_bv); + row_diff_traverse(graph, num_threads, max_length, rd_succ, &anchors_bv); } else { logger->warn("Assigning anchors without chosen RowDiff successors." " The last outgoing edges will be used for routing."); - boss.row_diff_traverse(num_threads, max_length, boss.get_last(), &anchors_bv); + auto last = get_last(graph); + row_diff_traverse(graph, num_threads, max_length, *last, &anchors_bv); } } @@ -505,7 +739,7 @@ void assign_anchors(const std::string &graph_fname, sdsl::bit_vector anchors(num_rows, false); for (BOSS::edge_index i = 1; i < anchors_bv.size(); ++i) { if (anchors_bv[i]) { - uint64_t graph_idx = graph.boss_to_kmer_index(i); + uint64_t graph_idx = i; assert(to_row(graph_idx) < num_rows); anchors[to_row(graph_idx)] = 1; } @@ -929,7 +1163,7 @@ void convert_batch_to_row_diff(const std::string &pred_succ_fprefix, // reduction (zero diff) __atomic_add_fetch(&row_nbits_block[chunk_idx], 1, __ATOMIC_RELAXED); } - } else { + } else if (succ || anchor[row_idx]) { bool is_anchor = anchor[row_idx]; // add current bit if this node is an anchor // or if the successor has zero diff diff --git a/metagraph/src/annotation/row_diff_builder.hpp b/metagraph/src/annotation/row_diff_builder.hpp index f57fe4c38c..47e3ffa116 100644 --- a/metagraph/src/annotation/row_diff_builder.hpp +++ b/metagraph/src/annotation/row_diff_builder.hpp @@ -16,13 +16,13 @@ void count_labels_per_row(const std::vector &source_files, const std::string &row_count_fname, bool with_coordinates = false); -void build_pred_succ(const std::string &graph_filename, +void build_pred_succ(const graph::DeBruijnGraph &graph, const std::string &outfbase, const std::string &count_vectors_dir, const std::string &row_count_extension, uint32_t num_threads); -void assign_anchors(const std::string &graph_filename, +void assign_anchors(const graph::DeBruijnGraph &graph, const std::string &outfbase, const std::filesystem::path &dest_dir, uint32_t max_length, diff --git a/metagraph/src/cli/load/load_annotated_graph.cpp b/metagraph/src/cli/load/load_annotated_graph.cpp index fdf0fbc7ec..5a2e227ff5 100644 --- a/metagraph/src/cli/load/load_annotated_graph.cpp +++ b/metagraph/src/cli/load/load_annotated_graph.cpp @@ -24,8 +24,8 @@ std::unique_ptr initialize_annotated_dbg(std::shared_ptrmax_index(); - const auto *dbg_graph = dynamic_cast(graph.get()); + auto base_graph = graph; if (graph->get_mode() == DeBruijnGraph::PRIMARY) { graph = std::make_shared(graph); logger->trace("Primary graph wrapped into canonical"); @@ -56,13 +56,7 @@ std::unique_ptr initialize_annotated_dbg(std::shared_ptr(annotation_temp->get_matrix()); if (IRowDiff *row_diff = dynamic_cast(&matrix)) { - if (!dbg_graph) { - logger->error("Only succinct de Bruijn graph representations" - " are supported for row-diff annotations"); - std::exit(1); - } - - row_diff->set_graph(dbg_graph); + row_diff->set_graph(base_graph.get()); if (auto *row_diff_column = dynamic_cast *>(&matrix)) { row_diff_column->load_anchor(config.infbase + kRowDiffAnchorExt); diff --git a/metagraph/src/cli/stats.cpp b/metagraph/src/cli/stats.cpp index f4a2cee9c6..c29ef07c09 100644 --- a/metagraph/src/cli/stats.cpp +++ b/metagraph/src/cli/stats.cpp @@ -76,6 +76,7 @@ void print_stats(const graph::DeBruijnGraph &graph, bool print_counts_hist) { std::cout << "====================== GRAPH STATS =====================" << std::endl; std::cout << "k: " << graph.get_k() << std::endl; std::cout << "nodes (k): " << graph.num_nodes() << std::endl; + std::cout << "max index (k): " << graph.max_index() << std::endl; std::cout << "mode: " << Config::graphmode_to_string(graph.get_mode()) << std::endl; if (auto weights = graph.get_extension()) { @@ -143,7 +144,6 @@ void print_stats(const graph::DeBruijnGraph &graph, bool print_counts_hist) { std::cout << std::endl; } } - std::cout << "========================================================" << std::endl; } diff --git a/metagraph/src/graph/alignment/aligner_seeder_methods.cpp b/metagraph/src/graph/alignment/aligner_seeder_methods.cpp index 306c7f6a0d..1a6d7e03d4 100644 --- a/metagraph/src/graph/alignment/aligner_seeder_methods.cpp +++ b/metagraph/src/graph/alignment/aligner_seeder_methods.cpp @@ -104,7 +104,7 @@ void suffix_to_prefix(const DBGSuccinct &dbg_succ, const auto &[first, last, seed_length] = final_range; assert(seed_length == boss.get_k()); for (boss::BOSS::edge_index i = first; i <= last; ++i) { - DBGSuccinct::node_index node = dbg_succ.boss_to_kmer_index(i); + DBGSuccinct::node_index node = dbg_succ.validate_edge(i); if (node) callback(node); } diff --git a/metagraph/src/graph/alignment/alignment.cpp b/metagraph/src/graph/alignment/alignment.cpp index b1bdd0d8a7..ef1f4fb29b 100644 --- a/metagraph/src/graph/alignment/alignment.cpp +++ b/metagraph/src/graph/alignment/alignment.cpp @@ -550,7 +550,7 @@ void Alignment::reverse_complement(const DeBruijnGraph &graph, // the node is present in the underlying graph, so use // lower-level methods const auto &boss = dbg_succ.get_boss(); - boss::BOSS::edge_index edge = dbg_succ.kmer_to_boss_index(nodes_[0]); + boss::BOSS::edge_index edge = nodes_[0]; boss::BOSS::TAlphabet edge_label = boss.get_W(edge) % boss.alph_size; // TODO: This picks the node which is found by always traversing @@ -565,7 +565,7 @@ void Alignment::reverse_complement(const DeBruijnGraph &graph, return; } - nodes_[0] = dbg_succ.boss_to_kmer_index(edge); + nodes_[0] = dbg_succ.validate_edge(edge); assert(nodes_[0]); sequence_.push_back(boss.decode(edge_label)); assert(graph.get_node_sequence(nodes_[0]) diff --git a/metagraph/src/graph/alignment/annotation_buffer.cpp b/metagraph/src/graph/alignment/annotation_buffer.cpp index 4020f312a7..a644bf2933 100644 --- a/metagraph/src/graph/alignment/annotation_buffer.cpp +++ b/metagraph/src/graph/alignment/annotation_buffer.cpp @@ -78,7 +78,7 @@ void AnnotationBuffer::fetch_queued_annotations() { continue; } - if (boss && !boss->get_W(dbg_succ->kmer_to_boss_index(base_path[i]))) { + if (boss && !boss->get_W(base_path[i])) { // skip dummy nodes if (node_to_cols_.try_emplace(base_path[i], 0).second && has_coordinates()) label_coords_.emplace_back(); diff --git a/metagraph/src/graph/graph_extensions/node_first_cache.cpp b/metagraph/src/graph/graph_extensions/node_first_cache.cpp index a945acf12f..297cdddeeb 100644 --- a/metagraph/src/graph/graph_extensions/node_first_cache.cpp +++ b/metagraph/src/graph/graph_extensions/node_first_cache.cpp @@ -36,14 +36,14 @@ void NodeFirstCache::call_incoming_edges(edge_index edge, void NodeFirstCache::call_incoming_kmers(node_index node, const IncomingEdgeCallback &callback) const { - assert(node > 0 && node <= dbg_succ_.num_nodes()); + assert(dbg_succ_.in_graph(node)); - edge_index edge = dbg_succ_.kmer_to_boss_index(node); + edge_index edge = node; call_incoming_edges(edge, [&](edge_index prev_edge) { - node_index prev = dbg_succ_.boss_to_kmer_index(prev_edge); - if (prev != DeBruijnGraph::npos) + node_index prev = prev_edge; + if (dbg_succ_.in_graph(prev)) callback(prev, get_first_char(prev_edge, edge)); } ); diff --git a/metagraph/src/graph/representation/base/dbg_wrapper.hpp b/metagraph/src/graph/representation/base/dbg_wrapper.hpp index 8ccc71ce62..ce46c0b8b8 100644 --- a/metagraph/src/graph/representation/base/dbg_wrapper.hpp +++ b/metagraph/src/graph/representation/base/dbg_wrapper.hpp @@ -71,7 +71,9 @@ class DBGWrapper : public DeBruijnGraph { virtual void call_nodes(const std::function &callback, const std::function &stop_early - = [](){ return false; }) const override = 0; + = [](){ return false; }, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const override = 0; virtual void call_kmers(const std::function &callback, diff --git a/metagraph/src/graph/representation/base/sequence_graph.cpp b/metagraph/src/graph/representation/base/sequence_graph.cpp index bd6d55c485..043e135790 100644 --- a/metagraph/src/graph/representation/base/sequence_graph.cpp +++ b/metagraph/src/graph/representation/base/sequence_graph.cpp @@ -23,12 +23,20 @@ static_assert(!(kBlockSize & 0xFF)); /*************** SequenceGraph ***************/ void SequenceGraph::call_nodes(const std::function &callback, - const std::function &stop_early) const { + const std::function &terminate, + size_t num_threads, + size_t batch_size) const { assert(num_nodes() == max_index()); const auto nnodes = num_nodes(); - for (node_index i = 1; i <= nnodes && !stop_early(); ++i) { - callback(i); + + #pragma omp parallel for num_threads(num_threads) schedule(static, batch_size) + for (node_index i = 1; i <= nnodes; ++i) { + if (terminate()) + continue; + + if (in_graph(i)) + callback(i); } } @@ -83,14 +91,14 @@ bool DeBruijnGraph::find(std::string_view sequence, void DeBruijnGraph ::adjacent_outgoing_nodes(node_index node, const std::function &callback) const { - assert(node > 0 && node <= max_index()); + assert(in_graph(node)); call_outgoing_kmers(node, [&](auto child, char) { callback(child); }); } void DeBruijnGraph ::adjacent_incoming_nodes(node_index node, const std::function &callback) const { - assert(node > 0 && node <= max_index()); + assert(in_graph(node)); call_incoming_kmers(node, [&](auto parent, char) { callback(parent); }); } @@ -122,7 +130,7 @@ void DeBruijnGraph::traverse(node_index start, for (; begin != end && !terminate(); ++begin) { start = traverse(start, *begin); - if (start == npos) + if (!in_graph(start)) return; callback(start); @@ -446,6 +454,7 @@ ::call_kmers(const std::function #include #include @@ -60,10 +62,13 @@ class SequenceGraph { const std::function &callback) const = 0; virtual void call_nodes(const std::function &callback, - const std::function &stop_early = [](){ return false; }) const; + const std::function &terminate = [](){ return false; }, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const; virtual uint64_t num_nodes() const = 0; virtual uint64_t max_index() const { return num_nodes(); }; + virtual bool in_graph(node_index node) const { return node > 0 && node <= max_index(); } virtual bool load(const std::string &filename_base) = 0; virtual void serialize(const std::string &filename_base) const = 0; @@ -203,6 +208,7 @@ class DeBruijnGraph : public SequenceGraph { const std::function &stop_early = [](){ return false; }) const; virtual size_t outdegree(node_index) const = 0; + virtual bool has_no_outgoing(node_index node) const { return outdegree(node) == 0; } virtual bool has_single_outgoing(node_index node) const { return outdegree(node) == 1; } virtual bool has_multiple_outgoing(node_index node) const { return outdegree(node) > 1; } diff --git a/metagraph/src/graph/representation/bitmap/dbg_bitmap.cpp b/metagraph/src/graph/representation/bitmap/dbg_bitmap.cpp index 577b5a72f2..7961155e1d 100644 --- a/metagraph/src/graph/representation/bitmap/dbg_bitmap.cpp +++ b/metagraph/src/graph/representation/bitmap/dbg_bitmap.cpp @@ -38,11 +38,11 @@ DBGBitmap::DBGBitmap(DBGBitmapConstructor *builder) : DBGBitmap(2) { void DBGBitmap::map_to_nodes(std::string_view sequence, const std::function &callback, const std::function &terminate) const { - for (const auto &[kmer, is_valid] : sequence_to_kmers(sequence, mode_ == CANONICAL)) { + for (const auto &[kmer, in_graph] : sequence_to_kmers(sequence, mode_ == CANONICAL)) { if (terminate()) return; - callback(is_valid ? to_node(kmer) : npos); + callback(in_graph ? to_node(kmer) : npos); } } @@ -53,17 +53,17 @@ void DBGBitmap::map_to_nodes(std::string_view sequence, void DBGBitmap::map_to_nodes_sequentially(std::string_view sequence, const std::function &callback, const std::function &terminate) const { - for (const auto &[kmer, is_valid] : sequence_to_kmers(sequence)) { + for (const auto &[kmer, in_graph] : sequence_to_kmers(sequence)) { if (terminate()) return; - callback(is_valid ? to_node(kmer) : npos); + callback(in_graph ? to_node(kmer) : npos); } } DBGBitmap::node_index DBGBitmap::traverse(node_index node, char next_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); auto kmer = node_to_kmer(node); kmer.to_next(k_, seq_encoder_.encode(next_char)); @@ -72,7 +72,7 @@ DBGBitmap::traverse(node_index node, char next_char) const { DBGBitmap::node_index DBGBitmap::traverse_back(node_index node, char prev_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); auto kmer = node_to_kmer(node); kmer.to_prev(k_, seq_encoder_.encode(prev_char)); @@ -81,7 +81,7 @@ DBGBitmap::traverse_back(node_index node, char prev_char) const { void DBGBitmap::call_outgoing_kmers(node_index node, const OutgoingEdgeCallback &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); const auto &kmer = node_to_kmer(node); @@ -96,7 +96,7 @@ void DBGBitmap::call_outgoing_kmers(node_index node, } size_t DBGBitmap::outdegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); if (complete_) return alphabet().size(); @@ -117,7 +117,7 @@ size_t DBGBitmap::outdegree(node_index node) const { } bool DBGBitmap::has_single_outgoing(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); if (complete_) return alphabet().size() == 1; @@ -142,7 +142,7 @@ bool DBGBitmap::has_single_outgoing(node_index node) const { } bool DBGBitmap::has_multiple_outgoing(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); if (complete_) return alphabet().size() > 1; @@ -168,7 +168,7 @@ bool DBGBitmap::has_multiple_outgoing(node_index node) const { void DBGBitmap::call_incoming_kmers(node_index node, const OutgoingEdgeCallback &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); const auto &kmer = node_to_kmer(node); @@ -183,7 +183,7 @@ void DBGBitmap::call_incoming_kmers(node_index node, } size_t DBGBitmap::indegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); if (complete_) return alphabet().size(); @@ -215,19 +215,19 @@ DBGBitmap::node_index DBGBitmap::kmer_to_node(std::string_view kmer) const { } uint64_t DBGBitmap::node_to_index(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); return complete_ ? node : kmers_.select1(node + 1); } DBGBitmap::Kmer DBGBitmap::node_to_kmer(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); return Kmer { complete_ ? node - 1 : kmers_.select1(node + 1) - 1 }; } std::string DBGBitmap::get_node_sequence(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); assert(sequence_to_kmers(seq_encoder_.kmer_to_sequence( node_to_kmer(node), k_)).size() == 1); assert(node == to_node(sequence_to_kmers(seq_encoder_.kmer_to_sequence( diff --git a/metagraph/src/graph/representation/canonical_dbg.cpp b/metagraph/src/graph/representation/canonical_dbg.cpp index 39b3798001..c4aeeec359 100644 --- a/metagraph/src/graph/representation/canonical_dbg.cpp +++ b/metagraph/src/graph/representation/canonical_dbg.cpp @@ -115,7 +115,7 @@ ::map_to_nodes_sequentially(std::string_view sequence, sequence.substr(1)); boss.map_to_edges(sequence.substr(1), [&](boss::BOSS::edge_index edge) { - path.push_back(dbg_succ->boss_to_kmer_index(edge)); + path.push_back(dbg_succ->validate_edge(edge)); ++it; }, []() { return false; }, @@ -285,7 +285,6 @@ void CanonicalDBG::call_incoming_kmers(node_index node, SmallVector parents(alphabet.size(), npos); // "- has_sentinel_" because there can't be a dummy sink with another non-dummy edge size_t max_num_edges_left = parents.size() - has_sentinel_; - auto incoming_kmer_callback = [&](node_index prev, char c) { assert(has_sentinel_ || c != boss::BOSS::kSentinel); assert(c == boss::BOSS::kSentinel || traverse_back(node, c) == prev); @@ -491,7 +490,9 @@ DeBruijnGraph::node_index CanonicalDBG::traverse_back(node_index node, } void CanonicalDBG::call_nodes(const std::function &callback, - const std::function &stop_early) const { + const std::function &stop_early, + size_t num_threads, + size_t batch_size) const { graph_->call_nodes( [&](node_index i) { callback(i); @@ -501,7 +502,9 @@ void CanonicalDBG::call_nodes(const std::function &callback, callback(j); } }, - stop_early + stop_early, + num_threads, + batch_size ); } @@ -601,18 +604,15 @@ ::adjacent_incoming_rc_strand(node_index node, //-> TCAAGCAGAAGACGGCATACGAGATCCTCT const boss::BOSS &boss = dbg_succ_->get_boss(); - boss::BOSS::edge_index rc_edge = get_cache().get_prefix_rc( - dbg_succ_->kmer_to_boss_index(node), - spelling_hint - ); + boss::BOSS::edge_index rc_edge = get_cache().get_prefix_rc(node, spelling_hint); if (!rc_edge) return; boss.call_outgoing(rc_edge, [&](boss::BOSS::edge_index adjacent_edge) { assert(dbg_succ_); - node_index prev = dbg_succ_->boss_to_kmer_index(adjacent_edge); - if (prev == DeBruijnGraph::npos) + node_index prev = adjacent_edge; + if (!dbg_succ_->in_graph(prev)) return; char c = boss.decode(boss.get_W(adjacent_edge) % boss.alph_size); @@ -665,18 +665,15 @@ ::adjacent_outgoing_rc_strand(node_index node, auto &cache = get_cache(); - boss::BOSS::edge_index rc_edge = cache.get_suffix_rc( - dbg_succ_->kmer_to_boss_index(node), - spelling_hint - ); + boss::BOSS::edge_index rc_edge = cache.get_suffix_rc(node, spelling_hint); if (!rc_edge) return; cache.call_incoming_edges(rc_edge, [&](edge_index prev_edge) { - node_index prev = dbg_succ_->boss_to_kmer_index(prev_edge); - if (!prev) + node_index prev = prev_edge; + if (!dbg_succ_->in_graph(prev)) return; char c = cache.get_first_char(prev_edge, rc_edge); diff --git a/metagraph/src/graph/representation/canonical_dbg.hpp b/metagraph/src/graph/representation/canonical_dbg.hpp index c4e4394019..4762b71d7e 100644 --- a/metagraph/src/graph/representation/canonical_dbg.hpp +++ b/metagraph/src/graph/representation/canonical_dbg.hpp @@ -111,7 +111,9 @@ class CanonicalDBG : public DBGWrapper { const std::function &stop_early = [](){ return false; }) const override final; virtual void call_nodes(const std::function &callback, - const std::function &stop_early = [](){ return false; }) const override final; + const std::function &stop_early = [](){ return false; }, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const override final; virtual bool operator==(const DeBruijnGraph &other) const override final; diff --git a/metagraph/src/graph/representation/hash/dbg_hash_fast.cpp b/metagraph/src/graph/representation/hash/dbg_hash_fast.cpp index 8aec35fa75..47ae4d570e 100644 --- a/metagraph/src/graph/representation/hash/dbg_hash_fast.cpp +++ b/metagraph/src/graph/representation/hash/dbg_hash_fast.cpp @@ -57,13 +57,13 @@ class DBGHashFastImpl : public DBGHashFast::DBGHashFastInterface { } void add_sequence(std::string_view sequence, - const std::function &on_insertion); + const std::function &on_insertion) override final; // Traverse graph mapping sequence to the graph nodes // and run callback for each node until the termination condition is satisfied void map_to_nodes(std::string_view sequence, const std::function &callback, - const std::function &terminate) const; + const std::function &terminate) const override final; // Traverse graph mapping sequence to the graph nodes // and run callback for each node until the termination condition is satisfied. @@ -71,19 +71,21 @@ class DBGHashFastImpl : public DBGHashFast::DBGHashFastInterface { // In canonical mode, non-canonical k-mers are NOT mapped to canonical ones void map_to_nodes_sequentially(std::string_view sequence, const std::function &callback, - const std::function &terminate) const; + const std::function &terminate) const override final; void call_outgoing_kmers(node_index node, - const OutgoingEdgeCallback &callback) const; + const OutgoingEdgeCallback &callback) const override final; void call_incoming_kmers(node_index node, - const IncomingEdgeCallback &callback) const; + const IncomingEdgeCallback &callback) const override final; void call_nodes(const std::function &callback, - const std::function &stop_early) const; + const std::function &stop_early, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const override final; // Traverse the outgoing edge - node_index traverse(node_index node, char next_char) const { + node_index traverse(node_index node, char next_char) const override final { assert(in_graph(node)); // TODO: use `next_kmer()` @@ -93,7 +95,7 @@ class DBGHashFastImpl : public DBGHashFast::DBGHashFastInterface { return get_node_index(kmer); } // Traverse the incoming edge - node_index traverse_back(node_index node, char prev_char) const { + node_index traverse_back(node_index node, char prev_char) const override final { assert(in_graph(node)); // TODO: check previous k-mer in vector similarly to `next_kmer()` @@ -103,70 +105,73 @@ class DBGHashFastImpl : public DBGHashFast::DBGHashFastInterface { return get_node_index(kmer); } - size_t outdegree(node_index) const; - bool has_single_outgoing(node_index node) const { + size_t outdegree(node_index) const override final; + bool has_single_outgoing(node_index node) const override final{ assert(in_graph(node)); return outdegree(node) == 1; } - bool has_multiple_outgoing(node_index node) const { + bool has_multiple_outgoing(node_index node) const override final { assert(in_graph(node)); return outdegree(node) > 1; } - size_t indegree(node_index) const; - bool has_no_incoming(node_index) const; - bool has_single_incoming(node_index node) const { + size_t indegree(node_index) const override final; + bool has_no_incoming(node_index) const override final; + bool has_single_incoming(node_index node) const override final { assert(in_graph(node)); return indegree(node) == 1; } - node_index kmer_to_node(std::string_view kmer) const { + node_index kmer_to_node(std::string_view kmer) const override final { assert(kmer.length() == k_); return get_node_index(seq_encoder_.encode(kmer)); } - std::string get_node_sequence(node_index node) const { + std::string get_node_sequence(node_index node) const override final { assert(in_graph(node)); return seq_encoder_.kmer_to_sequence(get_kmer(node), k_); } - size_t get_k() const { return k_; } - Mode get_mode() const { return mode_; } + size_t get_k() const override final { return k_; } + Mode get_mode() const override final { return mode_; } - uint64_t num_nodes() const { + uint64_t num_nodes() const override final { uint64_t nnodes = 0; call_nodes([&](auto) { nnodes++; }, [](){ return false; }); return nnodes; } - uint64_t max_index() const { return kmers_.size() * kAlphabetSize; } + uint64_t max_index() const override final { return kmers_.size() * kAlphabetSize; } - void serialize(std::ostream &out) const; - void serialize(const std::string &filename) const { + void serialize(std::ostream &out) const override final; + void serialize(const std::string &filename) const override final { std::ofstream out(utils::make_suffix(filename, kExtension), std::ios::binary); serialize(out); } - bool load(std::istream &in); - bool load(const std::string &filename) { + bool load(std::istream &in) override final; + bool load(const std::string &filename) override final { std::ifstream in(utils::make_suffix(filename, kExtension), std::ios::binary); return load(in); } - std::string file_extension() const { return kExtension; } + std::string file_extension() const override final { return kExtension; } - bool operator==(const DeBruijnGraph &other) const; + bool operator==(const DeBruijnGraph &other) const override final; - const std::string& alphabet() const { return seq_encoder_.alphabet; } + const std::string& alphabet() const override final { return seq_encoder_.alphabet; } - private: - bool in_graph(node_index node) const { - assert(node > 0 && node <= max_index()); + bool in_graph(node_index node) const override final { + if (node == npos) { + return false; + } + assert(DBGHashFast::DBGHashFastInterface::in_graph(node)); Flags flags = bits_[node_to_bucket(node)]; return (flags >> ((node - 1) % kAlphabetSize)) & static_cast(1); } + private: Vector> sequence_to_kmers(std::string_view sequence, bool canonical = false) const { return seq_encoder_.sequence_to_kmers(sequence, k_, canonical); @@ -234,8 +239,8 @@ void DBGHashFastImpl::add_sequence(std::string_view sequence, for (const auto &kmer_pair : sequence_to_kmers(sequence)) { // putting the structured binding in the for statement above crashes gcc 8.2.0 - const auto &[kmer, is_valid] = kmer_pair; - if (!is_valid) { + const auto &[kmer, in_graph] = kmer_pair; + if (!in_graph) { previous_valid = false; continue; } @@ -290,7 +295,7 @@ void DBGHashFastImpl::map_to_nodes_sequentially( std::string_view sequence, const std::function &callback, const std::function &terminate) const { - for (const auto &[kmer, is_valid] : sequence_to_kmers(sequence)) { + for (const auto &[kmer, in_graph] : sequence_to_kmers(sequence)) { if (terminate()) return; @@ -298,7 +303,7 @@ void DBGHashFastImpl::map_to_nodes_sequentially( || get_node_index(kmer) == npos || kmer == get_kmer(get_node_index(kmer))); - callback(is_valid ? get_node_index(kmer) : npos); + callback(in_graph ? get_node_index(kmer) : npos); // TODO: `next_kmer()` could speed this up } } @@ -309,13 +314,13 @@ template void DBGHashFastImpl::map_to_nodes(std::string_view sequence, const std::function &callback, const std::function &terminate) const { - for (const auto &[kmer, is_valid] : sequence_to_kmers(sequence, mode_ == CANONICAL)) { + for (const auto &[kmer, in_graph] : sequence_to_kmers(sequence, mode_ == CANONICAL)) { if (terminate()) return; assert(!get_node_index(kmer) || kmer == get_kmer(get_node_index(kmer))); - callback(is_valid ? get_node_index(kmer) : npos); + callback(in_graph ? get_node_index(kmer) : npos); // TODO: `next_kmer()` could speed this up } } @@ -530,13 +535,19 @@ DBGHashFastImpl::get_node_index(const Kmer &kmer) const { template void DBGHashFastImpl::call_nodes(const std::function &callback, - const std::function &stop_early) const { + const std::function &stop_early, + size_t num_threads, + size_t batch_size) const { + #pragma omp parallel for num_threads(num_threads) schedule(static, batch_size) for (size_t i = 0; i < kmers_.size(); ++i) { + if (stop_early()) + continue; + Flags flags = bits_[i]; for (TAlphabet c = 0; c < kAlphabetSize; ++c, flags >>= 1) { if (stop_early()) - return; + break; if (flags & static_cast(1)) callback(bucket_to_node(i) + c); diff --git a/metagraph/src/graph/representation/hash/dbg_hash_fast.hpp b/metagraph/src/graph/representation/hash/dbg_hash_fast.hpp index a200e98f47..a9b849301b 100644 --- a/metagraph/src/graph/representation/hash/dbg_hash_fast.hpp +++ b/metagraph/src/graph/representation/hash/dbg_hash_fast.hpp @@ -23,7 +23,7 @@ class DBGHashFast : public DeBruijnGraph { // all new real nodes and all new dummy node indexes allocated in graph. // In short: max_index[after] = max_index[before] + {num_invocations}. void add_sequence(std::string_view sequence, - const std::function &on_insertion = [](node_index) {}) { + const std::function &on_insertion = [](node_index) {}) override final { hash_dbg_->add_sequence(sequence, on_insertion); } @@ -31,7 +31,7 @@ class DBGHashFast : public DeBruijnGraph { // and run callback for each node until the termination condition is satisfied void map_to_nodes(std::string_view sequence, const std::function &callback, - const std::function &terminate = [](){ return false; }) const { + const std::function &terminate = [](){ return false; }) const override final { hash_dbg_->map_to_nodes(sequence, callback, terminate); } @@ -41,73 +41,77 @@ class DBGHashFast : public DeBruijnGraph { // In canonical mode, non-canonical k-mers are NOT mapped to canonical ones void map_to_nodes_sequentially(std::string_view sequence, const std::function &callback, - const std::function &terminate = [](){ return false; }) const { + const std::function &terminate = [](){ return false; }) const override final { hash_dbg_->map_to_nodes_sequentially(sequence, callback, terminate); } void call_nodes(const std::function &callback, - const std::function &stop_early = [](){ return false; }) const { - hash_dbg_->call_nodes(callback, stop_early); + const std::function &stop_early = [](){ return false; }, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const override final { + hash_dbg_->call_nodes(callback, stop_early, num_threads, batch_size); } void call_outgoing_kmers(node_index node, - const OutgoingEdgeCallback &callback) const { + const OutgoingEdgeCallback &callback) const override final { hash_dbg_->call_outgoing_kmers(node, callback); } void call_incoming_kmers(node_index node, - const IncomingEdgeCallback &callback) const { + const IncomingEdgeCallback &callback) const override final { hash_dbg_->call_incoming_kmers(node, callback); } // Traverse the outgoing edge - node_index traverse(node_index node, char next_char) const { + node_index traverse(node_index node, char next_char) const override final { return hash_dbg_->traverse(node, next_char); } // Traverse the incoming edge - node_index traverse_back(node_index node, char prev_char) const { + node_index traverse_back(node_index node, char prev_char) const override final { return hash_dbg_->traverse_back(node, prev_char); } - size_t outdegree(node_index node) const { return hash_dbg_->outdegree(node); } - bool has_single_outgoing(node_index node) const { return hash_dbg_->has_single_outgoing(node); } - bool has_multiple_outgoing(node_index node) const { return hash_dbg_->has_multiple_outgoing(node); } + size_t outdegree(node_index node) const override final { return hash_dbg_->outdegree(node); } + bool has_single_outgoing(node_index node) const override final { return hash_dbg_->has_single_outgoing(node); } + bool has_multiple_outgoing(node_index node) const override final { return hash_dbg_->has_multiple_outgoing(node); } - size_t indegree(node_index node) const { return hash_dbg_->indegree(node); } - bool has_no_incoming(node_index node) const { return hash_dbg_->has_no_incoming(node); } - bool has_single_incoming(node_index node) const { return hash_dbg_->has_single_incoming(node); } + size_t indegree(node_index node) const override final { return hash_dbg_->indegree(node); } + bool has_no_incoming(node_index node) const override final { return hash_dbg_->has_no_incoming(node); } + bool has_single_incoming(node_index node) const override final { return hash_dbg_->has_single_incoming(node); } - node_index kmer_to_node(std::string_view kmer) const { + node_index kmer_to_node(std::string_view kmer) const override final { return hash_dbg_->kmer_to_node(kmer); } - std::string get_node_sequence(node_index node) const { + std::string get_node_sequence(node_index node) const override final { return hash_dbg_->get_node_sequence(node); } - size_t get_k() const { return hash_dbg_->get_k(); } - Mode get_mode() const { return hash_dbg_->get_mode(); } + size_t get_k() const override final { return hash_dbg_->get_k(); } + Mode get_mode() const override final { return hash_dbg_->get_mode(); } - uint64_t num_nodes() const { return hash_dbg_->num_nodes(); } - uint64_t max_index() const { return hash_dbg_->max_index(); } + uint64_t num_nodes() const override final { return hash_dbg_->num_nodes(); } + uint64_t max_index() const override final { return hash_dbg_->max_index(); } + + bool in_graph(node_index node) const override final { return hash_dbg_->in_graph(node); } void serialize(std::ostream &out) const { hash_dbg_->serialize(out); } - void serialize(const std::string &filename) const { hash_dbg_->serialize(filename); } + void serialize(const std::string &filename) const override final { hash_dbg_->serialize(filename); } bool load(std::istream &in); - bool load(const std::string &filename); + bool load(const std::string &filename) override final; - std::string file_extension() const { return kExtension; } + std::string file_extension() const override final { return kExtension; } - bool operator==(const DeBruijnGraph &other) const { + bool operator==(const DeBruijnGraph &other) const override final { if (this == &other) return true; return other == *hash_dbg_; } - const std::string& alphabet() const { return hash_dbg_->alphabet(); } + const std::string& alphabet() const override final { return hash_dbg_->alphabet(); } static constexpr auto kExtension = ".hashfastdbg"; diff --git a/metagraph/src/graph/representation/hash/dbg_hash_ordered.cpp b/metagraph/src/graph/representation/hash/dbg_hash_ordered.cpp index 0cf773eb78..7aa911e815 100644 --- a/metagraph/src/graph/representation/hash/dbg_hash_ordered.cpp +++ b/metagraph/src/graph/representation/hash/dbg_hash_ordered.cpp @@ -139,8 +139,8 @@ void DBGHashOrderedImpl::add_sequence(std::string_view sequence, node_index prev_pos = kmers_.size(); #endif - for (const auto &[kmer, is_valid] : sequence_to_kmers(sequence)) { - skipped.push_back(skip() || !is_valid); + for (const auto &[kmer, in_graph] : sequence_to_kmers(sequence)) { + skipped.push_back(skip() || !in_graph); if (skipped.back()) continue; @@ -170,7 +170,7 @@ void DBGHashOrderedImpl::add_sequence(std::string_view sequence, reverse_complement(rev_comp.begin(), rev_comp.end()); auto it = skipped.end(); - for (const auto &[kmer, is_valid] : sequence_to_kmers(rev_comp)) { + for (const auto &[kmer, in_graph] : sequence_to_kmers(rev_comp)) { if (!*(--it) && kmers_.insert(kmer).second) on_insertion(kmers_.size()); } @@ -190,12 +190,12 @@ void DBGHashOrderedImpl::map_to_nodes_sequentially( node_index prev_index = n_nodes; #endif - for (const auto &[kmer, is_valid] : sequence_to_kmers(sequence)) { + for (const auto &[kmer, in_graph] : sequence_to_kmers(sequence)) { if (terminate()) return; #if _DBGHash_LINEAR_PATH_OPTIMIZATIONS - if (!is_valid) { + if (!in_graph) { prev_index = n_nodes; callback(npos); } else if (prev_index < n_nodes && get_kmer(prev_index + 1) == kmer) { @@ -205,7 +205,7 @@ void DBGHashOrderedImpl::map_to_nodes_sequentially( callback(prev_index = get_index(kmer)); } #else - callback(is_valid ? get_index(kmer) : npos); + callback(in_graph ? get_index(kmer) : npos); #endif } } @@ -221,12 +221,12 @@ void DBGHashOrderedImpl::map_to_nodes(std::string_view sequence, node_index prev_index = n_nodes; #endif - for (const auto &[kmer, is_valid] : sequence_to_kmers(sequence, mode_ == CANONICAL)) { + for (const auto &[kmer, in_graph] : sequence_to_kmers(sequence, mode_ == CANONICAL)) { if (terminate()) return; #if _DBGHash_LINEAR_PATH_OPTIMIZATIONS - if (!is_valid) { + if (!in_graph) { prev_index = n_nodes; callback(npos); } else if (prev_index < n_nodes && get_kmer(prev_index + 1) == kmer) { @@ -236,7 +236,7 @@ void DBGHashOrderedImpl::map_to_nodes(std::string_view sequence, callback(prev_index = get_index(kmer)); } #else - callback(is_valid ? get_index(kmer) : npos); + callback(in_graph ? get_index(kmer) : npos); #endif } } @@ -244,7 +244,7 @@ void DBGHashOrderedImpl::map_to_nodes(std::string_view sequence, template void DBGHashOrderedImpl::call_outgoing_kmers(node_index node, const OutgoingEdgeCallback &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); const auto &kmer = get_kmer(node); @@ -261,7 +261,7 @@ void DBGHashOrderedImpl::call_outgoing_kmers(node_index node, template void DBGHashOrderedImpl::call_incoming_kmers(node_index node, const IncomingEdgeCallback &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); const auto &kmer = get_kmer(node); @@ -278,7 +278,7 @@ void DBGHashOrderedImpl::call_incoming_kmers(node_index node, template typename DBGHashOrderedImpl::node_index DBGHashOrderedImpl::traverse(node_index node, char next_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); auto kmer = get_kmer(node); kmer.to_next(k_, seq_encoder_.encode(next_char)); @@ -297,7 +297,7 @@ DBGHashOrderedImpl::traverse(node_index node, char next_char) const { template typename DBGHashOrderedImpl::node_index DBGHashOrderedImpl::traverse_back(node_index node, char prev_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); auto kmer = get_kmer(node); kmer.to_prev(k_, seq_encoder_.encode(prev_char)); @@ -315,7 +315,7 @@ DBGHashOrderedImpl::traverse_back(node_index node, char prev_char) const { template size_t DBGHashOrderedImpl::outdegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); size_t outdegree = 0; @@ -334,7 +334,7 @@ size_t DBGHashOrderedImpl::outdegree(node_index node) const { template bool DBGHashOrderedImpl::has_single_outgoing(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); bool outgoing_edge_detected = false; @@ -357,7 +357,7 @@ bool DBGHashOrderedImpl::has_single_outgoing(node_index node) const { template bool DBGHashOrderedImpl::has_multiple_outgoing(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); bool outgoing_edge_detected = false; @@ -380,7 +380,7 @@ bool DBGHashOrderedImpl::has_multiple_outgoing(node_index node) const { template size_t DBGHashOrderedImpl::indegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); size_t indegree = 0; @@ -399,7 +399,7 @@ size_t DBGHashOrderedImpl::indegree(node_index node) const { template bool DBGHashOrderedImpl::has_no_incoming(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); const auto &kmer = get_kmer(node); @@ -416,7 +416,7 @@ bool DBGHashOrderedImpl::has_no_incoming(node_index node) const { template bool DBGHashOrderedImpl::has_single_incoming(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); bool incoming_edge_detected = false; @@ -447,7 +447,7 @@ DBGHashOrderedImpl::kmer_to_node(std::string_view kmer) const { template std::string DBGHashOrderedImpl::get_node_sequence(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); return seq_encoder_.kmer_to_sequence(get_kmer(node), k_); } @@ -560,7 +560,7 @@ DBGHashOrderedImpl::get_index(const Kmer &kmer) const { template const KMER& DBGHashOrderedImpl::get_kmer(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); assert(node == get_index(*(kmers_.nth(node - 1)))); return *(kmers_.nth(node - 1)); diff --git a/metagraph/src/graph/representation/hash/dbg_hash_string.cpp b/metagraph/src/graph/representation/hash/dbg_hash_string.cpp index ac7132be68..361d3ab0c4 100644 --- a/metagraph/src/graph/representation/hash/dbg_hash_string.cpp +++ b/metagraph/src/graph/representation/hash/dbg_hash_string.cpp @@ -63,7 +63,7 @@ DBGHashString::traverse(node_index node, char next_char) const { DBGHashString::node_index DBGHashString::traverse_back(node_index node, char prev_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); std::string kmer = get_node_sequence(node); kmer.pop_back(); return kmer_to_node(prev_char + kmer); @@ -72,7 +72,7 @@ DBGHashString::traverse_back(node_index node, char prev_char) const { void DBGHashString ::call_outgoing_kmers(node_index node, const OutgoingEdgeCallback &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); auto prefix = get_node_sequence(node).substr(1); @@ -86,7 +86,7 @@ ::call_outgoing_kmers(node_index node, void DBGHashString ::call_incoming_kmers(node_index node, const IncomingEdgeCallback &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); std::string suffix = get_node_sequence(node); suffix.pop_back(); @@ -99,7 +99,7 @@ ::call_incoming_kmers(node_index node, } size_t DBGHashString::outdegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); size_t outdegree = 0; @@ -117,7 +117,7 @@ size_t DBGHashString::outdegree(node_index node) const { } bool DBGHashString::has_single_outgoing(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); bool outgoing_edge_detected = false; @@ -139,7 +139,7 @@ bool DBGHashString::has_single_outgoing(node_index node) const { } bool DBGHashString::has_multiple_outgoing(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); bool outgoing_edge_detected = false; @@ -161,7 +161,7 @@ bool DBGHashString::has_multiple_outgoing(node_index node) const { } size_t DBGHashString::indegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); size_t indegree = 0; @@ -180,7 +180,7 @@ size_t DBGHashString::indegree(node_index node) const { } bool DBGHashString::has_no_incoming(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); auto prev = get_node_sequence(node); prev.pop_back(); @@ -197,7 +197,7 @@ bool DBGHashString::has_no_incoming(node_index node) const { } bool DBGHashString::has_single_incoming(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); bool incoming_edge_detected = false; @@ -239,7 +239,7 @@ DBGHashString::kmer_to_node(std::string_view kmer) const { } std::string DBGHashString::get_node_sequence(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); assert((kmers_.begin() + (node - 1))->length() == k_); return *(kmers_.begin() + (node - 1)); } diff --git a/metagraph/src/graph/representation/hash/dbg_sshash.cpp b/metagraph/src/graph/representation/hash/dbg_sshash.cpp index 4275a0f28c..df78ef2c0e 100644 --- a/metagraph/src/graph/representation/hash/dbg_sshash.cpp +++ b/metagraph/src/graph/representation/hash/dbg_sshash.cpp @@ -186,7 +186,7 @@ DBGSSHash::node_index DBGSSHash::reverse_complement(node_index node) const { if (node == npos) return npos; - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); if (node > dict_size()) return node - dict_size(); @@ -220,13 +220,13 @@ void DBGSSHash::map_to_nodes_sequentially(std::string_view sequence, } DBGSSHash::node_index DBGSSHash::traverse(node_index node, char next_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); // TODO: if a node is in the middle of a unitig, then we only need to check the next node index return kmer_to_node(get_node_sequence(node).substr(1) + next_char); } DBGSSHash::node_index DBGSSHash::traverse_back(node_index node, char prev_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); // TODO: if a node is in the middle of a unitig, then we only need to check the previous node index std::string string_kmer = prev_char + get_node_sequence(node); string_kmer.pop_back(); @@ -237,7 +237,7 @@ template void DBGSSHash::call_outgoing_kmers_with_rc( node_index node, const std::function& callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); std::string kmer = get_node_sequence(node); std::visit([&](const auto &dict) { using kmer_t = get_kmer_t; @@ -264,7 +264,7 @@ template void DBGSSHash::call_incoming_kmers_with_rc( node_index node, const std::function& callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); std::string kmer = get_node_sequence(node); std::visit([&](const auto &dict) { using kmer_t = get_kmer_t; @@ -314,14 +314,14 @@ void DBGSSHash::call_incoming_kmers(node_index node, } size_t DBGSSHash::outdegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); size_t res = 0; adjacent_outgoing_nodes(node, [&](node_index) { ++res; }); return res; } size_t DBGSSHash::indegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); size_t res = 0; adjacent_incoming_nodes(node, [&](node_index) { ++res; }); return res; @@ -329,13 +329,26 @@ size_t DBGSSHash::indegree(node_index node) const { void DBGSSHash::call_nodes( const std::function& callback, - const std::function &terminate) const { - for (size_t node_idx = 1; !terminate() && node_idx <= dict_size(); ++node_idx) { + const std::function &terminate, + size_t num_threads, + size_t batch_size) const { + #pragma omp parallel for num_threads(num_threads) schedule(static, batch_size) + for (size_t node_idx = 1; node_idx <= dict_size(); ++node_idx) { + if (terminate()) + continue; + callback(node_idx); } + if (terminate()) + return; + if (mode_ == CANONICAL) { - for (size_t node_idx = 1; !terminate() && node_idx <= dict_size(); ++node_idx) { + #pragma omp parallel for num_threads(num_threads) schedule(static, batch_size) + for (size_t node_idx = 1; node_idx <= dict_size(); ++node_idx) { + if (terminate()) + continue; + size_t rc_node_idx = reverse_complement(node_idx); if (rc_node_idx != node_idx) callback(rc_node_idx); @@ -371,7 +384,7 @@ DBGSSHash::node_index DBGSSHash::kmer_to_node(std::string_view kmer) const { } std::string DBGSSHash::get_node_sequence(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); std::string str_kmer(k_, ' '); node_index node_canonical = node > dict_size() ? node - dict_size() : node; uint64_t ssh_idx = graph_index_to_sshash(node_canonical); diff --git a/metagraph/src/graph/representation/hash/dbg_sshash.hpp b/metagraph/src/graph/representation/hash/dbg_sshash.hpp index 12c1a407bd..8d503b6206 100644 --- a/metagraph/src/graph/representation/hash/dbg_sshash.hpp +++ b/metagraph/src/graph/representation/hash/dbg_sshash.hpp @@ -76,7 +76,9 @@ class DBGSSHash : public DeBruijnGraph { node_index traverse_back(node_index node, char prev_char) const override; void call_nodes(const std::function& callback, - const std::function &terminate = [](){ return false; }) const override; + const std::function &terminate = [](){ return false; }, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const override; size_t outdegree(node_index) const override; diff --git a/metagraph/src/graph/representation/masked_graph.cpp b/metagraph/src/graph/representation/masked_graph.cpp index 319a936237..2af9d9f108 100644 --- a/metagraph/src/graph/representation/masked_graph.cpp +++ b/metagraph/src/graph/representation/masked_graph.cpp @@ -11,76 +11,76 @@ namespace graph { // Traverse the outgoing edge MaskedDeBruijnGraph::node_index MaskedDeBruijnGraph ::traverse(node_index node, char next_char) const { - assert(in_subgraph(node)); + assert(in_graph(node)); auto index = graph_->traverse(node, next_char); - return index && in_subgraph(index) ? index : npos; + return index && in_graph(index) ? index : npos; } // Traverse the incoming edge MaskedDeBruijnGraph::node_index MaskedDeBruijnGraph ::traverse_back(node_index node, char prev_char) const { - assert(in_subgraph(node)); + assert(in_graph(node)); auto index = graph_->traverse_back(node, prev_char); - return index && in_subgraph(index) ? index : npos; + return index && in_graph(index) ? index : npos; } size_t MaskedDeBruijnGraph::outdegree(node_index node) const { - assert(in_subgraph(node)); + assert(in_graph(node)); size_t outdegree = 0; graph_->adjacent_outgoing_nodes(node, [&](auto index) { - outdegree += in_subgraph(index); + outdegree += in_graph(index); }); return outdegree; } size_t MaskedDeBruijnGraph::indegree(node_index node) const { - assert(in_subgraph(node)); + assert(in_graph(node)); size_t indegree = 0; graph_->adjacent_incoming_nodes(node, [&](auto index) { - indegree += in_subgraph(index); + indegree += in_graph(index); }); return indegree; } void MaskedDeBruijnGraph ::adjacent_outgoing_nodes(node_index node, const std::function &callback) const { - assert(in_subgraph(node)); + assert(in_graph(node)); graph_->adjacent_outgoing_nodes(node, [&](auto node) { - if (in_subgraph(node)) + if (in_graph(node)) callback(node); }); } void MaskedDeBruijnGraph ::adjacent_incoming_nodes(node_index node, const std::function &callback) const { - assert(in_subgraph(node)); + assert(in_graph(node)); graph_->adjacent_incoming_nodes(node, [&](auto node) { - if (in_subgraph(node)) + if (in_graph(node)) callback(node); }); } void MaskedDeBruijnGraph ::call_outgoing_kmers(node_index kmer, const OutgoingEdgeCallback &callback) const { - assert(in_subgraph(kmer)); + assert(in_graph(kmer)); graph_->call_outgoing_kmers(kmer, [&](const auto &index, auto c) { - if (in_subgraph(index)) + if (in_graph(index)) callback(index, c); }); } void MaskedDeBruijnGraph ::call_incoming_kmers(node_index kmer, const IncomingEdgeCallback &callback) const { - assert(in_subgraph(kmer)); + assert(in_graph(kmer)); graph_->call_incoming_kmers(kmer, [&](const auto &index, auto c) { - if (in_subgraph(index)) + if (in_graph(index)) callback(index, c); }); } @@ -91,14 +91,14 @@ bit_vector_stat get_boss_mask(const DBGSuccinct &dbg_succ, sdsl::bit_vector mask_bv(dbg_succ.get_boss().num_edges() + 1, false); if (only_valid_nodes_in_mask) { kmers_in_graph.call_ones([&](auto i) { - assert(dbg_succ.kmer_to_boss_index(i)); - mask_bv[dbg_succ.kmer_to_boss_index(i)] = true; + assert(i); + mask_bv[i] = true; }); } else { dbg_succ.call_nodes([&](auto i) { - assert(dbg_succ.kmer_to_boss_index(i)); + assert(i); if (kmers_in_graph[i]) - mask_bv[dbg_succ.kmer_to_boss_index(i)] = true; + mask_bv[i] = true; }); } return bit_vector_stat(std::move(mask_bv)); @@ -113,7 +113,7 @@ void MaskedDeBruijnGraph::call_sequences(const CallPath &callback, dbg_succ->get_boss().call_sequences([&](std::string&& sequence, auto&& path) { for (auto &node : path) { - node = dbg_succ->boss_to_kmer_index(node); + node = dbg_succ->validate_edge(node); } callback(sequence, path); @@ -134,7 +134,7 @@ void MaskedDeBruijnGraph::call_unitigs(const CallPath &callback, dbg_succ->get_boss().call_unitigs([&](std::string&& sequence, auto&& path) { for (auto &node : path) { - node = dbg_succ->boss_to_kmer_index(node); + node = dbg_succ->validate_edge(node); } callback(sequence, path); @@ -150,7 +150,9 @@ void MaskedDeBruijnGraph::call_unitigs(const CallPath &callback, void MaskedDeBruijnGraph ::call_nodes(const std::function &callback, - const std::function &stop_early) const { + const std::function &stop_early, + size_t num_threads, + size_t batch_size) const { assert(max_index() + 1 == kmers_in_graph_->size()); bool stop = false; @@ -158,26 +160,38 @@ ::call_nodes(const std::function &callback, if (only_valid_nodes_in_mask_) { // iterate only through the nodes marked in the mask // TODO: add terminate to call_ones - kmers_in_graph_->call_ones([&](auto index) { - if (stop || !index) - return; + size_t batch_size = kmers_in_graph_->size() / num_threads; - assert(in_subgraph(index)); + #pragma omp parallel for num_threads(num_threads) schedule(static) + for (node_index begin = 0; begin <= kmers_in_graph_->size(); begin += batch_size) { + if (stop) + continue; - if (stop_early()) { - stop = true; - } else { - callback(index); - } - }); + size_t end = std::min(begin + batch_size, kmers_in_graph_->size()); + + kmers_in_graph_->call_ones_in_range(begin, end, [&](auto index) { + if (stop || !index) + return; + + assert(in_graph(index)); + + if (stop_early()) { + stop = true; + } else { + callback(index); + } + }); + } } else { // call all nodes in the base graph and check the mask graph_->call_nodes( [&](auto index) { - if (in_subgraph(index)) + if (in_graph(index)) callback(index); }, - stop_early + stop_early, + num_threads, + batch_size ); } } @@ -195,7 +209,7 @@ ::call_kmers(const std::function &callback throw early_term(); if (index) { - assert(in_subgraph(index)); + assert(in_graph(index)); // TODO: make this more efficient callback(index, get_node_sequence(index)); } @@ -204,7 +218,7 @@ ::call_kmers(const std::function &callback } else { // call all nodes in the base graph and check the mask graph_->call_kmers([&](node_index index, const std::string &seq) { - if (in_subgraph(index)) + if (in_graph(index)) callback(index, seq); }, stop_early); } @@ -218,7 +232,7 @@ void MaskedDeBruijnGraph::map_to_nodes(std::string_view sequence, graph_->map_to_nodes( sequence, [&](const node_index &index) { - callback(index && in_subgraph(index) ? index : npos); + callback(index && in_graph(index) ? index : npos); }, terminate ); @@ -234,7 +248,7 @@ ::map_to_nodes_sequentially(std::string_view sequence, graph_->map_to_nodes_sequentially( sequence, [&](const node_index &index) { - callback(index && in_subgraph(index) ? index : npos); + callback(index && in_graph(index) ? index : npos); }, terminate ); @@ -243,7 +257,7 @@ ::map_to_nodes_sequentially(std::string_view sequence, // Get string corresponding to |node_index|. // Note: Not efficient if sequences in nodes overlap. Use sparingly. std::string MaskedDeBruijnGraph::get_node_sequence(node_index index) const { - assert(in_subgraph(index)); + assert(in_graph(index)); return graph_->get_node_sequence(index); } diff --git a/metagraph/src/graph/representation/masked_graph.hpp b/metagraph/src/graph/representation/masked_graph.hpp index 7da93deb80..2d2fdd2e62 100644 --- a/metagraph/src/graph/representation/masked_graph.hpp +++ b/metagraph/src/graph/representation/masked_graph.hpp @@ -99,10 +99,12 @@ class MaskedDeBruijnGraph : public DBGWrapper { virtual size_t indegree(node_index) const override; virtual void call_nodes(const std::function &callback, - const std::function &stop_early = [](){ return false; }) const override; + const std::function &stop_early = [](){ return false; }, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const override; - virtual inline bool in_subgraph(node_index node) const { - assert(node > 0 && node <= max_index()); + virtual inline bool in_graph(node_index node) const override final { + assert(DBGWrapper::in_graph(node)); assert(kmers_in_graph_.get()); return (*kmers_in_graph_)[node]; diff --git a/metagraph/src/graph/representation/rc_dbg.hpp b/metagraph/src/graph/representation/rc_dbg.hpp index 8becf30a8b..5adad8fc01 100644 --- a/metagraph/src/graph/representation/rc_dbg.hpp +++ b/metagraph/src/graph/representation/rc_dbg.hpp @@ -130,9 +130,11 @@ class RCDBG : public DBGWrapper { virtual void call_nodes(const std::function &callback, const std::function &stop_early - = [](){ return false; }) const override { + = [](){ return false; }, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const override final { // all node IDs are the same - graph_->call_nodes(callback, stop_early); + graph_->call_nodes(callback, stop_early, num_threads, batch_size); } virtual void call_kmers(const std::function &callback, diff --git a/metagraph/src/graph/representation/succinct/boss.cpp b/metagraph/src/graph/representation/succinct/boss.cpp index 54060256db..d4a2ceda9d 100644 --- a/metagraph/src/graph/representation/succinct/boss.cpp +++ b/metagraph/src/graph/representation/succinct/boss.cpp @@ -2782,6 +2782,7 @@ void BOSS::row_diff_traverse(size_t num_threads, traverse_dummy_edges(*this, NULL, NULL, num_threads, [&](edge_index edge, size_t depth) { assert(depth <= get_k()); + unset_bit(terminal->data(), edge, async); set_bit(dummy.data(), edge, async); if (depth < get_k()) set_bit(visited.data(), edge, async); diff --git a/metagraph/src/graph/representation/succinct/boss.hpp b/metagraph/src/graph/representation/succinct/boss.hpp index 49a8883616..dce2d6c568 100644 --- a/metagraph/src/graph/representation/succinct/boss.hpp +++ b/metagraph/src/graph/representation/succinct/boss.hpp @@ -168,6 +168,7 @@ class BOSS { // pick the row-diff successor if (&rd_succ != last_ && !get_last(edge - 1)) { while (!rd_succ[edge]) { + assert(edge > 0); edge--; assert(!get_last(edge) && "a row-diff successor must exist"); } diff --git a/metagraph/src/graph/representation/succinct/dbg_succinct.cpp b/metagraph/src/graph/representation/succinct/dbg_succinct.cpp index 915c76af49..20c6d17fc8 100644 --- a/metagraph/src/graph/representation/succinct/dbg_succinct.cpp +++ b/metagraph/src/graph/representation/succinct/dbg_succinct.cpp @@ -82,27 +82,27 @@ bool DBGSuccinct::find(std::string_view sequence, // Traverse the outgoing edge node_index DBGSuccinct::traverse(node_index node, char next_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); // return npos if the character is invalid if (boss_graph_->encode(next_char) == boss_graph_->alph_size) return npos; // dbg node is a boss edge - BOSS::edge_index boss_edge = kmer_to_boss_index(node); + BOSS::edge_index boss_edge = node; boss_edge = boss_graph_->fwd(boss_edge); - return boss_to_kmer_index( + return validate_edge( boss_graph_->pick_edge(boss_edge, boss_graph_->encode(next_char)) ); } // Traverse the incoming edge node_index DBGSuccinct::traverse_back(node_index node, char prev_char) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); // dbg node is a boss edge - BOSS::edge_index edge = boss_graph_->bwd(kmer_to_boss_index(node)); - return boss_to_kmer_index( + BOSS::edge_index edge = boss_graph_->bwd(node); + return validate_edge( boss_graph_->pick_incoming_edge(edge, boss_graph_->encode(prev_char)) ); } @@ -128,11 +128,11 @@ inline void call_outgoing(const BOSS &boss, void DBGSuccinct::call_outgoing_kmers(node_index node, const OutgoingEdgeCallback &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - call_outgoing(*boss_graph_, kmer_to_boss_index(node), [&](auto i) { - auto next = boss_to_kmer_index(i); - if (next != npos) + call_outgoing(*boss_graph_, node, [&](auto i) { + auto next = i; + if (in_graph(next)) callback(next, boss_graph_->decode(boss_graph_->get_W(i) % boss_graph_->alph_size)); }); @@ -140,9 +140,9 @@ void DBGSuccinct::call_outgoing_kmers(node_index node, void DBGSuccinct::call_incoming_kmers(node_index node, const IncomingEdgeCallback &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto edge = kmer_to_boss_index(node); + auto edge = node; boss_graph_->call_incoming_to_target(boss_graph_->bwd(edge), boss_graph_->get_node_last_value(edge), @@ -150,8 +150,8 @@ void DBGSuccinct::call_incoming_kmers(node_index node, assert(boss_graph_->get_W(incoming_boss_edge) % boss_graph_->alph_size == boss_graph_->get_node_last_value(edge)); - auto prev = boss_to_kmer_index(incoming_boss_edge); - if (prev != npos) { + auto prev = incoming_boss_edge; + if (in_graph(prev)) { callback(prev, boss_graph_->decode( boss_graph_->get_minus_k_value(incoming_boss_edge, get_k() - 2).first @@ -164,20 +164,20 @@ void DBGSuccinct::call_incoming_kmers(node_index node, void DBGSuccinct::adjacent_outgoing_nodes(node_index node, const std::function &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - call_outgoing(*boss_graph_, kmer_to_boss_index(node), [&](auto i) { - auto next = boss_to_kmer_index(i); - if (next != npos) + call_outgoing(*boss_graph_, node, [&](auto i) { + auto next = i; + if (in_graph(next)) callback(next); }); } void DBGSuccinct::adjacent_incoming_nodes(node_index node, const std::function &callback) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto edge = kmer_to_boss_index(node); + auto edge = node; boss_graph_->call_incoming_to_target(boss_graph_->bwd(edge), boss_graph_->get_node_last_value(edge), @@ -185,13 +185,49 @@ void DBGSuccinct::adjacent_incoming_nodes(node_index node, assert(boss_graph_->get_W(incoming_boss_edge) % boss_graph_->alph_size == boss_graph_->get_node_last_value(edge)); - auto prev = boss_to_kmer_index(incoming_boss_edge); - if (prev != npos) + auto prev = incoming_boss_edge; + if (in_graph(prev)) callback(prev); } ); } +void DBGSuccinct::call_nodes(const std::function &callback, + const std::function &terminate, + size_t num_threads, + size_t batch_size) const { + if (valid_edges_) { + size_t block_size = max_index() / num_threads; + + #pragma omp parallel for num_threads(num_threads) schedule(static) + for (size_t begin = 1; begin <= max_index(); begin += block_size) { + if (terminate()) + continue; + + size_t end = std::min(begin + block_size, + static_cast(max_index() + 1)); + try { + valid_edges_->call_ones_in_range(begin, end, [&](uint64_t i) { + callback(i); + if (terminate()) + throw early_term(); + }); + } catch (early_term&) {} + } + } else if (!terminate()) { + try { + call_sequences([&](const std::string&, const auto &path) { + for (node_index node : path) { + if (terminate()) + throw early_term(); + + callback(node); + } + }, num_threads); + } catch (early_term&) {} + } +} + void DBGSuccinct::add_sequence(std::string_view sequence, const std::function &on_insertion) { if (sequence.size() < get_k()) @@ -223,7 +259,7 @@ void DBGSuccinct::add_sequence(std::string_view sequence, // Call all new nodes inserted including the dummy ones, unless they // are masked out. - on_insertion(boss_to_kmer_index(new_boss_edge)); + on_insertion(validate_edge(new_boss_edge)); } assert(!valid_edges_.get() || !(*valid_edges_)[0]); @@ -234,9 +270,9 @@ void DBGSuccinct::add_sequence(std::string_view sequence, } std::string DBGSuccinct::get_node_sequence(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto boss_edge = kmer_to_boss_index(node); + auto boss_edge = node; return boss_graph_->get_node_str(boss_edge) + boss_graph_->decode(boss_graph_->get_W(boss_edge) % boss_graph_->alph_size); @@ -256,7 +292,7 @@ void DBGSuccinct::map_to_nodes_sequentially(std::string_view sequence, boss_graph_->map_to_edges( sequence, - [&](BOSS::edge_index i) { callback(boss_to_kmer_index(i)); }, + [&](BOSS::edge_index i) { callback(validate_edge(i)); }, terminate, [&]() { if (!is_missing()) @@ -297,8 +333,8 @@ ::call_nodes_with_suffix_matching_longest_prefix( assert(first == last); auto edge = boss_graph_->pick_edge(last, encoded.back()); if (edge) { - auto kmer_index = boss_to_kmer_index(edge); - if (kmer_index != npos) { + auto kmer_index = edge; + if (in_graph(kmer_index)) { assert(str.size() == get_k()); assert(get_node_sequence(kmer_index) == str); callback(kmer_index, get_k()); @@ -322,8 +358,8 @@ ::call_nodes_with_suffix_matching_longest_prefix( boss_graph_->call_incoming_to_target(boss_graph_->bwd(e), boss_graph_->get_node_last_value(e), [&](BOSS::edge_index incoming_edge_idx) { - auto kmer_index = boss_to_kmer_index(incoming_edge_idx); - if (kmer_index != npos) { + auto kmer_index = incoming_edge_idx; + if (in_graph(kmer_index)) { assert(get_node_sequence(kmer_index).substr(get_k() - match_size) == str.substr(0, match_size)); nodes.emplace_back(kmer_index); @@ -344,8 +380,8 @@ ::call_nodes_with_suffix_matching_longest_prefix( boss_graph_->call_incoming_to_target(boss_graph_->bwd(e), boss_graph_->get_node_last_value(e), [&](BOSS::edge_index incoming_edge_idx) { - auto kmer_index = boss_to_kmer_index(incoming_edge_idx); - if (kmer_index != npos) { + auto kmer_index = incoming_edge_idx; + if (in_graph(kmer_index)) { assert(get_node_sequence(kmer_index).substr(get_k() - match_size) == str.substr(0, match_size)); callback(kmer_index, match_size); @@ -361,13 +397,13 @@ void DBGSuccinct::traverse(node_index start, const char *end, const std::function &callback, const std::function &terminate) const { - assert(start > 0 && start <= num_nodes()); + assert(in_graph(start)); assert(end >= begin); if (terminate()) return; - auto edge = kmer_to_boss_index(start); + auto edge = start; assert(edge); BOSS::TAlphabet w; @@ -379,8 +415,8 @@ void DBGSuccinct::traverse(node_index start, edge = boss_graph_->fwd(edge, w % boss_graph_->alph_size); edge = boss_graph_->pick_edge(edge, boss_graph_->encode(*begin)); - start = boss_to_kmer_index(edge); - if (start == npos) + start = edge; + if (!in_graph(start)) return; callback(start); @@ -442,13 +478,13 @@ void DBGSuccinct::map_to_nodes(std::string_view sequence, for (size_t i = 0; i < boss_edges.size() && !terminate(); ++i) { // the definition of a canonical k-mer is redefined: // use k-mer with smaller index in the BOSS table. - callback(boss_to_kmer_index(boss_edges[i])); + callback(validate_edge(boss_edges[i])); } } else { boss_graph_->map_to_edges( sequence, - [&](BOSS::edge_index i) { callback(boss_to_kmer_index(i)); }, + [&](BOSS::edge_index i) { callback(validate_edge(i)); }, terminate, [&]() { if (!is_missing()) @@ -468,7 +504,7 @@ void DBGSuccinct::call_sequences(const CallPath &callback, boss_graph_->call_sequences( [&](std::string&& seq, auto&& path) { for (auto &node : path) { - node = boss_to_kmer_index(node); + node = validate_edge(node); } callback(std::move(seq), std::move(path)); }, @@ -485,7 +521,7 @@ void DBGSuccinct::call_unitigs(const CallPath &callback, boss_graph_->call_unitigs( [&](std::string&& seq, auto&& path) { for (auto &node : path) { - node = boss_to_kmer_index(node); + node = validate_edge(node); } callback(std::move(seq), std::move(path)); }, @@ -500,8 +536,8 @@ ::call_kmers(const std::function &callback const std::function &stop_early) const { assert(boss_graph_.get()); boss_graph_->call_kmers([&](auto index, const std::string &seq) { - auto node = boss_to_kmer_index(index); - assert(node != npos); + auto node = index; + assert(in_graph(node)); callback(node, seq); }, stop_early); } @@ -509,17 +545,17 @@ ::call_kmers(const std::function &callback void DBGSuccinct ::call_source_nodes(const std::function &callback) const { boss_graph_->call_start_edges([&](auto boss_edge) { - auto node = boss_to_kmer_index(boss_edge); - assert(node != npos); + auto node = boss_edge; + assert(in_graph(node)); assert(!indegree(node)); callback(node); }); } size_t DBGSuccinct::outdegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto boss_edge = kmer_to_boss_index(node); + auto boss_edge = node; if (boss_edge == 1) return boss_graph_->succ_last(1) - 1; @@ -543,9 +579,9 @@ size_t DBGSuccinct::outdegree(node_index node) const { } bool DBGSuccinct::has_single_outgoing(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto boss_edge = kmer_to_boss_index(node); + auto boss_edge = node; if (boss_edge == 1) return boss_graph_->succ_last(1) == 2; @@ -569,9 +605,9 @@ bool DBGSuccinct::has_single_outgoing(node_index node) const { } bool DBGSuccinct::has_multiple_outgoing(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto boss_edge = kmer_to_boss_index(node); + auto boss_edge = node; if (boss_edge == 1) return boss_graph_->succ_last(1) > 2; @@ -586,9 +622,9 @@ bool DBGSuccinct::has_multiple_outgoing(node_index node) const { } size_t DBGSuccinct::indegree(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto boss_edge = kmer_to_boss_index(node); + auto boss_edge = node; if (boss_edge == 1) return 1; @@ -602,9 +638,9 @@ size_t DBGSuccinct::indegree(node_index node) const { } bool DBGSuccinct::has_no_incoming(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto boss_edge = kmer_to_boss_index(node); + auto boss_edge = node; if (boss_edge == 1) return false; @@ -618,9 +654,9 @@ bool DBGSuccinct::has_no_incoming(node_index node) const { } bool DBGSuccinct::has_single_incoming(node_index node) const { - assert(node > 0 && node <= num_nodes()); + assert(in_graph(node)); - auto boss_edge = kmer_to_boss_index(node); + auto boss_edge = node; if (boss_edge == 1) return false; @@ -645,6 +681,10 @@ uint64_t DBGSuccinct::num_nodes() const { : boss_graph_->num_edges(); } +uint64_t DBGSuccinct::max_index() const { + return boss_graph_->num_edges(); +} + bool DBGSuccinct::load_without_mask(const std::string &filename) { // release the old mask valid_edges_.reset(); @@ -757,24 +797,26 @@ void DBGSuccinct::serialize(const std::string &filename) const { throw std::ios_base::failure("Can't write to file " + out_filename); } - if (!valid_edges_.get()) - return; + auto serialize_valid_edges = [&](auto &&valid_edges) { + assert((boss_graph_->get_state() == BOSS::State::STAT + && dynamic_cast(valid_edges.get())) + || (boss_graph_->get_state() == BOSS::State::FAST + && dynamic_cast(valid_edges.get())) + || (boss_graph_->get_state() == BOSS::State::DYN + && dynamic_cast(valid_edges.get())) + || (boss_graph_->get_state() == BOSS::State::SMALL + && dynamic_cast(valid_edges.get()))); + + const auto out_filename = prefix + kDummyMaskExtension; + std::ofstream out = utils::open_new_ofstream(out_filename); + if (!out.good()) + throw std::ios_base::failure("Can't write to file " + out_filename); - assert((boss_graph_->get_state() == BOSS::State::STAT - && dynamic_cast(valid_edges_.get())) - || (boss_graph_->get_state() == BOSS::State::FAST - && dynamic_cast(valid_edges_.get())) - || (boss_graph_->get_state() == BOSS::State::DYN - && dynamic_cast(valid_edges_.get())) - || (boss_graph_->get_state() == BOSS::State::SMALL - && dynamic_cast(valid_edges_.get()))); - - const auto out_filename = prefix + kDummyMaskExtension; - std::ofstream out = utils::open_new_ofstream(out_filename); - if (!out.good()) - throw std::ios_base::failure("Can't write to file " + out_filename); + valid_edges->serialize(out); + }; - valid_edges_->serialize(out); + if (valid_edges_) + serialize_valid_edges(valid_edges_); if (bloom_filter_) { std::ofstream bloom_out = utils::open_new_ofstream(prefix + kBloomFilterExtension); @@ -855,9 +897,7 @@ BOSS::State DBGSuccinct::get_state() const { return boss_graph_->get_state(); } -void DBGSuccinct::mask_dummy_kmers(size_t num_threads, bool with_pruning) { - valid_edges_.reset(); - +std::unique_ptr DBGSuccinct::generate_valid_kmer_mask(size_t num_threads, bool with_pruning) const { auto vector_mask = with_pruning ? boss_graph_->prune_and_mark_all_dummy_edges(num_threads) : boss_graph_->mark_all_dummy_edges(num_threads); @@ -865,50 +905,54 @@ void DBGSuccinct::mask_dummy_kmers(size_t num_threads, bool with_pruning) { vector_mask.flip(); switch (get_state()) { - case BOSS::State::STAT: { - valid_edges_ = std::make_unique(std::move(vector_mask)); - break; - } - case BOSS::State::FAST: { - valid_edges_ = std::make_unique(std::move(vector_mask)); - break; - } - case BOSS::State::DYN: { - valid_edges_ = std::make_unique(std::move(vector_mask)); - break; - } - case BOSS::State::SMALL: { - valid_edges_ = std::make_unique(std::move(vector_mask)); - break; - } + case BOSS::State::STAT: + return std::make_unique(std::move(vector_mask)); + case BOSS::State::FAST: + return std::make_unique(std::move(vector_mask)); + case BOSS::State::DYN: + return std::make_unique(std::move(vector_mask)); + case BOSS::State::SMALL: + return std::make_unique(std::move(vector_mask)); + default: + throw std::runtime_error("Invalid state"); } +} + +void DBGSuccinct::mask_dummy_kmers(size_t num_threads, bool with_pruning) { + valid_edges_.reset(); + + valid_edges_ = generate_valid_kmer_mask(num_threads, with_pruning); assert(valid_edges_.get()); assert(valid_edges_->size() == boss_graph_->num_edges() + 1); assert(!(*valid_edges_)[0]); } -uint64_t DBGSuccinct::kmer_to_boss_index(node_index node) const { - assert(node > 0); - assert(node <= num_nodes()); +bool DBGSuccinct::in_graph(node_index node) const { + return DeBruijnGraph::in_graph(node) && (!valid_edges_ || (*valid_edges_)[node]); +} +node_index DBGSuccinct::validate_edge(node_index node) const { + return in_graph(node) ? node : npos; +} +node_index DBGSuccinct::select_node(uint64_t rank) const { + assert(rank <= num_nodes()); - if (!valid_edges_.get()) - return node; + if (!valid_edges_.get() || !rank) + return rank; - return valid_edges_->select1(node); + return valid_edges_->select1(rank); } -DBGSuccinct::node_index DBGSuccinct::boss_to_kmer_index(uint64_t boss_index) const { - assert(boss_index <= boss_graph_->num_edges()); - assert(!valid_edges_.get() || boss_index < valid_edges_->size()); +uint64_t DBGSuccinct::rank_node(node_index node) const { + assert(node <= max_index()); - if (!valid_edges_.get() || !boss_index) - return boss_index; + if (!valid_edges_.get() || !node) + return node; - if (!(*valid_edges_)[boss_index]) + if (!(*valid_edges_)[node]) return npos; - return valid_edges_->rank1(boss_index); + return valid_edges_->rank1(node); } void DBGSuccinct diff --git a/metagraph/src/graph/representation/succinct/dbg_succinct.hpp b/metagraph/src/graph/representation/succinct/dbg_succinct.hpp index bdbabe3104..940c26313c 100644 --- a/metagraph/src/graph/representation/succinct/dbg_succinct.hpp +++ b/metagraph/src/graph/representation/succinct/dbg_succinct.hpp @@ -37,6 +37,11 @@ class DBGSuccinct : public DeBruijnGraph { virtual void adjacent_incoming_nodes(node_index node, const std::function &callback) const override final; + virtual void call_nodes(const std::function &callback, + const std::function &terminate = [](){ return false; }, + size_t num_threads = 1, + size_t batch_size = 1'000'000) const override final; + // Insert sequence to graph and invoke callback |on_insertion| for each new // node index augmenting the range [1,...,max_index], including those not // pointing to any real node in graph. That is, the callback is invoked for @@ -110,6 +115,7 @@ class DBGSuccinct : public DeBruijnGraph { * edges in the BOSS graph (because an edge in the BOSS graph represents a k-mer). */ virtual uint64_t num_nodes() const override final; + virtual uint64_t max_index() const override final; virtual void mask_dummy_kmers(size_t num_threads, bool with_pruning) final; @@ -174,8 +180,10 @@ class DBGSuccinct : public DeBruijnGraph { virtual void call_source_nodes(const std::function &callback) const override final; - uint64_t kmer_to_boss_index(node_index kmer_index) const; - node_index boss_to_kmer_index(uint64_t boss_index) const; + virtual bool in_graph(node_index node) const override final; + node_index validate_edge(node_index node) const; + node_index select_node(uint64_t rank) const; + uint64_t rank_node(node_index node) const; void initialize_bloom_filter_from_fpr(double false_positive_rate, uint32_t max_num_hash_functions = -1); @@ -197,6 +205,8 @@ class DBGSuccinct : public DeBruijnGraph { Mode mode_; std::unique_ptr> bloom_filter_; + + std::unique_ptr generate_valid_kmer_mask(size_t num_threads, bool with_pruning) const; }; } // namespace graph diff --git a/metagraph/tests/annotation/row_diff/test_row_diff.cpp b/metagraph/tests/annotation/row_diff/test_row_diff.cpp index 158171a44e..5bf9d0416b 100644 --- a/metagraph/tests/annotation/row_diff/test_row_diff.cpp +++ b/metagraph/tests/annotation/row_diff/test_row_diff.cpp @@ -17,6 +17,10 @@ using ::testing::_; using mtg::annot::matrix::RowDiff; using mtg::annot::matrix::ColumnMajor; +static auto graph_to_anno_index(graph::DeBruijnGraph::node_index node) { + return graph::AnnotatedDBG::graph_to_anno_index(node); +} + typedef RowDiff::anchor_bv_type anchor_bv_type; TEST(RowDiff, Empty) { @@ -95,28 +99,28 @@ TEST(RowDiff, GetRows) { annot.load_anchor(fterm_temp.name()); auto rows = annot.get_rows({ 3, 3, 3, 3, 5, 5, 6, 7, 8, 9, 10, 11 }); - EXPECT_EQ("CTAG", graph.get_node_sequence(4)); + EXPECT_EQ("CTAG", graph.get_node_sequence(graph.select_node(4))); ASSERT_THAT(rows[3], ElementsAre(0, 1)); - EXPECT_EQ("AGCT", graph.get_node_sequence(6)); + EXPECT_EQ("AGCT", graph.get_node_sequence(graph.select_node(6))); ASSERT_THAT(rows[5], ElementsAre(1)); - EXPECT_EQ("CTCT", graph.get_node_sequence(7)); + EXPECT_EQ("CTCT", graph.get_node_sequence(graph.select_node(7))); ASSERT_THAT(rows[6], ElementsAre(0)); - EXPECT_EQ("TAGC", graph.get_node_sequence(8)); + EXPECT_EQ("TAGC", graph.get_node_sequence(graph.select_node(8))); ASSERT_THAT(rows[7], ElementsAre(1)); - EXPECT_EQ("ACTA", graph.get_node_sequence(9)); + EXPECT_EQ("ACTA", graph.get_node_sequence(graph.select_node(9))); ASSERT_THAT(rows[8], ElementsAre(1)); - EXPECT_EQ("ACTC", graph.get_node_sequence(10)); + EXPECT_EQ("ACTC", graph.get_node_sequence(graph.select_node(10))); ASSERT_THAT(rows[9], ElementsAre(0)); - EXPECT_EQ("GCTA", graph.get_node_sequence(11)); + EXPECT_EQ("GCTA", graph.get_node_sequence(graph.select_node(11))); ASSERT_THAT(rows[10], ElementsAre(1)); - EXPECT_EQ("TCTA", graph.get_node_sequence(12)); + EXPECT_EQ("TCTA", graph.get_node_sequence(graph.select_node(12))); ASSERT_THAT(rows[11], ElementsAre(0)); } @@ -149,28 +153,28 @@ TEST(RowDiff, GetAnnotation) { RowDiff annot(&graph, std::move(mat)); annot.load_anchor(fterm_temp.name()); - EXPECT_EQ("CTAG", graph.get_node_sequence(4)); + EXPECT_EQ("CTAG", graph.get_node_sequence(graph.select_node(4))); ASSERT_THAT(annot.get_rows({3})[0], ElementsAre(0, 1)); - EXPECT_EQ("AGCT", graph.get_node_sequence(6)); + EXPECT_EQ("AGCT", graph.get_node_sequence(graph.select_node(6))); ASSERT_THAT(annot.get_rows({5})[0], ElementsAre(1)); - EXPECT_EQ("CTCT", graph.get_node_sequence(7)); + EXPECT_EQ("CTCT", graph.get_node_sequence(graph.select_node(7))); ASSERT_THAT(annot.get_rows({6})[0], ElementsAre(0)); - EXPECT_EQ("TAGC", graph.get_node_sequence(8)); + EXPECT_EQ("TAGC", graph.get_node_sequence(graph.select_node(8))); ASSERT_THAT(annot.get_rows({7})[0], ElementsAre(1)); - EXPECT_EQ("ACTA", graph.get_node_sequence(9)); + EXPECT_EQ("ACTA", graph.get_node_sequence(graph.select_node(9))); ASSERT_THAT(annot.get_rows({8})[0], ElementsAre(1)); - EXPECT_EQ("ACTC", graph.get_node_sequence(10)); + EXPECT_EQ("ACTC", graph.get_node_sequence(graph.select_node(10))); ASSERT_THAT(annot.get_rows({9})[0], ElementsAre(0)); - EXPECT_EQ("GCTA", graph.get_node_sequence(11)); + EXPECT_EQ("GCTA", graph.get_node_sequence(graph.select_node(11))); ASSERT_THAT(annot.get_rows({10})[0], ElementsAre(1)); - EXPECT_EQ("TCTA", graph.get_node_sequence(12)); + EXPECT_EQ("TCTA", graph.get_node_sequence(graph.select_node(12))); ASSERT_THAT(annot.get_rows({11})[0], ElementsAre(0)); } @@ -187,47 +191,66 @@ TEST(RowDiff, GetAnnotationMasked) { graph.mask_dummy_kmers(1, false); // build annotation - sdsl::bit_vector bterminal = { 0, 0, 0, 0, 1, 0, 1, 0 }; + sdsl::bit_vector bterminal_masked = { 0, 0, 0, 0, 1, 0, 1, 0 }; + sdsl::bit_vector bterminal(graph.max_index() + 1); + sdsl::bit_vector cols_masked[2] = { + { 1, 0, 0, 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 1, 0, 1, 1 } + }; + sdsl::bit_vector cols_concrete[2]; + cols_concrete[0].resize(graph.max_index() + 1); + cols_concrete[1].resize(graph.max_index() + 1); + graph.call_nodes([&](auto i) { + auto rank = graph_to_anno_index(graph.rank_node(i)); + bterminal[graph_to_anno_index(i)] = bterminal_masked[rank]; + cols_concrete[0][graph_to_anno_index(i)] = cols_masked[0][rank]; + cols_concrete[1][graph_to_anno_index(i)] = cols_masked[1][rank]; + }); anchor_bv_type terminal(bterminal); utils::TempFile fterm_temp; std::ofstream fterm(fterm_temp.name(), ios::binary); terminal.serialize(fterm); fterm.flush(); - + std::vector> cols(2); - cols[0] = std::make_unique( - std::initializer_list({ 1, 0, 0, 0, 0, 0, 0, 0 })); - cols[1] = std::make_unique( - std::initializer_list({ 0, 0, 0, 0, 1, 0, 1, 1 })); + cols[0] = std::make_unique(std::move(cols_concrete[0])); + cols[1] = std::make_unique(std::move(cols_concrete[1])); ColumnMajor mat(std::move(cols)); RowDiff annot(&graph, std::move(mat)); annot.load_anchor(fterm_temp.name()); + EXPECT_EQ("CTAG", graph.get_node_sequence(graph.select_node(1))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(1))})[0], + ElementsAre(0, 1)); - EXPECT_EQ("CTAG", graph.get_node_sequence(1)); - ASSERT_THAT(annot.get_rows({0})[0], ElementsAre(0, 1)); + EXPECT_EQ("AGCT", graph.get_node_sequence(graph.select_node(2))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(2))})[0], + ElementsAre(1)); - EXPECT_EQ("AGCT", graph.get_node_sequence(2)); - ASSERT_THAT(annot.get_rows({1})[0], ElementsAre(1)); + EXPECT_EQ("CTCT", graph.get_node_sequence(graph.select_node(3))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(3))})[0], + ElementsAre(0)); - EXPECT_EQ("CTCT", graph.get_node_sequence(3)); - ASSERT_THAT(annot.get_rows({2})[0], ElementsAre(0)); + EXPECT_EQ("TAGC", graph.get_node_sequence(graph.select_node(4))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(4))})[0], + ElementsAre(1)); - EXPECT_EQ("TAGC", graph.get_node_sequence(4)); - ASSERT_THAT(annot.get_rows({3})[0], ElementsAre(1)); + EXPECT_EQ("ACTA", graph.get_node_sequence(graph.select_node(5))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(5))})[0], + ElementsAre(1)); - EXPECT_EQ("ACTA", graph.get_node_sequence(5)); - ASSERT_THAT(annot.get_rows({4})[0], ElementsAre(1)); - - EXPECT_EQ("ACTC", graph.get_node_sequence(6)); - ASSERT_THAT(annot.get_rows({5})[0], ElementsAre(0)); + EXPECT_EQ("ACTC", graph.get_node_sequence(graph.select_node(6))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(6))})[0], + ElementsAre(0)); - EXPECT_EQ("GCTA", graph.get_node_sequence(7)); - ASSERT_THAT(annot.get_rows({6})[0], ElementsAre(1)); + EXPECT_EQ("GCTA", graph.get_node_sequence(graph.select_node(7))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(7))})[0], + ElementsAre(1)); - EXPECT_EQ("TCTA", graph.get_node_sequence(8)); - ASSERT_THAT(annot.get_rows({7})[0], ElementsAre(0)); + EXPECT_EQ("TCTA", graph.get_node_sequence(graph.select_node(8))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(8))})[0], + ElementsAre(0)); } /** @@ -260,34 +283,34 @@ TEST(RowDiff, GetAnnotationBifurcation) { RowDiff annot(&graph, std::move(mat)); annot.load_anchor(fterm_temp.name()); - EXPECT_EQ("CTAG", graph.get_node_sequence(4)); + EXPECT_EQ("CTAG", graph.get_node_sequence(graph.select_node(4))); ASSERT_THAT(annot.get_rows({3})[0], ElementsAre(0, 1)); - EXPECT_EQ("CTAT", graph.get_node_sequence(5)); + EXPECT_EQ("CTAT", graph.get_node_sequence(graph.select_node(5))); ASSERT_THAT(annot.get_rows({4})[0], ElementsAre(1)); - EXPECT_EQ("TACT", graph.get_node_sequence(6)); + EXPECT_EQ("TACT", graph.get_node_sequence(graph.select_node(6))); ASSERT_THAT(annot.get_rows({5})[0], ElementsAre(0)); - EXPECT_EQ("AGCT", graph.get_node_sequence(7)); + EXPECT_EQ("AGCT", graph.get_node_sequence(graph.select_node(7))); ASSERT_THAT(annot.get_rows({6})[0], ElementsAre(0, 1)); - EXPECT_EQ("CTCT", graph.get_node_sequence(8)); + EXPECT_EQ("CTCT", graph.get_node_sequence(graph.select_node(8))); ASSERT_THAT(annot.get_rows({7})[0], ElementsAre(1)); - EXPECT_EQ("TAGC", graph.get_node_sequence(9)); + EXPECT_EQ("TAGC", graph.get_node_sequence(graph.select_node(9))); ASSERT_THAT(annot.get_rows({8})[0], ElementsAre(0, 1)); - EXPECT_EQ("ACTA", graph.get_node_sequence(12)); + EXPECT_EQ("ACTA", graph.get_node_sequence(graph.select_node(12))); ASSERT_THAT(annot.get_rows({11})[0], ElementsAre(0)); - EXPECT_EQ("ACTC", graph.get_node_sequence(13)); + EXPECT_EQ("ACTC", graph.get_node_sequence(graph.select_node(13))); ASSERT_THAT(annot.get_rows({12})[0], ElementsAre(1)); - EXPECT_EQ("GCTA", graph.get_node_sequence(14)); + EXPECT_EQ("GCTA", graph.get_node_sequence(graph.select_node(14))); ASSERT_THAT(annot.get_rows({13})[0], ElementsAre(0, 1)); - EXPECT_EQ("TCTA", graph.get_node_sequence(15)); + EXPECT_EQ("TCTA", graph.get_node_sequence(graph.select_node(15))); ASSERT_THAT(annot.get_rows({14})[0], ElementsAre(1)); } @@ -299,57 +322,77 @@ TEST(RowDiff, GetAnnotationBifurcationMasked) { graph.mask_dummy_kmers(1, false); // build annotation - sdsl::bit_vector bterminal = { 0, 1, 0, 0, 0, 0, 1, 0, 1, 0 }; + sdsl::bit_vector bterminal_masked = { 0, 1, 0, 0, 0, 0, 1, 0, 1, 0 }; + sdsl::bit_vector bterminal(graph.max_index() + 1); + sdsl::bit_vector cols_masked[2] = { + {0, 0, 1, 0, 0, 0, 1, 0, 1, 0 }, + {0, 1, 1, 0, 0, 0, 0, 0, 1, 0 } + }; + sdsl::bit_vector cols_concrete[2]; + cols_concrete[0].resize(graph.max_index() + 1); + cols_concrete[1].resize(graph.max_index() + 1); + graph.call_nodes([&](auto i) { + auto rank = graph_to_anno_index(graph.rank_node(i)); + bterminal[graph_to_anno_index(i)] = bterminal_masked[rank]; + cols_concrete[0][graph_to_anno_index(i)] = cols_masked[0][rank]; + cols_concrete[1][graph_to_anno_index(i)] = cols_masked[1][rank]; + }); anchor_bv_type terminal(bterminal); utils::TempFile fterm_temp; std::ofstream fterm(fterm_temp.name(), ios::binary); terminal.serialize(fterm); fterm.flush(); + + std::vector> cols(2); + cols[0] = std::make_unique(std::move(cols_concrete[0])); + cols[1] = std::make_unique(std::move(cols_concrete[1])); Vector diffs = { 1, 0, 1, 0, 0, 1 }; sdsl::bit_vector boundary = { 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1 }; - - std::vector> cols(2); - cols[0] = std::make_unique( - std::initializer_list({0, 0, 1, 0, 0, 0, 1, 0, 1, 0 })); - cols[1] = std::make_unique( - std::initializer_list({0, 1, 1, 0, 0, 0, 0, 0, 1, 0 })); - ColumnMajor mat(std::move(cols)); RowDiff annot(&graph, std::move(mat)); annot.load_anchor(fterm_temp.name()); - EXPECT_EQ("CTAG", graph.get_node_sequence(1)); - ASSERT_THAT(annot.get_rows({0})[0], ElementsAre(0, 1)); - - EXPECT_EQ("CTAT", graph.get_node_sequence(2)); - ASSERT_THAT(annot.get_rows({1})[0], ElementsAre(1)); + EXPECT_EQ("CTAG", graph.get_node_sequence(graph.select_node(1))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(1))})[0], + ElementsAre(0, 1)); + EXPECT_EQ("CTAT", graph.get_node_sequence(graph.select_node(2))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(2))})[0], + ElementsAre(1)); - EXPECT_EQ("TACT", graph.get_node_sequence(3)); - ASSERT_THAT(annot.get_rows({2})[0], ElementsAre(0)); + EXPECT_EQ("TACT", graph.get_node_sequence(graph.select_node(3))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(3))})[0], + ElementsAre(0)); - EXPECT_EQ("AGCT", graph.get_node_sequence(4)); - ASSERT_THAT(annot.get_rows({3})[0], ElementsAre(0, 1)); + EXPECT_EQ("AGCT", graph.get_node_sequence(graph.select_node(4))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(4))})[0], + ElementsAre(0, 1)); - EXPECT_EQ("CTCT", graph.get_node_sequence(5)); - ASSERT_THAT(annot.get_rows({4})[0], ElementsAre(1)); + EXPECT_EQ("CTCT", graph.get_node_sequence(graph.select_node(5))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(5))})[0], + ElementsAre(1)); - EXPECT_EQ("TAGC", graph.get_node_sequence(6)); - ASSERT_THAT(annot.get_rows({5})[0], ElementsAre(0, 1)); + EXPECT_EQ("TAGC", graph.get_node_sequence(graph.select_node(6))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(6))})[0], + ElementsAre(0, 1)); - EXPECT_EQ("ACTA", graph.get_node_sequence(7)); - ASSERT_THAT(annot.get_rows({6})[0], ElementsAre(0)); + EXPECT_EQ("ACTA", graph.get_node_sequence(graph.select_node(7))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(7))})[0], + ElementsAre(0)); - EXPECT_EQ("ACTC", graph.get_node_sequence(8)); - ASSERT_THAT(annot.get_rows({7})[0], ElementsAre(1)); + EXPECT_EQ("ACTC", graph.get_node_sequence(graph.select_node(8))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(8))})[0], + ElementsAre(1)); - EXPECT_EQ("GCTA", graph.get_node_sequence(9)); - ASSERT_THAT(annot.get_rows({8})[0], ElementsAre(0, 1)); + EXPECT_EQ("GCTA", graph.get_node_sequence(graph.select_node(9))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(9))})[0], + ElementsAre(0, 1)); - EXPECT_EQ("TCTA", graph.get_node_sequence(10)); - ASSERT_THAT(annot.get_rows({9})[0], ElementsAre(1)); + EXPECT_EQ("TCTA", graph.get_node_sequence(graph.select_node(10))); + ASSERT_THAT(annot.get_rows({graph_to_anno_index(graph.select_node(10))})[0], + ElementsAre(1)); } } // namespace diff --git a/metagraph/tests/annotation/test_aligner_labeled.cpp b/metagraph/tests/annotation/test_aligner_labeled.cpp index bdeddccf7f..dc4658b046 100644 --- a/metagraph/tests/annotation/test_aligner_labeled.cpp +++ b/metagraph/tests/annotation/test_aligner_labeled.cpp @@ -55,7 +55,8 @@ typedef ::testing::Types>, std::pair>, std::pair, std::pair, - std::pair> FewGraphAnnotationPairTypes; + std::pair, + std::pair> FewGraphAnnotationPairTypes; TYPED_TEST_SUITE(LabeledAlignerTest, FewGraphAnnotationPairTypes); diff --git a/metagraph/tests/annotation/test_annotated_dbg.cpp b/metagraph/tests/annotation/test_annotated_dbg.cpp index d1437aa725..52ad5bd077 100644 --- a/metagraph/tests/annotation/test_annotated_dbg.cpp +++ b/metagraph/tests/annotation/test_annotated_dbg.cpp @@ -509,8 +509,8 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsWithoutDummy) { ); EXPECT_EQ(num_nodes, anno_graph.get_graph().num_nodes()); - EXPECT_TRUE(anno_graph.get_annotator().num_objects() + k - < dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) + EXPECT_EQ(anno_graph.get_annotator().num_objects(), + dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) << dynamic_cast(anno_graph.get_graph()).get_boss(); EXPECT_FALSE(anno_graph.label_exists("First")); @@ -537,7 +537,7 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsWithoutDummy) { ); anno_graph.annotator_->insert_rows(edge_to_row_idx(inserted_nodes)); - EXPECT_EQ(anno_graph.get_graph().num_nodes() + 1, inserted_nodes.size()); + EXPECT_EQ(anno_graph.get_graph().max_index() + 1, inserted_nodes.size()); ASSERT_EQ(std::vector { "First" }, anno_graph.get_labels(seq_first, 1)); @@ -556,8 +556,8 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsWithoutDummy) { EXPECT_TRUE(anno_graph.label_exists("Third")); EXPECT_FALSE(anno_graph.label_exists("Fourth")); - EXPECT_TRUE(anno_graph.get_annotator().num_objects() + k - < dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) + EXPECT_EQ(anno_graph.get_annotator().num_objects(), + dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) << dynamic_cast(anno_graph.get_graph()).get_boss(); EXPECT_EQ(std::vector { "First" }, @@ -627,8 +627,8 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsWithoutDummyParallel) { std::make_unique>(graph->max_index()) ); - EXPECT_TRUE(anno_graph.get_annotator().num_objects() + k - < dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) + EXPECT_EQ(anno_graph.get_annotator().num_objects(), + dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) << dynamic_cast(anno_graph.get_graph()).get_boss(); EXPECT_FALSE(anno_graph.label_exists("First")); @@ -661,7 +661,7 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsWithoutDummyParallel) { ); anno_graph.annotator_->insert_rows(edge_to_row_idx(inserted_nodes)); - EXPECT_EQ(anno_graph.get_graph().num_nodes() + 1, inserted_nodes.size()); + EXPECT_EQ(anno_graph.get_graph().max_index() + 1, inserted_nodes.size()); ASSERT_EQ(std::vector { "First" }, anno_graph.get_labels(seq_first, 1)); @@ -685,8 +685,8 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsWithoutDummyParallel) { EXPECT_TRUE(anno_graph.label_exists("Third")); EXPECT_FALSE(anno_graph.label_exists("Fourth")); - EXPECT_TRUE(anno_graph.get_annotator().num_objects() + k - < dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) + EXPECT_EQ(anno_graph.get_annotator().num_objects(), + dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) << dynamic_cast(anno_graph.get_graph()).get_boss(); EXPECT_EQ(std::vector { "First" }, @@ -767,8 +767,8 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsPruneDummy) { EXPECT_FALSE(anno_graph.label_exists("Third")); EXPECT_FALSE(anno_graph.label_exists("Fourth")); - EXPECT_TRUE(anno_graph.get_annotator().num_objects() + 1 - < dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) + EXPECT_EQ(anno_graph.get_annotator().num_objects(), + dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) << dynamic_cast(anno_graph.get_graph()).get_boss(); ASSERT_EQ(std::vector { "First" }, @@ -783,7 +783,7 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsPruneDummy) { ); anno_graph.annotator_->insert_rows(edge_to_row_idx(inserted_nodes)); - EXPECT_EQ(anno_graph.get_graph().num_nodes() + 1, inserted_nodes.size()); + EXPECT_EQ(anno_graph.get_graph().max_index() + 1, inserted_nodes.size()); ASSERT_EQ(std::vector { "First" }, anno_graph.get_labels(seq_first, 1)); @@ -802,8 +802,8 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsPruneDummy) { EXPECT_TRUE(anno_graph.label_exists("Third")); EXPECT_FALSE(anno_graph.label_exists("Fourth")); - EXPECT_TRUE(anno_graph.get_annotator().num_objects() + 1 - < dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) + EXPECT_EQ(anno_graph.get_annotator().num_objects(), + dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) << dynamic_cast(anno_graph.get_graph()).get_boss(); EXPECT_EQ(std::vector { "First" }, @@ -890,8 +890,8 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsPruneDummyParallel) { EXPECT_FALSE(anno_graph.label_exists("Third")); EXPECT_FALSE(anno_graph.label_exists("Fourth")); - EXPECT_TRUE(anno_graph.get_annotator().num_objects() + 1 - < dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) + EXPECT_EQ(anno_graph.get_annotator().num_objects(), + dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) << dynamic_cast(anno_graph.get_graph()).get_boss(); ASSERT_EQ(std::vector { "First" }, @@ -906,7 +906,7 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsPruneDummyParallel) { ); anno_graph.annotator_->insert_rows(edge_to_row_idx(inserted_nodes)); - EXPECT_EQ(anno_graph.get_graph().num_nodes() + 1, inserted_nodes.size()); + EXPECT_EQ(anno_graph.get_graph().max_index() + 1, inserted_nodes.size()); ASSERT_EQ(std::vector { "First" }, anno_graph.get_labels(seq_first, 1)); @@ -930,8 +930,8 @@ TEST(AnnotatedDBG, ExtendGraphAddTwoPathsPruneDummyParallel) { EXPECT_TRUE(anno_graph.label_exists("Third")); EXPECT_FALSE(anno_graph.label_exists("Fourth")); - EXPECT_TRUE(anno_graph.get_annotator().num_objects() + 1 - < dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) + EXPECT_EQ(anno_graph.get_annotator().num_objects(), + dynamic_cast(anno_graph.get_graph()).get_boss().num_edges()) << dynamic_cast(anno_graph.get_graph()).get_boss(); EXPECT_EQ(std::vector { "First" }, diff --git a/metagraph/tests/annotation/test_annotated_dbg_helpers.cpp b/metagraph/tests/annotation/test_annotated_dbg_helpers.cpp index 39e387335e..0787d69fb9 100644 --- a/metagraph/tests/annotation/test_annotated_dbg_helpers.cpp +++ b/metagraph/tests/annotation/test_annotated_dbg_helpers.cpp @@ -90,9 +90,6 @@ std::unique_ptr build_anno_graph(uint64_t k, } if constexpr(std::is_same_v) { - static_assert(std::is_same_v); - assert(dynamic_cast(base_graph.get())); - std::filesystem::path tmp_dir = utils::create_temp_dir("", "test_col"); auto out_fs_path = tmp_dir/"test_col"; std::string out_path = out_fs_path; @@ -183,9 +180,6 @@ std::unique_ptr build_anno_graph(uint64_t k, return std::make_unique(graph, std::move(annotator)); } else if constexpr(std::is_same_v) { - static_assert(std::is_same_v); - assert(dynamic_cast(base_graph.get())); - std::filesystem::path tmp_dir = utils::create_temp_dir("", "test_col"); auto out_fs_path = tmp_dir/"test_col"; std::string out_path = out_fs_path; @@ -217,7 +211,7 @@ std::unique_ptr build_anno_graph(uint64_t k, auto rd_path = out_path + RowDiffDiskAnnotator::kExtension; auto annotator = std::make_unique( - annot::LabelEncoder<>(), static_cast(base_graph.get())); + annot::LabelEncoder<>(), base_graph.get()); if (!annotator->load(rd_path)) { logger->error("Cannot load annotations from {}", rd_path); exit(1); @@ -243,6 +237,8 @@ template std::unique_ptr build_anno_graph build_anno_graph(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); template std::unique_ptr build_anno_graph(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); +template std::unique_ptr build_anno_graph(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); +template std::unique_ptr build_anno_graph(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); template std::unique_ptr build_anno_graph(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); template std::unique_ptr build_anno_graph(uint64_t, const std::vector &, const std::vector&, DeBruijnGraph::Mode, bool); diff --git a/metagraph/tests/annotation/test_converters.cpp b/metagraph/tests/annotation/test_converters.cpp index cd39b07b47..d16b77ef8d 100644 --- a/metagraph/tests/annotation/test_converters.cpp +++ b/metagraph/tests/annotation/test_converters.cpp @@ -19,6 +19,13 @@ using namespace mtg; using namespace mtg::annot; using namespace ::testing; +static auto graph_to_anno_index(graph::DeBruijnGraph::node_index node) { + return graph::AnnotatedDBG::graph_to_anno_index(node); +} +static auto anno_to_graph_index(graph::AnnotatedDBG::row_index row) { + return graph::AnnotatedDBG::anno_to_graph_index(row); +} + const std::string test_data_dir = "../tests/data"; const std::string test_dump_basename = test_data_dir + "/dump_test"; const std::string test_dump_basename_row_compressed_merge = test_dump_basename + "_row_compressed_merge"; @@ -189,9 +196,9 @@ TEST(RowDiff, succ) { */ const std::vector expected_succ = { 3, 0, 4, 2 }; - const std::vector expected_succ_boundary = { 1, 0, 1, 0, 1, 0, 1, 0, 1 }; + const std::vector expected_succ_boundary = { 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1 }; const std::vector expected_pred = { 2, 4, 1, 3 }; - const std::vector expected_pred_boundary = { 0, 1, 1, 0, 1, 0, 1, 0, 1 }; + const std::vector expected_pred_boundary = { 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1 }; for (uint32_t max_depth : { 1, 3, 5 }) { std::filesystem::remove_all(dst_dir); @@ -211,7 +218,10 @@ TEST(RowDiff, succ) { sdsl::int_vector_buffer succ(succ_file, std::ios::in); ASSERT_EQ(expected_succ.size(), succ.size()); for (uint32_t i = 0; i < succ.size(); ++i) { - EXPECT_EQ(expected_succ[i], succ[i]) << max_depth << " " << i; + EXPECT_EQ( + anno_to_graph_index(expected_succ[i]), + graph->rank_node(anno_to_graph_index(succ[i])) + ) << max_depth << " " << i; } sdsl::int_vector_buffer<1> succ_boundary(succ_boundary_file, std::ios::in); @@ -223,7 +233,10 @@ TEST(RowDiff, succ) { sdsl::int_vector_buffer pred(pred_file, std::ios::in); EXPECT_EQ(expected_pred.size(), pred.size()); for (uint32_t i = 0; i < pred.size(); ++i) { - EXPECT_EQ(expected_pred[i], pred[i]) << max_depth << " " << i; + EXPECT_EQ( + anno_to_graph_index(expected_pred[i]), + graph->rank_node(anno_to_graph_index(pred[i])) + ) << max_depth << " " << i; } sdsl::int_vector_buffer<1> pred_boundary(pred_boundary_file, std::ios::in); @@ -279,8 +292,10 @@ TEST(RowDiff, ConvertFromColumnCompressedSameLabels) { std::unique_ptr graph = create_graph(3, { "ACGTCAC" }); graph->serialize(graph_fname); - ColumnCompressed source_annot(5); - source_annot.add_labels({ 0, 1, 2, 3, 4 }, labels); + ColumnCompressed source_annot(graph->max_index()); + std::vector edges(graph->max_index()); + std::iota(begin(edges), end(edges), 0); + source_annot.add_labels(edges, labels); source_annot.serialize(annot_fname); convert_to_row_diff({ annot_fname }, graph_fname, 1e9, max_depth, dst_dir, dst_dir, RowDiffStage::COMPUTE_REDUCTION); @@ -293,13 +308,13 @@ TEST(RowDiff, ConvertFromColumnCompressedSameLabels) { .load_anchor(graph_fname + matrix::kRowDiffAnchorExt); ASSERT_EQ(labels.size(), annotator.num_labels()); - ASSERT_EQ(5u, annotator.num_objects()); + ASSERT_EQ(graph->max_index(), annotator.num_objects()); EXPECT_EQ(labels.size() * expected_relations[max_depth - 1], annotator.num_relations()); - for (uint32 i = 0; i < annotator.num_objects(); ++i) { - ASSERT_THAT(annotator.get_labels(i), ContainerEq(labels)); - } + graph->call_nodes([&](uint32_t node_idx) { + ASSERT_THAT(annotator.get_labels(graph_to_anno_index(node_idx)), ContainerEq(labels)); + }); } } std::filesystem::remove_all(dst_dir); @@ -326,8 +341,10 @@ TEST(RowDiff, ConvertFromColumnCompressedSameLabelsMultipleColumns) { std::vector sources; for (const std::string &label : labels) { - ColumnCompressed source_annot(5); - source_annot.add_labels({ 0, 1, 2, 3, 4 }, { label }); + ColumnCompressed source_annot(graph->max_index()); + std::vector edges(graph->max_index()); + std::iota(begin(edges), end(edges), 0); + source_annot.add_labels(edges, { label }); const std::string annot_fname = dst_dir/(label + ColumnCompressed<>::kExtension); source_annot.serialize(annot_fname); @@ -346,12 +363,12 @@ TEST(RowDiff, ConvertFromColumnCompressedSameLabelsMultipleColumns) { .load_anchor(graph_fname + matrix::kRowDiffAnchorExt); ASSERT_EQ(1, annotator.num_labels()); - ASSERT_EQ(5u, annotator.num_objects()); + ASSERT_EQ(graph->max_index(), annotator.num_objects()); EXPECT_EQ(expected_relations[max_depth - 1], annotator.num_relations()); - for (uint32 idx = 0; idx < annotator.num_objects(); ++idx) { - ASSERT_THAT(annotator.get_labels(idx), ElementsAre(labels[i])); - } + graph->call_nodes([&](uint32_t node_idx) { + ASSERT_THAT(annotator.get_labels(graph_to_anno_index(node_idx)), ElementsAre(labels[i])); + }); } } } @@ -382,13 +399,13 @@ void test_row_diff(uint32_t k, graph->mask_dummy_kmers(1, false); graph->serialize(graph_fname); - ColumnCompressed initial_annotation(graph->num_nodes()); + ColumnCompressed initial_annotation(graph->max_index()); std::unordered_set all_labels; - for (uint32_t anno_idx = 0; anno_idx < graph->num_nodes(); ++anno_idx) { - const std::vector &labels = annotations[anno_idx]; - initial_annotation.add_labels({anno_idx}, labels); + graph->call_nodes([&](uint32_t node_idx) { + const auto &labels = annotations[graph_to_anno_index(graph->rank_node(node_idx))]; + initial_annotation.add_labels({graph_to_anno_index(node_idx)}, labels); std::for_each(labels.begin(), labels.end(), [&](auto l) { all_labels.insert(l); }); - } + }); initial_annotation.serialize(annot_fname); @@ -402,12 +419,12 @@ void test_row_diff(uint32_t k, .load_anchor(graph_fname + matrix::kRowDiffAnchorExt); ASSERT_EQ(all_labels.size(), annotator.num_labels()); - ASSERT_EQ(graph->num_nodes(), annotator.num_objects()); + ASSERT_EQ(graph->max_index(), annotator.num_objects()); - for (uint32_t anno_idx = 0; anno_idx < graph->num_nodes(); ++anno_idx) { - ASSERT_THAT(annotator.get_labels(anno_idx), - UnorderedElementsAreArray(annotations[anno_idx])); - } + graph->call_nodes([&](uint32_t node_idx) { + ASSERT_THAT(annotator.get_labels(graph_to_anno_index(node_idx)), + UnorderedElementsAreArray(annotations[graph_to_anno_index(graph->rank_node(node_idx))])); + }); std::filesystem::remove_all(dst_dir); } @@ -433,14 +450,14 @@ void test_row_diff_separate_columns(uint32_t k, graph->serialize(graph_fname); std::map> col_annotations; - for (uint32_t anno_idx = 0; anno_idx < graph->num_nodes(); ++anno_idx) { - for (const auto &label : annotations[anno_idx]) { - col_annotations[label].push_back(anno_idx); + graph->call_nodes([&](auto node_idx) { + for (const auto &label : annotations[graph_to_anno_index(graph->rank_node(node_idx))]) { + col_annotations[label].push_back(graph_to_anno_index(node_idx)); } - } + }); for (const auto& [label, indices] : col_annotations) { - ColumnCompressed initial_annotation(graph->num_nodes()); + ColumnCompressed initial_annotation(graph->max_index()); initial_annotation.add_labels(indices, {label}); std::string annot_fname = dst_dir/("anno_" + label + ColumnCompressed<>::kExtension); @@ -460,7 +477,7 @@ void test_row_diff_separate_columns(uint32_t k, const_cast &>(annotator.get_matrix()) .load_anchor(graph_fname + matrix::kRowDiffAnchorExt); - ASSERT_EQ(graph->num_nodes(), annotator.num_objects()); + ASSERT_EQ(graph->max_index(), annotator.num_objects()); std::vector actual_indices; annotator.call_objects(label, diff --git a/metagraph/tests/graph/succinct/test_dbg_succinct.cpp b/metagraph/tests/graph/succinct/test_dbg_succinct.cpp index 88a25e2716..1bc51ffd59 100644 --- a/metagraph/tests/graph/succinct/test_dbg_succinct.cpp +++ b/metagraph/tests/graph/succinct/test_dbg_succinct.cpp @@ -20,9 +20,9 @@ TEST(DBGSuccinct, get_degree_with_source_dummy) { + std::string(k, 'T')); // dummy source k-mer: '$$$$$' - EXPECT_EQ(std::string(k, '$'), graph->get_node_sequence(1)); - EXPECT_EQ(1ull, graph->outdegree(1)); - EXPECT_EQ(1ull, graph->indegree(1)); + EXPECT_EQ(std::string(k, '$'), graph->get_node_sequence(graph->select_node(1))); + EXPECT_EQ(1ull, graph->outdegree(graph->select_node(1))); + EXPECT_EQ(1ull, graph->indegree(graph->select_node(1))); // 'AAAAA' auto node_A = graph->kmer_to_node(std::string(k, 'A')); @@ -40,7 +40,7 @@ TEST(DBGSuccinct, get_degree_with_source_dummy) { graph->mask_dummy_kmers(1, false); // dummy source k-mer: '$$$$$' - EXPECT_NE(std::string(k, '$'), graph->get_node_sequence(1)); + EXPECT_NE(std::string(k, '$'), graph->get_node_sequence(graph->select_node(1))); // 'AAAAA' node_A = graph->kmer_to_node(std::string(k, 'A')); @@ -65,9 +65,9 @@ TEST(DBGSuccinct, get_degree_with_source_and_sink_dummy) { + std::string(k - 1, 'T')); // dummy source k-mer: '$$$$$' - EXPECT_EQ(std::string(k, '$'), graph->get_node_sequence(1)); - EXPECT_EQ(1ull, graph->outdegree(1)); - EXPECT_EQ(1ull, graph->indegree(1)); + EXPECT_EQ(std::string(k, '$'), graph->get_node_sequence(graph->select_node(1))); + EXPECT_EQ(1ull, graph->outdegree(graph->select_node(1))); + EXPECT_EQ(1ull, graph->indegree(graph->select_node(1))); // 'AAAAA' auto node_A = graph->kmer_to_node(std::string(k, 'A')); @@ -85,7 +85,7 @@ TEST(DBGSuccinct, get_degree_with_source_and_sink_dummy) { graph->mask_dummy_kmers(1, false); // dummy source k-mer: '$$$$$' - EXPECT_NE(std::string(k, '$'), graph->get_node_sequence(1)); + EXPECT_NE(std::string(k, '$'), graph->get_node_sequence(graph->select_node(1))); // 'AAAAA' node_A = graph->kmer_to_node(std::string(k, 'A')); @@ -109,7 +109,7 @@ TEST(DBGSuccinct, is_single_outgoing_simple) { uint64_t single_outgoing_counter = 0; for (DBGSuccinct::node_index i = 1; i <= graph->num_nodes(); ++i) { - if (graph->outdegree(i) == 1) + if (graph->outdegree(graph->select_node(i)) == 1) single_outgoing_counter++; } @@ -126,7 +126,7 @@ TEST(DBGSuccinct, is_single_outgoing_for_multiple_valid_edges) { uint64_t single_outgoing_counter = 0; for (DBGSuccinct::node_index i = 1; i <= graph->num_nodes(); ++i) { - if (graph->outdegree(i) == 1) + if (graph->outdegree(graph->select_node(i)) == 1) single_outgoing_counter++; } diff --git a/metagraph/tests/graph/test_masked_graph.cpp b/metagraph/tests/graph/test_masked_graph.cpp index c0bfa065ce..cff59d019d 100644 --- a/metagraph/tests/graph/test_masked_graph.cpp +++ b/metagraph/tests/graph/test_masked_graph.cpp @@ -536,7 +536,7 @@ TYPED_TEST(MaskedDeBruijnGraphTest, CallUnitigsMaskPath) { graph.map_to_nodes( unitig, [&](const auto &index) { - EXPECT_TRUE(graph.in_subgraph(index)); + EXPECT_TRUE(graph.in_graph(index)); EXPECT_NE(DeBruijnGraph::npos, index); } ); @@ -721,7 +721,7 @@ TYPED_TEST(MaskedDeBruijnGraphTest, CheckNodes) { std::multiset ref_nodes; full_graph->call_nodes([&](auto i) { - if (graph.in_subgraph(i)) + if (graph.in_graph(i)) ref_nodes.insert(i); }); @@ -803,7 +803,7 @@ TYPED_TEST(MaskedDeBruijnGraphTest, CheckOutgoingNodes) { full_graph->adjacent_outgoing_nodes(node, [&](auto i) { outnodes_full.push_back(i); }); outnodes_full.erase(std::remove_if(outnodes_full.begin(), outnodes_full.end(), - [&](auto i) { return !graph.in_subgraph(i); }), + [&](auto i) { return !graph.in_graph(i); }), outnodes_full.end()); EXPECT_EQ(convert_to_set(outnodes_full), convert_to_set(outnodes)); } @@ -847,7 +847,7 @@ TYPED_TEST(MaskedDeBruijnGraphTest, CheckIncomingNodes) { full_graph->adjacent_incoming_nodes(node, [&](auto i) { innodes_full.push_back(i); }); innodes_full.erase(std::remove_if(innodes_full.begin(), innodes_full.end(), - [&](auto i) { return !graph.in_subgraph(i); }), + [&](auto i) { return !graph.in_graph(i); }), innodes_full.end()); EXPECT_EQ(convert_to_set(innodes_full), convert_to_set(innodes)); }