Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions marc_db/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from marc_db import __version__
from marc_db.db import create_database, get_session, get_marc_db_url
from marc_db.ingest import ingest_from_tsvs
from marc_db.remove import remove_isolate
from marc_db.mock import fill_mock_db


Expand All @@ -13,6 +14,7 @@ def main():
" init \tInitialize a new database.\n"
" mock_db \tFill mock values into an empty db (for testing).\n"
" ingest \tIngest data from TSV files into the database.\n"
" remove \tRemove an isolate and associated records.\n"
)

parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -100,6 +102,25 @@ def main():
yes=args_ingest.yes,
session=get_session(db_url),
)
elif args.command == "remove":
parser_remove = argparse.ArgumentParser(
prog="marc_db remove",
usage="%(prog)s --sample-id SAMPLE_ID",
description="Remove an isolate and its associated data.",
)
parser_remove.add_argument(
"--sample-id", required=True, help="SampleID of the isolate to remove."
)
parser_remove.add_argument(
"--yes", action="store_true", help="Skip confirmation prompt."
)
args_remove = parser_remove.parse_args(remaining)
create_database(db_url)
remove_isolate(
sample_id=args_remove.sample_id,
yes=args_remove.yes,
session=get_session(db_url),
)
else:
parser.print_help()
sys.stderr.write("Unrecognized command.\n")
106 changes: 106 additions & 0 deletions marc_db/remove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Callable, Optional

from sqlalchemy import select
from sqlalchemy.orm import Session

from marc_db.db import get_session
from marc_db.models import (
Aliquot,
Antimicrobial,
Assembly,
AssemblyQC,
Contaminant,
Isolate,
TaxonomicAssignment,
)


def _summarize_isolate(session: Session, sample_id: str) -> dict:
assembly_ids = select(Assembly.id).where(Assembly.isolate_id == sample_id)
return {
"aliquots": session.query(Aliquot)
.filter(Aliquot.isolate_id == sample_id)
.count(),
"assemblies": session.query(Assembly)
.filter(Assembly.isolate_id == sample_id)
.count(),
"assembly_qc": session.query(AssemblyQC)
.filter(AssemblyQC.assembly_id.in_(assembly_ids))
.count(),
"taxonomic_assignments": session.query(TaxonomicAssignment)
.filter(TaxonomicAssignment.assembly_id.in_(assembly_ids))
.count(),
"contaminants": session.query(Contaminant)
.filter(Contaminant.assembly_id.in_(assembly_ids))
.count(),
"antimicrobials": session.query(Antimicrobial)
.filter(Antimicrobial.assembly_id.in_(assembly_ids))
.count(),
}


def remove_isolate(
*,
sample_id: str,
yes: bool = False,
session: Optional[Session] = None,
input_fn: Callable[[str], str] = input,
):
"""Remove a single isolate and its associated records."""

created_session = False
if session is None:
session = get_session()
created_session = True

trans = session.begin_nested() if session.in_transaction() else session.begin()
try:
isolate = session.get(Isolate, sample_id)
if isolate is None:
print(f"No isolate found with SampleID {sample_id}.")
trans.rollback()
return

counts = _summarize_isolate(session, sample_id)

if not yes:
print(f"Isolate {sample_id} will be removed with the following records:")
for label, count in counts.items():
print(f" {label.replace('_', ' ')}: {count}")
answer = input_fn("Proceed with deletion? [y/N]: ").strip().lower()
if answer not in {"y", "yes"}:
trans.rollback()
print("Removal cancelled.")
return

assembly_ids = select(Assembly.id).where(Assembly.isolate_id == sample_id)

session.query(Antimicrobial).filter(
Antimicrobial.assembly_id.in_(assembly_ids)
).delete(synchronize_session=False)
session.query(Contaminant).filter(
Contaminant.assembly_id.in_(assembly_ids)
).delete(synchronize_session=False)
session.query(TaxonomicAssignment).filter(
TaxonomicAssignment.assembly_id.in_(assembly_ids)
).delete(synchronize_session=False)
session.query(AssemblyQC).filter(
AssemblyQC.assembly_id.in_(assembly_ids)
).delete(synchronize_session=False)
session.query(Assembly).filter(
Assembly.isolate_id == sample_id
).delete(synchronize_session=False)
session.query(Aliquot).filter(Aliquot.isolate_id == sample_id).delete(
synchronize_session=False
)
session.query(Isolate).filter(Isolate.sample_id == sample_id).delete(
synchronize_session=False
)

trans.commit()
except Exception:
trans.rollback()
raise
finally:
if created_session:
session.close()
101 changes: 101 additions & 0 deletions tests/test_remove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import pandas as pd
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from marc_db.ingest import ingest_from_tsvs
from marc_db.models import (
Aliquot,
Antimicrobial,
Assembly,
AssemblyQC,
Base,
Isolate,
TaxonomicAssignment,
)
from marc_db.remove import remove_isolate


data_dir = Path(__file__).parent


def _build_session():
engine = create_engine("sqlite:///:memory:")
Session = sessionmaker(bind=engine)
session = Session()
Base.metadata.create_all(engine)
return session, engine


def test_remove_isolate_deletes_associations():
session, engine = _build_session()
isolates_df = pd.read_csv(data_dir / "test_multi_aliquot.tsv", sep="\t")
assemblies_df = pd.read_csv(data_dir / "test_assembly_data.tsv", sep="\t")
tax_df = pd.read_csv(data_dir / "test_taxonomic_assignment.tsv", sep="\t")
amr_df = pd.read_csv(data_dir / "test_amr_data.tsv", sep="\t")

ingest_from_tsvs(
isolates=isolates_df,
assemblies=assemblies_df,
assembly_qcs=assemblies_df,
taxonomic_assignments=tax_df,
antimicrobials=amr_df,
yes=True,
session=session,
)

remove_isolate(sample_id="sample1", yes=True, session=session)

remaining_sample = "sample2"
expected_aliquots = isolates_df.loc[
isolates_df["SampleID"] == remaining_sample
].shape[0]
expected_assemblies = assemblies_df.loc[
assemblies_df["SampleID"] == remaining_sample
].shape[0]
expected_taxonomic = tax_df.loc[tax_df["SampleID"] == remaining_sample].shape[0]
expected_amr = amr_df.loc[amr_df["SampleID"] == remaining_sample].shape[0]

assert session.query(Isolate).count() == 1
assert session.query(Aliquot).count() == expected_aliquots
assert session.query(Assembly).count() == expected_assemblies

remaining_assembly_ids = [
asm.id
for asm in session.query(Assembly).filter(
Assembly.isolate_id == remaining_sample
)
]
assert session.query(AssemblyQC).filter(
AssemblyQC.assembly_id.in_(remaining_assembly_ids)
).count() == expected_assemblies
assert session.query(TaxonomicAssignment).filter(
TaxonomicAssignment.assembly_id.in_(remaining_assembly_ids)
).count() == expected_taxonomic
assert session.query(Antimicrobial).filter(
Antimicrobial.assembly_id.in_(remaining_assembly_ids)
).count() == expected_amr

session.close()
engine.dispose()


def test_remove_isolate_cancelled_does_not_delete(capsys):
session, engine = _build_session()
isolates_df = pd.read_csv(data_dir / "test_multi_aliquot.tsv", sep="\t")

ingest_from_tsvs(isolates=isolates_df, yes=True, session=session)

remove_isolate(
sample_id="sample1",
yes=False,
session=session,
input_fn=lambda _: "n",
)

captured = capsys.readouterr()
assert "Removal cancelled." in captured.out
assert session.query(Isolate).count() == 2

session.close()
engine.dispose()