Skip to content
165 changes: 66 additions & 99 deletions src/assembly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::{

use bit_set::BitSet;
use clap::ValueEnum;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

use crate::{
bounds::{bound_exceeded, Bound},
Expand All @@ -37,6 +37,7 @@ use crate::{
kernels::KernelMode,
molecule::Molecule,
utils::connected_components_under_edges,
reductions::CompatGraph,
};

/// Parallelization strategy for the search phase.
Expand Down Expand Up @@ -190,21 +191,32 @@ fn fractures(
#[allow(clippy::too_many_arguments)]
fn recurse_index_search_serial(
mol: &Molecule,
matches: &[(BitSet, BitSet)],
fragments: &[BitSet],
state_index: usize,
mut best_index: usize,
largest_remove: usize,
best_index: Arc<AtomicUsize>,
bounds: &[Bound],
matches_graph: &CompatGraph,
subgraph: BitSet,
) -> (usize, usize) {
let largest_remove = {
if let Some(v) = subgraph.iter().next() {
matches_graph.weight(v) + 1
}
else {
return (state_index, 1);
}
};

// If any bounds are exceeded, halt this search branch.
if bound_exceeded(
mol,
fragments,
state_index,
best_index,
best_index.load(Relaxed),
largest_remove,
bounds,
matches_graph,
&subgraph,
) {
return (state_index, 1);
}
Expand All @@ -217,30 +229,32 @@ fn recurse_index_search_serial(
// For every pair of duplicatable subgraphs compatible with the current set
// of fragments, recurse using the fragments obtained by removing this pair
// and adding one subgraph back.
for (i, (h1, h2)) in matches.iter().enumerate() {
for v in subgraph.iter() {
let (h1, h2) = matches_graph.matches(v);
if let Some(fractures) = fractures(mol, fragments, h1, h2) {
let subgraph_clone = matches_graph.forward_neighbors(v, &subgraph);

// Recurse using the remaining matches and updated fragments.
let (child_index, child_states_searched) = recurse_index_search_serial(
mol,
&matches[i + 1..],
&fractures,
state_index - h1.len() + 1,
best_index,
h1.len(),
best_index.clone(),
bounds,
matches_graph,
subgraph_clone,
);

// Update the best assembly indices (across children states and
// the entire search) and the number of descendant states searched.
best_child_index = best_child_index.min(child_index);
best_index = best_index.min(best_child_index);
best_index.fetch_min(best_child_index, Relaxed);
states_searched += child_states_searched;
}
}

(best_child_index, states_searched)
}

/// Recursive helper for the depth-one parallel version of index_search.
///
/// Inputs:
Expand All @@ -256,29 +270,33 @@ fn recurse_index_search_serial(
#[allow(clippy::too_many_arguments)]
fn recurse_index_search_depthone(
mol: &Molecule,
matches: &[(BitSet, BitSet)],
fragments: &[BitSet],
state_index: usize,
best_index: Arc<AtomicUsize>,
bounds: &[Bound],
matches_graph: &CompatGraph,
subgraph: BitSet,
) -> (usize, usize) {
// Keep track of the number of states searched, including this one.
let states_searched = Arc::new(AtomicUsize::from(1));

// For every pair of duplicatable subgraphs compatible with the current set
// of fragments, recurse using the fragments obtained by removing this pair
// and adding one subgraph back.
matches.par_iter().enumerate().for_each(|(i, (h1, h2))| {
subgraph.iter().collect::<Vec<usize>>().par_iter().for_each(|v| {
let (h1, h2) = matches_graph.matches(*v);
if let Some(fractures) = fractures(mol, fragments, h1, h2) {
let subgraph_clone = matches_graph.forward_neighbors(*v, &subgraph);

// Recurse using the remaining matches and updated fragments.
let (child_index, child_states_searched) = recurse_index_search_depthone_helper(
let (child_index, child_states_searched) = recurse_index_search_serial(
mol,
&matches[i + 1..],
&fractures,
state_index - h1.len() + 1,
best_index.clone(),
h1.len(),
bounds,
matches_graph,
subgraph_clone,
);

// Update the best assembly indices (across children states and
Expand All @@ -291,74 +309,6 @@ fn recurse_index_search_depthone(
(best_index.load(Relaxed), states_searched.load(Relaxed))
}

/// Recursive helper for the depth-one parallel version of index_search.
///
/// Inputs:
/// - `mol`: The molecule whose assembly index is being calculated.
/// - `matches`: The remaining non-overlapping isomorphic subgraph pairs.
/// - `fragments`: TODO
/// - `state_index`: The assembly index of this assembly state.
/// - `best_index`: The smallest assembly index for all assembly states so far.
/// - `largest_remove`: An upper bound on the size of fragments that can be
/// removed from this or any descendant state.
/// - `bounds`: The list of bounding strategies to apply.
///
/// Returns, from this assembly state and any of its descendents:
/// - `usize`: The best assembly index found.
/// - `usize`: The number of assembly states searched.
#[allow(clippy::too_many_arguments)]
fn recurse_index_search_depthone_helper(
mol: &Molecule,
matches: &[(BitSet, BitSet)],
fragments: &[BitSet],
state_index: usize,
best_index: Arc<AtomicUsize>,
largest_remove: usize,
bounds: &[Bound],
) -> (usize, usize) {
// If any bounds are exceeded, halt this search branch.
if bound_exceeded(
mol,
fragments,
state_index,
best_index.load(Relaxed),
largest_remove,
bounds,
) {
return (state_index, 1);
}

// Keep track of the best assembly index found in any of this assembly
// state's children and the number of states searched, including this one.
let mut best_child_index = state_index;
let mut states_searched = 1;

// For every pair of duplicatable subgraphs compatible with the current set
// of fragments, recurse using the fragments obtained by removing this pair
// and adding one subgraph back.
for (i, (h1, h2)) in matches.iter().enumerate() {
if let Some(fractures) = fractures(mol, fragments, h1, h2) {
// Recurse using the remaining matches and updated fragments.
let (child_index, child_states_searched) = recurse_index_search_depthone_helper(
mol,
&matches[i + 1..],
&fractures,
state_index - h1.len() + 1,
best_index.clone(),
h1.len(),
bounds,
);

// Update the best assembly indices (across children states and
// the entire search) and the number of descendant states searched.
best_child_index = best_child_index.min(child_index);
best_index.fetch_min(best_child_index, Relaxed);
states_searched += child_states_searched;
}
}

(best_child_index, states_searched)
}

/// Recursive helper for the parallel version of index_search.
///
Expand All @@ -377,12 +327,13 @@ fn recurse_index_search_depthone_helper(
#[allow(clippy::too_many_arguments)]
fn recurse_index_search_parallel(
mol: &Molecule,
matches: &[(BitSet, BitSet)],
fragments: &[BitSet],
state_index: usize,
best_index: Arc<AtomicUsize>,
largest_remove: usize,
bounds: &[Bound],
matches_graph: &CompatGraph,
subgraph: BitSet,
) -> (usize, usize) {
// If any bounds are exceeded, halt this search branch.
if bound_exceeded(
Expand All @@ -392,6 +343,8 @@ fn recurse_index_search_parallel(
best_index.load(Relaxed),
largest_remove,
bounds,
matches_graph,
&subgraph,
) {
return (state_index, 1);
}
Expand All @@ -404,17 +357,21 @@ fn recurse_index_search_parallel(
// For every pair of duplicatable subgraphs compatible with the current set
// of fragments, recurse using the fragments obtained by removing this pair
// and adding one subgraph back.
matches.par_iter().enumerate().for_each(|(i, (h1, h2))| {
subgraph.iter().collect::<Vec<usize>>().par_iter().for_each(|v| {
let (h1, h2) = matches_graph.matches(*v);
if let Some(fractures) = fractures(mol, fragments, h1, h2) {
let subgraph_clone = matches_graph.forward_neighbors(*v, &subgraph);

// Recurse using the remaining matches and updated fragments.
let (child_index, child_states_searched) = recurse_index_search_parallel(
mol,
&matches[i + 1..],
&fractures,
state_index - h1.len() + 1,
best_index.clone(),
h1.len(),
bounds,
matches_graph,
subgraph_clone,
);

// Update the best assembly indices (across children states and
Expand Down Expand Up @@ -509,6 +466,11 @@ pub fn index_search(

// Enumerate non-overlapping isomorphic subgraph pairs.
let matches = matches(mol, enumerate_mode, canonize_mode).collect::<Vec<_>>();
let matches_graph = CompatGraph::new(matches);
let mut subgraph = BitSet::with_capacity(matches_graph.len());
for i in 0..matches_graph.len() {
subgraph.insert(i);
}

// Initialize the first fragment as the entire graph.
let mut init = BitSet::new();
Expand All @@ -518,41 +480,46 @@ pub fn index_search(

// Search for the shortest assembly pathway recursively.
let (index, states_searched) = match parallel_mode {
ParallelMode::None => recurse_index_search_serial(
mol,
&matches,
&[init],
edge_count - 1,
edge_count - 1,
edge_count,
bounds,
),
ParallelMode::None => {
let best_index = Arc::new(AtomicUsize::from(edge_count - 1));
recurse_index_search_serial(
mol,
&[init],
edge_count - 1,
best_index,
bounds,
&matches_graph,
subgraph,
)
}
ParallelMode::DepthOne => {
let best_index = Arc::new(AtomicUsize::from(edge_count - 1));
recurse_index_search_depthone(
mol,
&matches,
&[init],
edge_count - 1,
best_index,
bounds,
&matches_graph,
subgraph,
)
}
ParallelMode::Always => {
let best_index = Arc::new(AtomicUsize::from(edge_count - 1));
recurse_index_search_parallel(
mol,
&matches,
&[init],
edge_count - 1,
best_index,
edge_count,
bounds,
&matches_graph,
subgraph,
)
}
};

(index as u32, matches.len() as u32, states_searched)
(index as u32, matches_graph.len() as u32, states_searched)
}

/// Computes a molecule's assembly index using an efficient default strategy.
Expand Down
Loading
Loading