Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions aligner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# aligner.py
from __future__ import annotations
import os, shutil, subprocess, tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from fasta_io import FastaReader, revcomp
from gff_io import Gene, Transcript

@dataclass
class Aln:
qname: str
rname: str
pos: int # 1-based start on target
cigar: str
mapq: int
is_rev: bool
nm: Optional[int] # edit distance if present
length_q: int

@dataclass
class GeneMap:
gid: str
rname: Optional[str]
start: Optional[int]
end: Optional[int]
strand: Optional[str]
identity: float
coverage: float
cigar: Optional[str]

def _parse_sam_line(ln: str) -> Optional[Aln]:
if ln.startswith("@"): return None
p = ln.rstrip("\n").split("\t")
if len(p) < 11: return None
qname, flag, rname, pos, mapq, cigar = p[0], int(p[1]), p[2], int(p[3]), int(p[4]), p[5]
if rname == "*": return None
tags = { t.split(":",2)[0]: t.split(":",2)[2] for t in p[11:] if ":" in t }
nm = int(tags["NM"]) if "NM" in tags else None
is_rev = bool(flag & 16)
lq = int(tags["ql"]) if "ql" in tags else 0
return Aln(qname, rname, pos, cigar, mapq, is_rev, nm, lq)

def _write_gene_fasta(tmpfa: Path, ref_fa: FastaReader, genes: Dict[str, Gene]) -> Dict[str, Tuple[str,int,int,bool]]:
"""
Write per-gene sequences (always forward in reference genomic orientation).
Returns map gid -> (seqid, start, end, gene_on_minus_strand)
"""
m: Dict[str, Tuple[str,int,int,bool]] = {}
with tmpfa.open("w") as fh:
for g in genes.values():
start, end = (g.start, g.end)
seq = ref_fa.slice(g.seqid, start, end)
if g.strand == "-":
# keep query in forward reference orientation: DO NOT revcomp here
pass
fh.write(f">{g.gid}\n")
for i in range(0, len(seq), 80):
fh.write(seq[i:i+80] + "\n")
m[g.gid] = (g.seqid, start, end, g.strand == "-")
return m

def run_minimap2(query_fa: Path, target_fa: Path, threads: int = 4, extra: Optional[List[str]] = None) -> List[Aln]:
if shutil.which("minimap2") is None:
return [] # handled by fallback
cmd = ["minimap2", "-a", "--eqx", "--end-bonus", "5", "-N", "50", "-p", "0.5", "-t", str(threads)]
if extra: cmd.extend(extra)
cmd.extend([str(target_fa), str(query_fa)])
proc = subprocess.run(cmd, check=True, text=True, stdout=subprocess.PIPE)
alns: List[Aln] = []
for ln in proc.stdout.splitlines():
al = _parse_sam_line(ln)
if al: alns.append(al)
return alns

def fallback_exact_align(ref_fa: FastaReader, tgt_fa: FastaReader, genes: Dict[str, Gene]) -> List[Aln]:
out: List[Aln] = []
for g in genes.values():
seq = ref_fa.slice(g.seqid, g.start, g.end)
for chrom in tgt_fa.chromosomes:
s = tgt_fa.get(chrom)
i = s.find(seq)
is_rev = False
if i == -1:
rc = revcomp(seq)
i = s.find(rc)
is_rev = i != -1
if i != -1:
out.append(Aln(g.gid, chrom, i+1, f"{len(seq)}M", 60, is_rev, 0, len(seq)))
break
return out

def align_genes(ref_fasta: Path, ref_genes: Dict[str, Gene], target_fasta: Path, threads: int = 4) -> Dict[str, GeneMap]:
ref = FastaReader(ref_fasta)
tgt = FastaReader(target_fasta)
with tempfile.TemporaryDirectory() as td:
qfa = Path(td) / "genes.fa"
meta = _write_gene_fasta(qfa, ref, ref_genes)
alns = run_minimap2(qfa, Path(target_fasta), threads=threads)
if not alns:
alns = fallback_exact_align(ref, tgt, ref_genes)
# choose best per gene (highest mapq; then shortest NM; then coverage)
best: Dict[str, Aln] = {}
for a in alns:
cur = best.get(a.qname)
if cur is None or (a.mapq, -(cur.nm or 1<<30)) < (a.mapq, -(a.nm or 1<<30)):
best[a.qname] = a
# convert to GeneMap (identity approx)
gmaps: Dict[str, GeneMap] = {}
for gid, aln in best.items():
qlen = ref_genes[gid].end - ref_genes[gid].start + 1
mm = aln.nm or 0
ident = max(0.0, 1.0 - (mm / max(1, qlen)))
strand = "-" if aln.is_rev else "+"
gmaps[gid] = GeneMap(gid, aln.rname, aln.pos, aln.pos + qlen - 1, strand, ident, 1.0, aln.cigar)
return gmaps
51 changes: 51 additions & 0 deletions cigar_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# cigar_map.py
from __future__ import annotations
from typing import List, Tuple

def parse_cigar(cg: str) -> List[Tuple[int, str]]:
num = ""
out: List[Tuple[int, str]] = []
for ch in cg:
if ch.isdigit():
num += ch
else:
if not num: raise ValueError(f"Bad CIGAR: {cg}")
out.append((int(num), ch))
num = ""
if num: raise ValueError(f"Trailing length in CIGAR: {cg}")
return out

def project_ref_interval(cigar: str, aln_pos_1: int, q_start_1: int, q_end_1: int) -> Tuple[int,int]:
"""
Map a query (reference gene) interval [q_start_1, q_end_1] (1-based, inclusive)
through an alignment that starts on target at aln_pos_1 with CIGAR 'cigar'.
Returns target 1-based inclusive coordinates. Raises if region not fully covered.
"""
ops = parse_cigar(cigar)
q = 1
t = aln_pos_1
t_start = t_end = None
q_s, q_e = q_start_1, q_end_1
for ln, op in ops:
if op in "MX=": # consumes both
for _ in range(ln):
if q == q_s: t_start = t
if q == q_e:
t_end = t
break
q += 1; t += 1
if t_end is not None: break
elif op == "I": # query insertion: consume query only
q += ln
elif op == "D" or op == "N": # deletion/skip: consume target only
t += ln
elif op == "S" or op == "H": # clipping: consume query
q += ln
elif op == "P": # padding: ignore
continue
else:
raise ValueError(f"Unhandled CIGAR op {op}")
if t_start is None or t_end is None:
raise RuntimeError("Interval not fully mapped by alignment")
if t_start > t_end: t_start, t_end = t_end, t_start
return t_start, t_end
54 changes: 54 additions & 0 deletions fasta_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# fasta_io.py
from __future__ import annotations
from pathlib import Path
from typing import Dict, Iterable

class FastaReader:
"""
Lightweight FASTA reader with in-memory contigs and O(1) slicing.
"""
def __init__(self, path: str | Path) -> None:
self.path = Path(path)
self._seqs: Dict[str, str] = {}

def _ensure_loaded(self) -> None:
if self._seqs:
return
cur = None
buf: list[str] = []
with self.path.open() as fh:
for ln in fh:
ln = ln.strip()
if not ln:
continue
if ln.startswith(">"):
if cur is not None:
self._seqs[cur] = "".join(buf).upper()
cur = ln[1:].split()[0]
buf = []
else:
buf.append(ln)
if cur is not None:
self._seqs[cur] = "".join(buf).upper()

@property
def chromosomes(self) -> Iterable[str]:
self._ensure_loaded()
return self._seqs.keys()

def get(self, chrom: str) -> str:
self._ensure_loaded()
return self._seqs[chrom]

def slice(self, chrom: str, start_1: int, end_1: int) -> str:
"""Inclusive, 1-based slice."""
self._ensure_loaded()
s = self._seqs[chrom]
if not (1 <= start_1 <= end_1 <= len(s)):
raise ValueError(f"Invalid slice {chrom}:{start_1}-{end_1} (len={len(s)})")
return s[start_1-1:end_1]

def revcomp(seq: str) -> str:
tr = str.maketrans("ACGTRYSWKMBDHVNacgtryswkmbdhvn",
"TGCAYRSWMKVHDBNtgcayrswmkvhdbn")
return seq.translate(tr)[::-1]
111 changes: 111 additions & 0 deletions gff_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# gff_io.py
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Iterable, Tuple

def parse_attrs(s: str) -> Dict[str, str]:
out: Dict[str, str] = {}
for kv in filter(None, s.split(";")):
if "=" in kv:
k, v = kv.split("=", 1)
out[k] = v
return out

def fmt_attrs(d: Dict[str, str]) -> str:
return ";".join(f"{k}={v}" for k, v in d.items())

@dataclass
class Exon:
start: int
end: int

@dataclass
class CDS:
start: int
end: int
phase: int = 0

@dataclass
class Transcript:
tid: str
seqid: str
start: int
end: int
strand: str
attrs: Dict[str, str] = field(default_factory=dict)
exons: List[Exon] = field(default_factory=list)
cdss: List[CDS] = field(default_factory=list)

@dataclass
class Gene:
gid: str
seqid: str
start: int
end: int
strand: str
attrs: Dict[str, str] = field(default_factory=dict)
transcripts: Dict[str, Transcript] = field(default_factory=dict)

def load_gff(path: str | Path) -> Dict[str, Gene]:
"""
Minimal, Ensembl-style GFF3 parser (gene/mRNA/exon/CDS). Keeps hierarchy.
"""
genes: Dict[str, Gene] = {}
path = Path(path)
with path.open() as fh:
for ln in fh:
if not ln or ln.startswith("#"):
continue
p = ln.rstrip("\n").split("\t")
if len(p) < 9:
continue
seqid, _src, ftype, s, e, _score, strand, phase, attr = p
s, e = int(s), int(e)
a = parse_attrs(attr)
fid = a.get("ID")
parent = a.get("Parent")
if ftype == "gene" and fid:
genes[fid] = Gene(fid, seqid, s, e, strand, a, {})
elif ftype in ("mRNA", "transcript") and fid and parent and parent in genes:
genes[parent].transcripts[fid] = Transcript(fid, seqid, s, e, strand, a)
elif ftype == "exon" and parent:
for pid in parent.split(","):
for g in genes.values():
tx = g.transcripts.get(pid)
if tx:
tx.exons.append(Exon(s, e))
break
elif ftype == "CDS" and parent:
for pid in parent.split(","):
for g in genes.values():
tx = g.transcripts.get(pid)
if tx:
cd = CDS(s, e, 0 if phase == "." else int(phase))
tx.cdss.append(cd)
break
# canonical ordering
for g in genes.values():
for tx in g.transcripts.values():
tx.exons.sort(key=lambda x: x.start)
tx.cdss.sort(key=lambda x: x.start)
return genes

def write_gff(genes: Dict[str, Gene], out: str | Path) -> None:
out = Path(out)
rows: List[Tuple] = []
for g in genes.values():
rows.append((g.seqid, "idmapper", "gene", g.start, g.end, ".", g.strand, ".", fmt_attrs(g.attrs)))
for tx in g.transcripts.values():
rows.append((tx.seqid, "idmapper", "mRNA", tx.start, tx.end, ".", tx.strand, ".", fmt_attrs(tx.attrs)))
for i, ex in enumerate(tx.exons, 1):
attrs = {"ID": f"{tx.tid}.exon{i}", "Parent": tx.tid}
rows.append((tx.seqid, "idmapper", "exon", ex.start, ex.end, ".", tx.strand, ".", fmt_attrs(attrs)))
for i, cd in enumerate(tx.cdss, 1):
attrs = {"ID": f"{tx.tid}.cds{i}", "Parent": tx.tid}
rows.append((tx.seqid, "idmapper", "CDS", cd.start, cd.end, ".", tx.strand, str(cd.phase), fmt_attrs(attrs)))
rows.sort(key=lambda r: (r[0], r[3], r[2]))
with out.open("w") as fh:
fh.write("##gff-version 3\n")
for r in rows:
fh.write("\t".join(map(str, r)) + "\n")
Loading