Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bin/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def merge_all_assignments(list_assignment_files, output_file):
for assignment_file in list_assignment_files:
new_assignments = KrakenAssignments(assignment_file, load=True)
changes = kraken_assignments.update(new_assignments, changes)

kraken_assignments.save()
return changes

Expand Down
80 changes: 69 additions & 11 deletions bin/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,34 @@ def get_percentage(self, taxon_id, denominator="classified"):
# print(f"{count}/{total} = {percentage:.2f}")
return percentage

def add_sorted_descendants(self, taxon_id, sorted_list):
"""
Add all descendants of taxon_id to sorted_list in order of count (highest first).

Parameters:
taxon_id (str): a taxon ID
sorted_list (list): list of taxon_ids to be appended to
"""
children = list(self.entries[taxon_id].children)
children.sort(key=lambda x: (self.entries[x].count, self.entries[x].name), reverse=True)
sorted_list.append(taxon_id)
for child in children:
self.add_sorted_descendants(child, sorted_list)

def sort_entries(self):
"""
Sort the entries dictionary so that parent/child relationships preserved.
"""

sorted_list = []

if "0" in self.entries:
sorted_list.append("0")
if "1" in self.entries:
self.add_sorted_descendants("1", sorted_list)

self.entries = dict([(key,self.entries[key]) for key in sorted_list])

def to_source_target_df(
self, out_file="source_target.csv", max_rank=None, domain=None
):
Expand Down Expand Up @@ -550,15 +578,23 @@ def get_mrca(self, taxon_id_1, taxon_id_2):

entry1 = self.entries[taxon_id_1]
entry2 = self.entries[taxon_id_2]

if taxon_id_1 == "1" and taxon_id_1 in entry2.hierarchy:
# print(f"MRCA of old {taxon_id_1} and new {taxon_id_2} is {taxon_id_1}")
return taxon_id_1

while (
i < len(entry1.hierarchy)
and i < len(entry2.hierarchy)
and entry1.hierarchy[i] == entry2.hierarchy[i]
):
if i == len(entry1.hierarchy) - 1 or i == len(entry2.hierarchy) - 1:
break
elif entry1.hierarchy[i+1] != entry2.hierarchy[i+1]:
break
i += 1
# print(f"MRCA of old {taxon_id_1} and new {taxon_id_2} is {entry1.hierarchy[i]}")

# print(f"MRCA of old {taxon_id_1} and new {taxon_id_2} is {entry1.hierarchy} position {i}")
return entry1.hierarchy[i]

def update_counts(self, changes):
Expand All @@ -570,52 +606,72 @@ def update_counts(self, changes):
"""
for old_taxon_id in changes:
for new_taxon_id in changes[old_taxon_id]:
print(
f"Moving {changes[old_taxon_id][new_taxon_id]} counts from {old_taxon_id} to {new_taxon_id}"
)
print(f"Moving {changes[old_taxon_id][new_taxon_id]} counts from {old_taxon_id} to {new_taxon_id}")

mrca = self.get_mrca(old_taxon_id, new_taxon_id)
print(f"MRCA of {old_taxon_id} and {new_taxon_id} is {mrca}")

self.entries[old_taxon_id].ucount -= changes[old_taxon_id][new_taxon_id]
self.entries[old_taxon_id].count -= changes[old_taxon_id][new_taxon_id]
assert self.entries[old_taxon_id].ucount >= 0
print(f"Removing {changes[old_taxon_id][new_taxon_id]} ucounts from {old_taxon_id}")

mrca = self.get_mrca(old_taxon_id, new_taxon_id)
if not (old_taxon_id == "1" and old_taxon_id in self.entries[new_taxon_id].hierarchy):
self.entries[old_taxon_id].count -= changes[old_taxon_id][new_taxon_id]
print(f"Removing {changes[old_taxon_id][new_taxon_id]} counts from {old_taxon_id}")

assert self.entries[old_taxon_id].ucount >= 0

if old_taxon_id != "0":
for taxon_id in reversed(self.entries[old_taxon_id].hierarchy):
if taxon_id != mrca:
self.entries[taxon_id].count -= changes[old_taxon_id][
new_taxon_id
]
print(f"Removing {changes[old_taxon_id][new_taxon_id]} counts from {taxon_id}")
assert self.entries[taxon_id].count >= 0
elif taxon_id == mrca:
break

self.entries[new_taxon_id].ucount += changes[old_taxon_id][new_taxon_id]
self.entries[new_taxon_id].count += changes[old_taxon_id][new_taxon_id]
if not (new_taxon_id == "1" and new_taxon_id in self.entries[old_taxon_id].hierarchy):
self.entries[new_taxon_id].ucount += changes[old_taxon_id][new_taxon_id]
self.entries[new_taxon_id].count += changes[old_taxon_id][new_taxon_id]
print(f"Adding {changes[old_taxon_id][new_taxon_id]} counts and ucounts to {new_taxon_id}")

for taxon_id in reversed(self.entries[new_taxon_id].hierarchy):
if taxon_id != mrca:
self.entries[taxon_id].count += changes[old_taxon_id][
new_taxon_id
]
print(f"Adding {changes[old_taxon_id][new_taxon_id]} counts to {taxon_id}")

elif taxon_id == mrca:
break

self.unclassified = self.entries["0"].count
self.classified = self.entries["1"].count if "1" in self.entries else 0
if (self.total != self.classified + self.unclassified):
print(f"Broke after {old_taxon_id} and {new_taxon_id} with {self.unclassified}, {self.classified}")
assert self.total == self.classified + self.unclassified

def clean(self):
"""
Removes entries which have 0 counts and references to them.
"""
set_zeroes = set()
for taxon_id in self.entries:
if self.entries[taxon_id].count == 0:
if self.entries[taxon_id].count == 0 and taxon_id not in ["0", "1"]:
set_zeroes.add(taxon_id)
for taxon_id in set_zeroes:
if taxon_id not in self.entries:
continue
entry = self.entries[taxon_id]
if entry.parent in self.entries:
print(taxon_id, entry.parent, self.entries[entry.parent].children)
print("Removing zero entry", taxon_id, "with parent", entry.parent, "and children", self.entries[entry.parent].children)
self.entries[entry.parent].children.remove(taxon_id)
for child in entry.children:
if child in self.entries:
assert child in set_zeroes
del self.entries[taxon_id]
print(f"Removed {len(set_zeroes)} zero-count entries")

def update(self, new_report, changes):
"""
Expand Down Expand Up @@ -654,6 +710,8 @@ def save(self, file_name=None):
"""
Save the KrakenReport object in kraken report format
"""
self.sort_entries()

if not file_name:
file_name = self.file_name
with open(file_name, "w") as out:
Expand Down
55 changes: 55 additions & 0 deletions modules/utils.nf
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import groovy.json.JsonBuilder

include { paired_concatenate } from '../modules/preprocess'

process get_versions {

conda 'environment.yml'
Expand Down Expand Up @@ -90,3 +92,56 @@ workflow get_fastq_ch {
emit:
fastq_ch
}

workflow get_fastq_channels {
take:
unique_id

main:
if (params.run_dir) {
run_dir = file("${params.run_dir}", type: "dir", checkIfExists: true)
if (params.paired) {
fastq_ch = Channel.fromFilePairs("${run_dir}/*_R{1,2}*.f*q*", type: "file", checkIfExists: true)

paired_concatenate(fastq_ch)
paired_concatenate.out.concatenated_fastq.set { combined_fastq_ch }
}
else {
fastq_ch = Channel.fromPath("${run_dir}/*", type: "dir", checkIfExists: true, maxDepth: 1).map { [it.baseName, get_fq_files_in_dir(it)] }
fastq_ch.tap { combined_fastq_ch }
}
}
else if (params.paired && params.fastq1 && params.fastq2) {
fastq1 = file(params.fastq1, type: "file", checkIfExists: true)
fastq2 = file(params.fastq2, type: "file", checkIfExists: true)
fastq_ch = Channel.from([[unique_id, fastq1, fastq2]])

paired_concatenate(fastq_ch)
paired_concatenate.out.concatenated_fastq.set { combined_fastq_ch }
}
else if (params.fastq) {
fastq = file(params.fastq, type: "file", checkIfExists: true)
fastq_ch = Channel.from([[unique_id, fastq]])

fastq_ch.tap { combined_fastq_ch }
}
else if (params.fastq_dir) {
fastqdir = file("${params.fastq_dir}", type: "dir", checkIfExists: true)
Channel.fromPath(fastqdir / "*.f*q*", type: "file")
.set { input_ch }
fastq_ch = input_ch.map { fastq -> [unique_id, fastq] }

fastq_ch
.map { unique_id, fastq -> [unique_id + ".fq.gz", fastq] }
.collectFile()
.map { it -> [it.simpleName, it] }
.set { combined_fastq_ch }
}
else {
error "No input fastq files provided. Please provide either --run_dir, --fastq_dir, --fastq or --fastq1 and --fastq2 with --paired."
}

emit:
processed_fastq = fastq_ch
combined_fastq = combined_fastq_ch
}
37 changes: 27 additions & 10 deletions subworkflows/classify.nf
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,17 @@ process merge_classifications {
"""
}

workflow classify {
workflow reclassify {
take:
fastq_ch
concat_fastq_ch
assignments_ch
kreport_ch
raise_server

main:
default_classify(concat_fastq_ch, raise_server)

if (params.run_viral_reclassification) {
setup_taxonomy()
extract_virus_fraction(fastq_ch, default_classify.out.assignments, default_classify.out.kreport, setup_taxonomy.out)
viral_classify(extract_virus_fraction.out.virus, raise_server)
default_classify.out.assignments
.join(default_classify.out.kreport, by: [0, 1])
viral_classify(fastq_ch, raise_server)
assignments_ch
.join(kreport_ch, by: [0, 1])
.map { unique_id, database_name, assignments, kreport -> [unique_id, assignments, kreport] }
.set { default_ch }
viral_classify.out.assignments
Expand All @@ -86,6 +82,27 @@ workflow classify {
merge_classifications(merge_ch)
assignments = merge_classifications.out.assignments
kreport = merge_classifications.out.kreport

emit:
assignments = assignments
kreport = kreport
}

workflow classify {
take:
fastq_ch
concat_fastq_ch
raise_server

main:
default_classify(concat_fastq_ch, raise_server)

if (params.run_viral_reclassification) {
setup_taxonomy()
extract_virus_fraction(fastq_ch, default_classify.out.assignments, default_classify.out.kreport, setup_taxonomy.out)
reclassify(extract_virus_fraction.out.virus, default_classify.out.assignments, default_classify.out.kreport, raise_server)
assignments = reclassify.out.assignments
kreport = reclassify.out.kreport
}
else {
assignments = default_classify.out.assignments
Expand Down
17 changes: 12 additions & 5 deletions subworkflows/run_module.nf
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
// workflow to run kraken, check for human, run qc checks and generate html report for a single sample fastq
include { get_params_and_versions ; get_fastq_ch } from '../modules/utils'
include { get_params_and_versions ; get_fastq_channels } from '../modules/utils'

include { preprocess } from '../modules/preprocess'
include { qc_checks } from '../modules/qc_checks'
include { centrifuge_classify } from '../modules/centrifuge_classification'
include { kraken_classify } from '../modules/kraken_classification'
include { sourmash_classify } from '../modules/sourmash_classification'
include { check_hcid_status } from '../modules/check_hcid_status'
include { check_spike_status } from '../modules/check_spike_status'
include { extract_all } from '../modules/extract_all'
include { classify_virus_fastq } from '../modules/classify_novel_viruses'
include { classify ; reclassify } from '../subworkflows/classify'

workflow run_module {
take:
unique_id

main:
get_params_and_versions(unique_id)
get_fastq_ch(unique_id)
fastq_ch = get_fastq_ch.out
get_fastq_channels(unique_id)
fastq_ch = get_fastq_channels.out.processed_fastq

if (params.module == "preprocess") {
preprocess(unique_id)
Expand All @@ -30,7 +30,7 @@ workflow run_module {
centrifuge_classify(fastq_ch)
}
else if (params.module == "kraken_classification") {
kraken_classify(fastq_ch, "default", params.raise_server)
classify(fastq_ch, get_fastq_channels.out.combined_fastq, params.raise_server)
}
else if (params.module == "sourmash_classification") {
sourmash_classify(fastq_ch)
Expand All @@ -53,6 +53,13 @@ workflow run_module {
kreport_ch = Channel.of([unique_id, "default", kreport])
taxonomy_dir = file(params.taxonomy, type: "dir", checkIfExists: true)
extract_all(fastq_ch, assignments_ch, kreport_ch, taxonomy_dir)
}
else if (params.module == "kraken_reclassification") {
assignments = file(params.kraken_assignments, type: "file", checkIfExists: true)
assignments_ch = Channel.of([unique_id, "default", assignments])
kreport = file(params.kraken_report, type: "file", checkIfExists: true)
kreport_ch = Channel.of([unique_id, "default", kreport])
reclassify(fastq_ch, assignments_ch, kreport_ch, params.raise_server)
}
else if (params.module == "classify_novel_viruses") {
classify_virus_fastq(fastq_ch)
Expand Down