diff --git a/marc_db/cli.py b/marc_db/cli.py index 852fceb..834fae9 100644 --- a/marc_db/cli.py +++ b/marc_db/cli.py @@ -3,6 +3,7 @@ from marc_db import __version__ from marc_db.db import create_database, get_session, get_marc_db_url from marc_db.ingest import ingest_from_tsvs +from marc_db.remove import remove_isolate from marc_db.mock import fill_mock_db @@ -13,6 +14,7 @@ def main(): " init \tInitialize a new database.\n" " mock_db \tFill mock values into an empty db (for testing).\n" " ingest \tIngest data from TSV files into the database.\n" + " remove \tRemove an isolate and associated records.\n" ) parser = argparse.ArgumentParser( @@ -100,6 +102,25 @@ def main(): yes=args_ingest.yes, session=get_session(db_url), ) + elif args.command == "remove": + parser_remove = argparse.ArgumentParser( + prog="marc_db remove", + usage="%(prog)s --sample-id SAMPLE_ID", + description="Remove an isolate and its associated data.", + ) + parser_remove.add_argument( + "--sample-id", required=True, help="SampleID of the isolate to remove." + ) + parser_remove.add_argument( + "--yes", action="store_true", help="Skip confirmation prompt." + ) + args_remove = parser_remove.parse_args(remaining) + create_database(db_url) + remove_isolate( + sample_id=args_remove.sample_id, + yes=args_remove.yes, + session=get_session(db_url), + ) else: parser.print_help() sys.stderr.write("Unrecognized command.\n") diff --git a/marc_db/remove.py b/marc_db/remove.py new file mode 100644 index 0000000..0b2826e --- /dev/null +++ b/marc_db/remove.py @@ -0,0 +1,106 @@ +from typing import Callable, Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from marc_db.db import get_session +from marc_db.models import ( + Aliquot, + Antimicrobial, + Assembly, + AssemblyQC, + Contaminant, + Isolate, + TaxonomicAssignment, +) + + +def _summarize_isolate(session: Session, sample_id: str) -> dict: + assembly_ids = select(Assembly.id).where(Assembly.isolate_id == sample_id) + return { + "aliquots": session.query(Aliquot) + .filter(Aliquot.isolate_id == sample_id) + .count(), + "assemblies": session.query(Assembly) + .filter(Assembly.isolate_id == sample_id) + .count(), + "assembly_qc": session.query(AssemblyQC) + .filter(AssemblyQC.assembly_id.in_(assembly_ids)) + .count(), + "taxonomic_assignments": session.query(TaxonomicAssignment) + .filter(TaxonomicAssignment.assembly_id.in_(assembly_ids)) + .count(), + "contaminants": session.query(Contaminant) + .filter(Contaminant.assembly_id.in_(assembly_ids)) + .count(), + "antimicrobials": session.query(Antimicrobial) + .filter(Antimicrobial.assembly_id.in_(assembly_ids)) + .count(), + } + + +def remove_isolate( + *, + sample_id: str, + yes: bool = False, + session: Optional[Session] = None, + input_fn: Callable[[str], str] = input, +): + """Remove a single isolate and its associated records.""" + + created_session = False + if session is None: + session = get_session() + created_session = True + + trans = session.begin_nested() if session.in_transaction() else session.begin() + try: + isolate = session.get(Isolate, sample_id) + if isolate is None: + print(f"No isolate found with SampleID {sample_id}.") + trans.rollback() + return + + counts = _summarize_isolate(session, sample_id) + + if not yes: + print(f"Isolate {sample_id} will be removed with the following records:") + for label, count in counts.items(): + print(f" {label.replace('_', ' ')}: {count}") + answer = input_fn("Proceed with deletion? [y/N]: ").strip().lower() + if answer not in {"y", "yes"}: + trans.rollback() + print("Removal cancelled.") + return + + assembly_ids = select(Assembly.id).where(Assembly.isolate_id == sample_id) + + session.query(Antimicrobial).filter( + Antimicrobial.assembly_id.in_(assembly_ids) + ).delete(synchronize_session=False) + session.query(Contaminant).filter( + Contaminant.assembly_id.in_(assembly_ids) + ).delete(synchronize_session=False) + session.query(TaxonomicAssignment).filter( + TaxonomicAssignment.assembly_id.in_(assembly_ids) + ).delete(synchronize_session=False) + session.query(AssemblyQC).filter( + AssemblyQC.assembly_id.in_(assembly_ids) + ).delete(synchronize_session=False) + session.query(Assembly).filter( + Assembly.isolate_id == sample_id + ).delete(synchronize_session=False) + session.query(Aliquot).filter(Aliquot.isolate_id == sample_id).delete( + synchronize_session=False + ) + session.query(Isolate).filter(Isolate.sample_id == sample_id).delete( + synchronize_session=False + ) + + trans.commit() + except Exception: + trans.rollback() + raise + finally: + if created_session: + session.close() diff --git a/tests/test_remove.py b/tests/test_remove.py new file mode 100644 index 0000000..ed17036 --- /dev/null +++ b/tests/test_remove.py @@ -0,0 +1,101 @@ +import pandas as pd +from pathlib import Path +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from marc_db.ingest import ingest_from_tsvs +from marc_db.models import ( + Aliquot, + Antimicrobial, + Assembly, + AssemblyQC, + Base, + Isolate, + TaxonomicAssignment, +) +from marc_db.remove import remove_isolate + + +data_dir = Path(__file__).parent + + +def _build_session(): + engine = create_engine("sqlite:///:memory:") + Session = sessionmaker(bind=engine) + session = Session() + Base.metadata.create_all(engine) + return session, engine + + +def test_remove_isolate_deletes_associations(): + session, engine = _build_session() + isolates_df = pd.read_csv(data_dir / "test_multi_aliquot.tsv", sep="\t") + assemblies_df = pd.read_csv(data_dir / "test_assembly_data.tsv", sep="\t") + tax_df = pd.read_csv(data_dir / "test_taxonomic_assignment.tsv", sep="\t") + amr_df = pd.read_csv(data_dir / "test_amr_data.tsv", sep="\t") + + ingest_from_tsvs( + isolates=isolates_df, + assemblies=assemblies_df, + assembly_qcs=assemblies_df, + taxonomic_assignments=tax_df, + antimicrobials=amr_df, + yes=True, + session=session, + ) + + remove_isolate(sample_id="sample1", yes=True, session=session) + + remaining_sample = "sample2" + expected_aliquots = isolates_df.loc[ + isolates_df["SampleID"] == remaining_sample + ].shape[0] + expected_assemblies = assemblies_df.loc[ + assemblies_df["SampleID"] == remaining_sample + ].shape[0] + expected_taxonomic = tax_df.loc[tax_df["SampleID"] == remaining_sample].shape[0] + expected_amr = amr_df.loc[amr_df["SampleID"] == remaining_sample].shape[0] + + assert session.query(Isolate).count() == 1 + assert session.query(Aliquot).count() == expected_aliquots + assert session.query(Assembly).count() == expected_assemblies + + remaining_assembly_ids = [ + asm.id + for asm in session.query(Assembly).filter( + Assembly.isolate_id == remaining_sample + ) + ] + assert session.query(AssemblyQC).filter( + AssemblyQC.assembly_id.in_(remaining_assembly_ids) + ).count() == expected_assemblies + assert session.query(TaxonomicAssignment).filter( + TaxonomicAssignment.assembly_id.in_(remaining_assembly_ids) + ).count() == expected_taxonomic + assert session.query(Antimicrobial).filter( + Antimicrobial.assembly_id.in_(remaining_assembly_ids) + ).count() == expected_amr + + session.close() + engine.dispose() + + +def test_remove_isolate_cancelled_does_not_delete(capsys): + session, engine = _build_session() + isolates_df = pd.read_csv(data_dir / "test_multi_aliquot.tsv", sep="\t") + + ingest_from_tsvs(isolates=isolates_df, yes=True, session=session) + + remove_isolate( + sample_id="sample1", + yes=False, + session=session, + input_fn=lambda _: "n", + ) + + captured = capsys.readouterr() + assert "Removal cancelled." in captured.out + assert session.query(Isolate).count() == 2 + + session.close() + engine.dispose()