diff --git a/include/classify_arguments.hpp b/include/classify_arguments.hpp index e57df13..8b9e6d3 100644 --- a/include/classify_arguments.hpp +++ b/include/classify_arguments.hpp @@ -32,8 +32,8 @@ struct ClassifyArguments { // Output options bool run_extract {false}; std::string category_to_extract; - std::filesystem::path extract_file; - std::filesystem::path extract_file2; + std::string prefix; + std::unordered_map> extract_category_to_file; // General options std::string log_file {"charon.log"}; @@ -62,7 +62,7 @@ struct ClassifyArguments { ss += "\tmin_diff:\t\t" + std::to_string(min_proportion_difference) + "\n\n"; ss += "\tcategory_to_extract:\t" + category_to_extract + "\n"; - ss += "\textract_file:\t\t" + extract_file.string() + "\n\n"; + ss += "\tprefix:\t" + prefix + "\n\n"; ss += "\tlog_file:\t\t" + log_file + "\n"; ss += "\tthreads:\t\t" + std::to_string(threads) + "\n"; diff --git a/include/dehost_arguments.hpp b/include/dehost_arguments.hpp index 5f1d2f2..26412b1 100644 --- a/include/dehost_arguments.hpp +++ b/include/dehost_arguments.hpp @@ -12,8 +12,14 @@ struct DehostArguments { std::filesystem::path read_file2; bool is_paired { false }; std::string db; - uint8_t chunk_size { 100 }; + // Output options + bool run_extract {false}; + std::string category_to_extract; + std::string prefix; + std::unordered_map> extract_category_to_file; + + uint8_t chunk_size { 100 }; // Stats options float lo_hi_threshold {0.15}; @@ -29,11 +35,7 @@ struct DehostArguments { float min_proportion_difference { 0.04 }; float min_prob_difference{ 0 }; - // Output options - bool run_extract {false}; - std::string category_to_extract; - std::filesystem::path extract_file; - std::filesystem::path extract_file2; + // General options std::string log_file {"charon.log"}; @@ -45,32 +47,31 @@ struct DehostArguments { std::string ss; ss += "\n\nDehost Arguments:\n\n"; - ss += "\tread_file:\t\t" + read_file.string() + "\n"; - ss += "\tread_file2:\t\t" + read_file2.string() + "\n"; - ss += "\tdb:\t\t\t" + db + "\n\n"; + ss += "\tread_file:\t\t\t" + read_file.string() + "\n"; + ss += "\tread_file2:\t\t\t" + read_file2.string() + "\n"; + ss += "\tdb:\t\t\t\t" + db + "\n\n"; - ss += "\tcategory_to_extract:\t" + category_to_extract + "\n"; - ss += "\textract_file:\t\t" + extract_file.string() + "\n\n"; - ss += "\textract_file2:\t\t" + extract_file2.string() + "\n\n"; + ss += "\tcategory_to_extract:\t\t" + category_to_extract + "\n"; + ss += "\tprefix:\t\t\t\t" + prefix + "\n\n"; - ss += "\tchunk_size:\t\t" + std::to_string(chunk_size) + "\n\n"; + ss += "\tchunk_size:\t\t\t" + std::to_string(chunk_size) + "\n\n"; ss += "\tlo_hi_threshold:\t\t" + std::to_string(lo_hi_threshold) + "\n"; - ss += "\tnum_reads_to_fit:\t" + std::to_string(num_reads_to_fit) + "\n"; - ss += "\tdist:\t\t" + dist + "\n\n"; + ss += "\tnum_reads_to_fit:\t\t" + std::to_string(num_reads_to_fit) + "\n"; + ss += "\tdist:\t\t\t\t" + dist + "\n\n"; - ss += "\tmin_length:\t\t" + std::to_string(min_length) + "\n"; - ss += "\tmin_quality:\t\t" + std::to_string(min_quality) + "\n"; + ss += "\tmin_length:\t\t\t" + std::to_string(min_length) + "\n"; + ss += "\tmin_quality:\t\t\t" + std::to_string(min_quality) + "\n"; ss += "\tmin_compression:\t\t" + std::to_string(min_compression) + "\n"; - ss += "\tconfidence_threshold:\t" + std::to_string(confidence_threshold) + "\n"; + ss += "\tconfidence_threshold:\t\t" + std::to_string(confidence_threshold) + "\n"; ss += "\thost_unique_prop_lo_threshold:\t" + std::to_string(host_unique_prop_lo_threshold) + "\n"; - ss += "\tmin_proportion_difference:\t\t" + std::to_string(min_proportion_difference) + "\n\n"; + ss += "\tmin_proportion_difference:\t" + std::to_string(min_proportion_difference) + "\n"; ss += "\tmin_prob_difference:\t\t" + std::to_string(min_prob_difference) + "\n\n"; - ss += "\tlog_file:\t\t" + log_file + "\n"; - ss += "\tthreads:\t\t" + std::to_string(threads) + "\n"; - ss += "\tverbosity:\t\t" + std::to_string(verbosity) + "\n\n"; + ss += "\tlog_file:\t\t\t" + log_file + "\n"; + ss += "\tthreads:\t\t\t" + std::to_string(threads) + "\n"; + ss += "\tverbosity:\t\t\t" + std::to_string(verbosity) + "\n\n"; return ss; } diff --git a/include/index.hpp b/include/index.hpp index 0448d37..97b81d0 100644 --- a/include/index.hpp +++ b/include/index.hpp @@ -81,6 +81,15 @@ class Index return index; } + uint8_t get_category_index(const std::string category) const + { + const auto index = summary_.category_index(category); + if (index == std::numeric_limits::max()) + PLOG_ERROR << "Index does not contain category "; + assert(index < std::numeric_limits::max()); + return index; + } + double max_fpr() const { return max_fpr_; diff --git a/include/result.hpp b/include/result.hpp index d889e71..f6dddb3 100644 --- a/include/result.hpp +++ b/include/result.hpp @@ -44,9 +44,7 @@ class Result std::vector> cached_reads_; bool run_extract_; - uint8_t extract_category_; - seqan3::sequence_file_output extract_handle_; - seqan3::sequence_file_output extract_handle2_; + std::unordered_map>> extract_handles_; public: Result() = default; @@ -59,39 +57,33 @@ class Result Result(const ClassifyArguments& opt, const InputSummary & summary): input_summary_{summary}, result_summary_(summary.num_categories()), - run_extract_(opt.run_extract), - extract_handle_{std::cout, seqan3::format_fasta{}}, - extract_handle2_{std::cout, seqan3::format_fasta{}} + run_extract_(opt.run_extract) { stats_model_ = StatsModel(opt, summary); if (opt.run_extract){ - extract_handle_ = seqan3::sequence_file_output{opt.extract_file}; - if (opt.is_paired) - extract_handle2_ = seqan3::sequence_file_output{opt.extract_file2}; - - extract_category_ = category_index(opt.category_to_extract); + for (const auto [category_index,extract_files] : opt.extract_category_to_file) { + for (const auto extract_file : extract_files) + extract_handles_[category_index].push_back(seqan3::sequence_file_output{extract_file}); + cached_reads_.reserve(opt.num_reads_to_fit*summary.num_categories()*4); + } } - cached_reads_.reserve(opt.num_reads_to_fit*summary.num_categories()*4); + }; Result(const DehostArguments& opt, const InputSummary & summary): input_summary_{summary}, result_summary_(summary.num_categories()), - run_extract_(opt.run_extract), - extract_handle_{std::cout, seqan3::format_fasta{}}, - extract_handle2_{std::cout, seqan3::format_fasta{}} + run_extract_(opt.run_extract) { stats_model_ = StatsModel(opt, summary); if (opt.run_extract){ - extract_handle_ = seqan3::sequence_file_output{opt.extract_file}; - if (opt.is_paired) - extract_handle2_ = seqan3::sequence_file_output{opt.extract_file2}; - - extract_category_ = category_index(opt.category_to_extract); + for (const auto [category_index,extract_files] : opt.extract_category_to_file) { + for (const auto extract_file : extract_files) + extract_handles_[category_index].push_back(seqan3::sequence_file_output{extract_file}); + cached_reads_.reserve(opt.num_reads_to_fit*summary.num_categories()*4); + } } - cached_reads_.reserve(opt.num_reads_to_fit*summary.num_categories()*4); - }; const InputSummary& input_summary() const @@ -104,7 +96,7 @@ class Result return input_summary_.category_index(category); } - bool classify_read(ReadEntry& read_entry, const bool dehost=false) + uint8_t classify_read(ReadEntry& read_entry, const bool dehost=false) { PLOG_VERBOSE << "Classify read " << read_entry.read_id(); if (dehost) @@ -123,28 +115,29 @@ class Result result_summary_.unclassified_count += 1; } } - return run_extract_ and read_entry.call() == extract_category_; + return read_entry.call(); + } - void extract_read(const record_type& record) + void extract_read(const uint8_t category_index, const record_type& record) { #pragma omp critical(extract_read) - extract_handle_.push_back(record); + extract_handles_[category_index][0].push_back(record); } - void extract_paired_read(const record_type& record, const record_type& record2) + void extract_paired_read(const uint8_t category_index, const record_type& record, const record_type& record2) { #pragma omp critical(extract_read) - extract_handle_.push_back(record); + extract_handles_[category_index][0].push_back(record); #pragma omp critical(extract_read2) - extract_handle2_.push_back(record2); + extract_handles_[category_index][1].push_back(record2); } void add_read(ReadEntry& read_entry, const record_type& record, bool dehost=false){ if (stats_model_.ready()) { - auto read_to_extract = classify_read(read_entry, dehost); - if (run_extract_ and read_to_extract){ - extract_read(record); + auto category_index = classify_read(read_entry, dehost); + if (run_extract_ and extract_handles_.find(category_index) != extract_handles_.end()){ + extract_read(category_index, record); } } else { PLOG_VERBOSE << "Add read " << read_entry.read_id() << " to training "; @@ -168,9 +161,9 @@ class Result void add_paired_read(ReadEntry& read_entry, const record_type& record, const record_type& record2, const bool dehost=false){ if (stats_model_.ready()) { - auto read_to_extract = classify_read(read_entry, dehost); - if (run_extract_ and read_to_extract){ - extract_paired_read(record, record2); + auto category_index = classify_read(read_entry, dehost); + if (run_extract_ and extract_handles_.find(category_index) != extract_handles_.end()){ + extract_paired_read(category_index, record, record2); } } else { PLOG_VERBOSE << "Add read " << read_entry.read_id() << " to training "; @@ -200,14 +193,15 @@ class Result { auto & read_entry = read_record.read; const auto & record = read_record.record; - bool read_to_extract = classify_read(read_entry, dehost); - if (run_extract_ and read_to_extract){ + auto category_index = classify_read(read_entry, dehost); + std::cout << +category_index << (category_index < std::numeric_limits::max()) << std::endl; + if (run_extract_ and extract_handles_.find(category_index) != extract_handles_.end()){ if (read_record.is_paired) { const auto & record2 = read_record.record2; - extract_paired_read(record, record2); + extract_paired_read(category_index, record, record2); } else { - extract_read(record); + extract_read(category_index, record); } } } diff --git a/include/utils.hpp b/include/utils.hpp index 9181cb9..4fdf67a 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -41,4 +41,6 @@ size_t max_num_hashes_for_fpr(const IndexArguments & opt); std::string sequence_to_string(const __type_pack_element<0, std::vector, std::string, std::vector>& input); float get_compression_ratio(const std::string& sequence); + +std::string get_extension(const std::filesystem::path); #endif diff --git a/src/classify_main.cpp b/src/classify_main.cpp index 5f32bd5..17b14ee 100644 --- a/src/classify_main.cpp +++ b/src/classify_main.cpp @@ -57,15 +57,10 @@ void setup_classify_subcommand(CLI::App& app) classify_subcommand->add_option("-e,--extract", opt->category_to_extract, "Reads from this category in the index will be extracted to file.") ->type_name("STRING"); - classify_subcommand->add_option("--extract_file", opt->extract_file, "Fasta/q file for output") - ->transform(make_absolute) - ->check(CLI::NonexistentPath.description("")) - ->type_name("FILE"); - - classify_subcommand->add_option("--extract_file2", opt->extract_file2, "Fasta/q file for output") - ->transform(make_absolute) + classify_subcommand->add_option("-p,--prefix", opt->prefix, "Prefix for the output files.") + ->type_name("FILE") ->check(CLI::NonexistentPath.description("")) - ->type_name("FILE"); + ->default_str(""); classify_subcommand->add_option("-d,--dist", opt->dist, "Probability distribution to use for modelling.") ->type_name("STRING"); @@ -293,21 +288,31 @@ int classify_main(ClassifyArguments & opt) opt.run_extract = (opt.category_to_extract != ""); const auto categories = index.categories(); - if (opt.run_extract and std::find(categories.begin(), categories.end(), opt.category_to_extract) == categories.end()) + if (opt.run_extract and opt.category_to_extract != "all" and std::find(categories.begin(), categories.end(), opt.category_to_extract) == categories.end()) { std::string options = ""; for (auto i: categories) options += i + " "; - PLOG_ERROR << "Cannot extract " << opt.category_to_extract << ", please chose one of [ " << options << "]"; + PLOG_ERROR << "Cannot extract " << opt.category_to_extract << ", please chose one of [ all " << options << "]"; return 1; - } else if (opt.run_extract and opt.extract_file == ""){ - opt.extract_file = opt.read_file; - if (opt.is_paired){ - opt.extract_file.replace_extension(opt.category_to_extract + opt.read_file.extension().string()); - opt.extract_file2 = opt.read_file2; - opt.extract_file2.replace_extension(opt.category_to_extract + opt.read_file.extension().string()); - } else - opt.extract_file.replace_extension(opt.category_to_extract + opt.read_file.extension().string()); + } else if (opt.run_extract){ + if (opt.prefix == "") + opt.prefix = "charon"; + std::vector to_extract; + if (opt.category_to_extract == "all") + to_extract = categories; + else + to_extract.push_back(opt.category_to_extract); + const auto extension = get_extension(opt.read_file); + for (const auto & category : to_extract){ + const auto category_index = index.get_category_index(category); + if (opt.is_paired) { + opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + "_1" + extension + ".gz"); + opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + "_2" + extension + ".gz"); + } else { + opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + extension + ".gz"); + } + } } if (opt.dist != "gamma" and opt.dist != "beta") diff --git a/src/dehost_main.cpp b/src/dehost_main.cpp index e2253a6..2cd6b75 100644 --- a/src/dehost_main.cpp +++ b/src/dehost_main.cpp @@ -230,15 +230,10 @@ void setup_dehost_subcommand(CLI::App& app) dehost_subcommand->add_option("-e,--extract", opt->category_to_extract, "Reads from this category in the index will be extracted to file.") ->type_name("STRING"); - dehost_subcommand->add_option("--extract_file", opt->extract_file, "Fasta/q file for output") - ->transform(make_absolute) - ->check(CLI::NonexistentPath.description("")) - ->type_name("FILE"); - - dehost_subcommand->add_option("--extract_file2", opt->extract_file2, "Fasta/q file for output") - ->transform(make_absolute) + dehost_subcommand->add_option("-p,--prefix", opt->prefix, "Prefix for the output files.") + ->type_name("FILE") ->check(CLI::NonexistentPath.description("")) - ->type_name("FILE"); + ->default_str(""); dehost_subcommand ->add_option("--chunk_size", opt->chunk_size, "Read file is read in chunks of this size, to be processed in parallel within a chunk.") @@ -502,21 +497,32 @@ int dehost_main(DehostArguments & opt) opt.run_extract = (opt.category_to_extract != ""); const auto categories = index.categories(); - if (opt.run_extract and std::find(categories.begin(), categories.end(), opt.category_to_extract) == categories.end()) + if (opt.run_extract and opt.category_to_extract != "all" and std::find(categories.begin(), categories.end(), opt.category_to_extract) == categories.end()) { std::string options = ""; for (auto i: categories) options += i + " "; - PLOG_ERROR << "Cannot extract " << opt.category_to_extract << ", please chose one of [ " << options << "]"; + PLOG_ERROR << "Cannot extract " << opt.category_to_extract << ", please chose one of [ all " << options << "]"; return 1; - } else if (opt.run_extract and opt.extract_file == ""){ - opt.extract_file = opt.read_file; - if (opt.is_paired){ - opt.extract_file.replace_extension(opt.category_to_extract + opt.read_file.extension().string()); - opt.extract_file2 = opt.read_file2; - opt.extract_file2.replace_extension(opt.category_to_extract + opt.read_file.extension().string()); - } else - opt.extract_file.replace_extension(opt.category_to_extract + opt.read_file.extension().string()); + } else if (opt.run_extract){ + if (opt.prefix == "") + opt.prefix = "charon"; + std::vector to_extract; + if (opt.category_to_extract == "all") + to_extract = categories; + else + to_extract.push_back(opt.category_to_extract); + + const auto extension = get_extension(opt.read_file); + for (const auto & category : to_extract){ + const auto category_index = index.get_category_index(category); + if (opt.is_paired) { + opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + "_1" + extension + ".gz"); + opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + "_2" + extension + ".gz"); + } else { + opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + extension + ".gz"); + } + } } if (opt.dist != "gamma" and opt.dist != "beta" and opt.dist != "kde") diff --git a/src/utils.cpp b/src/utils.cpp index 99f2232..68d0bd2 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -129,3 +129,11 @@ float get_compression_ratio(const std::string& sequence){ return static_cast(compressed_size)/static_cast(initial_size); } +std::string get_extension(const std::filesystem::path read_file){ + auto ext = read_file.extension().string(); + if (ext == ".gz"){ + std::filesystem::path short_read_file = read_file.stem(); + ext = short_read_file.extension().string(); + } + return ext; +}