-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsetup_data.py
More file actions
77 lines (53 loc) · 2.03 KB
/
setup_data.py
File metadata and controls
77 lines (53 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from pathlib import Path
import tarfile
import cogent3
from cogent3.core.alignment import Alignment
COMPRESSED_DATA = Path("turtle_dataset.tar.gz")
TURTLE_PATH = Path("turtle_dataset/turtle.nex")
PARTITION_PATH = Path("turtle_dataset/partition.nex")
OUT_DIR = Path("turtle_partitions/")
def parse_nexus_charsets(nexus_text: str) -> dict[str, tuple[int, int]]:
lines = nexus_text.splitlines()
result = {}
for line in lines:
line = line.strip()
if not line.lower().startswith("charset"):
continue
_, name, _, start, _, stop = line[:-1].split()
start = int(start.strip())
end = int(stop.strip())
# convert 1-based coordinates to 0-based
result[name.split(".")[0]] = (start - 1, end)
return result
def extract_data(tar_gz_file: os.PathLike) -> None:
with tarfile.open(tar_gz_file, "r:gz") as tar:
tar.extractall(filter="data")
def get_alignment(path: os.PathLike) -> Alignment:
aln = cogent3.load_aligned_seqs(path, moltype="dna")
return aln.renamed_seqs(lambda x: x.split("_")[0].title())
def get_splits(path: os.PathLike) -> dict[str, Alignment]:
partitions = parse_nexus_charsets(path.read_text())
return {name: aln[start:stop] for name, (start, stop) in partitions.items()}
def write_sequences(
out_dir: os.PathLike, aln: Alignment, splits: dict[str, Alignment]
) -> None:
out_dir.mkdir(exist_ok=True)
out = 0
for name, aln in splits.items():
counts = aln.counts_per_seq()
num_valid = counts.row_sum().to_dict()
aln = aln.take_seqs_if(
lambda seq: num_valid[seq.name] > 0,
)
out += 1
outpath = out_dir / f"{name}.fa"
aln.write(outpath)
print(f"Wrote {out}/{len(splits)} to {out_dir}")
if __name__ == "__main__":
extract_data(COMPRESSED_DATA)
aln = get_alignment(TURTLE_PATH)
print(f"{aln.num_seqs} sequences")
splits = get_splits(PARTITION_PATH)
print(f"{len(splits)} partitions")
write_sequences(OUT_DIR, aln, splits)