Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/classify_arguments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct ClassifyArguments {
// Output options
bool run_extract {false};
std::string category_to_extract;
std::filesystem::path extract_file;
std::filesystem::path extract_file2;
std::string prefix;
std::unordered_map<uint8_t, std::vector<std::filesystem::path>> extract_category_to_file;

// General options
std::string log_file {"charon.log"};
Expand Down Expand Up @@ -62,7 +62,7 @@ struct ClassifyArguments {
ss += "\tmin_diff:\t\t" + std::to_string(min_proportion_difference) + "\n\n";

ss += "\tcategory_to_extract:\t" + category_to_extract + "\n";
ss += "\textract_file:\t\t" + extract_file.string() + "\n\n";
ss += "\tprefix:\t" + prefix + "\n\n";

ss += "\tlog_file:\t\t" + log_file + "\n";
ss += "\tthreads:\t\t" + std::to_string(threads) + "\n";
Expand Down
45 changes: 23 additions & 22 deletions include/dehost_arguments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ struct DehostArguments {
std::filesystem::path read_file2;
bool is_paired { false };
std::string db;
uint8_t chunk_size { 100 };

// Output options
bool run_extract {false};
std::string category_to_extract;
std::string prefix;
std::unordered_map<uint8_t, std::vector<std::filesystem::path>> extract_category_to_file;

uint8_t chunk_size { 100 };

// Stats options
float lo_hi_threshold {0.15};
Expand All @@ -29,11 +35,7 @@ struct DehostArguments {
float min_proportion_difference { 0.04 };
float min_prob_difference{ 0 };

// Output options
bool run_extract {false};
std::string category_to_extract;
std::filesystem::path extract_file;
std::filesystem::path extract_file2;


// General options
std::string log_file {"charon.log"};
Expand All @@ -45,32 +47,31 @@ struct DehostArguments {
std::string ss;

ss += "\n\nDehost Arguments:\n\n";
ss += "\tread_file:\t\t" + read_file.string() + "\n";
ss += "\tread_file2:\t\t" + read_file2.string() + "\n";
ss += "\tdb:\t\t\t" + db + "\n\n";
ss += "\tread_file:\t\t\t" + read_file.string() + "\n";
ss += "\tread_file2:\t\t\t" + read_file2.string() + "\n";
ss += "\tdb:\t\t\t\t" + db + "\n\n";

ss += "\tcategory_to_extract:\t" + category_to_extract + "\n";
ss += "\textract_file:\t\t" + extract_file.string() + "\n\n";
ss += "\textract_file2:\t\t" + extract_file2.string() + "\n\n";
ss += "\tcategory_to_extract:\t\t" + category_to_extract + "\n";
ss += "\tprefix:\t\t\t\t" + prefix + "\n\n";

ss += "\tchunk_size:\t\t" + std::to_string(chunk_size) + "\n\n";
ss += "\tchunk_size:\t\t\t" + std::to_string(chunk_size) + "\n\n";
ss += "\tlo_hi_threshold:\t\t" + std::to_string(lo_hi_threshold) + "\n";
ss += "\tnum_reads_to_fit:\t" + std::to_string(num_reads_to_fit) + "\n";
ss += "\tdist:\t\t" + dist + "\n\n";
ss += "\tnum_reads_to_fit:\t\t" + std::to_string(num_reads_to_fit) + "\n";
ss += "\tdist:\t\t\t\t" + dist + "\n\n";

ss += "\tmin_length:\t\t" + std::to_string(min_length) + "\n";
ss += "\tmin_quality:\t\t" + std::to_string(min_quality) + "\n";
ss += "\tmin_length:\t\t\t" + std::to_string(min_length) + "\n";
ss += "\tmin_quality:\t\t\t" + std::to_string(min_quality) + "\n";
ss += "\tmin_compression:\t\t" + std::to_string(min_compression) + "\n";
ss += "\tconfidence_threshold:\t" + std::to_string(confidence_threshold) + "\n";
ss += "\tconfidence_threshold:\t\t" + std::to_string(confidence_threshold) + "\n";
ss += "\thost_unique_prop_lo_threshold:\t" + std::to_string(host_unique_prop_lo_threshold) + "\n";
ss += "\tmin_proportion_difference:\t\t" + std::to_string(min_proportion_difference) + "\n\n";
ss += "\tmin_proportion_difference:\t" + std::to_string(min_proportion_difference) + "\n";
ss += "\tmin_prob_difference:\t\t" + std::to_string(min_prob_difference) + "\n\n";



ss += "\tlog_file:\t\t" + log_file + "\n";
ss += "\tthreads:\t\t" + std::to_string(threads) + "\n";
ss += "\tverbosity:\t\t" + std::to_string(verbosity) + "\n\n";
ss += "\tlog_file:\t\t\t" + log_file + "\n";
ss += "\tthreads:\t\t\t" + std::to_string(threads) + "\n";
ss += "\tverbosity:\t\t\t" + std::to_string(verbosity) + "\n\n";

return ss;
}
Expand Down
9 changes: 9 additions & 0 deletions include/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ class Index
return index;
}

uint8_t get_category_index(const std::string category) const
{
const auto index = summary_.category_index(category);
if (index == std::numeric_limits<uint8_t>::max())
PLOG_ERROR << "Index does not contain category ";
assert(index < std::numeric_limits<uint8_t>::max());
return index;
}

double max_fpr() const
{
return max_fpr_;
Expand Down
72 changes: 33 additions & 39 deletions include/result.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ class Result
std::vector<ReadRecord<record_type>> cached_reads_;

bool run_extract_;
uint8_t extract_category_;
seqan3::sequence_file_output<outfile_field_ids, outfile_format> extract_handle_;
seqan3::sequence_file_output<outfile_field_ids, outfile_format> extract_handle2_;
std::unordered_map<uint8_t, std::vector<seqan3::sequence_file_output<outfile_field_ids, outfile_format>>> extract_handles_;

public:
Result() = default;
Expand All @@ -59,39 +57,33 @@ class Result
Result(const ClassifyArguments& opt, const InputSummary & summary):
input_summary_{summary},
result_summary_(summary.num_categories()),
run_extract_(opt.run_extract),
extract_handle_{std::cout, seqan3::format_fasta{}},
extract_handle2_{std::cout, seqan3::format_fasta{}}
run_extract_(opt.run_extract)
{
stats_model_ = StatsModel(opt, summary);
if (opt.run_extract){
extract_handle_ = seqan3::sequence_file_output{opt.extract_file};
if (opt.is_paired)
extract_handle2_ = seqan3::sequence_file_output{opt.extract_file2};

extract_category_ = category_index(opt.category_to_extract);
for (const auto [category_index,extract_files] : opt.extract_category_to_file) {
for (const auto extract_file : extract_files)
extract_handles_[category_index].push_back(seqan3::sequence_file_output{extract_file});
cached_reads_.reserve(opt.num_reads_to_fit*summary.num_categories()*4);
}
}
cached_reads_.reserve(opt.num_reads_to_fit*summary.num_categories()*4);


};

Result(const DehostArguments& opt, const InputSummary & summary):
input_summary_{summary},
result_summary_(summary.num_categories()),
run_extract_(opt.run_extract),
extract_handle_{std::cout, seqan3::format_fasta{}},
extract_handle2_{std::cout, seqan3::format_fasta{}}
run_extract_(opt.run_extract)
{
stats_model_ = StatsModel(opt, summary);
if (opt.run_extract){
extract_handle_ = seqan3::sequence_file_output{opt.extract_file};
if (opt.is_paired)
extract_handle2_ = seqan3::sequence_file_output{opt.extract_file2};

extract_category_ = category_index(opt.category_to_extract);
for (const auto [category_index,extract_files] : opt.extract_category_to_file) {
for (const auto extract_file : extract_files)
extract_handles_[category_index].push_back(seqan3::sequence_file_output{extract_file});
cached_reads_.reserve(opt.num_reads_to_fit*summary.num_categories()*4);
}
}
cached_reads_.reserve(opt.num_reads_to_fit*summary.num_categories()*4);

};

const InputSummary& input_summary() const
Expand All @@ -104,7 +96,7 @@ class Result
return input_summary_.category_index(category);
}

bool classify_read(ReadEntry& read_entry, const bool dehost=false)
uint8_t classify_read(ReadEntry& read_entry, const bool dehost=false)
{
PLOG_VERBOSE << "Classify read " << read_entry.read_id();
if (dehost)
Expand All @@ -123,28 +115,29 @@ class Result
result_summary_.unclassified_count += 1;
}
}
return run_extract_ and read_entry.call() == extract_category_;
return read_entry.call();

}

void extract_read(const record_type& record)
void extract_read(const uint8_t category_index, const record_type& record)
{
#pragma omp critical(extract_read)
extract_handle_.push_back(record);
extract_handles_[category_index][0].push_back(record);
}

void extract_paired_read(const record_type& record, const record_type& record2)
void extract_paired_read(const uint8_t category_index, const record_type& record, const record_type& record2)
{
#pragma omp critical(extract_read)
extract_handle_.push_back(record);
extract_handles_[category_index][0].push_back(record);
#pragma omp critical(extract_read2)
extract_handle2_.push_back(record2);
extract_handles_[category_index][1].push_back(record2);
}

void add_read(ReadEntry& read_entry, const record_type& record, bool dehost=false){
if (stats_model_.ready()) {
auto read_to_extract = classify_read(read_entry, dehost);
if (run_extract_ and read_to_extract){
extract_read(record);
auto category_index = classify_read(read_entry, dehost);
if (run_extract_ and extract_handles_.find(category_index) != extract_handles_.end()){
extract_read(category_index, record);
}
} else {
PLOG_VERBOSE << "Add read " << read_entry.read_id() << " to training ";
Expand All @@ -168,9 +161,9 @@ class Result

void add_paired_read(ReadEntry& read_entry, const record_type& record, const record_type& record2, const bool dehost=false){
if (stats_model_.ready()) {
auto read_to_extract = classify_read(read_entry, dehost);
if (run_extract_ and read_to_extract){
extract_paired_read(record, record2);
auto category_index = classify_read(read_entry, dehost);
if (run_extract_ and extract_handles_.find(category_index) != extract_handles_.end()){
extract_paired_read(category_index, record, record2);
}
} else {
PLOG_VERBOSE << "Add read " << read_entry.read_id() << " to training ";
Expand Down Expand Up @@ -200,14 +193,15 @@ class Result
{
auto & read_entry = read_record.read;
const auto & record = read_record.record;
bool read_to_extract = classify_read(read_entry, dehost);
if (run_extract_ and read_to_extract){
auto category_index = classify_read(read_entry, dehost);
std::cout << +category_index << (category_index < std::numeric_limits<uint8_t>::max()) << std::endl;
if (run_extract_ and extract_handles_.find(category_index) != extract_handles_.end()){
if (read_record.is_paired)
{
const auto & record2 = read_record.record2;
extract_paired_read(record, record2);
extract_paired_read(category_index, record, record2);
} else {
extract_read(record);
extract_read(category_index, record);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ size_t max_num_hashes_for_fpr(const IndexArguments & opt);
std::string sequence_to_string(const __type_pack_element<0, std::vector<seqan3::dna5>, std::string, std::vector<seqan3::phred94>>& input);

float get_compression_ratio(const std::string& sequence);

std::string get_extension(const std::filesystem::path);
#endif
41 changes: 23 additions & 18 deletions src/classify_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,10 @@ void setup_classify_subcommand(CLI::App& app)
classify_subcommand->add_option("-e,--extract", opt->category_to_extract, "Reads from this category in the index will be extracted to file.")
->type_name("STRING");

classify_subcommand->add_option("--extract_file", opt->extract_file, "Fasta/q file for output")
->transform(make_absolute)
->check(CLI::NonexistentPath.description(""))
->type_name("FILE");

classify_subcommand->add_option("--extract_file2", opt->extract_file2, "Fasta/q file for output")
->transform(make_absolute)
classify_subcommand->add_option("-p,--prefix", opt->prefix, "Prefix for the output files.")
->type_name("FILE")
->check(CLI::NonexistentPath.description(""))
->type_name("FILE");
->default_str("<prefix>");

classify_subcommand->add_option("-d,--dist", opt->dist, "Probability distribution to use for modelling.")
->type_name("STRING");
Expand Down Expand Up @@ -293,21 +288,31 @@ int classify_main(ClassifyArguments & opt)

opt.run_extract = (opt.category_to_extract != "");
const auto categories = index.categories();
if (opt.run_extract and std::find(categories.begin(), categories.end(), opt.category_to_extract) == categories.end())
if (opt.run_extract and opt.category_to_extract != "all" and std::find(categories.begin(), categories.end(), opt.category_to_extract) == categories.end())
{
std::string options = "";
for (auto i: categories)
options += i + " ";
PLOG_ERROR << "Cannot extract " << opt.category_to_extract << ", please chose one of [ " << options << "]";
PLOG_ERROR << "Cannot extract " << opt.category_to_extract << ", please chose one of [ all " << options << "]";
return 1;
} else if (opt.run_extract and opt.extract_file == ""){
opt.extract_file = opt.read_file;
if (opt.is_paired){
opt.extract_file.replace_extension(opt.category_to_extract + opt.read_file.extension().string());
opt.extract_file2 = opt.read_file2;
opt.extract_file2.replace_extension(opt.category_to_extract + opt.read_file.extension().string());
} else
opt.extract_file.replace_extension(opt.category_to_extract + opt.read_file.extension().string());
} else if (opt.run_extract){
if (opt.prefix == "")
opt.prefix = "charon";
std::vector<std::string> to_extract;
if (opt.category_to_extract == "all")
to_extract = categories;
else
to_extract.push_back(opt.category_to_extract);
const auto extension = get_extension(opt.read_file);
for (const auto & category : to_extract){
const auto category_index = index.get_category_index(category);
if (opt.is_paired) {
opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + "_1" + extension + ".gz");
opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + "_2" + extension + ".gz");
} else {
opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + extension + ".gz");
}
}
}

if (opt.dist != "gamma" and opt.dist != "beta")
Expand Down
42 changes: 24 additions & 18 deletions src/dehost_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,10 @@ void setup_dehost_subcommand(CLI::App& app)
dehost_subcommand->add_option("-e,--extract", opt->category_to_extract, "Reads from this category in the index will be extracted to file.")
->type_name("STRING");

dehost_subcommand->add_option("--extract_file", opt->extract_file, "Fasta/q file for output")
->transform(make_absolute)
->check(CLI::NonexistentPath.description(""))
->type_name("FILE");

dehost_subcommand->add_option("--extract_file2", opt->extract_file2, "Fasta/q file for output")
->transform(make_absolute)
dehost_subcommand->add_option("-p,--prefix", opt->prefix, "Prefix for the output files.")
->type_name("FILE")
->check(CLI::NonexistentPath.description(""))
->type_name("FILE");
->default_str("<prefix>");

dehost_subcommand
->add_option("--chunk_size", opt->chunk_size, "Read file is read in chunks of this size, to be processed in parallel within a chunk.")
Expand Down Expand Up @@ -502,21 +497,32 @@ int dehost_main(DehostArguments & opt)

opt.run_extract = (opt.category_to_extract != "");
const auto categories = index.categories();
if (opt.run_extract and std::find(categories.begin(), categories.end(), opt.category_to_extract) == categories.end())
if (opt.run_extract and opt.category_to_extract != "all" and std::find(categories.begin(), categories.end(), opt.category_to_extract) == categories.end())
{
std::string options = "";
for (auto i: categories)
options += i + " ";
PLOG_ERROR << "Cannot extract " << opt.category_to_extract << ", please chose one of [ " << options << "]";
PLOG_ERROR << "Cannot extract " << opt.category_to_extract << ", please chose one of [ all " << options << "]";
return 1;
} else if (opt.run_extract and opt.extract_file == ""){
opt.extract_file = opt.read_file;
if (opt.is_paired){
opt.extract_file.replace_extension(opt.category_to_extract + opt.read_file.extension().string());
opt.extract_file2 = opt.read_file2;
opt.extract_file2.replace_extension(opt.category_to_extract + opt.read_file.extension().string());
} else
opt.extract_file.replace_extension(opt.category_to_extract + opt.read_file.extension().string());
} else if (opt.run_extract){
if (opt.prefix == "")
opt.prefix = "charon";
std::vector<std::string> to_extract;
if (opt.category_to_extract == "all")
to_extract = categories;
else
to_extract.push_back(opt.category_to_extract);

const auto extension = get_extension(opt.read_file);
for (const auto & category : to_extract){
const auto category_index = index.get_category_index(category);
if (opt.is_paired) {
opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + "_1" + extension + ".gz");
opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + "_2" + extension + ".gz");
} else {
opt.extract_category_to_file[category_index].push_back(opt.prefix + "_" + category + extension + ".gz");
}
}
}

if (opt.dist != "gamma" and opt.dist != "beta" and opt.dist != "kde")
Expand Down
8 changes: 8 additions & 0 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,11 @@ float get_compression_ratio(const std::string& sequence){
return static_cast<double>(compressed_size)/static_cast<double>(initial_size);
}

std::string get_extension(const std::filesystem::path read_file){
auto ext = read_file.extension().string();
if (ext == ".gz"){
std::filesystem::path short_read_file = read_file.stem();
ext = short_read_file.extension().string();
}
return ext;
}