diff --git a/scripts/summarize_kraken2_reports.py b/scripts/summarize_kraken2_reports.py index 1ab3377..400aac1 100644 --- a/scripts/summarize_kraken2_reports.py +++ b/scripts/summarize_kraken2_reports.py @@ -2,78 +2,59 @@ from typing import Iterator, TextIO -allowed_ranks = {"D": "k", "P": "p", "C": "c", "O": "o", "F": "f", "G": "g", "S": "s"} -rank_order = {"D": 1, "P": 2, "C": 3, "O": 4, "F": 5, "G": 6, "S": 7} - - -def parse_kraken2_tsv_report( - file_handler: TextIO, -) -> Iterator[dict[str, float | int | str]]: +def parse_kraken2_tsv_report(file_handler: TextIO) -> Iterator[dict]: for line in file_handler: - data = line.strip().split("\t") - if len(data) == 6: - ( - percentage, - fragments_covered, - fragments_assigned, - rank, - taxon_id, - scientific_name, - ) = data - if rank in allowed_ranks.keys(): - yield { - "percentage": float(percentage), - "fragments_covered": int(fragments_covered), - "fragments_assigned": int(fragments_assigned), - "rank": rank, - "taxon_id": int(taxon_id), - "scientific_name": scientific_name, - } - - -def consensus_lineage_str(rank_stack: list[str]) -> str: - missing_ranks = [k for k, v in rank_order.items() if v > len(rank_stack)] - rank_stack += [f"{allowed_ranks[r]}__" for r in missing_ranks] - return "; ".join(rank_stack) + parts = line.rstrip("\n").split("\t") + if len(parts) != 6: + continue + percent, clade_reads, direct_reads, rank_code, taxid, name = parts + depth = len(name) - len(name.lstrip()) # leading spaces as taxonomic depth + yield { + "percentage": float(percent), + "fragments_covered": int(clade_reads), + "fragments_assigned": int(direct_reads), + "rank_code": rank_code, + "taxon_id": int(taxid), + "scientific_name": name.strip(), + "depth": depth // 2, # kraken2 uses 2 spaces per level + } + + +def consensus_lineage_str(lineage_stack: list[str]) -> str: + # Fill missing levels with empty placeholders if needed + full_lineage = lineage_stack + [f"__"] * (max(7 - len(lineage_stack), 0)) + return "; ".join(full_lineage) def create_kraken2_tsv_report( - reports: list[Iterator[dict[str, float | int | str]]], report_names: list[str] + reports: list[Iterator[dict]], report_names: list[str] ) -> tuple[dict[str, dict[int, int]], dict[int, str]]: consensus_lineages = {} report_counts = {} for report, report_name in zip(reports, report_names): - rank_stack = [] + lineage_stack = [] counts = {} - for line in report: - if line["rank"] in allowed_ranks.keys(): - # Update fragments assigned count - counts[line["taxon_id"]] = line["fragments_assigned"] - - # Update rank stack - if len(rank_stack) < rank_order[line["rank"]]: - rank_stack.append( - f"{allowed_ranks[line['rank']]}__{line['scientific_name'].lstrip()}" - ) - elif len(rank_stack) == rank_order[line["rank"]]: - rank_stack[-1] = ( - f"{allowed_ranks[line['rank']]}__{line['scientific_name'].lstrip()}" - ) - else: - rank_stack = rank_stack[: rank_order[line["rank"]]] - rank_stack[-1] = ( - f"{allowed_ranks[line['rank']]}__{line['scientific_name'].lstrip()}" - ) - - # Update consensus lineages - if line["taxon_id"] not in consensus_lineages: - consensus_lineages[line["taxon_id"]] = consensus_lineage_str( - rank_stack - ) - - # Update report counts + for entry in report: + depth = entry["depth"] + name = entry["scientific_name"] + taxon_id = entry["taxon_id"] + assigned = entry["fragments_assigned"] + + # Update stack for lineage at current depth + if len(lineage_stack) <= depth: + lineage_stack.extend(["__"] * (depth - len(lineage_stack) + 1)) + lineage_stack = lineage_stack[: depth + 1] + lineage_stack[depth] = f"__{name}" + + # Store read count + counts[taxon_id] = assigned + + # Store consensus lineage string (once) + if taxon_id not in consensus_lineages: + consensus_lineages[taxon_id] = consensus_lineage_str(lineage_stack) + report_counts[report_name] = counts return report_counts, consensus_lineages @@ -84,26 +65,26 @@ def write_kraken2_tsv_summary( consensus_lineages: dict[int, str], file_handler: TextIO, ) -> None: - # Write header - header = ( - "#OTU ID\t" - + "\t".join([k for k, _ in report_counts.items()]) - + "\tConsensus Lineage\n" - ) + sample_names = list(report_counts.keys()) + header = "#OTU ID\t" + "\t".join(sample_names) + "\tConsensus Lineage\n" file_handler.write(header) - # Loop through consensus lineages - for taxon_id, lineage in consensus_lineages.items(): - output_line = f"{taxon_id}\t" - for report_name in [k for k, _ in report_counts.items()]: - output_line += f"{report_counts[report_name].get(taxon_id, 0)}\t" - file_handler.write(output_line + f"{lineage}\n") - - -report_names = [Path(x).stem for x in snakemake.input.reports] -report_counts, consensus_lineages = create_kraken2_tsv_report( - [parse_kraken2_tsv_report(open(x)) for x in snakemake.input.reports], report_names -) -write_kraken2_tsv_summary( - report_counts, consensus_lineages, open(snakemake.output.summary, "w") -) + for taxid, lineage in consensus_lineages.items(): + counts = [str(report_counts[sample].get(taxid, 0)) for sample in sample_names] + file_handler.write(f"{taxid}\t" + "\t".join(counts) + f"\t{lineage}\n") + + +# If using Snakemake +if "snakemake" in globals(): + reports = snakemake.input.reports # type: ignore + summary = snakemake.output.summary # type: ignore + + report_names = [Path(p).stem for p in reports] + parsed_reports = [parse_kraken2_tsv_report(open(p)) for p in reports] + + report_counts, consensus_lineages = create_kraken2_tsv_report( + parsed_reports, report_names + ) + + with open(summary, "w") as out_f: + write_kraken2_tsv_summary(report_counts, consensus_lineages, out_f) diff --git a/scripts/test_summarize_kraken2_reports.py b/scripts/test_summarize_kraken2_reports.py index 1c821ba..fd1570e 100644 --- a/scripts/test_summarize_kraken2_reports.py +++ b/scripts/test_summarize_kraken2_reports.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from scripts.summarize_kraken2_reports_f import ( +from scripts.summarize_kraken2_reports import ( parse_kraken2_tsv_report, create_kraken2_tsv_report, write_kraken2_tsv_summary, @@ -52,6 +52,7 @@ def test_parse_kraken2_tsv_report(reports): reports, _ = reports report = reports[0] parsed_report = parse_kraken2_tsv_report(open(report)) + assert False assert list(next(parsed_report).keys()) == [ "percentage",