From 9f39c5c845730cdf281d0b9c450c84cfa6137e8f Mon Sep 17 00:00:00 2001 From: xyg123 <33658607+xyg123@users.noreply.github.com> Date: Wed, 14 Jan 2026 13:58:27 +0000 Subject: [PATCH 01/16] feat: e2g interval features for l2g (#1144) * feat: e2g interval features for l2g * chore: pre-commit auto fixes [...] * feat: config for interval features * fix: splat columns in convert_from_long_to_wide * fix: interval start end to integers * feat: overlap intervals using bin candidates * fix: interval start end as integers * fix: binned interval join * fix: trainer _setup redundant parameter * chore: fix docstring in interval feature methods * chore: fix interval feature descriptions * chore: clarify caching strategy * fix: address comments and adjust tests accordingly --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> Co-authored-by: project-defiant --- .../datasets/l2g_features/intervals.md | 13 + .../datasources/intervals/_intervals.md | 8 +- src/gentropy/assets/schemas/intervals.json | 4 +- src/gentropy/common/spark.py | 2 +- src/gentropy/config.py | 7 + .../dataset/l2g_features/intervals.py | 461 ++++++++++++++++++ src/gentropy/datasource/intervals/e2g.py | 4 +- src/gentropy/l2g.py | 11 + src/gentropy/method/l2g/feature_factory.py | 8 +- src/gentropy/method/l2g/trainer.py | 3 +- tests/gentropy/dataset/test_l2g_feature.py | 185 +++++++ 11 files changed, 693 insertions(+), 13 deletions(-) create mode 100644 docs/python_api/datasets/l2g_features/intervals.md create mode 100644 src/gentropy/dataset/l2g_features/intervals.py diff --git a/docs/python_api/datasets/l2g_features/intervals.md b/docs/python_api/datasets/l2g_features/intervals.md new file mode 100644 index 000000000..eec2aa2e7 --- /dev/null +++ b/docs/python_api/datasets/l2g_features/intervals.md @@ -0,0 +1,13 @@ +--- +title: Epigenetic regulatory region features +--- + +## List of features + +::: gentropy.dataset.l2g_features.intervals.E2gMeanFeature +::: gentropy.dataset.l2g_features.intervals.E2gMeanNeighbourhoodFeature + +## Common logic + +::: gentropy.dataset.l2g_features.intervals.e2g_interval_feature_wide_logic +::: gentropy.dataset.l2g_features.intervals.get_or_make_e2g_wide diff --git a/docs/python_api/datasources/intervals/_intervals.md b/docs/python_api/datasources/intervals/_intervals.md index ac1a6a09c..6d47f0758 100644 --- a/docs/python_api/datasources/intervals/_intervals.md +++ b/docs/python_api/datasources/intervals/_intervals.md @@ -8,14 +8,10 @@ In this section, we provide a list of studies that focus on interaction and inte 1. **E2G (Gschwind et al., Nov 2023):** _Title:_ "An encyclopedia of enhancer-gene regulatory interactions in the human genome". - This study comprises of a large, curated compendium of enhancer→gene links built by integrating multiple evidence types (epigenomic signals, 3D contacts, expression correlations, and CRISPR perturbations) across many biosamples. The resource reports confidence/score per enhancer–gene pair and is organised by biosample/cell type. - -DOI: 10.1101/2023.11.09.563812 + This study comprises of a large, curated compendium of enhancer→gene links built by integrating multiple evidence types (epigenomic signals, 3D contacts, expression correlations, and CRISPR perturbations) across many biosamples from ENCODE. The resource reports confidence/score per enhancer–gene pair and is organised by biosample/cell type. DOI: 10.1101/2023.11.09.563812 2. **EPIraction (Nurtdinov et al., Feb 2025):** _Title:_ "EPIraction - an atlas of candidate enhancer-gene interactions in human tissues and cell lines". - This study is a genome-wide atlas of candidate enhancer–gene links inferred primarily from H3K27ac ChIP-seq (enhancer activity) integrated with Hi-C contact probabilities, scored per tissue/cell line—methodologically similar in spirit to ABC-style scoring. - -DOI: 10.1101/2025.02.18.638885 + This study is a genome-wide atlas of candidate enhancer–gene links inferred primarily from H3K27ac ChIP-seq (enhancer activity) integrated with Hi-C contact probabilities, scored per tissue/cell line—methodologically similar in spirit to ABC-style scoring. A UCSC track hub is available. DOI: 10.1101/2025.02.18.638885 For in-depth details on each study, you may refer to the respective publications. diff --git a/src/gentropy/assets/schemas/intervals.json b/src/gentropy/assets/schemas/intervals.json index 45217fab2..5cdc4b71e 100644 --- a/src/gentropy/assets/schemas/intervals.json +++ b/src/gentropy/assets/schemas/intervals.json @@ -10,13 +10,13 @@ "metadata": {}, "name": "start", "nullable": false, - "type": "string" + "type": "integer" }, { "metadata": {}, "name": "end", "nullable": false, - "type": "string" + "type": "integer" }, { "metadata": {}, diff --git a/src/gentropy/common/spark.py b/src/gentropy/common/spark.py index 1b6ac65da..1f005b9ed 100644 --- a/src/gentropy/common/spark.py +++ b/src/gentropy/common/spark.py @@ -93,7 +93,7 @@ def convert_from_long_to_wide( +---+---------+---------+ """ - return df.groupBy(id_vars).pivot(var_name).agg(f.first(value_name)) + return df.groupBy(*id_vars).pivot(var_name).agg(f.first(value_name)) def nullify_empty_array(column: Column) -> Column: diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 95d92b37a..9baacb7e6 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -265,6 +265,9 @@ class LocusToGeneConfig(StepConfig): "vepMaximumNeighbourhood", "vepMean", "vepMeanNeighbourhood", + # intervals + "e2gMean", + "e2gMeanNeighbourhood", # other "geneCount500kb", "proteinGeneCount500kb", @@ -311,6 +314,7 @@ class LocusToGeneFeatureMatrixConfig(StepConfig): colocalisation_path: str | None = None study_index_path: str | None = None target_index_path: str | None = None + intervals_path: str | None = None feature_matrix_path: str = MISSING features_list: list[str] = field( default_factory=lambda: [ @@ -345,6 +349,9 @@ class LocusToGeneFeatureMatrixConfig(StepConfig): "vepMaximumNeighbourhood", "vepMean", "vepMeanNeighbourhood", + # intervals + "e2gMean", + "e2gMeanNeighbourhood", # other "geneCount500kb", "proteinGeneCount500kb", diff --git a/src/gentropy/dataset/l2g_features/intervals.py b/src/gentropy/dataset/l2g_features/intervals.py new file mode 100644 index 000000000..12e0659a6 --- /dev/null +++ b/src/gentropy/dataset/l2g_features/intervals.py @@ -0,0 +1,461 @@ +"""Collection of methods that extract features from the interval datasets.""" + +from __future__ import annotations + +from typing import Any + +import pyspark.sql.functions as f +from pyspark.sql import DataFrame, Window + +from gentropy.common.processing import extract_chromosome, extract_position +from gentropy.common.spark import convert_from_wide_to_long +from gentropy.dataset.intervals import Intervals +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature +from gentropy.dataset.l2g_gold_standard import L2GGoldStandard +from gentropy.dataset.study_locus import StudyLocus + + +def _explode_interval_bins( + iv: DataFrame, + *, + bin_size: int, + max_bins_per_interval: int, +) -> DataFrame: + """Given iv(df): columns [iv_chromosome, start, end, geneId, score]. + + Returns columns with interval bins exploded: [iv_chromosome, start, end, geneId, score, iv_bin] + + Args: + iv (DataFrame): Intervals DataFrame + bin_size (int): Size of bins for the binned overlap + max_bins_per_interval (int): Maximum number of bins to explode per interval + + Returns: + DataFrame: DataFrame with interval bins exploded + """ + start_bin = (f.col("start") / f.lit(bin_size)).cast("long") + end_bin = (f.col("end") / f.lit(bin_size)).cast("long") + n_bins = end_bin - start_bin + f.lit(1) + + df = ( + iv.withColumn("start_bin", start_bin) + .withColumn("end_bin", end_bin) + .withColumn("n_bins", n_bins) + .filter(f.col("n_bins") > 0) + .filter(f.col("n_bins") <= f.lit(max_bins_per_interval)) + .withColumn("bin_seq", f.sequence(f.col("start_bin"), f.col("end_bin"))) + .withColumn("iv_bin", f.explode("bin_seq")) + .drop("bin_seq", "start_bin", "end_bin", "n_bins") + ) + return df + + +def e2g_interval_feature_wide_logic_binned( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + intervals: Intervals, + base_name: str = "e2gMean", + pp_min: float = 0.001, + bin_size: int = 50_000, + max_bins_per_interval: int = 1000, + repartitions_variants: int | None = None, + repartitions_intervals: int | None = None, +) -> DataFrame: + """Computes the feature using a bin accelerated overlap. + + 1) Bin variants: var_bin = floor(position / bin_size) + 2) Explode interval bins: iv_bin across [start_bin, end_bin] with safety cap + 3) Join on (chromosome, bin), then exact position filter + 4) Per variant per gene take max(score); weight by PP; sum to gene per locus + 5) Add neighbourhood ratio within locus + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci + that will be used for annotation + intervals (Intervals): The dataset containing interval information + base_name (str): The base name of the feature + pp_min (float): Minimum posterior probability to consider a variant + bin_size (int): Size of bins for the binned overlap + max_bins_per_interval (int): Maximum number of bins to explode per interval + repartitions_variants (int | None): Number of repartitions for variant side + repartitions_intervals (int | None): Number of repartitions for interval side + + Returns: + DataFrame: a WIDE DF with studyLocusId, geneId, e2gMean, e2gMeanNeighbourhood, neighbourhood is ratio-centred: + e2gMeanNeighbourhood = e2gMean / mean(e2gMean within locus) + """ + sl = study_loci_to_annotate.df.alias("sl") + iv = intervals.df.alias("iv") + + study_loci_exploded = ( + sl.withColumn("variantInLocus", f.explode_outer("locus")) + .withColumn( + "chromosome", + extract_chromosome(f.col("variantInLocus").getField("variantId")), + ) + .withColumn( + "position", + extract_position(f.col("variantInLocus").getField("variantId")).cast("int"), + ) + .withColumn("pp", f.col("variantInLocus.posteriorProbability").cast("double")) + .filter(f.col("pp") >= f.lit(pp_min)) + .select( + f.col("studyLocusId").alias("studyLocusId"), + f.col("chromosome").alias("sl_chromosome"), + f.col("position").alias("position"), + f.col("pp").alias("pp"), + ) + .filter( + f.col("sl_chromosome").isNotNull() + & f.col("position").isNotNull() + & f.col("pp").isNotNull() + ) + .alias("slx") + ) + + # Intervals minimal selection + intervals_filtered = ( + iv.select( + f.col("chromosome").alias("iv_chromosome"), + f.col("start").cast("int").alias("start"), + f.col("end").cast("int").alias("end"), + f.col("geneId").alias("geneId"), + f.col("score").cast("double").alias("score"), + ) + .filter(f.col("score").isNotNull()) + .alias("ivf") + ) + + # Add bins on both sides + slx_binned = study_loci_exploded.withColumn( + "var_bin", (f.col("position") / f.lit(bin_size)).cast("long") + ) + if repartitions_variants: + slx_binned = slx_binned.repartition( + repartitions_variants, "sl_chromosome", "var_bin" + ) + else: + slx_binned = slx_binned.repartition("sl_chromosome", "var_bin") + + ivf_binned = _explode_interval_bins( + intervals_filtered, + bin_size=bin_size, + max_bins_per_interval=max_bins_per_interval, + ) + if repartitions_intervals: + ivf_binned = ivf_binned.repartition( + repartitions_intervals, "iv_chromosome", "iv_bin" + ) + else: + ivf_binned = ivf_binned.repartition("iv_chromosome", "iv_bin") + + # Bin join then exact positional filter + joined = ( + slx_binned.alias("cs") + .join( + ivf_binned.alias("iv"), + on=[ + f.col("cs.sl_chromosome") == f.col("iv.iv_chromosome"), + f.col("cs.var_bin") == f.col("iv.iv_bin"), + ], + how="inner", + ) + .filter( + (f.col("cs.position") >= f.col("iv.start")) + & (f.col("cs.position") <= f.col("iv.end")) + ) + .select( + f.col("cs.studyLocusId").alias("studyLocusId"), + f.col("cs.sl_chromosome").alias("chromosome"), + f.col("cs.position").alias("position"), + f.col("cs.pp").alias("pp"), + f.col("iv.geneId").alias("geneId"), + f.col("iv.score").alias("score"), + ) + ) + + # Per variant per gene max interval score, keep pp + per_variant_gene = joined.groupBy( + "studyLocusId", "chromosome", "position", "geneId" + ).agg( + f.max("score").alias("maxScore"), + f.first("pp", ignorenulls=True).alias("pp"), + ) + + # Weight and aggregate to gene per locus + base_df = ( + per_variant_gene.withColumn( + "weightedIntervalScore", f.col("maxScore") * f.col("pp") + ) + .groupBy("studyLocusId", "geneId") + .agg(f.sum("weightedIntervalScore").alias(base_name)) + ).persist() + + # Neighbourhood ratio within locus, using locus max as the denominator + w = Window.partitionBy("studyLocusId") + with_max = base_df.withColumn("regional_max", f.max(base_name).over(w)) + neigh_ratio = f.when( + f.col("regional_max") != 0, f.col(base_name) / f.col("regional_max") + ).otherwise(f.lit(0.0)) + + wide = with_max.select( + "studyLocusId", + "geneId", + f.col(base_name).alias(base_name), + neigh_ratio.alias(f"{base_name}Neighbourhood"), + ) + return wide + + +def e2g_interval_feature_wide_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + intervals: Intervals, + base_name: str = "e2gMean", + use_binned: bool = True, + pp_min: float = 0.001, + bin_size: int = 50_000, + max_bins_per_interval: int = 200, + repartitions_variants: int | None = None, + repartitions_intervals: int | None = None, +) -> DataFrame: + """Wrapper that defaults to the binned implementation. + + Set use_binned=False to fall back to a plain overlap if ever needed. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci + that will be used for annotation + intervals (Intervals): The dataset containing interval information + base_name (str): The base name of the feature + use_binned (bool): Whether to use the binned overlap logic + pp_min (float): Minimum posterior probability to consider a variant + bin_size (int): Size of bins for the binned overlap + max_bins_per_interval (int): Maximum number of bins to explode per interval + repartitions_variants (int | None): Number of repartitions for variant side + repartitions_intervals (int | None): Number of repartitions for interval side + + Returns: + DataFrame: a WIDE DF with studyLocusId, geneId, e2gMean, e2gMeanNeighbourhood, neighbourhood is ratio-centred: + e2gMeanNeighbourhood = e2gMean / mean(e2gMean within locus) + """ + if use_binned: + return e2g_interval_feature_wide_logic_binned( + study_loci_to_annotate, + intervals=intervals, + base_name=base_name, + pp_min=pp_min, + bin_size=bin_size, + max_bins_per_interval=max_bins_per_interval, + repartitions_variants=repartitions_variants, + repartitions_intervals=repartitions_intervals, + ) + + # Fallback: original plain overlap logic (kept for completeness) + sl = study_loci_to_annotate.df.alias("sl") + iv = intervals.df.alias("iv") + study_loci_exploded = ( + sl.withColumn("variantInLocus", f.explode_outer("locus")) + .withColumn( + "chromosome", + extract_chromosome(f.col("variantInLocus").getField("variantId")), + ) + .withColumn( + "position", + extract_position(f.col("variantInLocus").getField("variantId")).cast("int"), + ) + .withColumn( + "posteriorProbability", + f.col("variantInLocus.posteriorProbability").cast("double"), + ) + .filter(f.col("posteriorProbability") > f.lit(pp_min)) + .select( + f.col("studyLocusId").alias("studyLocusId"), + f.col("chromosome").alias("sl_chromosome"), + f.col("position").alias("position"), + f.col("posteriorProbability").alias("pp"), + ) + .alias("slx") + ) + intervals_filtered = iv.select( + f.col("chromosome").alias("iv_chromosome"), + f.col("start").cast("int").alias("start"), + f.col("end").cast("int").alias("end"), + f.col("geneId").alias("geneId"), + f.col("score").alias("score"), + ).alias("ivf") + + joined = study_loci_exploded.join( + intervals_filtered, + (f.col("slx.sl_chromosome") == f.col("ivf.iv_chromosome")) + & (f.col("position") >= f.col("start")) + & (f.col("position") <= f.col("end")), + "inner", + ).select( + f.col("studyLocusId"), + f.col("slx.sl_chromosome").alias("chromosome"), + f.col("position"), + f.col("pp"), + f.col("geneId"), + f.col("score"), + ) + + per_variant_gene = joined.groupBy( + "studyLocusId", "chromosome", "position", "geneId" + ).agg( + f.max("score").alias("maxScore"), + f.first("pp", ignorenulls=True).alias("pp"), + ) + + base_df = ( + per_variant_gene.withColumn( + "weightedIntervalScore", f.col("maxScore") * f.col("pp") + ) + .groupBy("studyLocusId", "geneId") + .agg(f.sum("weightedIntervalScore").alias(base_name)) + ).persist() + + w = Window.partitionBy("studyLocusId") + with_max = base_df.withColumn("regional_max", f.max(base_name).over(w)) + neigh_ratio = f.when( + f.col("regional_max") != 0, f.col(base_name) / f.col("regional_max") + ).otherwise(f.lit(0.0)) + + wide = with_max.select( + "studyLocusId", + "geneId", + f.col(base_name).alias(base_name), + neigh_ratio.alias(f"{base_name}Neighbourhood"), + ) + return wide + + +def get_or_make_e2g_wide( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + feature_dependency: dict[str, Any], + base_name: str = "e2gMean", + use_binned: bool = True, + pp_min: float = 0.001, + bin_size: int = 50_000, + max_bins_per_interval: int = 200, + repartitions_variants: int | None = None, + repartitions_intervals: int | None = None, +) -> DataFrame: + """Compute or retrieve the e2g wide feature DataFrame with optional binned join settings. + + This method implements a caching registry within the `feature_dependency` dictionary object defined by parent caller. + The method stores the reference to wide e2g dataframe execution plan under specific cache_key, + so subsequent feature factory calls to the E2GFeature.compute() can reference the cached resource instead of recomputing the plan. + + Note: + The caching mechanism acts on the `feature_dependency` dictionary and modifies it in place as of side effect. + + The cache key incorporates parameters that affect output. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci + that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the e2g information + base_name (str): The base name of the feature + use_binned (bool): Whether to use the binned overlap logic + pp_min (float): Minimum posterior probability to consider a variant + bin_size (int): Size of bins for the binned overlap + max_bins_per_interval (int): Maximum number of bins to explode per interval + repartitions_variants (int | None): Number of repartitions for variant side + repartitions_intervals (int | None): Number of repartitions for interval side + + Returns: + DataFrame: Features dataset + """ + cache_key = f"_e2g_wide::{base_name}::binned={use_binned}::ppmin={pp_min}::bin={bin_size}::cap={max_bins_per_interval}" + if cache_key not in feature_dependency: + wide = e2g_interval_feature_wide_logic( + study_loci_to_annotate, + intervals=feature_dependency["intervals"], + base_name=base_name, + use_binned=use_binned, + pp_min=pp_min, + bin_size=bin_size, + max_bins_per_interval=max_bins_per_interval, + repartitions_variants=repartitions_variants, + repartitions_intervals=repartitions_intervals, + ).persist() + feature_dependency[cache_key] = wide + return feature_dependency[cache_key] + + +class E2gMeanFeature(L2GFeature): + """e2gMean feature from E2G intervals.""" + + feature_dependency_type = Intervals + feature_name = "e2gMean" + + @classmethod + def compute( + cls: type[E2gMeanFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> E2gMeanFeature: + """Compute e2gMean feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci + that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the e2g information, expecting intervals + + Returns: + E2gMeanFeature: Computed e2gMean feature. + """ + wide = get_or_make_e2g_wide( + study_loci_to_annotate, + feature_dependency=feature_dependency, + base_name=cls.feature_name, + use_binned=True, + ) + df_long = convert_from_wide_to_long( + wide.select("studyLocusId", "geneId", cls.feature_name), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + value_vars=(cls.feature_name,), + ) + return cls(_df=df_long, _schema=cls.get_schema()) + + +class E2gMeanNeighbourhoodFeature(L2GFeature): + """e2gMeanNeighbourhood feature from E2G intervals.""" + + feature_dependency_type = Intervals + feature_name = "e2gMeanNeighbourhood" + + @classmethod + def compute( + cls: type[E2gMeanNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> E2gMeanNeighbourhoodFeature: + """Compute e2gMeanNeighbourhood feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci + that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the e2g information, expecting intervals + + Returns: + E2gMeanNeighbourhoodFeature: Computed e2gMeanNeighbourhood feature. + """ + wide = get_or_make_e2g_wide( + study_loci_to_annotate, + feature_dependency=feature_dependency, + base_name="e2gMean", + use_binned=True, + ) + df_long = convert_from_wide_to_long( + wide.select("studyLocusId", "geneId", cls.feature_name), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + value_vars=(cls.feature_name,), + ) + return cls(_df=df_long, _schema=cls.get_schema()) diff --git a/src/gentropy/datasource/intervals/e2g.py b/src/gentropy/datasource/intervals/e2g.py index d3419ea60..25052b2ea 100644 --- a/src/gentropy/datasource/intervals/e2g.py +++ b/src/gentropy/datasource/intervals/e2g.py @@ -154,8 +154,8 @@ def parse( _df=( parsed.select( f.col("chromosome"), - f.col("start").cast("string"), - f.col("end").cast("string"), + f.col("start").cast("integer"), + f.col("end").cast("integer"), f.col("geneId"), f.col("biosampleName"), f.col("intervalType"), diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 3fc3394b0..46872a490 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -13,6 +13,7 @@ from gentropy.common.session import Session from gentropy.common.spark import calculate_harmonic_sum from gentropy.dataset.colocalisation import Colocalisation +from gentropy.dataset.intervals import Intervals from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.l2g_prediction import L2GPrediction @@ -39,6 +40,7 @@ def __init__( colocalisation_path: str | None = None, study_index_path: str | None = None, target_index_path: str | None = None, + intervals_path: str | None = None, feature_matrix_path: str, append_null_features: bool = False, ) -> None: @@ -52,6 +54,7 @@ def __init__( colocalisation_path (str | None): Path to the colocalisation dataset study_index_path (str | None): Path to the study index dataset target_index_path (str | None): Path to the target index dataset + intervals_path (str | None): Path to the interval dataset feature_matrix_path (str): Path to the L2G feature matrix output dataset append_null_features (bool): Whether to append null features to the feature matrix. Defaults to False. """ @@ -83,12 +86,20 @@ def __init__( if target_index_path else None ) + + intervals = ( + Intervals.from_parquet(session, intervals_path, recursiveFileLookup=True) + if intervals_path + else None + ) + features_input_loader = L2GFeatureInputLoader( variant_index=variant_index, colocalisation=coloc, study_index=studies, study_locus=credible_set, target_index=target_index, + intervals=intervals, ) fm = credible_set.filter(f.col("studyType") == "gwas").build_feature_matrix( diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index ca1ffb339..7954f87b6 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -43,6 +43,10 @@ VepMeanFeature, VepMeanNeighbourhoodFeature, ) +from gentropy.dataset.l2g_features.intervals import ( + E2gMeanFeature, + E2gMeanNeighbourhoodFeature, +) from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus @@ -127,6 +131,8 @@ class FeatureFactory: "vepMeanNeighbourhood": VepMeanNeighbourhoodFeature, "vepMaximum": VepMaximumFeature, "vepMaximumNeighbourhood": VepMaximumNeighbourhoodFeature, + "e2gMean": E2gMeanFeature, + "e2gMeanNeighbourhood": E2gMeanNeighbourhoodFeature, "geneCount500kb": GeneCountFeature, "proteinGeneCount500kb": ProteinGeneCountFeature, "isProteinCoding": ProteinCodingFeature, @@ -181,7 +187,7 @@ def compute_feature( Args: feature_name (str): name of the feature - features_input_loader (L2GFeatureInputLoader): Object that contais features input. + features_input_loader (L2GFeatureInputLoader): Object that contains features input. Returns: L2GFeature: instantiated feature object diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index bc0128ad1..65726aa2d 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -35,6 +35,7 @@ from matplotlib.axes._axes import Axes from shap._explanation import Explanation from wandb.sdk.wandb_run import Run + import logging @@ -435,7 +436,7 @@ def run_all_folds() -> None: config = dict(sweep_run.config) # Reset wandb setup to ensure clean state - _setup(_reset=True) + _setup() wandb_termlog(f"Sweep URL: {sweep_url}") wandb_termlog(f"Sweep Group URL: {sweep_group_url}") diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index bc2ee2108..b805ec92b 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -22,6 +22,7 @@ from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.target_index import TargetIndex +from gentropy.dataset.intervals import Intervals from gentropy.dataset.l2g_features.colocalisation import ( EQtlColocClppMaximumFeature, EQtlColocClppMaximumNeighbourhoodFeature, @@ -68,6 +69,10 @@ CredibleSetConfidenceFeature, ProteinCodingFeature, ) + +from gentropy.dataset.l2g_features.intervals import ( + e2g_interval_feature_wide_logic, +) from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.variant_index import VariantIndex @@ -705,6 +710,186 @@ def _setup( ) +class TestE2GIntervalFeatures: + """Tests for e2g base + neighbourhood features, single-overlap implementation.""" + + @pytest.fixture(autouse=True) + def _setup(self, spark: SparkSession) -> None: + """Set up test fixtures.""" + self.sample_study_locus = StudyLocus( + _df=spark.createDataFrame( + [ + { + "studyLocusId": "SL_1", + "variantId": "1_1000001_A_C", + "studyId": "study1", + "locus": [ + {"variantId": "1_1000001_A_C", "posteriorProbability": 0.6}, + { + "variantId": "1_1000200_G_T", + "posteriorProbability": 0.0005, + }, # filtered out + ], + "chromosome": "1", + }, + ], + StudyLocus.get_schema(), + ), + _schema=StudyLocus.get_schema(), + ) + + # --- Intervals fixture built to your schema (score: double; resourceScore: array[struct] or None) --- + # Only required/non-nullable fields must be populated; others can be None + intervals_rows = [ + # Two overlapping intervals for SAME variant position and geneA (tests max(score) per variant–gene) + { + "chromosome": "1", + "start": 1000000, + "end": 1000100, + "geneId": "geneA", + "score": 0.8, + "distanceToTss": None, + "resourceScore": None, + "datasourceId": "dummy", + "intervalType": "pchic", + "pmid": None, + "biofeature": None, + "biosampleName": None, + "biosampleId": None, + "studyId": None, + "intervalId": None, + }, + { + "chromosome": "1", + "start": 1000000, + "end": 1000100, + "geneId": "geneA", + "score": 0.5, + "distanceToTss": None, + "resourceScore": None, + "datasourceId": "dummy", + "intervalType": "pchic", + "pmid": None, + "biofeature": None, + "biosampleName": None, + "biosampleId": None, + "studyId": None, + "intervalId": None, + }, + # Another gene overlapping the kept variant position + { + "chromosome": "1", + "start": 1000000, + "end": 1000500, + "geneId": "geneB", + "score": 0.2, + "distanceToTss": None, + "resourceScore": None, + "datasourceId": "dummy", + "intervalType": "pchic", + "pmid": None, + "biofeature": None, + "biosampleName": None, + "biosampleId": None, + "studyId": None, + "intervalId": None, + }, + # Interval near low-PP variant (won't contribute because PP threshold filters it before the join) + { + "chromosome": "1", + "start": 1000150, + "end": 1000300, + "geneId": "geneA", + "score": 1.0, + "distanceToTss": None, + "resourceScore": None, + "datasourceId": "dummy", + "intervalType": "pchic", + "pmid": None, + "biofeature": None, + "biosampleName": None, + "biosampleId": None, + "studyId": None, + "intervalId": None, + }, + ] + self.sample_intervals = Intervals( + _df=spark.createDataFrame(intervals_rows, schema=Intervals.get_schema()), + _schema=Intervals.get_schema(), + ) + + def test_e2g_interval_feature_wide_once_base_and_neighbourhood( + self, spark: SparkSession + ) -> None: + """Base e2gMean and ratio neighbourhood computed correctly in one pass.""" + wide = e2g_interval_feature_wide_logic( + self.sample_study_locus, + intervals=self.sample_intervals, + base_name="e2gMean", + ) + + observed = ( + wide.select( + "studyLocusId", + "geneId", + f.round("e2gMean", 4).alias("e2gMean"), + f.round("e2gMeanNeighbourhood", 4).alias("e2gMeanNeighbourhood"), + ) + .orderBy("geneId") + .collect() + ) + + # Calculations: + # kept variant: 1_1000001 with PP=0.6 + # geneA: max(score)=0.8 -> 0.8*0.6 = 0.48 + # geneB: score=0.2 -> 0.2*0.6 = 0.12 + # locus max = 0.48; ratios: geneA=1.0, geneB=0.25 + expected = [ + ("SL_1", "geneA", 0.48, 1), + ("SL_1", "geneB", 0.12, 0.25), + ] + assert [ + (r["studyLocusId"], r["geneId"], r["e2gMean"], r["e2gMeanNeighbourhood"]) + for r in observed + ] == expected + + def test_e2g_interval_feature_wide_once_base_and_neighbourhood_no_bin( + self, spark: SparkSession + ) -> None: + """Base e2gMean and ratio neighbourhood computed correctly in one pass.""" + wide = e2g_interval_feature_wide_logic( + self.sample_study_locus, + intervals=self.sample_intervals, + base_name="e2gMean", + use_binned=False, + ) + + observed = ( + wide.select( + "studyLocusId", + "geneId", + f.round("e2gMean", 4).alias("e2gMean"), + f.round("e2gMeanNeighbourhood", 4).alias("e2gMeanNeighbourhood"), + ) + .orderBy("geneId") + .collect() + ) + + # Calculations: + # kept variant: 1_1000001 with PP=0.6 + # geneA: max(score)=0.8 -> 0.8*0.6 = 0.48 + # geneB: score=0.2 -> 0.2*0.6 = 0.12 + # locus max = 0.48; ratios: geneA=1.0, geneB=0.25 + expected = [ + ("SL_1", "geneA", 0.48, 1), + ("SL_1", "geneB", 0.12, 0.25), + ] + assert [ + (r["studyLocusId"], r["geneId"], r["e2gMean"], r["e2gMeanNeighbourhood"]) + for r in observed + ] == expected + + class TestCommonVepFeatureLogic: """Test the common_vep_feature_logic methods.""" From 8f92be7a06a0006857adde791aa3385534d1da1d Mon Sep 17 00:00:00 2001 From: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:30:08 +0000 Subject: [PATCH 02/16] feat(intervals): quality control (#1186) Co-authored-by: project-defiant --- .vscode/settings.json | 6 + src/gentropy/assets/schemas/contig_index.json | 23 + src/gentropy/assets/schemas/intervals.json | 22 +- src/gentropy/assets/schemas/target_index.json | 8 +- src/gentropy/config.py | 23 +- src/gentropy/dataset/contig_index.py | 65 ++ src/gentropy/dataset/dataset.py | 58 +- src/gentropy/dataset/intervals.py | 583 ++++++++++++++++-- src/gentropy/dataset/study_index.py | 11 +- src/gentropy/dataset/study_locus.py | 13 +- src/gentropy/dataset/summary_statistics.py | 16 +- src/gentropy/dataset/target_index.py | 27 + src/gentropy/datasource/intervals/e2g.py | 117 ++-- .../datasource/intervals/epiraction.py | 126 ++-- src/gentropy/intervals.py | 102 ++- src/gentropy/study_locus_validation.py | 33 +- src/gentropy/study_validation.py | 27 +- .../dataset/test_dataset_exclusion.py | 10 +- tests/gentropy/dataset/test_intervals.py | 92 +++ tests/gentropy/dataset/test_l2g_feature.py | 8 +- .../gentropy/datasource/intervals/test_e2g.py | 3 - tests/gentropy/step/test_interval_step.py | 201 ++++++ tests/gentropy/step/test_study_qc.py | 1 - 23 files changed, 1297 insertions(+), 278 deletions(-) create mode 100644 src/gentropy/assets/schemas/contig_index.json create mode 100644 src/gentropy/dataset/contig_index.py create mode 100644 tests/gentropy/dataset/test_intervals.py create mode 100644 tests/gentropy/step/test_interval_step.py diff --git a/.vscode/settings.json b/.vscode/settings.json index cb6f93e9b..fb6b857de 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,7 +35,9 @@ "bgzip", "biobank", "biosample", + "biosamples", "colocalisation", + "contig", "diffpval", "eqtl", "finngen", @@ -47,6 +49,10 @@ "harmonised", "Harmonises", "Harmonising", + "iend", + "INTRAGENIC", + "istart", + "itype", "liftover", "logpval", "logsum", diff --git a/src/gentropy/assets/schemas/contig_index.json b/src/gentropy/assets/schemas/contig_index.json new file mode 100644 index 000000000..8d855c031 --- /dev/null +++ b/src/gentropy/assets/schemas/contig_index.json @@ -0,0 +1,23 @@ +{ + "fields": [ + { + "metadata": {}, + "name": "id", + "nullable": false, + "type": "string" + }, + { + "metadata": {}, + "name": "start", + "nullable": false, + "type": "long" + }, + { + "metadata": {}, + "name": "end", + "nullable": false, + "type": "long" + } + ], + "type": "struct" +} diff --git a/src/gentropy/assets/schemas/intervals.json b/src/gentropy/assets/schemas/intervals.json index 5cdc4b71e..209171367 100644 --- a/src/gentropy/assets/schemas/intervals.json +++ b/src/gentropy/assets/schemas/intervals.json @@ -10,13 +10,13 @@ "metadata": {}, "name": "start", "nullable": false, - "type": "integer" + "type": "long" }, { "metadata": {}, "name": "end", "nullable": false, - "type": "integer" + "type": "long" }, { "metadata": {}, @@ -92,6 +92,12 @@ "nullable": true, "type": "string" }, + { + "metadata": {}, + "name": "biosampleFromSourceId", + "nullable": true, + "type": "string" + }, { "metadata": {}, "name": "biosampleId", @@ -107,8 +113,18 @@ { "metadata": {}, "name": "intervalId", - "nullable": true, + "nullable": false, "type": "string" + }, + { + "metadata": {}, + "name": "qualityControls", + "type": { + "type": "array", + "elementType": "string", + "containsNull": true + }, + "nullable": true } ], "type": "struct" diff --git a/src/gentropy/assets/schemas/target_index.json b/src/gentropy/assets/schemas/target_index.json index 01535b08c..076312aa0 100644 --- a/src/gentropy/assets/schemas/target_index.json +++ b/src/gentropy/assets/schemas/target_index.json @@ -682,25 +682,25 @@ "metadata": {}, "name": "probeMinerScore", "nullable": true, - "type": "long" + "type": "double" }, { "metadata": {}, "name": "probesDrugsScore", "nullable": true, - "type": "long" + "type": "double" }, { "metadata": {}, "name": "scoreInCells", "nullable": true, - "type": "long" + "type": "double" }, { "metadata": {}, "name": "scoreInOrganisms", "nullable": true, - "type": "long" + "type": "double" }, { "metadata": {}, diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 9baacb7e6..ca0c8fefe 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -515,10 +515,29 @@ class IntervalE2GStepConfig(StepConfig): """Interval E2G step configuration.""" target_index_path: str = MISSING + biosample_mapping_path: str = MISSING + biosample_index_path: str = MISSING + chromosome_contig_index_path: str = MISSING interval_source: str = MISSING - interval_e2g_path: str = MISSING + valid_output_path: str = MISSING + invalid_output_path: str = MISSING + min_valid_score: float = 0.6 + max_valid_score: float = 1.0 + invalid_qc_reasons: list[str] = field( + default_factory=lambda: [ + "UNRESOLVED_TARGET", + "UNKNOWN_BIOSAMPLE", + "SCORE_OUTSIDE_BOUNDS", + "UNKNOWN_INTERVAL_TYPE", + "AMBIGUOUS_SCORE", + "UNKNOWN_PROJECT_ID", + "INVALID_CHROMOSOME", + "INVALID_RANGE", + "AMBIGUOUS_INTERVAL_TYPE", + ] + ) - _target_: str = "gentropy.variant_index.IntervalE2GStep" + _target_: str = "gentropy.intervals.IntervalE2GStep" @dataclass diff --git a/src/gentropy/dataset/contig_index.py b/src/gentropy/dataset/contig_index.py new file mode 100644 index 000000000..22625af87 --- /dev/null +++ b/src/gentropy/dataset/contig_index.py @@ -0,0 +1,65 @@ +"""Contig (chromosome) index.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from pyspark.sql import functions as f + +from gentropy.common.schemas import parse_spark_schema +from gentropy.dataset.dataset import Dataset + +if TYPE_CHECKING: + from pyspark.sql.types import StructType + + +@dataclass +class ContigIndex(Dataset): + """Contig index. + + A contig index captures contiguous data structure with `id`, `start`, `end` fields. + This dataset can represent chromosome bounds, contigs or scaffolds. + + The indexing is expected to be 0-based. + + Examples: + --- + >>> df = spark.createDataFrame([ + ... ("1", 0, 248956422), + ... ("2", 0, 242193529), + ... ("X", 0, 156040895),], + ... schema=["id", "start", "end"]) + >>> contig_index = ContigIndex(_df=df) + >>> contig_index.canonical().df.show() + +---+-----+---------+ + | id|start| end| + +---+-----+---------+ + | 1| 0|248956422| + | 2| 0|242193529| + | X| 0|156040895| + +---+-----+---------+ + + """ + + CANONICAL_CHROMOSOMES = [str(i) for i in range(1, 23)] + ["X", "Y", "MT"] + """Canonical chromosomes""" + + @classmethod + def get_schema(cls: type[ContigIndex]) -> StructType: + """Provide the schema for the ContigIndex dataset. + + Returns: + StructType: The schema of the ContigIndex dataset. + """ + return parse_spark_schema("contig_index.json") + + def canonical(self) -> ContigIndex: + """Get the canonical subpart of the index. + + Returns: + ContigIndex: Filtered by canonical chromosomes. + """ + return ContigIndex( + _df=self.df.filter(f.col("id").isin(self.CANONICAL_CHROMOSOMES)) + ) diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index aef303056..1b6dc66be 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -3,10 +3,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass from enum import Enum from functools import reduce -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, ParamSpec, Self, TypeVar from pyspark.sql import DataFrame from pyspark.sql import functions as f @@ -16,13 +17,35 @@ from gentropy.common.schemas import SchemaValidationError, compare_struct_schemas if TYPE_CHECKING: - from enum import Enum - from pyspark.sql import Column from pyspark.sql.types import StructType from gentropy.common.session import Session +T = TypeVar("T", bound="Dataset") +P = ParamSpec("P") + + +class DatasetValidationResult(NamedTuple, Generic[T]): + """Dataset validation result.""" + + valid: T + invalid: T + + +def qc_test(func: Callable[P, T]) -> Callable[P, T]: + """Decorator to mark methods as quality control tests. + + Args: + func (Callable[P, T]): Function to be decorated. The function should take any parameters and return Dataset derivative. + + Returns: + Callable[P, T]: Decorated function. + """ + setattr(func, "__is_qc_test__", True) + + return func + @dataclass class Dataset(ABC): @@ -200,7 +223,9 @@ def validate_schema(self: Dataset) -> None: f"Schema validation failed for {type(self).__name__}", discrepancies ) - def valid_rows(self: Self, invalid_flags: list[str], invalid: bool = False) -> Self: + def valid_rows( + self: Self, invalid_flags: list[str] + ) -> DatasetValidationResult[Self]: """Filters `Dataset` according to a list of quality control flags. Only `Dataset` classes with a QC column can be validated. This method checks do following steps: @@ -210,10 +235,9 @@ def valid_rows(self: Self, invalid_flags: list[str], invalid: bool = False) -> S Args: invalid_flags (list[str]): List of quality control flags to be excluded. - invalid (bool): If True returns the invalid rows, instead of the valid. Defaults to False. Returns: - Self: filtered dataset. + DatasetValidationResult[Self]: A named tuple with valid and invalid Datasets. Raises: ValueError: If the Dataset does not contain a QC column or if the invalid_flags elements do not exist in QC mappings flags. @@ -243,10 +267,10 @@ def valid_rows(self: Self, invalid_flags: list[str], invalid: bool = False) -> S f.array([f.lit(i) for i in invalid_reasons]), qc ) # Returning the filtered dataset: - if invalid: - return self.filter(~filterCondition) - else: - return self.filter(filterCondition) + return DatasetValidationResult( + valid=self.filter(filterCondition), + invalid=self.filter(~filterCondition), + ) def drop_infinity_values(self: Self, *cols: str) -> Self: """Drop infinity values from Double typed column. @@ -404,3 +428,17 @@ def generate_identifier(uniqueness_defining_columns: list[str]) -> Column: for column in uniqueness_defining_columns ] return f.md5(f.concat(*hashable_columns)) + + @classmethod + def qc_tests(cls: type[Self]) -> list[Callable[..., Self]]: + """Get all quality control test methods defined in the Dataset class. + + Returns: + list[Callable[..., Self]]: List of quality control test methods. + """ + qc_methods = [] + for attribute_name in dir(cls): + attribute = getattr(cls, attribute_name) + if callable(attribute) and getattr(attribute, "__is_qc_test__", False): + qc_methods.append(attribute) + return qc_methods diff --git a/src/gentropy/dataset/intervals.py b/src/gentropy/dataset/intervals.py index 6b919de86..7fb8f7bdb 100644 --- a/src/gentropy/dataset/intervals.py +++ b/src/gentropy/dataset/intervals.py @@ -3,21 +3,76 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum from typing import TYPE_CHECKING +from pyspark.sql import Window +from pyspark.sql import functions as f +from pyspark.sql import types as t + from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.biosample_index import BiosampleIndex -from gentropy.dataset.dataset import Dataset +from gentropy.dataset.contig_index import ContigIndex +from gentropy.dataset.dataset import Dataset, DatasetValidationResult, qc_test from gentropy.dataset.target_index import TargetIndex if TYPE_CHECKING: - from pyspark.sql import DataFrame, SparkSession + from pyspark.sql import Column from pyspark.sql.types import StructType +class IntervalDataSource(str, Enum): + """Enum for interval data sources.""" + + E2G = "E2G" + EPIRACTION = "epiraction" + + +class IntervalQualityCheck(str, Enum): + """Enum for interval quality check reasons.""" + + UNRESOLVED_TARGET = "Target/gene identifier could not match to reference" + UNKNOWN_BIOSAMPLE = "Biosample identifier was not found in the reference" + SCORE_OUTSIDE_BOUNDS = "Score was above or below specified thresholds" + UNKNOWN_INTERVAL_TYPE = "Interval type is not supported" + AMBIGUOUS_SCORE = "Interval has a duplicate with different score" + UNKNOWN_PROJECT_ID = "Project id could not be resolved to any known dataset" + INVALID_CHROMOSOME = "Interval chromosome was not found in contig index" + INVALID_RANGE = "Interval range exceeded chromosome bounds" + AMBIGUOUS_INTERVAL_TYPE = ( + "Multiple interval types for the same (region, geneId) pair" + ) + + +class IntervalType(str, Enum): + """Enum representing interval type.""" + + PROMOTER = "promoter" # Promoter region + ENHANCER = "enhancer" # Enhancer region + INTRAGENIC = "intragenic" # Within gene + INTERGENIC = "intergenic" # Between genes + GENIC = "genic" # Within or near gene + + @dataclass class Intervals(Dataset): - """Intervals dataset links genes to genomic regions based on genome interaction studies.""" + """Intervals dataset links genes to genomic regions based on genome interaction studies. + + Examples: + >>> data = [("1", 100, 200, "ENSG1", "E2G", "promoter", "interval1"),] + >>> schema = "chromosome STRING, start LONG, end LONG, geneId STRING, datasourceId STRING, intervalType STRING, intervalId STRING" + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> intervals = Intervals(_df=df) + >>> intervals.df.show(truncate=False) + +----------+-----+---+------+------------+------------+----------+ + |chromosome|start|end|geneId|datasourceId|intervalType|intervalId| + +----------+-----+---+------+------------+------------+----------+ + |1 |100 |200|ENSG1 |E2G |promoter |interval1 | + +----------+-----+---+------+------------+------------+----------+ + + """ + + id_cols = ["chromosome", "start", "end", "geneId", "studyId", "intervalType"] @classmethod def get_schema(cls: type[Intervals]) -> StructType: @@ -29,48 +84,502 @@ def get_schema(cls: type[Intervals]) -> StructType: return parse_spark_schema("intervals.json") @classmethod - def from_source( - cls: type[Intervals], - spark: SparkSession, - source_name: str, - source_path: str, - target_index: TargetIndex, - biosample_index: BiosampleIndex, - biosample_mapping: DataFrame, + def get_QC_column_name(cls: type[Intervals]) -> str: + """Abstract method to get the QC column name. Assumes None unless overridden by child classes. + + Returns: + str: QC column name. + """ + return "qualityControls" + + @classmethod + def get_QC_mappings(cls: type[Intervals]) -> dict[str, str]: + """Quality control flag to QC column category mappings. + + Returns: + dict[str, str]: Mapping between flag name and QC column category value. + + Examples: + >>> mappings = Intervals.get_QC_mappings() + >>> for key, value in mappings.items(): + ... print(f"{key}: {value}") + UNRESOLVED_TARGET: Target/gene identifier could not match to reference + UNKNOWN_BIOSAMPLE: Biosample identifier was not found in the reference + SCORE_OUTSIDE_BOUNDS: Score was above or below specified thresholds + UNKNOWN_INTERVAL_TYPE: Interval type is not supported + AMBIGUOUS_SCORE: Interval has a duplicate with different score + UNKNOWN_PROJECT_ID: Project id could not be resolved to any known dataset + INVALID_CHROMOSOME: Interval chromosome was not found in contig index + INVALID_RANGE: Interval range exceeded chromosome bounds + AMBIGUOUS_INTERVAL_TYPE: Multiple interval types for the same (region, geneId) pair + + """ + return {member.name: member.value for member in IntervalQualityCheck} + + @staticmethod + def distance_to_tss( + istart: Column, iend: Column, itype: Column, tss: Column + ) -> Column: + """Compute distance from interval to TSS. + + Args: + istart (Column): Interval start position. + iend (Column): Interval end position. + itype (Column): Interval type. + tss (Column): Transcription start site position. + + Returns: + Column: Distance from interval to TSS. + + Examples: + >>> data = [(100, 200, 'enhancer', 150), # tss within interval + ... (300, 400, 'promoter', 350), # promoter type always 0 distance + ... (500, 600, 'enhancer', 400), # tss 100 bp away the istart + ... (700, 800, 'enhancer', None)] # tss is null + >>> df = spark.createDataFrame(data, ['istart', 'iend', 'itype', 'tss']) + >>> df.withColumn('distanceToTss', Intervals.distance_to_tss( + ... f.col('istart'), f.col('iend'), f.col('itype'), f.col('tss')) + ... ).show() + +------+----+--------+----+-------------+ + |istart|iend| itype| tss|distanceToTss| + +------+----+--------+----+-------------+ + | 100| 200|enhancer| 150| 0| + | 300| 400|promoter| 350| 0| + | 500| 600|enhancer| 400| 100| + | 700| 800|enhancer|NULL| NULL| + +------+----+--------+----+-------------+ + + """ + is_promoter = itype == f.lit(IntervalType.PROMOTER.value) + tss_in_interval = (tss >= istart) & (tss <= iend) + + expr = ( + f.when((is_promoter) | (tss_in_interval), f.lit(0)) + .when(tss.isNull(), f.lit(None).cast(t.IntegerType())) + .otherwise(f.least(f.abs(tss - istart), f.abs(tss - iend))) + ) + + return expr.cast(t.IntegerType()).alias("distanceToTss") + + @qc_test + def validate_datasource_id(self: Intervals) -> Intervals: + """Validate datasourceId in the Intervals dataset. + + Returns: + Intervals: Intervals dataset with invalid datasourceId flagged. + + Examples: + >>> data = [("1", 100, 200, "UNKNOWN_ID", "promoter", "interval1"), + ... ("1", 150, 250, "E2G", "enhancer", "interval2"), + ... ("2", 300, 400, "epiraction", "intragenic", "interval3"), + ... ("2", 350, 450, "", "promoter", "interval4")] + >>> schema = "chromosome STRING, start LONG, end LONG, datasourceId STRING, intervalType STRING, intervalId STRING" + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> intervals = Intervals(_df=df) + >>> validated_intervals = intervals.validate_datasource_id() + >>> validated_intervals.df.select("intervalId", "qualityControls").show(truncate=False) + +----------+-------------------------------------------------------+ + |intervalId|qualityControls | + +----------+-------------------------------------------------------+ + |interval1 |[Project id could not be resolved to any known dataset]| + |interval2 |[] | + |interval3 |[] | + |interval4 |[Project id could not be resolved to any known dataset]| + +----------+-------------------------------------------------------+ + + """ + qc_column = self.get_QC_column_name() + if qc_column not in self.df.columns: + self.df = self.df.withColumn( + qc_column, f.array().cast(t.ArrayType(t.StringType())) + ) + valid_df = self.df.withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + ~f.col("datasourceId").isin([ds.value for ds in IntervalDataSource]), + IntervalQualityCheck.UNKNOWN_PROJECT_ID, + ), + ) + return Intervals(_df=valid_df) + + @qc_test + def validate_interval_range( + self: Intervals, contig_index: ContigIndex + ) -> Intervals: + """Validate chromosome labels in the Intervals dataset. + + Args: + contig_index (ContigIndex): Contig index. + + Returns: + Intervals: Intervals dataset with invalid chromosome labels flagged. + + Examples: + >>> contig_data = [("1", 0, 250), + ... ("2", 0, 200)] + >>> contig_schema = "id STRING, start LONG, end LONG" + >>> contig_df = spark.createDataFrame(data=contig_data, schema=contig_schema) + >>> contig_index = ContigIndex(_df=contig_df) + >>> data = [("UNKNOWN_CHR", 100, 200, "E2G", "promoter", "interval1"), + ... ("1", 150, 250, "E2G", "enhancer", "interval2"), + ... ("2", 300, 400, "E2G", "intragenic", "interval3")] + >>> schema = "chromosome STRING, start LONG, end LONG, datasourceId STRING, intervalType STRING, intervalId STRING" + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> intervals = Intervals(_df=df) + >>> validated_intervals = intervals.validate_interval_range(contig_index) + >>> validated_intervals.df.select("intervalId", "qualityControls").show(truncate=False) + +----------+---------------------------------------------------+ + |intervalId|qualityControls | + +----------+---------------------------------------------------+ + |interval1 |[Interval chromosome was not found in contig index]| + |interval2 |[] | + |interval3 |[Interval range exceeded chromosome bounds] | + +----------+---------------------------------------------------+ + + """ + qc_column = self.get_QC_column_name() + if qc_column not in self.df.columns: + self.df = self.df.withColumn( + qc_column, f.array().cast(t.ArrayType(t.StringType())) + ) + chromosomes = f.broadcast( + contig_index.canonical().df.select( + f.col("start").alias("contigStart"), + f.col("end").alias("contigEnd"), + f.col("id").alias("chromosome"), + ) + ) + valid_df = ( + self.df.repartitionByRange("chromosome") + .join(chromosomes, on="chromosome", how="left") + .withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + # The chromosome is not canonical, + # resulting in empty contig bounds after left join + ((f.col("contigStart").isNull()) | (f.col("contigEnd").isNull())), + IntervalQualityCheck.INVALID_CHROMOSOME, + ), + ) + .withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + # interval Range exceeds bounds the contig range + ( + (f.col("start") < f.col("contigStart")) + | (f.col("end") > f.col("contigEnd")) + ), + IntervalQualityCheck.INVALID_RANGE, + ), + ) + .drop("contigStart", "contigEnd") + ) + + return Intervals(_df=valid_df, _schema=Intervals.get_schema()) + + @qc_test + def validate_target(self: Intervals, target_index: TargetIndex) -> Intervals: + """Validate targets in the Intervals dataset. + + Args: + target_index (TargetIndex): Target index. + + Returns: + Intervals: Intervals dataset with invalid targets flagged. + + Examples: + >>> target_data = [("ENSG1",), ("ENSG2",)] + >>> target_schema = "id STRING" + >>> target_df = spark.createDataFrame(data=target_data, schema=target_schema) + >>> target_index = TargetIndex(_df=target_df) + >>> data = [("1", 100, 200, "ENSG1", "E2G", "promoter", "interval1"), + ... ("1", 150, 250, "", "E2G", "enhancer", "interval2"), + ... ("2", 300, 400, "OTHER", "epiraction", "intragenic", "interval3")] + >>> schema = "chromosome STRING, start LONG, end LONG, geneId STRING, datasourceId STRING, intervalType STRING, intervalId STRING" + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> intervals = Intervals(_df=df) + >>> validated_intervals = intervals.validate_target(target_index) + >>> validated_intervals.df.select("intervalId", "qualityControls").show(truncate=False) + +----------+-----------------------------------------------------+ + |intervalId|qualityControls | + +----------+-----------------------------------------------------+ + |interval1 |[] | + |interval2 |[Target/gene identifier could not match to reference]| + |interval3 |[Target/gene identifier could not match to reference]| + +----------+-----------------------------------------------------+ + + """ + qc_column = self.get_QC_column_name() + if qc_column not in self.df.columns: + self.df = self.df.withColumn( + qc_column, f.array().cast(t.ArrayType(t.StringType())) + ) + gene_set = target_index.df.select( + f.col("id").alias("geneId"), f.lit(True).alias("isIdFound") + ) + validated_df = ( + self.df.join(gene_set, on="geneId", how="left") + .withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + f.col("isIdFound").isNull(), + IntervalQualityCheck.UNRESOLVED_TARGET, + ), + ) + .drop("isIdFound") + ) + return Intervals(_df=validated_df, _schema=Intervals.get_schema()) + + @qc_test + def validate_biosample( + self: Intervals, biosample_index: BiosampleIndex ) -> Intervals: - """Collect interval data for a particular source. + """Validate biosamples in the Intervals dataset. Args: - spark (SparkSession): Spark session - source_name (str): Name of the interval source - source_path (str): Path to the interval source file - target_index (TargetIndex): Target index - biosample_index (BiosampleIndex): Biosample index - biosample_mapping (DataFrame): Biosample mapping DataFrame + biosample_index (BiosampleIndex): Biosample index. Returns: - Intervals: Intervals dataset + Intervals: Intervals dataset with invalid biosamples flagged. - Raises: - ValueError: If the source name is not recognised + Examples: + >>> biosample_data = [("BS1", "name1"), ("BS2", "name2")] + >>> biosample_schema = "biosampleId STRING, biosampleName STRING" + >>> biosample_df = spark.createDataFrame(data=biosample_data, schema=biosample_schema) + >>> biosample_index = BiosampleIndex(_df=biosample_df) + >>> data = [("1", 100, 200, "E2G", "promoter", "interval1", "BS1"), + ... ("1", 150, 250, "E2G", "enhancer", "interval2", "UNKNOWN_BS")] + >>> schema = "chromosome STRING, start LONG, end LONG, datasourceId STRING, intervalType STRING, intervalId STRING, biosampleId STRING" + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> intervals = Intervals(_df=df) + >>> validated_intervals = intervals.validate_biosample(biosample_index) + >>> validated_intervals.df.select("intervalId", "qualityControls").show(truncate=False) + +----------+-----------------------------------------------------+ + |intervalId|qualityControls | + +----------+-----------------------------------------------------+ + |interval1 |[] | + |interval2 |[Biosample identifier was not found in the reference]| + +----------+-----------------------------------------------------+ + """ - from gentropy.datasource.intervals.e2g import IntervalsE2G - from gentropy.datasource.intervals.epiraction import IntervalsEpiraction - - if source_name == "e2g": - raw = IntervalsE2G.read(spark, source_path) - return IntervalsE2G.parse( - raw_e2g_df=raw, - biosample_mapping=biosample_mapping, - target_index=target_index, - biosample_index=biosample_index, + qc_column = self.get_QC_column_name() + if qc_column not in self.df.columns: + self.df = self.df.withColumn( + qc_column, f.array().cast(t.ArrayType(t.StringType())) ) + biosample_set = biosample_index.df.select( + f.col("biosampleId"), f.lit(True).alias("isIdFound") + ) + validated_df = ( + self.df.join(biosample_set, on="biosampleId", how="left") + .withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + f.col("isIdFound").isNull(), + IntervalQualityCheck.UNKNOWN_BIOSAMPLE, + ), + ) + .drop("isIdFound") + ) + return Intervals(_df=validated_df, _schema=Intervals.get_schema()) + + @qc_test + def validate_interval_type(self: Intervals) -> Intervals: + """Validate interval types in the Intervals dataset. + + Returns: + Intervals: Intervals dataset with invalid interval types flagged. - if source_name == "epiraction": - raw = IntervalsEpiraction.read(spark, source_path) - return IntervalsEpiraction.parse( - raw_epiraction_df=raw, - target_index=target_index, + Examples: + >>> data = [("1", 100, 200, "ENSG1", "E2G", "promoter", "interval1"), + ... ("1", 150, 250, "ENSG2", "E2G", "enhancer", "interval2"), + ... ("2", 300, 400, "ENSG3", "E2G", "intragenic", "interval3"), + ... ("2", 300, 400, "ENSG3", "E2G", "intergenic", "interval4"), + ... ("2", 400, 500, "ENSG4", "E2G", "other", "interval5"), + ... ("2", 450, 550, "ENSG5", "E2G", "", "interval6")] + >>> schema = "chromosome STRING, start LONG, end LONG, geneId STRING, datasourceId STRING, intervalType STRING, intervalId STRING" + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> intervals = Intervals(_df=df) + >>> validated_intervals = intervals.validate_interval_type() + >>> validated_intervals.df.select("intervalType", "qualityControls").show(truncate=False) + +------------+------------------------------------------------------------+ + |intervalType|qualityControls | + +------------+------------------------------------------------------------+ + |promoter |[] | + |enhancer |[] | + |intragenic |[Multiple interval types for the same (region, geneId) pair]| + |intergenic |[Multiple interval types for the same (region, geneId) pair]| + |other |[Interval type is not supported] | + | |[Interval type is not supported] | + +------------+------------------------------------------------------------+ + + """ + qc_column = self.get_QC_column_name() + if qc_column not in self.df.columns: + self.df = self.df.withColumn( + qc_column, f.array().cast(t.ArrayType(t.StringType())) ) + valid_df = self.df.withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + ~f.col("intervalType").isin( + [interval_type.value for interval_type in IntervalType] + ), + IntervalQualityCheck.UNKNOWN_INTERVAL_TYPE, + ), + ) + + window = Window.partitionBy("chromosome", "start", "end", "geneId") - raise ValueError(f"Unknown interval source: {source_name!r}") + valid_df = valid_df.withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + f.size(f.collect_set("intervalType").over(window)) > 1, + IntervalQualityCheck.AMBIGUOUS_INTERVAL_TYPE, + ), + ) + + return Intervals(_df=valid_df) + + @qc_test + def validate_score( + self: Intervals, min_score: float, max_score: float + ) -> Intervals: + """Validate scores in the Intervals dataset. + + Args: + min_score (float): Minimum acceptable score. + max_score (float): Maximum acceptable score. + + Returns: + Intervals: Intervals dataset with invalid scores flagged. + + Examples: + >>> data = [("1", 100, 200, "E2G", "promoter", 0.5, "interval1"), + ... ("1", 150, 250, "E2G", "enhancer", -1.0, "interval2"), + ... ("2", 300, 400, "E2G", "intragenic", 2.0, "interval3"), + ... ("2", 350, 450, "E2G", "promoter", None, "interval4")] + >>> schema = "chromosome STRING, start LONG, end LONG, datasourceId STRING, intervalType STRING, score DOUBLE, intervalId STRING" + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> intervals = Intervals(_df=df) + >>> validated_intervals = intervals.validate_score(min_score=0.0, max_score=1.0) + >>> validated_intervals.df.select("intervalId", "qualityControls").show(truncate=False) + +----------+-----------------------------------------------+ + |intervalId|qualityControls | + +----------+-----------------------------------------------+ + |interval1 |[] | + |interval2 |[Score was above or below specified thresholds]| + |interval3 |[Score was above or below specified thresholds]| + |interval4 |[Score was above or below specified thresholds]| + +----------+-----------------------------------------------+ + + """ + qc_column = self.get_QC_column_name() + if qc_column not in self.df.columns: + self.df = self.df.withColumn( + qc_column, f.array().cast(t.ArrayType(t.StringType())) + ) + valid_df = self.df.withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + ~f.col("score").between(min_score, max_score) | f.col("score").isNull(), + IntervalQualityCheck.SCORE_OUTSIDE_BOUNDS, + ), + ) + return Intervals(_df=valid_df) + + @qc_test + def validate_id_has_unique_score(self: Intervals) -> Intervals: + """Validate unique (id, score) group. + + The assumption is that the same interval (defined as chromosome, start, end, biosampleId, geneId, studyId, intervalType) should not have different scores. + + Returns: + Intervals: Intervals dataset with ambiguous scores flagged. + + Examples: + >>> data = [("1", 100, 200, "ENSG1", "S1", "BS1", "E2G", "promoter", 0.5, "interval1"), + ... ("1", 100, 200, "ENSG1", "S1", "BS1", "E2G", "promoter", 0.7, "interval2"), + ... ("2", 300, 400, "ENSG2", "S1", "BS2", "E2G", "enhancer", 0.9, "interval3")] + >>> schema = "chromosome STRING, start LONG, end LONG, geneId STRING, studyId STRING, biosampleId STRING, datasourceId STRING, intervalType STRING, score DOUBLE, intervalId STRING" + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> intervals = Intervals(_df=df) + >>> validated_intervals = intervals.validate_id_has_unique_score() + >>> validated_intervals.df.select("intervalId", "qualityControls").show(truncate=False) + +----------+-----------------------------------------------+ + |intervalId|qualityControls | + +----------+-----------------------------------------------+ + |interval1 |[Interval has a duplicate with different score]| + |interval2 |[Interval has a duplicate with different score]| + |interval3 |[] | + +----------+-----------------------------------------------+ + + """ + qc_column = self.get_QC_column_name() + if qc_column not in self.df.columns: + self.df = self.df.withColumn( + qc_column, f.array().cast(t.ArrayType(t.StringType())) + ) + w = Window().partitionBy( + "chromosome", + "start", + "end", + "biosampleId", + "geneId", + "studyId", + "intervalType", + ) + valid_df = self.df.withColumn( + qc_column, + self.update_quality_flag( + f.col(qc_column), + (f.size(f.array_distinct(f.collect_list(f.col("score")).over(w))) > 1), + IntervalQualityCheck.AMBIGUOUS_SCORE, + ), + ) + + return Intervals(_df=valid_df) + + def qc( + self, + contig_index: ContigIndex, + target_index: TargetIndex, + biosample_index: BiosampleIndex, + min_valid_score: float, + max_valid_score: float, + invalid_qc_reasons: list[str] | None = None, + ) -> DatasetValidationResult[Intervals]: + """Perform Quality Control over Intervals dataset. + + Args: + contig_index (ContigIndex): Contig index. + target_index (TargetIndex): Target index. + biosample_index (BiosampleIndex): Biosample index. + min_valid_score (float): Minimum valid score for interval QC. + max_valid_score (float): Maximum valid score for interval QC. + invalid_qc_reasons (list[str] | None): List of invalid quality check reason names from `IntervalQualityCheck` (e.g. ['INVALID_CHROMOSOME']). + + Returns: + DatasetValidationResult[Intervals]: Valid and invalid Intervals datasets. + """ + if invalid_qc_reasons is None: + invalid_qc_reasons = [] + return ( + self.validate_datasource_id() + .validate_interval_range(contig_index) + .validate_target(target_index) + .validate_biosample(biosample_index) + .validate_interval_type() + .validate_score(min_valid_score, max_valid_score) + .validate_id_has_unique_score() + .persist() + .valid_rows(invalid_qc_reasons) + ) diff --git a/src/gentropy/dataset/study_index.py b/src/gentropy/dataset/study_index.py index 386956679..3b3e8ea6e 100644 --- a/src/gentropy/dataset/study_index.py +++ b/src/gentropy/dataset/study_index.py @@ -17,7 +17,7 @@ from gentropy.assets import data from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark import convert_from_wide_to_long, filter_array_struct -from gentropy.dataset.dataset import Dataset +from gentropy.dataset.dataset import Dataset, qc_test if TYPE_CHECKING: from pyspark.sql import Column, DataFrame @@ -275,6 +275,7 @@ def has_summarystats(self: StudyIndex) -> Column: """ return self.df.hasSumstats + @qc_test def validate_unique_study_id(self: StudyIndex) -> StudyIndex: """Validating the uniqueness of study identifiers and flagging duplicated studies. @@ -293,6 +294,7 @@ def validate_unique_study_id(self: StudyIndex) -> StudyIndex: _schema=StudyIndex.get_schema(), ) + @qc_test def validate_project_id( self: StudyIndex, deprecated_project_ids: list[str] ) -> StudyIndex: @@ -358,6 +360,7 @@ def _normalise_disease( ) ) + @qc_test def validate_disease(self: StudyIndex, disease_map: DataFrame) -> StudyIndex: """Validate diseases in the study index dataset. @@ -429,6 +432,7 @@ def validate_disease(self: StudyIndex, disease_map: DataFrame) -> StudyIndex: _schema=StudyIndex.get_schema(), ) + @qc_test def validate_study_type(self: StudyIndex) -> StudyIndex: """Validating study type and flag unsupported types. @@ -453,6 +457,7 @@ def validate_study_type(self: StudyIndex) -> StudyIndex: ) return StudyIndex(_df=validated_df, _schema=StudyIndex.get_schema()) + @qc_test def validate_target(self: StudyIndex, target_index: TargetIndex) -> StudyIndex: """Validating gene identifiers in the study index against the provided target index. @@ -492,6 +497,7 @@ def validate_target(self: StudyIndex, target_index: TargetIndex) -> StudyIndex: return StudyIndex(_df=validated_df, _schema=StudyIndex.get_schema()) + @qc_test def validate_biosample( self: StudyIndex, biosample_index: BiosampleIndex ) -> StudyIndex: @@ -541,6 +547,7 @@ def validate_biosample( return StudyIndex(_df=validated_df, _schema=StudyIndex.get_schema()) + @qc_test def annotate_sumstats_qc( self: StudyIndex, sumstats_qc: SummaryStatisticsQC, @@ -650,6 +657,7 @@ def annotate_sumstats_qc( _schema=StudyIndex.get_schema(), ) + @qc_test def validate_analysis_flags(self: StudyIndex) -> StudyIndex: """Validating analysis flags in the study index dataset. @@ -671,6 +679,7 @@ def validate_analysis_flags(self: StudyIndex) -> StudyIndex: ) return StudyIndex(_df=df, _schema=StudyIndex.get_schema()) + @qc_test def deconvolute_studies(self: StudyIndex) -> StudyIndex: """Deconvolute the study index dataset. diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index 875bee215..9c0d55bfd 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -19,7 +19,7 @@ ) from gentropy.common.stats import get_logsum, neglogpval_from_pvalue from gentropy.config import WindowBasedClumpingStepConfig -from gentropy.dataset.dataset import Dataset +from gentropy.dataset.dataset import Dataset, qc_test from gentropy.dataset.study_index import StudyQualityCheck from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.dataset.variant_index import VariantIndex @@ -156,6 +156,7 @@ class StudyLocus(Dataset): This dataset captures associations between study/traits and a genetic loci as provided by finemapping methods. """ + @qc_test def validate_study(self: StudyLocus, study_index: StudyIndex) -> StudyLocus: """Flagging study loci if the corresponding study has issues. @@ -228,6 +229,7 @@ def validate_study(self: StudyLocus, study_index: StudyIndex) -> StudyLocus: _schema=self.get_schema(), ) + @qc_test def annotate_study_type(self: StudyLocus, study_index: StudyIndex) -> StudyLocus: """Gets study type from study index and adds it to study locus. @@ -246,6 +248,7 @@ def annotate_study_type(self: StudyLocus, study_index: StudyIndex) -> StudyLocus _schema=self.get_schema(), ) + @qc_test def validate_chromosome_label(self: StudyLocus) -> StudyLocus: """Flagging study loci, where chromosome is coded not as 1:22, X, Y, Xy and MT. @@ -274,6 +277,7 @@ def validate_chromosome_label(self: StudyLocus) -> StudyLocus: _schema=self.get_schema(), ) + @qc_test def validate_variant_identifiers( self: StudyLocus, variant_index: VariantIndex ) -> StudyLocus: @@ -333,6 +337,7 @@ def validate_variant_identifiers( _schema=self.get_schema(), ) + @qc_test def validate_lead_pvalue(self: StudyLocus, pvalue_cutoff: float) -> StudyLocus: """Flag associations below significant threshold. @@ -370,6 +375,7 @@ def validate_lead_pvalue(self: StudyLocus, pvalue_cutoff: float) -> StudyLocus: _schema=self.get_schema(), ) + @qc_test def validate_unique_study_locus_id(self: StudyLocus) -> StudyLocus: """Validating the uniqueness of study-locus identifiers and flagging duplicated studyloci. @@ -429,6 +435,7 @@ def _qc_subsignificant_associations( StudyLocusQualityCheck.SUBSIGNIFICANT_FLAG, ) + @qc_test def qc_abnormal_pips( self: StudyLocus, sum_pips_lower_threshold: float = 0.99, @@ -1169,6 +1176,7 @@ def exclude_region( _schema=StudyLocus.get_schema(), ) + @qc_test def qc_MHC_region(self: StudyLocus) -> StudyLocus: """Adds qualityControl flag when lead overlaps with MHC region. @@ -1192,6 +1200,7 @@ def qc_MHC_region(self: StudyLocus) -> StudyLocus: ) return self + @qc_test def qc_redundant_top_hits_from_PICS(self: StudyLocus) -> StudyLocus: """Flag associations from top hits when the study contains other PICS associations from summary statistics. @@ -1230,6 +1239,7 @@ def qc_redundant_top_hits_from_PICS(self: StudyLocus) -> StudyLocus: _schema=StudyLocus.get_schema(), ) + @qc_test def qc_explained_by_SuSiE(self: StudyLocus) -> StudyLocus: """Flag associations that are explained by SuSiE associations. @@ -1308,6 +1318,7 @@ def qc_explained_by_SuSiE(self: StudyLocus) -> StudyLocus: _schema=StudyLocus.get_schema(), ) + @qc_test def _qc_no_population(self: StudyLocus) -> StudyLocus: """Flag associations where the study doesn't have population information to resolve LD. diff --git a/src/gentropy/dataset/summary_statistics.py b/src/gentropy/dataset/summary_statistics.py index 96e1acf59..f348cfc91 100644 --- a/src/gentropy/dataset/summary_statistics.py +++ b/src/gentropy/dataset/summary_statistics.py @@ -76,12 +76,16 @@ def window_based_clumping( """ from gentropy.method.window_based_clumping import WindowBasedClumping - return WindowBasedClumping.clump( - # Before clumping, we filter the summary statistics by p-value: - self.pvalue_filter(gwas_significance), - distance=distance, - # After applying the clumping, we filter the clumped loci by the flag: - ).valid_rows(["WINDOW_CLUMPED"]) + return ( + WindowBasedClumping.clump( + # Before clumping, we filter the summary statistics by p-value: + self.pvalue_filter(gwas_significance), + distance=distance, + # After applying the clumping, we filter the clumped loci by the flag: + ) + .valid_rows(["WINDOW_CLUMPED"]) + .valid + ) def locus_breaker_clumping( self: SummaryStatistics, diff --git a/src/gentropy/dataset/target_index.py b/src/gentropy/dataset/target_index.py index 9be508d0f..517ff7f1d 100644 --- a/src/gentropy/dataset/target_index.py +++ b/src/gentropy/dataset/target_index.py @@ -1,10 +1,12 @@ """Target index dataset.""" + from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING import pyspark.sql.functions as f +from pyspark.sql import types as t from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.dataset import Dataset @@ -74,3 +76,28 @@ def symbols_lut(self: TargetIndex) -> DataFrame: f.col("genomicLocation.chromosome").alias("chromosome"), "tss", ) + + def tss_lut(self: TargetIndex) -> DataFrame: + """Gene TSS lookup table. + + The TSS is determined using the following priority: + 1. preferred TSS from target index + 2. canonical transcript start|end based on strand + 3. genomic location start|end based on strand + + Returns: + DataFrame: Gene LUT for TSS mapping containing `geneId` and `tss` columns. + """ + ct_tss = f.when( + f.col("canonicalTranscript.strand") == "+", + f.col("canonicalTranscript.start"), + ).when( + f.col("canonicalTranscript.strand") == "-", + f.col("canonicalTranscript.end"), + ) + gl_tss = f.when( + f.col("genomicLocation.strand") == 1, f.col("genomicLocation.start") + ).when(f.col("genomicLocation.strand") == -1, f.col("genomicLocation.end")) + + tss = f.coalesce(f.col("tss"), ct_tss, gl_tss).cast(t.LongType()).alias("tss") + return self.df.select(f.col("id").alias("geneId"), tss) diff --git a/src/gentropy/datasource/intervals/e2g.py b/src/gentropy/datasource/intervals/e2g.py index 25052b2ea..b2e0cc19d 100644 --- a/src/gentropy/datasource/intervals/e2g.py +++ b/src/gentropy/datasource/intervals/e2g.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING, ClassVar -import pyspark.sql.functions as f +from pyspark.sql import functions as f +from pyspark.sql import types as t -from gentropy.dataset.biosample_index import BiosampleIndex -from gentropy.dataset.intervals import Intervals +from gentropy.common.processing import normalize_chromosome +from gentropy.dataset.intervals import IntervalDataSource, Intervals from gentropy.dataset.target_index import TargetIndex if TYPE_CHECKING: @@ -17,9 +18,7 @@ class IntervalsE2G: """Interval dataset from E2G.""" - DATASET_NAME: ClassVar[str] = "E2G" PMID: ClassVar[str] = "38014075" # PMID for the E2G paper - VALID_INTERVAL_TYPES: ClassVar[list[str]] = ["promoter", "genic", "intergenic"] @staticmethod def read(spark: SparkSession, path: str) -> DataFrame: @@ -45,7 +44,6 @@ def parse( raw_e2g_df: DataFrame, biosample_mapping: DataFrame, target_index: TargetIndex, - biosample_index: BiosampleIndex, ) -> Intervals: """Parse E2G dataset. @@ -53,20 +51,27 @@ def parse( raw_e2g_df (DataFrame): Raw E2G DataFrame. biosample_mapping (DataFrame): Biosample mapping DataFrame. target_index (TargetIndex): Target index. - biosample_index (BiosampleIndex): Biosample index. Returns: Intervals: Parsed Intervals dataset. """ + if "#chr" in raw_e2g_df.columns: + chr_col = "#chr" + else: + chr_col = "chr" base = ( raw_e2g_df.withColumn( "studyId", f.regexp_extract(f.col("file_path"), r"([^/]+)\.bed\.gz$", 1) ) - .withColumn("chromosome", f.regexp_replace(f.col("chr"), "^chr", "")) + .withColumn("chromosome", normalize_chromosome(f.col(chr_col))) + .withColumn("start", f.col("start").cast("long")) + .withColumn("end", f.col("end").cast("long")) .withColumnRenamed("TargetGeneEnsemblID", "geneId") .withColumnRenamed("CellType", "biosampleName") .withColumnRenamed("Score", "score") + .withColumn("score", f.col("score").cast("double")) .withColumnRenamed("class", "intervalType") + .withColumn("intervalType", f.lower(f.trim(f.col("intervalType")))) .withColumn( "resourceScore", f.array( @@ -82,91 +87,53 @@ def parse( ), ), ) - .withColumn("start", f.col("start").cast("long")) - .withColumn("end", f.col("end").cast("long")) - ) - - base = base.withColumn( - "intervalType", f.lower(f.trim(f.col("intervalType"))) - ).filter(f.col("intervalType").isin(cls.VALID_INTERVAL_TYPES)) - - # Target Index: preferred TSS + fallbacks (canonical transcript, genomicLocation) - ti = target_index._df.select( - f.col("id").alias("geneId"), - f.col("tss").cast("long").alias("tss_primary"), - f.col("canonicalTranscript.chromosome").alias("ct_chr"), - f.col("canonicalTranscript.start").cast("long").alias("ct_start"), - f.col("canonicalTranscript.end").cast("long").alias("ct_end"), - f.col("canonicalTranscript.strand").alias("ct_strand"), - f.col("genomicLocation.chromosome").alias("gl_chr"), - f.col("genomicLocation.start").cast("long").alias("gl_start"), - f.col("genomicLocation.end").cast("long").alias("gl_end"), - f.col("genomicLocation.strand").cast("int").alias("gl_strand"), - ) - - ct_tss = f.when(f.col("ct_strand") == "+", f.col("ct_start")).when( - f.col("ct_strand") == "-", f.col("ct_end") - ) - gl_tss = f.when(f.col("gl_strand") == 1, f.col("gl_start")).when( - f.col("gl_strand") == -1, f.col("gl_end") - ) - - ti_with_tss = ti.withColumn( - "tss_coalesced", f.coalesce(f.col("tss_primary"), ct_tss, gl_tss) - ) - - # Join & recompute distanceToTss - joined = base.alias("iv").join( - ti_with_tss.alias("ti"), on="geneId", how="inner" ) - tss = f.col("ti.tss_coalesced") - dist_core = f.when( - (tss >= f.col("iv.start")) & (tss <= f.col("iv.end")), f.lit(0) - ).otherwise( - f.least(f.abs(tss - f.col("iv.start")), f.abs(tss - f.col("iv.end"))) + tss_lut = target_index.tss_lut() + biosample_lut = biosample_mapping.select("biosampleName", "biosampleId") + # Add distance to TSS from interval bounds + joined = ( + base.alias("iv") + .join(tss_lut.alias("ti"), on="geneId", how="left") + .join(biosample_lut.alias("bl"), on="biosampleName", how="left") ) - distance_expr = ( - f.when(f.col("iv.intervalType") == "promoter", f.lit(0)) - .when(tss.isNull(), f.lit(None).cast("long")) - .otherwise(dist_core) - ) - parsed = ( - joined.withColumn("distanceToTss", distance_expr.cast("double")) - .withColumn( - "intervalId", - f.sha1( - f.concat_ws("_", "chromosome", "start", "end", "geneId", "studyId") + joined.withColumn( + "distanceToTss", + Intervals.distance_to_tss( + f.col("iv.start"), + f.col("iv.end"), + f.col("iv.intervalType"), + f.col("ti.tss"), ), ) - .join( - biosample_mapping.select("biosampleName", "biosampleId"), - on="biosampleName", - how="left", - ) - .join( - biosample_index.df.select("biosampleId"), on="biosampleId", how="inner" + .withColumn( + "intervalId", + Intervals.generate_identifier(Intervals.id_cols), ) + .withColumn("qualityControls", f.array().cast("array")) ) return Intervals( _df=( parsed.select( f.col("chromosome"), - f.col("start").cast("integer"), - f.col("end").cast("integer"), + f.col("start"), + f.col("end"), f.col("geneId"), - f.col("biosampleName"), - f.col("intervalType"), - f.col("distanceToTss").cast("integer"), - f.col("score").cast("double"), + f.col("score"), + f.col("distanceToTss"), f.col("resourceScore"), - f.lit(cls.DATASET_NAME).alias("datasourceId"), + f.lit(IntervalDataSource.E2G.value).alias("datasourceId"), + f.col("intervalType"), f.lit(cls.PMID).alias("pmid"), - f.col("studyId"), + f.lit(None).cast(t.StringType()).alias("biofeature"), + f.col("biosampleName"), + f.lit(None).cast(t.StringType()).alias("biosampleFromSourceId"), f.col("biosampleId"), + f.col("studyId"), f.col("intervalId"), + f.col("qualityControls"), ) ), _schema=Intervals.get_schema(), diff --git a/src/gentropy/datasource/intervals/epiraction.py b/src/gentropy/datasource/intervals/epiraction.py index 0c321a889..920fd86f0 100644 --- a/src/gentropy/datasource/intervals/epiraction.py +++ b/src/gentropy/datasource/intervals/epiraction.py @@ -4,9 +4,11 @@ from typing import TYPE_CHECKING, ClassVar -import pyspark.sql.functions as f +from pyspark.sql import functions as f +from pyspark.sql import types as t -from gentropy.dataset.intervals import Intervals +from gentropy.common.processing import normalize_chromosome +from gentropy.dataset.intervals import IntervalDataSource, Intervals from gentropy.dataset.target_index import TargetIndex if TYPE_CHECKING: @@ -16,9 +18,7 @@ class IntervalsEpiraction: """Interval dataset from EPIraction.""" - DATASET_NAME: ClassVar[str] = "epiraction" PMID: ClassVar[str] = "40027634" - VALID_INTERVAL_TYPES: ClassVar[list[str]] = ["promoter", "enhancer"] @staticmethod def read(spark: SparkSession, path: str) -> DataFrame: @@ -53,13 +53,24 @@ def parse( Returns: Intervals: Parsed Intervals dataset. """ + if "#chr" in raw_epiraction_df.columns: + chr_col = "#chr" + else: + chr_col = "chr" base = ( - raw_epiraction_df.filter(f.col("class").isin(cls.VALID_INTERVAL_TYPES)) - .withColumn("chromosome", f.regexp_replace(f.col("chr"), r"^chr", "")) + raw_epiraction_df.withColumn( + "studyId", + f.regexp_extract(f.input_file_name(), r"([^/]+)\.bed\.gz$", 1), + ) + .withColumn("chromosome", normalize_chromosome(f.col(chr_col))) + .withColumn("start", f.col("start").cast("long")) + .withColumn("end", f.col("end").cast("long")) .withColumnRenamed("TargetGeneEnsemblID", "geneId") .withColumnRenamed("CellType", "biosampleName") .withColumnRenamed("Score", "score") + .withColumn("score", f.col("score").cast("double")) .withColumnRenamed("class", "intervalType") + .withColumn("intervalType", f.lower(f.trim(f.col("intervalType")))) .withColumn( "resourceScore", f.array( @@ -89,88 +100,47 @@ def parse( ), ), ) - .withColumn("start", f.col("start").cast("long")) - .withColumn("end", f.col("end").cast("long")) - .withColumn("intervalType", f.lower(f.trim(f.col("intervalType")))) - ) - - # Target Index: preferred TSS (+ fallbacks) - ti = target_index._df.select( - f.col("id").alias("geneId"), - f.col("tss").cast("long").alias("tss_primary"), - f.col("canonicalTranscript.start").cast("long").alias("ct_start"), - f.col("canonicalTranscript.end").cast("long").alias("ct_end"), - f.col("canonicalTranscript.strand").alias("ct_strand"), - f.col("genomicLocation.start").cast("long").alias("gl_start"), - f.col("genomicLocation.end").cast("long").alias("gl_end"), - f.col("genomicLocation.strand").cast("int").alias("gl_strand"), - ) - - ct_tss = f.when(f.col("ct_strand") == "+", f.col("ct_start")).when( - f.col("ct_strand") == "-", f.col("ct_end") - ) - gl_tss = f.when(f.col("gl_strand") == 1, f.col("gl_start")).when( - f.col("gl_strand") == -1, f.col("gl_end") - ) - - ti_with_tss = ti.withColumn( - "tss_from_target_index", f.coalesce(f.col("tss_primary"), ct_tss, gl_tss) ) + tss_lut = target_index.tss_lut() - has_input_tss = "distanceToTSS" in base.columns - base_with_fallback = ( - base.withColumn("tss_from_input", f.col("distanceToTSS").cast("long")) - if has_input_tss - else base.withColumn("tss_from_input", f.lit(None).cast("long")) - ) - - joined = base_with_fallback.alias("iv").join( - ti_with_tss.alias("ti"), on="geneId", how="inner" - ) - - tss = f.coalesce(f.col("ti.tss_from_target_index"), f.col("iv.tss_from_input")) - - dist_core = f.when( - (tss >= f.col("iv.start")) & (tss <= f.col("iv.end")), f.lit(0) - ).otherwise( - f.least(f.abs(tss - f.col("iv.start")), f.abs(tss - f.col("iv.end"))) - ) - distance_expr = ( - f.when(f.col("iv.intervalType") == "promoter", f.lit(0)) - .when(tss.isNull(), f.lit(None).cast("long")) - .otherwise(dist_core) - ) - - parsed = joined.withColumn( - "distanceToTss", distance_expr.cast("double") - ).withColumn( - "intervalId", - f.sha1( - f.concat_ws( - "_", - f.col("iv.chromosome"), + joined = base.alias("iv").join(tss_lut.alias("ti"), on="geneId", how="left") + parsed = ( + joined.withColumn( + "distanceToTss", + Intervals.distance_to_tss( f.col("iv.start"), f.col("iv.end"), - f.col("iv.geneId"), - f.lit(cls.DATASET_NAME), - ) - ), + f.col("iv.intervalType"), + f.col("ti.tss"), + ), + ) + .withColumn( + "intervalId", + Intervals.generate_identifier(Intervals.id_cols), + ) + .withColumn("qualityControls", f.array().cast("array")) ) return Intervals( _df=( parsed.select( - f.col("iv.chromosome").alias("chromosome"), - f.col("iv.start").cast("string").alias("start"), - f.col("iv.end").cast("string").alias("end"), - f.col("iv.geneId").alias("geneId"), - f.col("iv.biosampleName").alias("biosampleName"), - f.col("iv.intervalType").alias("intervalType"), - f.col("distanceToTss").cast("integer").alias("distanceToTss"), - f.col("iv.score").cast("double").alias("score"), - f.col("iv.resourceScore").alias("resourceScore"), - f.lit(cls.DATASET_NAME).alias("datasourceId"), + f.col("chromosome"), + f.col("start"), + f.col("end"), + f.col("geneId"), + f.col("score"), + f.col("distanceToTss"), + f.col("resourceScore"), + f.lit(IntervalDataSource.EPIRACTION.value).alias("datasourceId"), + f.col("intervalType"), f.lit(cls.PMID).alias("pmid"), + f.lit(None).cast(t.StringType()).alias("biofeature"), + f.col("biosampleName"), + f.lit(None).cast(t.StringType()).alias("biosampleFromSourceId"), + f.lit(None).cast(t.StringType()).alias("biosampleId"), + f.col("studyId"), + f.col("intervalId"), + f.col("qualityControls"), ) ), _schema=Intervals.get_schema(), diff --git a/src/gentropy/intervals.py b/src/gentropy/intervals.py index 03c9b4f3a..a87a4a44e 100644 --- a/src/gentropy/intervals.py +++ b/src/gentropy/intervals.py @@ -4,6 +4,7 @@ from gentropy.common.session import Session from gentropy.dataset.biosample_index import BiosampleIndex +from gentropy.dataset.contig_index import ContigIndex from gentropy.dataset.target_index import TargetIndex from gentropy.datasource.intervals.e2g import IntervalsE2G from gentropy.datasource.intervals.epiraction import IntervalsEpiraction @@ -13,7 +14,6 @@ class IntervalE2GStep: """Interval E2G step. This step generates a dataset that contains interval evidence supporting the functional associations of variants with genes. - """ def __init__( @@ -22,8 +22,13 @@ def __init__( target_index_path: str, biosample_mapping_path: str, biosample_index_path: str, + chromosome_contig_index_path: str, interval_source: str, - interval_e2g_path: str, + valid_output_path: str, + invalid_output_path: str, + min_valid_score: float = 0.6, + max_valid_score: float = 1.0, + invalid_qc_reasons: list[str] | None = None, ) -> None: """Run intervals step. @@ -32,23 +37,44 @@ def __init__( target_index_path (str): Input target index path. biosample_mapping_path (str): Input biosample mapping path. biosample_index_path (str): Input biosample index path. + chromosome_contig_index_path (str): Input chromosome contig index path. interval_source (str): Input intervals source path. - interval_e2g_path (str): Output processed e2g intervals path. + valid_output_path (str): Output valid intervals path. + invalid_output_path (str): Output invalid intervals path. + min_valid_score (float): Minimum valid score for interval QC. + max_valid_score (float): Maximum valid score for interval QC. + invalid_qc_reasons (list[str] | None): List of invalid quality check reason names from `IntervalQualityCheck` (e.g. ['INVALID_CHROMOSOME']). """ - target_index = TargetIndex.from_parquet( - session, - target_index_path, - ).persist() - biosample_mapping = session.spark.read.option("header", "true").csv( - biosample_mapping_path - ) + invalid_qc_reasons = invalid_qc_reasons or [] + + biosample_mapping = session.spark.read.csv(biosample_mapping_path, header=True) + target_index = TargetIndex.from_parquet(session, target_index_path).persist() biosample_index = BiosampleIndex.from_parquet(session, biosample_index_path) + contig_index = ContigIndex.from_parquet(session, chromosome_contig_index_path) data = IntervalsE2G.read(session.spark, interval_source) - interval_e2g = IntervalsE2G.parse( - data, biosample_mapping, target_index, biosample_index + interval_e2g = IntervalsE2G.parse(data, biosample_mapping, target_index) + valid, invalid = interval_e2g.qc( + contig_index=contig_index, + target_index=target_index, + biosample_index=biosample_index, + min_valid_score=min_valid_score, + max_valid_score=max_valid_score, + invalid_qc_reasons=invalid_qc_reasons, + ) + ( + valid.df.repartitionByRange( + session.output_partitions, "chromosome", "start" + ) + .write.mode(session.write_mode) + .parquet(valid_output_path) + ) + ( + invalid.df.repartitionByRange( + session.output_partitions, "chromosome", "start" + ) + .write.mode(session.write_mode) + .parquet(invalid_output_path) ) - - interval_e2g.df.write.mode(session.write_mode).parquet(interval_e2g_path) class IntervalEpiractionStep: @@ -62,24 +88,54 @@ def __init__( self, session: Session, target_index_path: str, + biosample_index_path: str, + chromosome_contig_index_path: str, interval_source: str, - interval_epiraction_path: str, + valid_output_path: str, + invalid_output_path: str, + min_valid_score: float = 0.6, + max_valid_score: float = 1.0, + invalid_qc_reasons: list[str] | None = None, ) -> None: """Run intervals step. Args: session (Session): Session object. target_index_path (str): Input target index path. + biosample_index_path (str): Input biosample index path. + chromosome_contig_index_path (str): Input chromosome contig index path. interval_source (str): Input intervals source path. - interval_epiraction_path (str): Output processed interval epiraction path. + valid_output_path (str): Output valid intervals path. + invalid_output_path (str): Output invalid intervals path. + min_valid_score (float): Minimum valid score for interval QC. + max_valid_score (float): Maximum valid score for interval QC. + invalid_qc_reasons (list[str] | None): List of invalid quality check reason names from `IntervalQualityCheck` (e.g. ['INVALID_CHROMOSOME']). """ - target_index = TargetIndex.from_parquet( - session, - target_index_path, - ).persist() + invalid_qc_reasons = invalid_qc_reasons or [] + target_index = TargetIndex.from_parquet(session, target_index_path).persist() data = IntervalsEpiraction.read(session.spark, interval_source) interval_epiraction = IntervalsEpiraction.parse(data, target_index) - - interval_epiraction.df.write.mode(session.write_mode).parquet( - interval_epiraction_path + biosample_index = BiosampleIndex.from_parquet(session, biosample_index_path) + contig_index = ContigIndex.from_parquet(session, chromosome_contig_index_path) + valid, invalid = interval_epiraction.qc( + contig_index=contig_index, + target_index=target_index, + biosample_index=biosample_index, + min_valid_score=min_valid_score, + max_valid_score=max_valid_score, + invalid_qc_reasons=invalid_qc_reasons, + ) + ( + valid.df.repartitionByRange( + session.output_partitions, "chromosome", "start" + ) + .write.mode(session.write_mode) + .parquet(valid_output_path) + ) + ( + invalid.df.repartitionByRange( + session.output_partitions, "chromosome", "start" + ) + .write.mode(session.write_mode) + .parquet(invalid_output_path) ) diff --git a/src/gentropy/study_locus_validation.py b/src/gentropy/study_locus_validation.py index 31cfa1064..0564bdb1b 100644 --- a/src/gentropy/study_locus_validation.py +++ b/src/gentropy/study_locus_validation.py @@ -57,22 +57,31 @@ def __init__( .filter_credible_set(credible_interval=CredibleInterval.IS95) # Flagging credible sets with PIP > 1 or PIP < 0.95 .qc_abnormal_pips( - sum_pips_lower_threshold=0.95, sum_pips_upper_threshold=1.0001 + sum_pips_lower_threshold=0.95, + sum_pips_upper_threshold=1.0001, ) # Annotate credible set confidence: .assign_confidence() # Flagging trans qtls: .flag_trans_qtls(study_index, target_index, trans_qtl_threshold) - ).persist() # we will need this for 2 types of outputs + .persist() # we will need this for 2 types of outputs + ) - # Valid study locus partitioned to simplify the finding of overlaps - study_locus_with_qc.valid_rows(invalid_qc_reasons).df.repartitionByRange( - session.output_partitions, "chromosome", "position" - ).sortWithinPartitions("chromosome", "position").write.mode( - session.write_mode - ).parquet(valid_study_locus_path) + result = study_locus_with_qc.valid_rows(invalid_qc_reasons) - # Invalid study locus - study_locus_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.coalesce( - session.output_partitions - ).write.mode(session.write_mode).parquet(invalid_study_locus_path) + ( + # Valid study locus partitioned to simplify the finding of overlaps + result.valid.df.repartitionByRange( + session.output_partitions, + "chromosome", + "position", + ) + .sortWithinPartitions("chromosome", "position") + .write.mode(session.write_mode) + .parquet(valid_study_locus_path) + ) + ( + result.invalid.df.coalesce(session.output_partitions) + .write.mode(session.write_mode) + .parquet(invalid_study_locus_path) + ) diff --git a/src/gentropy/study_validation.py b/src/gentropy/study_validation.py index d32937d06..95c70a110 100644 --- a/src/gentropy/study_validation.py +++ b/src/gentropy/study_validation.py @@ -53,7 +53,7 @@ def __init__( target_index = TargetIndex.from_parquet(session, target_index_path) biosample_index = BiosampleIndex.from_parquet(session, biosample_index_path) # Reading disease index and pre-process. - # This logic does not belong anywhere, but gentorpy has no disease dataset yet. + # This logic does not belong anywhere, but gentropy has no disease dataset yet. disease_index = ( session.spark.read.parquet(disease_index_path) .select( @@ -76,16 +76,19 @@ def __init__( .validate_project_id(deprecated_project_ids) # Flag obsolete projectIds .validate_target(target_index) # Flagging QTL studies with invalid targets .validate_disease(disease_index) # Flagging invalid EFOs - .validate_biosample( - biosample_index - ) # Flagging QTL studies with invalid biosamples + .validate_biosample(biosample_index) # Flagging invalid biosample in QTLs .validate_analysis_flags() # Flagging studies with case case design - ).persist() # we will need this for 2 types of outputs - - study_index_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.coalesce( - session.output_partitions - ).write.mode(session.write_mode).parquet(invalid_study_index_path) + .persist() # we will need this for 2 types of outputs + ) - study_index_with_qc.valid_rows(invalid_qc_reasons).df.coalesce( - session.output_partitions - ).write.mode(session.write_mode).parquet(valid_study_index_path) + result = study_index_with_qc.valid_rows(invalid_qc_reasons) + ( + result.valid.df.coalesce(session.output_partitions) + .write.mode(session.write_mode) + .parquet(valid_study_index_path) + ) + ( + result.invalid.df.coalesce(session.output_partitions) + .write.mode(session.write_mode) + .parquet(invalid_study_index_path) + ) diff --git a/tests/gentropy/dataset/test_dataset_exclusion.py b/tests/gentropy/dataset/test_dataset_exclusion.py index 1b6fce967..fb6d66633 100644 --- a/tests/gentropy/dataset/test_dataset_exclusion.py +++ b/tests/gentropy/dataset/test_dataset_exclusion.py @@ -62,9 +62,7 @@ def test_valid_rows( """Test valid rows.""" passing_studies = [ study["studyId"] - for study in self.study_index.valid_rows( - filter_, invalid=False - ).df.collect() + for study in self.study_index.valid_rows(filter_).valid.df.collect() ] assert passing_studies == expected @@ -82,7 +80,7 @@ def test_invalid_rows( """Test invalid rows.""" failing_studies = [ study["studyId"] - for study in self.study_index.valid_rows(filter_, invalid=True).df.collect() + for study in self.study_index.valid_rows(filter_).invalid.df.collect() ] assert failing_studies == expected @@ -90,7 +88,7 @@ def test_invalid_rows( def test_failing_quality_flag(self: TestDataExclusion) -> None: """Test invalid quality flag.""" with pytest.raises(ValueError): - self.study_index.valid_rows(self.INCORRECT_FLAG, invalid=True).df.collect() + self.study_index.valid_rows(self.INCORRECT_FLAG).invalid.df.collect() with pytest.raises(ValueError): - self.study_index.valid_rows(self.INCORRECT_FLAG, invalid=False).df.collect() + self.study_index.valid_rows(self.INCORRECT_FLAG).valid.df.collect() diff --git a/tests/gentropy/dataset/test_intervals.py b/tests/gentropy/dataset/test_intervals.py new file mode 100644 index 000000000..ebbf1e606 --- /dev/null +++ b/tests/gentropy/dataset/test_intervals.py @@ -0,0 +1,92 @@ +"""Test Interval dataset.""" + +import pytest +from pyspark.sql import DataFrame + +from gentropy import Session +from gentropy.dataset.biosample_index import BiosampleIndex +from gentropy.dataset.contig_index import ContigIndex +from gentropy.dataset.intervals import Intervals +from gentropy.dataset.target_index import TargetIndex + + +@pytest.fixture +def interval_dataframe(session: Session) -> DataFrame: + """Get the interval dataframe for testing.""" + data = [ + ("1", 100, 200, "ENSG1", "E2G", "promoter", "1", "biosample1", 0.1, "1"), + ("1", 150, 250, "ENSG2", "E2G", "enhancer", "2", "biosample2", 0.2, "1"), + ("2", 300, 400, "ENSG3", "E2G", "intragenic", "3", "biosample1", 0.3, "1"), + ("2", 300, 400, "ENSG3", "E2G", "promoter", "4", "biosample2", 0.4, "1"), + ("2", 400, 500, "ENSG4", "epiraction", "other", "5", "biosample1", 0.5, "1"), + ("3", 450, 550, "ENSG5", "other", "6", "interval6", "biosample2", 0.6, "1"), + ] + schema = "chromosome STRING, start LONG, end LONG, geneId STRING, datasourceId STRING, intervalType STRING, intervalId STRING, biosampleId STRING, score DOUBLE, studyId STRING" + return session.spark.createDataFrame(data, schema=schema) + + +@pytest.fixture +def contig_index(session: Session) -> ContigIndex: + """Get a mock contig index.""" + data = [("1", 0, 200), ("2", 0, 300)] + schema = "id STRING, start LONG, end LONG" + df = session.spark.createDataFrame(data, schema=schema) + return ContigIndex(df) + + +@pytest.fixture +def target_index(session: Session) -> TargetIndex: + """Get a mock target index.""" + data = [ + ("ENSG1", "Gene1"), + ("ENSG2", "Gene2"), + ("ENSG3", "Gene3"), + ("ENSG4", "Gene4"), + ("ENSG5", "Gene5"), + ] + schema = "id STRING, approvedSymbol STRING" + df = session.spark.createDataFrame(data, schema=schema) + return TargetIndex(df) + + +@pytest.fixture +def biosample_index(session: Session) -> BiosampleIndex: + """Get a mock biosample index.""" + data = [("biosample1", "CellType1"), ("biosample2", "CellType2")] + schema = "biosampleId STRING, biosampleName STRING" + df = session.spark.createDataFrame(data, schema=schema) + return BiosampleIndex(df) + + +class TestIntervalDataset: + """Test Interval dataset functionalities.""" + + def test_qc( + self, + interval_dataframe: DataFrame, + contig_index: ContigIndex, + target_index: TargetIndex, + biosample_index: BiosampleIndex, + ) -> None: + """Test QC method of Interval dataset.""" + intervals = Intervals(interval_dataframe) + + valid, invalid = intervals.qc( + contig_index=contig_index, + target_index=target_index, + biosample_index=biosample_index, + min_valid_score=0.5, + max_valid_score=1.0, + invalid_qc_reasons=["INVALID_CHROMOSOME"], + ) + + # Validate the results + assert isinstance(valid, Intervals) + assert isinstance(invalid, Intervals) + + # Assert the total count remains the same + assert valid.df.count() + invalid.df.count() == interval_dataframe.count() + + # Assert there are no INVALID_CHROMOSOME in valid intervals + assert valid.df.count() == 5, "Expected 5 valid intervals" + assert invalid.df.count() == 1, "Expected 1 interval with INVALID_CHROMOSOME" diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index b805ec92b..9f5247af3 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -757,7 +757,7 @@ def _setup(self, spark: SparkSession) -> None: "biosampleName": None, "biosampleId": None, "studyId": None, - "intervalId": None, + "intervalId": "1", }, { "chromosome": "1", @@ -774,7 +774,7 @@ def _setup(self, spark: SparkSession) -> None: "biosampleName": None, "biosampleId": None, "studyId": None, - "intervalId": None, + "intervalId": "2", }, # Another gene overlapping the kept variant position { @@ -792,7 +792,7 @@ def _setup(self, spark: SparkSession) -> None: "biosampleName": None, "biosampleId": None, "studyId": None, - "intervalId": None, + "intervalId": "3", }, # Interval near low-PP variant (won't contribute because PP threshold filters it before the join) { @@ -810,7 +810,7 @@ def _setup(self, spark: SparkSession) -> None: "biosampleName": None, "biosampleId": None, "studyId": None, - "intervalId": None, + "intervalId": "4", }, ] self.sample_intervals = Intervals( diff --git a/tests/gentropy/datasource/intervals/test_e2g.py b/tests/gentropy/datasource/intervals/test_e2g.py index 0cc7d00eb..fce5fdf5f 100644 --- a/tests/gentropy/datasource/intervals/test_e2g.py +++ b/tests/gentropy/datasource/intervals/test_e2g.py @@ -5,7 +5,6 @@ import pytest from pyspark.sql import DataFrame, SparkSession -from gentropy.dataset.biosample_index import BiosampleIndex from gentropy.dataset.intervals import Intervals from gentropy.dataset.target_index import TargetIndex from gentropy.datasource.intervals.e2g import IntervalsE2G @@ -36,7 +35,6 @@ def test_e2g_intervals_from_source( sample_intervals_e2g: DataFrame, sample_biosample_mapping: DataFrame, mock_target_index: TargetIndex, - mock_biosample_index: BiosampleIndex, ) -> None: """Test E2GIntervals creation with mock data.""" assert isinstance( @@ -44,7 +42,6 @@ def test_e2g_intervals_from_source( sample_intervals_e2g, sample_biosample_mapping, mock_target_index, - mock_biosample_index, ), Intervals, ) diff --git a/tests/gentropy/step/test_interval_step.py b/tests/gentropy/step/test_interval_step.py new file mode 100644 index 000000000..65d2bc055 --- /dev/null +++ b/tests/gentropy/step/test_interval_step.py @@ -0,0 +1,201 @@ +"""Integration test for interval step.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from gentropy import Session +from gentropy.dataset.dataset import DatasetValidationResult +from gentropy.intervals import IntervalE2GStep, IntervalEpiractionStep + + +class TestIntervalE2GStep: + """Test Interval E2G Step.""" + + @pytest.mark.step_test + @patch("gentropy.intervals.BiosampleIndex") + @patch("gentropy.intervals.ContigIndex") + @patch("gentropy.intervals.TargetIndex") + @patch("gentropy.intervals.IntervalsE2G") + @patch("pyspark.sql.readwriter.DataFrameReader.csv") + def test_interval_e2g_step( + self, + spark_read: MagicMock, + intervals_e2g: MagicMock, + target: MagicMock, + contig: MagicMock, + biosample: MagicMock, + session: Session, + tmp_path: Path, + ) -> None: + """Test Interval E2G Step callstack.""" + target_index_path = (tmp_path / "target").as_posix() + biosample_mapping_path = (tmp_path / "biosample_mapping.tsv").as_posix() + biosample_index_path = (tmp_path / "biosample").as_posix() + contig_index_path = (tmp_path / "contig").as_posix() + interval_source = (tmp_path / "e2g_source").as_posix() + valid_output_path = (tmp_path / "valid_intervals").as_posix() + invalid_output_path = (tmp_path / "invalid_intervals").as_posix() + min_valid_score = 0.5 + max_valid_score = 1.0 + invalid_qc_reasons = ["INVALID_CHROMOSOME", "INVALID_SCORE"] + + dummy_df = session.spark.createDataFrame( + [("A", 1), ("C", 2)], schema="chromosome STRING, start LONG" + ) + + # Mock the `Dataset.from_parquet` methods to return MagicMock instances + target.from_parquet = MagicMock() + target.from_parquet.persist = MagicMock() + biosample.from_parquet = MagicMock() + contig.from_parquet = MagicMock() + + # Mock spark_read return value to return dummy_df + spark_read.return_value = dummy_df + + # Mock E2G.read to return dummy_df + intervals_e2g.read = MagicMock(return_value=dummy_df) + intervals_instance = MagicMock() + intervals_instance.df = dummy_df + intervals_instance.qc = MagicMock( + return_value=DatasetValidationResult( + valid=intervals_instance, invalid=intervals_instance + ) + ) + intervals_e2g.parse = MagicMock(return_value=intervals_instance) + + IntervalE2GStep( + session=session, + target_index_path=target_index_path, + biosample_mapping_path=biosample_mapping_path, + biosample_index_path=biosample_index_path, + chromosome_contig_index_path=contig_index_path, + interval_source=interval_source, + valid_output_path=valid_output_path, + invalid_output_path=invalid_output_path, + min_valid_score=min_valid_score, + max_valid_score=max_valid_score, + invalid_qc_reasons=invalid_qc_reasons, + ) + + # Assert the biosample mapping was read via csv file + spark_read.assert_called_once_with(biosample_mapping_path, header=True) + # Assert the datasets were read + target.from_parquet.assert_called_once_with(session, target_index_path) + target.from_parquet.return_value.persist.assert_called_once() + + biosample.from_parquet.assert_called_once_with(session, biosample_index_path) + contig.from_parquet.assert_called_once_with(session, contig_index_path) + # Assert the read method was called for E2G data + intervals_e2g.read.assert_called_once_with(session.spark, interval_source) + # Assert the parse method was called for E2G data + intervals_e2g.parse.assert_called_once_with( + intervals_e2g.read.return_value, + spark_read.return_value, + target.from_parquet.return_value.persist.return_value, + ) + # Assert the qc method was called + intervals_instance.qc.assert_called_once_with( + contig_index=contig.from_parquet.return_value, + target_index=target.from_parquet.return_value.persist.return_value, + biosample_index=biosample.from_parquet.return_value, + min_valid_score=min_valid_score, + max_valid_score=max_valid_score, + invalid_qc_reasons=invalid_qc_reasons, + ) + # Assert that valid and invalid datasets were written + assert Path(valid_output_path).exists() + assert Path(invalid_output_path).exists() + + +class TestIntervalEpiractionStep: + """Test Interval Epiraction Step.""" + + @pytest.mark.step_test + @patch("gentropy.intervals.BiosampleIndex") + @patch("gentropy.intervals.ContigIndex") + @patch("gentropy.intervals.TargetIndex") + @patch("gentropy.intervals.IntervalsEpiraction") + def test_interval_epiraction_step( + self, + intervals_epiraction: MagicMock, + target: MagicMock, + contig: MagicMock, + biosample: MagicMock, + session: Session, + tmp_path: Path, + ) -> None: + """Test Interval Epiraction Step callstack.""" + target_index_path = (tmp_path / "target").as_posix() + biosample_index_path = (tmp_path / "biosample").as_posix() + contig_index_path = (tmp_path / "contig").as_posix() + interval_source = (tmp_path / "epiraction_source").as_posix() + valid_output_path = (tmp_path / "valid_intervals").as_posix() + invalid_output_path = (tmp_path / "invalid_intervals").as_posix() + min_valid_score = 0.5 + max_valid_score = 1.0 + invalid_qc_reasons = ["INVALID_CHROMOSOME", "INVALID_SCORE"] + + dummy_df = session.spark.createDataFrame( + [("A", 0), ("C", 1)], schema="chromosome STRING, start LONG" + ) + + # Mock the `Dataset.from_parquet` methods to return MagicMock instances + target.from_parquet = MagicMock() + target.from_parquet.persist = MagicMock() + biosample.from_parquet = MagicMock() + contig.from_parquet = MagicMock() + + # Mock Epiraction.read to return dummy_df + intervals_epiraction.read = MagicMock(return_value=dummy_df) + intervals_instance = MagicMock() + intervals_instance.df = dummy_df + intervals_instance.qc = MagicMock( + return_value=DatasetValidationResult( + valid=intervals_instance, invalid=intervals_instance + ) + ) + intervals_epiraction.parse = MagicMock(return_value=intervals_instance) + + IntervalEpiractionStep( + session=session, + target_index_path=target_index_path, + biosample_index_path=biosample_index_path, + chromosome_contig_index_path=contig_index_path, + interval_source=interval_source, + valid_output_path=valid_output_path, + invalid_output_path=invalid_output_path, + min_valid_score=min_valid_score, + max_valid_score=max_valid_score, + invalid_qc_reasons=invalid_qc_reasons, + ) + + # Assert the datasets were read + target.from_parquet.assert_called_once_with(session, target_index_path) + target.from_parquet.return_value.persist.assert_called_once() + biosample.from_parquet.assert_called_once_with(session, biosample_index_path) + contig.from_parquet.assert_called_once_with(session, contig_index_path) + # Assert the read method was called for Epiraction data + intervals_epiraction.read.assert_called_once_with( + session.spark, interval_source + ) + # Assert the parse method was called for Epiraction data + intervals_epiraction.parse.assert_called_once_with( + intervals_epiraction.read.return_value, + target.from_parquet.return_value.persist.return_value, + ) + # Assert the qc method was called + intervals_instance.qc.assert_called_once_with( + contig_index=contig.from_parquet.return_value, + target_index=target.from_parquet.return_value.persist.return_value, + biosample_index=biosample.from_parquet.return_value, + min_valid_score=min_valid_score, + max_valid_score=max_valid_score, + invalid_qc_reasons=invalid_qc_reasons, + ) + # Assert that valid and invalid datasets were written + assert Path(valid_output_path).exists() + assert Path(invalid_output_path).exists() diff --git a/tests/gentropy/step/test_study_qc.py b/tests/gentropy/step/test_study_qc.py index 5368528bf..40ae16a87 100644 --- a/tests/gentropy/step/test_study_qc.py +++ b/tests/gentropy/step/test_study_qc.py @@ -98,5 +98,4 @@ def test_step( study_index.validate_analysis_flags.assert_called_once() # Assert valid_rows - study_index.valid_rows.assert_any_call(invalid_qc_reasons, invalid=True) study_index.valid_rows.assert_any_call(invalid_qc_reasons) From f547b03c6efe31cb017318c09c56d79848eaad29 Mon Sep 17 00:00:00 2001 From: Daniel-Considine <113430683+Daniel-Considine@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:53:49 +0000 Subject: [PATCH 03/16] fix: bug with allele matching in LD index variant liftover (#1185) * fix: bug with allele matching in liftover * test: fixing tests and minor bugfix in code * fix: adding logic to ensure no unexpected contig ids --------- Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- .gitignore | 1 + src/gentropy/datasource/gnomad/ld.py | 74 +++++++++++-------- .../datasource/gnomad/test_gnomad_ld.py | 24 +++++- 3 files changed, 64 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index 8e4ee8af9..830cc9f16 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ wandb/ hail*.log .python-version .idea +.venv/ diff --git a/src/gentropy/datasource/gnomad/ld.py b/src/gentropy/datasource/gnomad/ld.py index 9550d76ac..b37d854ce 100644 --- a/src/gentropy/datasource/gnomad/ld.py +++ b/src/gentropy/datasource/gnomad/ld.py @@ -12,7 +12,6 @@ from hail.linalg import BlockMatrix from pyspark.sql import Window -from gentropy.common.genomic_region import liftover_loci from gentropy.common.spark import get_top_ranked_in_window, get_value_from_row from gentropy.common.types import LD_Population from gentropy.config import LDIndexConfig @@ -123,7 +122,7 @@ def _create_ldindex_for_population( population_id: str, ld_matrix_path: str, ld_index_raw_path: str, - grch37_to_grch38_chain_path: str, + liftover_ht_path: str, min_r2: float, ) -> DataFrame: """Create LDIndex for a specific population. @@ -132,7 +131,7 @@ def _create_ldindex_for_population( population_id (str): Population ID ld_matrix_path (str): Path to the LD matrix ld_index_raw_path (str): Path to the LD index - grch37_to_grch38_chain_path (str): Path to the chain file used to lift over the coordinates + liftover_ht_path (str): Path to gnomad hail table used to lift over the coordinates min_r2 (float): Minimum r2 value to keep in the table Returns: @@ -146,7 +145,7 @@ def _create_ldindex_for_population( # Prepare table with variant indices ld_index = GnomADLDMatrix._process_variant_indices( hl.read_table(ld_index_raw_path), - grch37_to_grch38_chain_path, + liftover_ht_path, ) return GnomADLDMatrix._resolve_variant_indices(ld_index, ld_matrix).select( @@ -156,7 +155,8 @@ def _create_ldindex_for_population( @staticmethod def _process_variant_indices( - ld_index_raw: hl.Table, grch37_to_grch38_chain_path: str + ld_index_raw: hl.Table, + liftover_ht_path: str, ) -> DataFrame: """Creates a look up table between variants and their coordinates in the LD Matrix. @@ -164,35 +164,36 @@ def _process_variant_indices( Args: ld_index_raw (hl.Table): LD index table from GnomAD - grch37_to_grch38_chain_path (str): Path to the chain file used to lift over the coordinates + liftover_ht_path (str): Path to gnomad hail table used to lift over the coordinates Returns: DataFrame: Look up table between variants in build hg38 and their coordinates in the LD Matrix """ - ld_index_38 = liftover_loci(ld_index_raw, grch37_to_grch38_chain_path, "GRCh38") + ht_37 = ld_index_raw + ht_lo = hl.read_table(liftover_ht_path).key_by( + "original_locus", "original_alleles" + ) + ld_index_38 = ht_37.join(ht_lo, how="inner") return ( ld_index_38.to_spark() # Filter out variants where the liftover failed - .filter(f.col("`locus_GRCh38.position`").isNotNull()) + .filter(f.col("`locus_1.position`").isNotNull()) .select( - f.regexp_replace("`locus_GRCh38.contig`", "chr", "").alias( - "chromosome" - ), - f.col("`locus_GRCh38.position`").alias("position"), + f.regexp_replace("`locus_1.contig`", "chr", "").alias("chromosome"), + f.col("`locus_1.position`").alias("position"), f.concat_ws( "_", - f.regexp_replace("`locus_GRCh38.contig`", "chr", ""), - f.col("`locus_GRCh38.position`"), - f.col("`alleles`").getItem(0), - f.col("`alleles`").getItem(1), + f.regexp_replace("`locus_1.contig`", "chr", ""), + f.col("`locus_1.position`"), + f.col("`alleles_1`").getItem(0), + f.col("`alleles_1`").getItem(1), ).alias("variantId"), f.col("idx"), ) - # Filter out ambiguous liftover results: multiple indices for the same variant - .withColumn("count", f.count("*").over(Window.partitionBy(["variantId"]))) - .filter(f.col("count") == 1) - .drop("count") + .filter( + f.col("chromosome").isin([str(i) for i in range(1, 23)] + ["X", "Y"]) + ) ) @staticmethod @@ -286,7 +287,7 @@ def as_ld_index( pop, ld_matrix_path, ld_index_raw_path.format(pop), - self.grch37_to_grch38_chain_path, + self.liftover_ht_path, min_r2, ) ld_indices_unaggregated.append(pop_ld_index) @@ -328,7 +329,7 @@ def get_ld_variants( ld_index_df = ( self._process_variant_indices( hl.read_table(self.ld_index_raw_template.format(POP=gnomad_ancestry)), - self.grch37_to_grch38_chain_path, + self.liftover_ht_path, ) .filter( (f.col("chromosome") == chromosome) @@ -468,8 +469,8 @@ def get_locus_index( & (liftover_ht.locus.position <= end) ) .key_by() - .select("locus", "alleles", "original_locus") - .key_by("original_locus", "alleles") + .select("locus", "alleles", "original_locus", "original_alleles") + .key_by("original_locus", "original_alleles") .naive_coalesce(20) ) @@ -500,9 +501,7 @@ def get_numpy_matrix( idx = [row["idx"] for row in locus_index.select("idx").collect()] half_matrix = ( - BlockMatrix.read( - self.ld_matrix_template.format(POP=gnomad_ancestry) - ) + BlockMatrix.read(self.ld_matrix_template.format(POP=gnomad_ancestry)) .filter(idx, idx) .to_numpy() ) @@ -544,12 +543,23 @@ def get_locus_index_boundaries( return joined_index def _filter_liftover_by_locus( - self, + self: GnomADLDMatrix, liftover_ht: hl.Table, chromosome: str, start: int, - end: int - ) -> hl.Table: + end: int, + ) -> hl.Table: + """Filter liftover hail table by locus. + + Args: + liftover_ht (hl.Table): Liftover hail table + chromosome (str): Chromosome of locus + start (int): Start position for locus + end (int): End position for locus + + Returns: + hl.Table: Filtered liftover hail table + """ liftover_ht = ( liftover_ht.filter( (liftover_ht.locus.contig == chromosome) @@ -557,8 +567,8 @@ def _filter_liftover_by_locus( & (liftover_ht.locus.position <= end) ) .key_by() - .select("locus", "alleles", "original_locus") - .key_by("original_locus", "alleles") + .select("locus", "alleles", "original_locus", "original_alleles") + .key_by("original_locus", "original_alleles") .naive_coalesce(20) ) diff --git a/tests/gentropy/datasource/gnomad/test_gnomad_ld.py b/tests/gentropy/datasource/gnomad/test_gnomad_ld.py index ac9567bfc..4784d747d 100644 --- a/tests/gentropy/datasource/gnomad/test_gnomad_ld.py +++ b/tests/gentropy/datasource/gnomad/test_gnomad_ld.py @@ -3,6 +3,7 @@ from __future__ import annotations from math import sqrt +from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch @@ -92,16 +93,33 @@ def test_get_ld_variants__square( assert sqrt(self.ld_slice.count()) == int(sqrt(self.ld_slice.count())) @pytest.fixture(autouse=True) - def _setup(self: TestGnomADLDMatrixVariants, spark: SparkSession) -> None: + def _setup( + self: TestGnomADLDMatrixVariants, spark: SparkSession, tmp_path: Path + ) -> None: """Prepares fixtures for the test.""" hl.init(sc=spark.sparkContext, log="/dev/null", idempotent=True) ld_test_population = "test-pop" + liftover_path = str(tmp_path / "mock_liftover.ht") + ld_index_path = str(tmp_path / "example_test-pop.ht") + + # Create a mock liftover table that maps 1:1 and ensure LD index is keyed + ht = hl.read_table("tests/gentropy/data_samples/example_test-pop.ht") + ht = ht.key_by("locus", "alleles") + ht.write(ld_index_path, overwrite=True) + + ht_lo = ht.annotate( + original_locus=ht.locus, + original_alleles=ht.alleles, + locus_1=ht.locus, + alleles_1=ht.alleles, + ) + ht_lo.write(liftover_path, overwrite=True) gnomad_ld_matrix = GnomADLDMatrix( ld_matrix_template="tests/gentropy/data_samples/example_{POP}.bm", - ld_index_raw_template="tests/gentropy/data_samples/example_{POP}.ht", - grch37_to_grch38_chain_path="tests/gentropy/data_samples/grch37_to_grch38.over.chain", + ld_index_raw_template=str(tmp_path) + "/example_{POP}.ht", + liftover_ht_path=liftover_path, ) self.ld_slice = gnomad_ld_matrix.get_ld_variants( gnomad_ancestry=ld_test_population, From 262c904d850451aa9e16d2f6d3223408c50ef9fb Mon Sep 17 00:00:00 2001 From: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> Date: Fri, 13 Feb 2026 14:57:51 +0000 Subject: [PATCH 04/16] feat: Session refactoring (#1174) * feat(dataset): universal reader * refactor(datasource): eqtl catalogue parser allow for parquet files * feat: updates to gentropy session * chore: update doctests --------- Co-authored-by: project-defiant --- .github/workflows/pr.yaml | 2 +- Makefile | 19 +- docs/python_api/common/session.md | 2 +- pyproject.toml | 5 +- src/gentropy/__init__.py | 8 +- src/gentropy/assets/log4j.properties | 12 + src/gentropy/common/session.py | 723 +++++++++++++++--- src/gentropy/config.py | 11 +- src/gentropy/dataset/colocalisation.py | 6 +- src/gentropy/dataset/dataset.py | 28 +- .../datasource/eqtl_catalogue/__init__.py | 24 + .../datasource/eqtl_catalogue/finemapping.py | 30 +- .../datasource/eqtl_catalogue/study_index.py | 42 +- .../finngen_meta/summary_statistics.py | 20 +- src/gentropy/eqtl_catalogue.py | 11 +- src/gentropy/l2g.py | 2 +- src/utils/spark.py | 1 + tests/gentropy/common/test_session.py | 131 +++- tests/gentropy/conftest.py | 30 +- .../eqtl_catalogue/test_eqtl_catalogue.py | 2 +- .../finngen/test_finngen_finemapping.py | 7 +- .../test_finngen_meta_summary_statistics.py | 40 +- .../datasource/gnomad/test_gnomad_ld.py | 9 +- tests/gentropy/no_spark/test_no_spark.py | 145 ++++ 24 files changed, 1076 insertions(+), 234 deletions(-) create mode 100644 src/gentropy/assets/log4j.properties create mode 100644 tests/gentropy/no_spark/test_no_spark.py diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index f17b04e9b..47f71da66 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -39,7 +39,7 @@ jobs: - name: Check dependencies run: uv run deptry . - name: Run tests - run: uv run pytest + run: make test - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: diff --git a/Makefile b/Makefile index bc83f9485..2ff17f5ba 100644 --- a/Makefile +++ b/Makefile @@ -38,9 +38,22 @@ check: ## Lint and format code @uv run pydoclint --config=pyproject.toml src @uv run pydoclint --config=pyproject.toml --skip-checking-short-docstrings=true tests -test: ## Run tests - @echo "Running Tests..." - @uv run pytest +test-no-shared-spark-session: ## Run tests that can not rely on shared SparkSession. + @echo "Running tests that can not rely on shared SparkSession fixture..." + @COVERAGE_FILE=.coverage.no_shared_spark uv run pytest -m "no_shared_spark and not download_jars_from_web" -n0 --cov-report= + +test-shared-spark-session: ## Run tests that can use shared SparkSession fixture. + @echo "Running tests that can share SparkSession fixture..." + @COVERAGE_FILE=.coverage.shared_spark uv run pytest --cov-report= + +test-no-shared-spark-session-web-dependencies: ## Run tests that require to download spark dependency jars from the web (not run by default). + @echo "Running tests that can not rely on shared SparkSession and require downloading jar dependencies from web..." + @COVERAGE_FILE=.coverage.no_shared_spark_web_deps uv run pytest -n0 -m "download_jars_from_web" --cov-report= + +test: test-no-shared-spark-session test-shared-spark-session ## Run default test suite + @uv run coverage combine .coverage.shared_spark .coverage.no_shared_spark + @uv run coverage xml + @rm -f .coverage.shared_spark .coverage.no_shared_spark build-documentation: ## Create local server with documentation @echo "Building Documentation..." diff --git a/docs/python_api/common/session.md b/docs/python_api/common/session.md index de45e1ad9..c3de28224 100644 --- a/docs/python_api/common/session.md +++ b/docs/python_api/common/session.md @@ -5,4 +5,4 @@ title: session ## Spark Session wrapper for gentropy :::gentropy.common.session.Session -:::gentropy.common.session.Log4j +:::gentropy.common.session.SparkWriteMode diff --git a/pyproject.toml b/pyproject.toml index 54995d99d..26b6b4e0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,13 +130,12 @@ color = true exclude = ["dist"] [tool.pytest.ini_options] -addopts = "-n auto --doctest-modules --cov=src/ --cov-report=xml --cache-clear" +addopts = "-n auto --doctest-modules --cov=src/ --cov-report=xml --cache-clear -m 'not download_jars_from_web and not no_shared_spark'" pythonpath = ["."] testpaths = ["tests/gentropy", "src/gentropy"] -markers = ["step_test", "long_test"] +markers = ["step_test", "download_jars_from_web", "no_shared_spark"] filterwarnings = [ "ignore:.*it is preferred to specify type hints for pandas UDF.*:UserWarning" - ] # Semi-strict mode for mypy diff --git a/src/gentropy/__init__.py b/src/gentropy/__init__.py index c2c9cda5d..58944ca3d 100644 --- a/src/gentropy/__init__.py +++ b/src/gentropy/__init__.py @@ -4,12 +4,18 @@ import warnings -# NOTE: Suppress DeprecationWarnings from pyspark related to pandas API on Spark due to LooseVersion being deprecated in Python 3.12+ +# NOTE: Suppress DeprecationWarnings and UserWarnings from pyspark related to pandas API on Spark due to LooseVersion being deprecated in Python 3.12+ warnings.filterwarnings( "ignore", category=DeprecationWarning, module="pyspark.sql.pandas.utils", ) +warnings.filterwarnings( + "ignore", + category=UserWarning, + module="pyspark.sql.pandas.functions", +) + from gentropy.common.session import Session from gentropy.dataset.biosample_index import BiosampleIndex diff --git a/src/gentropy/assets/log4j.properties b/src/gentropy/assets/log4j.properties new file mode 100644 index 000000000..0bb6fdb5f --- /dev/null +++ b/src/gentropy/assets/log4j.properties @@ -0,0 +1,12 @@ +# Source - https://stackoverflow.com/a/76196464 + +# Set everything to be logged to the console +log4j.rootCategory=ERROR, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Set the log level to ERROR for everything +log4j.logger.org.apache=ERROR +log4j.logger.org.apache.spark=ERROR diff --git a/src/gentropy/common/session.py b/src/gentropy/common/session.py index e0bcf18d0..499ad53a6 100644 --- a/src/gentropy/common/session.py +++ b/src/gentropy/common/session.py @@ -2,9 +2,12 @@ from __future__ import annotations +import os +from enum import StrEnum +from pathlib import Path from typing import TYPE_CHECKING, Any, Protocol -import hail as hl +import pandas as pd from pyspark.conf import SparkConf from pyspark.sql import SparkSession @@ -13,75 +16,470 @@ from pyspark.sql.types import StructType +class NativeFileFormat(StrEnum): + """Enum for supported file formats.""" + + PARQUET = "parquet" + CSV = "csv" + TSV = "tsv" + JSON = "json" + + +class SparkWriteMode(StrEnum): + """Enum for Spark write modes.""" + + APPEND = "append" + OVERWRITE = "overwrite" + IGNORE = "ignore" + ERROR_IF_EXISTS = "errorifexists" + + @classmethod + def ensure(cls, v: str | None) -> str: + """Ensure the writeMode is correct. + + Args: + v (str | None): input value + + Returns: + str: mapping + + Raises: + ValueError: when the value is not found. + """ + match v: + case "append": + return cls.APPEND.value + case "overwrite": + return cls.OVERWRITE.value + case "ignore": + return cls.IGNORE.value + case "errorifexists" | None: + return cls.ERROR_IF_EXISTS.value + case _: + raise ValueError("Incorrect writeMode specified to Session object.") + + class Session: - """This class provides a Spark session and logger.""" + """This class provides a wrapper around SparkSession object with custom parameters. + + The wrapper has a few default sets of configurations. See constructor for references. + + !!! info "Custom Spark Configuration" + - **Output configuration**: write_mode and output_partitions, these set of parameters is stored respectively + under `spark.gentropy.writeMode` and `spark.gentropy.outputPartitions`. + Both parameters are used when writing datasets in gentropy steps. The `writeMode` will reflect on how Spark should handle existing data at the output path, + while `outputPartitions` will determine the number of partitions to use when writing out datasets (typically, excluding studyIndex datasets). For exact usage check the respective step implementation. + - **Hail configuration**: If `start_hail` is set to True, the Spark session will be configured with hail. + By default the path to the Hail jar will be inferred from the installed Hail package location. + Note that custom Hail configuration parameters can be passed through the `extended_hail_conf` argument. + - **Dynamic allocation configuration**: If `dynamic_allocation` is set to True, the Spark session will include + `spark.dynamicAllocation.enabled`, `spark.dynamicAllocation.minExecutors`, `spark.dynamicAllocation.initialExecutors` and `spark.shuffle.service.enabled` configurations with 2 executors as minimum. + - **Enhanced BGZF codec configuration**: If `use_enhanced_bgzip_codec` is set to True, the Spark session will be configured to use the `BGZFEnhancedGzipCodec` for reading block gzipped files. + + Note: + The custom configuration parameters for gentropy are prefixed with `spark.gentropy.` to avoid conflicts with other Spark applications. + + Examples: + Create a new Spark Session on local machine with 4 executors, 4 cores and 8g of memory per executor + + >>> from gentropy.common.session import Session + >>> session = Session( + ... spark_uri="local[4]", + ... extended_spark_conf={ + ... "spark.executor.instances": "4", + ... "spark.executor.cores": "4", + ... "spark.executor.memory": "8g", + ... }, + ... ) # doctest: +SKIP + + Find existing session (if any exists) + + >>> session = Session.find() # doctest: +SKIP + + Create a new Spark Session with Hail support + + >>> session = Session(start_hail=True) # doctest: +SKIP + + Connect to running Spark cluster (yarn) + + >>> session = Session(spark_uri="yarn") # doctest: +SKIP + + Specify custom Hail configuration parameters + + >>> session = Session( + ... start_hail=True, + ... extended_hail_conf={"min_block_size": "32MB"} + ... ) # doctest: +SKIP + + Specify custom output parameters + + >>> session = Session( + ... output_partitions=100, + ... write_mode=SparkWriteMode.OVERWRITE + ... ) # doctest: +SKIP + + Specify via string (auto-converted to SparkWriteMode) if possible + + >>> session = Session( + ... output_partitions=100, + ... write_mode="overwrite" + ... ) # doctest: +SKIP + + Stop the session + + >>> session.spark.stop() # doctest: +SKIP + + View the path to spark ui + + >>> session.spark.sparkContext.uiWebUrl # doctest: +SKIP + + Example session with hadoop connector for S3 compatible storage + + >>> session = Session( + ... extended_spark_conf={ + ... # Executor + ... 'spark.executor.memory': '32g', + ... 'spark.executor.cores': '8', + ... 'spark.excutor.memoryOverhead': '4g', + ... 'spark.dynamicAllocation.enabled': 'true', + ... 'spark.sql.files.maxPartitionBytes': '512m', + ... # Driver + ... 'spark.driver.memory': '25g', + ... 'spark.executor.extraJavaOptions': '-XX:+UseG1GC -XX:MaxGCPauseMillis=200 -XX:+ParallelRefProcEnabled -XX:+AlwaysPreTouch', + ... 'spark.jars.packages': 'org.apache.hadoop:hadoop-aws:3.3.6,com.amazonaws:aws-java-sdk-bundle:1.12.367', + ... 'spark.hadoop.fs.s3a.impl': 'org.apache.hadoop.fs.s3a.S3AFileSystem', + ... 'spark.hadoop.fs.s3a.endpoint': f'https://{credentials.s3_host_url}:{credentials.s3_host_port}', + ... 'spark.hadoop.fs.s3a.path.style.access': 'true', + ... 'spark.hadoop.fs.s3a.connection.ssl.enabled': 'true', + ... 'spark.hadoop.fs.s3a.access.key': f'{credentials.access_key_id}', + ... 'spark.hadoop.fs.s3a.secret.key': f'{credentials.secret_access_key}', + ... # Throughput tuning + ... 'spark.hadoop.fs.s3a.connection.maximum': '1000', + ... 'spark.hadoop.fs.s3a.threads.max': '1024', + ... 'spark.hadoop.fs.s3a.attempts.maximum': '20', + ... 'spark.hadoop.fs.s3a.connection.timeout': '600000', # 10min + ... } + ... ) # doctest: +SKIP + + Example session with hadoop connector for Google Cloud Storage + + >>> session = Session( + ... extended_spark_conf={ + ... 'spark.driver.maxResultSize': '0', + ... 'spark.debug.maxToStringFields': '2000', + ... 'spark.sql.broadcastTimeout': '3000', + ... 'spark.sql.adaptive.enabled': 'true', + ... 'spark.sql.adaptive.coalescePartitions.enabled': 'true', + ... 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', + ... # google cloud storage connector + ... 'spark.jars.packages': 'com.google.cloud.bigdataoss:gcs-connector:hadoop3-2.2.21', + ... 'spark.network.timeout': '10s', + ... 'spark.network.timeoutInterval': '10s', + ... 'spark.executor.heartbeatInterval': '6s', + ... 'spark.hadoop.fs.gs.block.size': '134217728', + ... 'spark.hadoop.fs.gs.inputstream.buffer.size': '8388608', + ... 'spark.hadoop.fs.gs.outputstream.buffer.size': '8388608', + ... 'spark.hadoop.fs.gs.outputstream.sync.min.interval.ms': '2000', + ... 'spark.hadoop.fs.gs.status.parallel.enable': 'true', + ... 'spark.hadoop.fs.gs.glob.algorithm': 'CONCURRENT', + ... 'spark.hadoop.fs.gs.copy.with.rewrite.enable': 'true', + ... 'spark.hadoop.fs.gs.metadata.cache.enable': 'false', + ... 'spark.hadoop.fs.gs.auth.type': 'APPLICATION_DEFAULT', + ... 'spark.hadoop.fs.gs.impl': 'com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem', + ... 'spark.hadoop.fs.AbstractFileSystem.gs.impl': 'com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS', + ... } + ... ) # doctest: +SKIP + + """ def __init__( self: Session, spark_uri: str = "local[*]", - write_mode: str = "errorifexists", app_name: str = "gentropy", + write_mode: str = SparkWriteMode.ERROR_IF_EXISTS.value, hail_home: str | None = None, start_hail: bool = False, extended_spark_conf: dict[str, str] | None = None, + extended_hail_conf: dict[str, Any] | None = None, output_partitions: int = 200, use_enhanced_bgzip_codec: bool = False, + dynamic_allocation: bool = True, + log_level: str | None = "INFO", ) -> None: """Initialises spark session and logger. + The wrapper over SparkSession will either connect to an existing active Spark session or create a new one with the provided configuration. + + If spark session already exists, the provided configuration will have no effect on the session. + If any parameters will be different between existing session config and requested config, + a warning will be logged to suggest rebuilding the session with the new configuration. + Args: spark_uri (str): Spark URI. Defaults to "local[*]". - write_mode (str): Spark write mode. Defaults to "errorifexists". app_name (str): Spark application name. Defaults to "gentropy". + write_mode (str): Spark write mode. Defaults to SparkWriteMode.ERROR_IF_EXISTS. hail_home (str | None): Path to Hail installation. Defaults to None. start_hail (bool): Whether to start Hail. Defaults to False. extended_spark_conf (dict[str, str] | None): Extended Spark configuration. Defaults to None. + extended_hail_conf (dict[str, Any] | None): Extended Hail configuration. Defaults to None. output_partitions (int): Number of partitions for output datasets. Defaults to 200. use_enhanced_bgzip_codec (bool): Whether to use the BGZFEnhancedGzipCodec for reading block gzipped files. Defaults to False. + dynamic_allocation (bool): Whether to enable Spark dynamic allocation. Defaults to True. + log_level (str | None): Spark log level. Defaults to "INFO". """ - merged_conf = self._create_merged_config( - start_hail, - use_enhanced_bgzip_codec, - extended_spark_conf, - hail_home, + # Provide sane defaults for extended configurations + + self._extended_hail_conf = extended_hail_conf or {} + self._extended_spark_conf = extended_spark_conf or {} + self._write_mode = SparkWriteMode.ensure(write_mode) + self._output_partitions = output_partitions or 200 + self._hail_home = hail_home + # Build the requested config, small overhead, but we + # can report if existing session is up to date with provided configuration. + _c = self._build_config( + dynamic_allocation=dynamic_allocation, + start_hail=start_hail, + use_enhanced_bgzip_codec=use_enhanced_bgzip_codec, ) + # Create or retrieve the Spark session + _spark_exists = isinstance(SparkSession.getActiveSession(), SparkSession) + if _spark_exists: + self.spark = ( + SparkSession.Builder().master(spark_uri).appName(app_name).getOrCreate() + ) + self.logger = Log4j(self.spark, level=log_level) + self.conf = self.spark.sparkContext.getConf() + # Check existing configuration against requested + self._compare_conf(current=self.conf, requested=_c) + else: + # The sparkSession does not exist yet, initialize the spark session with new configuration + self.spark = ( + SparkSession.Builder() + .config(conf=_c) + .master(spark_uri) + .appName(app_name) + .getOrCreate() + ) + # Initialize Hail if requested + if start_hail: + import hail as hl - self.spark = ( - SparkSession.Builder() - .config(conf=merged_conf) - .master(spark_uri) - .appName(app_name) - .getOrCreate() - ) - self.logger = Log4j(self.spark) + self._extended_hail_conf.setdefault("log", "/dev/null") + self._extended_hail_conf.setdefault("quiet", True) + self._extended_hail_conf.setdefault("idempotent", True) + hl.init(sc=self.spark.sparkContext, **self._extended_hail_conf) - self.write_mode = write_mode + self.logger = Log4j(self.spark, level=log_level) + self.conf = self.spark.sparkContext.getConf() - self.hail_home = hail_home - self.start_hail = start_hail - self.use_enhanced_bgzip_codec = use_enhanced_bgzip_codec + def _build_config( + self, + dynamic_allocation: bool, + start_hail: bool, + use_enhanced_bgzip_codec: bool, + ) -> SparkConf: + """Prepare the SparkConf object with the requested configuration. + + Args: + dynamic_allocation (bool): Whether to enable Spark dynamic allocation. + start_hail (bool): Whether to include Hail configuration. + use_enhanced_bgzip_codec (bool): Whether to include enhanced BGZIP codec configuration. + + Returns: + SparkConf: SparkConf object with the requested configuration. + + """ + # Create a fresh SparkConf object... + _c = SparkConf(loadDefaults=False) + # ...and update it with requested parameters + _c = self._setup_output_config(_c, self._output_partitions, self._write_mode) + _c = self._setup_log4j_config(_c) + if dynamic_allocation: + _c = self._setup_dynamic_allocation_config(_c) if start_hail: - hl.init(sc=self.spark.sparkContext, log="/dev/null") - self.output_partitions = output_partitions + _c = self._setup_hail_config(_c, self._hail_home) + if use_enhanced_bgzip_codec: + _c = self._setup_enhanced_bgzip_config(_c) + # If any additional packages or jars, ensure they are included along existing ones instead of overwritten + if self._extended_spark_conf: + _c = self._setup_extended_spark_conf(self._extended_spark_conf, _c) + return _c + + def _compare_conf(self, current: SparkConf, requested: SparkConf) -> None: + """Compare current Spark configuration with the requested configuration. + + This method will log a warning for each configuration key that is present in the requested configuration but has a different value in the current configuration. - def _default_config(self: Session) -> SparkConf: - """Default spark configuration. + Args: + current (SparkConf): Current Spark configuration. + requested (SparkConf): Requested Spark configuration. + """ + for key, value in requested.getAll(): + current_value = current.get(key, None) + if current_value != value: + self.logger.warning( + f"Consider rebuilding SparkSession to apply requested configuration: '{key}' has value '{current_value}' but '{value}' was requested." + ) + + @property + def use_enhanced_bgzip_codec(self) -> bool: + """Check if the session is configured to use the BGZFEnhancedGzipCodec for reading block gzipped files. + + Returns: + bool: True if the session is configured to use the BGZFEnhancedGzipCodec, False otherwise. + """ + return ( + self.conf.get("spark.gentropy.useEnhancedBgzipCodec", "false").lower() + == "true" + ) + + @property + def output_partitions(self) -> int: + """Get the number of output partitions. + + Returns: + int: Number of output partitions. + """ + return int(self.conf.get("spark.gentropy.outputPartitions", "200")) + + @property + def write_mode(self) -> SparkWriteMode: + """Get the Spark write mode. + + Returns: + SparkWriteMode: Spark write mode. + """ + return SparkWriteMode( + self.conf.get( + "spark.gentropy.writeMode", SparkWriteMode.ERROR_IF_EXISTS.value + ) + ) + + @classmethod + def find(cls) -> Session: + """Finds the current active Spark session. + + If no active Spark session is found, the method will raise an AttributeError. + + Returns: + Session: Current active Spark session. + + Raises: + AttributeError: If no active Spark session is found. + """ + active_spark = SparkSession.getActiveSession() + if active_spark is None: + raise AttributeError("Active Spark not found.") + return Session() + + @classmethod + def _setup_extended_spark_conf( + cls, extended_spark_conf: dict[str, str], _c: SparkConf + ) -> SparkConf: + """Append extended spark configuration to the existing SparkConf object. + + This method ensures that packages and jars are included instead of overwritten. + + Args: + extended_spark_conf (dict[str, str]): Extended Spark configuration to include in the session. + _c (SparkConf): Existing SparkConf object to update. + + Returns: + SparkConf: Updated SparkConf object with extended configuration included. + """ + for key, value in extended_spark_conf.items(): + match key: + case "spark.jars": + _c = Session._append_jar(_c, value) + case "spark.jars.packages": + _c = Session._append_package(_c, value) + case "spark.driver.extraClassPath": + _c = Session._append_to_driver_classpath(_c, value) + case "spark.executor.extraClassPath": + _c = Session._append_to_executor_classpath(_c, value) + case _: + _c = _c.set(key, value) + return _c + + @staticmethod + def _setup_output_config( + c: SparkConf, output_partitions: int, write_mode: str + ) -> SparkConf: + """Output spark configuration. + + Args: + c (SparkConf): Existing Spark configuration. + output_partitions (int): Number of output partitions. + write_mode (str): Spark write mode. + + Returns: + SparkConf: adjusted spark configuration with output settings. + """ + return c.set("spark.gentropy.outputPartitions", str(output_partitions)).set( + "spark.gentropy.writeMode", str(write_mode) + ) + + @staticmethod + def _setup_dynamic_allocation_config(c: SparkConf) -> SparkConf: + """Setup Spark dynamic allocation configuration. + + Args: + c (SparkConf): Existing Spark configuration. Returns: - SparkConf: Default spark configuration. + SparkConf: Adjusted spark configuration with dynamic allocation settings. """ return ( - SparkConf() - # Dynamic allocation - .set("spark.dynamicAllocation.enabled", "true") + c.set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.minExecutors", "2") .set("spark.dynamicAllocation.initialExecutors", "2") .set("spark.shuffle.service.enabled", "true") ) - def _bgzip_config(self: Session) -> SparkConf: + @staticmethod + def _setup_hail_config( + c: SparkConf, + hail_home: str | None = None, + ) -> SparkConf: + """Setup Hail Spark configuration. + + Args: + c (SparkConf): Existing Spark configuration. + hail_home (str | None): Path to Hail installation. + + Returns: + SparkConf: Adjusted spark configuration with Hail settings. + """ + if not hail_home: + import hail as hl + + hail_home = Path(hl.__file__).parent.as_posix() + jar_path = f"{hail_home}/backend/hail-all-spark.jar" + if not Path(jar_path).exists(): + raise FileNotFoundError( + f"Hail jar not found at {jar_path}. Please set hail_home in Session." + ) + c = Session._append_jar(c, jar_path) + c = Session._append_to_driver_classpath(c, jar_path) + # NOTE: the docs mention to not use full path for exectuor classPath + c = Session._append_to_executor_classpath(c, "./hail-all-spark.jar") + return ( + c.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", "is.hail.kryo.HailKryoRegistrator") + .set("spark.gentropy.enableHail", "true") + .set("spark.gentropy.hailHome", hail_home) + ) + + @staticmethod + def _setup_enhanced_bgzip_config(c: SparkConf) -> SparkConf: """Spark configuration for reading block gzipped files. + Args: + c (SparkConf): Existing Spark configuration. + + Returns: + SparkConf: Adjusted spark configuration with BGZFEnhancedGzipCodec settings. + Configuration that adds the hadoop-bam package and sets the BGZFEnhancedGzipCodec. Based on hadoop-bam jar artifact from [maven](https://mvnrepository.com/artifact/org.seqdoop/hadoop-bam/7.10.0). @@ -89,77 +487,119 @@ def _bgzip_config(self: Session) -> SparkConf: Full details of the codec can be found in [hadoop-bam](https://github.com/HadoopGenomics/Hadoop-BAM/blob/7.10.0/src/main/java/org/seqdoop/hadoop_bam/util/BGZFEnhancedGzipCodec.java) This codec implements: - (1) SplittableCompressionCodec allowing parallel reading of bgzip files. - (2) GzipCodec allowing reading of standard gzip files. + (1) SplittableCompressionCodec allowing parallel reading of bgzip files. + (2) GzipCodec allowing reading of standard gzip files. + """ + c = Session._append_package(c, "org.seqdoop:hadoop-bam:7.10.0") + return c.set( + "spark.hadoop.io.compression.codecs", + "org.seqdoop.hadoop_bam.util.BGZFEnhancedGzipCodec", + ).set("spark.gentropy.useEnhancedBgzipCodec", "true") + + @staticmethod + def _append_jar(c: SparkConf, jar: str) -> SparkConf: + """Append a jar to the existing spark.jars configuration. + + Args: + c (SparkConf): Existing Spark configuration. + jar (str): Jar to add to the configuration. Returns: - SparkConf: Spark configuration for reading block gzipped files. + SparkConf: Adjusted spark configuration with the new jar included in the spark.jars setting. """ - return ( - SparkConf() - .set("spark.jars.packages", "org.seqdoop:hadoop-bam:7.10.0") - .set( - "spark.hadoop.io.compression.codecs", - "org.seqdoop.hadoop_bam.util.BGZFEnhancedGzipCodec", - ) - ) + existing_jars = c.get("spark.jars", "") + if jar not in existing_jars: + new_jars = f"{existing_jars},{jar}" if existing_jars else jar + return c.set("spark.jars", new_jars) + return c - def _hail_config(self: Session, hail_home: str) -> SparkConf: - """Returns the Hail specific Spark configuration. + @staticmethod + def _append_package(c: SparkConf, package: str) -> SparkConf: + """Append a package to the existing spark.jars.packages configuration. Args: - hail_home (str): Path to Hail installation. + c (SparkConf): Existing Spark configuration. + package (str): Package to add to the configuration. Returns: - SparkConf: Hail specific Spark configuration. + SparkConf: Adjusted spark configuration with the new package included in the spark.jars.packages setting. """ - return ( - SparkConf() - .set("spark.jars", f"{hail_home}/backend/hail-all-spark.jar") - .set( - "spark.driver.extraClassPath", f"{hail_home}/backend/hail-all-spark.jar" + existing_packages = c.get("spark.jars.packages", "") + if package not in existing_packages: + new_packages = ( + f"{existing_packages},{package}" if existing_packages else package ) - .set("spark.executor.extraClassPath", "./hail-all-spark.jar") - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "is.hail.kryo.HailKryoRegistrator") + return c.set("spark.jars.packages", new_packages) + return c + + @staticmethod + def _append_to_executor_classpath(c: SparkConf, jar: str) -> SparkConf: + """Append a jar to the existing driver and executor classpath. + + Args: + c (SparkConf): Existing Spark configuration. + jar (str): Jar to add to the classpath. + + Returns: + SparkConf: Adjusted spark configuration with the new jar included in the driver and executor classpath. + """ + existing_executor_cp = c.get("spark.executor.extraClassPath", "") + # NOTE: use os.pathsep, as it should default to ';' on windows and ':' on unix based systems. + new_executor_cp = ( + f"{existing_executor_cp}{os.pathsep}{jar}" if existing_executor_cp else jar ) + if jar not in existing_executor_cp: + return c.set("spark.executor.extraClassPath", new_executor_cp) + return c - def _create_merged_config( - self: Session, - start_hail: bool, - use_enhanced_bgzip_codec: bool, - extended_spark_conf: dict[str, str] | None, - hail_home: str | None = None, - ) -> SparkConf: - """Merges the default, and optionally the Hail and extended configurations if provided. + @staticmethod + def _append_to_driver_classpath(c: SparkConf, jar: str) -> SparkConf: + """Append a jar to the existing driver classpath. Args: - start_hail (bool): Whether to start Hail. - use_enhanced_bgzip_codec (bool): Whether to use the BGZFEnhancedGzipCodec for reading block gzipped files. - extended_spark_conf (dict[str, str] | None): Extended Spark configuration. - hail_home (str | None): Path to Hail installation. + c (SparkConf): Existing Spark configuration. + jar (str): Jar to add to the classpath. - Raises: - ValueError: If Hail home is not specified but Hail is requested. + Returns: + SparkConf: Adjusted spark configuration with the new jar included in the driver classpath. + """ + existing_driver_cp = c.get("spark.driver.extraClassPath", "") + # NOTE: use os.pathsep, as it should default to ';' on windows and ':' on unix based systems. + new_driver_cp = ( + f"{existing_driver_cp}{os.pathsep}{jar}" if existing_driver_cp else jar + ) + if jar not in existing_driver_cp: + return c.set("spark.driver.extraClassPath", new_driver_cp) + return c + + @staticmethod + def _setup_log4j_config(c: SparkConf) -> SparkConf: + """Setup Log4j Spark configuration. + + Args: + c (SparkConf): Existing Spark configuration. Returns: - SparkConf: Merged Spark configuration. + SparkConf: Adjusted spark configuration with log4j settings. + + !!! info "Log4j Configuration": + This method points to the static log4j properties file included in the gentropy assets. + The default configuration sets the log level to ERROR for all Spark logs. This is done to + prevent the excessive logging from Spark initialization, the actual log level can be adjusted + post initialization using the Log4j class. """ - all_settings = self._default_config().getAll() - if start_hail: - if not hail_home: - raise ValueError("Hail home must be specified to start Hail.") - all_settings += self._hail_config(hail_home).getAll() - if use_enhanced_bgzip_codec: - all_settings += self._bgzip_config().getAll() - if extended_spark_conf is not None: - all_settings += list(extended_spark_conf.items()) - return SparkConf().setAll(all_settings) + import importlib.resources as pkg_resources + + from gentropy import assets as asf + + prop = str(pkg_resources.files(asf).joinpath("log4j.properties")) + c.set("spark.driver.extraJavaOptions", f"-Dlog4j.configuration=file:{prop}") + return c def load_data( self: Session, path: str | list[str], - format: str = "parquet", + fmt: str = "parquet", schema: StructType | str | None = None, **kwargs: bool | float | int | str | None, ) -> DataFrame: @@ -169,19 +609,122 @@ def load_data( Args: path (str | list[str]): path to the dataset - format (str): file format. Defaults to parquet. + fmt (str): file format. Defaults to parquet. schema (StructType | str | None): Schema to use when reading the data. - **kwargs (bool | float | int | str | None): Additional arguments to pass to spark.read.load. `mergeSchema` is set to True, `recursiveFileLookup` is set to False by default. + **kwargs (bool | float | int | str | None): Additional arguments to pass to spark.read.load. Returns: - DataFrame: Dataframe + DataFrame: Dataframe containing the loaded data. + + !!! note "Default options for supported formats" + By default: + - `mergeSchema` is set to True for parquet format. + - `recursiveFileLookup` is set to False. + - For `tsv` format `sep` and `header` options are set to tab and `True` respectively. + - For `csv` format `header` is set to `True`. + + !!! warning "Loading data from URL" + If the provided path is a URL (starting with http:// or https://), the method will attempt to load the data + and parallelize it for processing, this can be very slow it the file is large. Consider downloading the data + to a distributed file system and loading it from there instead. Only supported formats for loading from URL are `csv` and `tsv`. + Loading does not allow for recursive file lookup, nor supports multiple URLs. + + !!! note "Supported formats" + Supported file formats are + - parquet + - csv + - tsv + - json (including jsonl/jsonlines) + + Examples: + Load single tsv file from url, the header is expected at the 0-th row + + >>> session.load_data('https://some_file.tsv', fmt='tsv') # doctest: +SKIP + + Load single csv file from url, no header, expected schema + + >>> session.load_data('https://some_file.csv', fmt='csv', header=False, schema="A int, B int") # doctest: +SKIP + + Load the parquet dataset from google cloud storage, note that the Hadoop connector is required in Session + + >>> session.load_data('gs://your_bucket/dataset') # doctest: +SKIP + + Load multiple json files from s3 storage, note that the Hadoop connector is required in Session + + >>> session.load_data(['s3a://some_bucket/file1.jsonl', 's3a://some_bucket/file2.jsonl'], fmt='json') # doctest: +SKIP """ # Set default kwargs + _format = fmt.lower() + kwargs.setdefault("recursiveFileLookup", False) + + match _format: + case "parquet": + _fmt = NativeFileFormat.PARQUET.value + kwargs.setdefault("mergeSchema", True) + case "tsv": + _fmt = NativeFileFormat.CSV.value + kwargs.setdefault("sep", "\t") + kwargs.setdefault("header", True) + if not schema: + kwargs.setdefault("inferSchema", "true") + case "csv": + _fmt = NativeFileFormat.CSV.value + kwargs.setdefault("header", True) + if not schema: + kwargs.setdefault("inferSchema", "true") + case "json" | "jsonl" | "jsonlines": + _fmt = NativeFileFormat.JSON.value + case _: + raise ValueError(f"Unsupported file format: {_format}") + + match path: + case list(): + all_strings = len(path) > 0 and all(isinstance(p, str) for p in path) + assert all_strings, "Path must be a non-empty list of strings." + case str(): + if path.startswith(("http://", "https://")): + return self._load_from_url(path, fmt=_fmt, schema=schema, **kwargs) + case _: + raise ValueError("Path must be a string or a list of strings.") + return self.spark.read.load(path, format=_fmt, schema=schema, **kwargs) + + def _load_from_url( + self: Session, + url: str, + fmt: str, + schema: StructType | str | None = None, + **kwargs: Any, + ) -> DataFrame: + """Load CSV/TSV/JSON data from a URL into a Spark DataFrame. + + Args: + url (str): single URL to load data from. + fmt (str): File format. Currently only 'csv', 'tsv' or 'json' are supported for loading from URL. + schema (StructType | str | None): Schema to use when reading the data. + **kwargs (Any): Additional arguments to pass to spark.read.csv. + + Returns: + DataFrame: Dataframe containing the loaded data. + """ + self.logger.warning( + "Reading data over HTTP/HTTPS. This may be slow for large datasets. Consider downloading the data to a distributed file system." + ) + + match fmt: + case "csv": + _header = kwargs.get("header", False) + header = 0 if _header else None + df = pd.read_csv(url, header=header, sep=kwargs.get("sep")) + case "json": + df = pd.read_json(url) + case _: + raise ValueError("Only csv, tsv and json are URL supported formats") if schema is None: - kwargs["inferSchema"] = kwargs.get("inferSchema", True) - kwargs["mergeSchema"] = kwargs.get("mergeSchema", True) - kwargs["recursiveFileLookup"] = kwargs.get("recursiveFileLookup", False) - return self.spark.read.load(path, format=format, schema=schema, **kwargs) + return self.spark.createDataFrame( + data=df, + samplingRatio=kwargs.get("samplingRation", 0.4), + ) + return self.spark.createDataFrame(data=df, schema=schema, verifySchema=True) class JavaLogger(Protocol): @@ -193,7 +736,6 @@ def error(self, message: str) -> None: Args: message (str): The error message to log. """ - ... def warn(self, message: str) -> None: """Log a warning message. @@ -201,7 +743,6 @@ def warn(self, message: str) -> None: Args: message (str): The error message to log. """ - ... def info(self, message: str) -> None: """Log an info message. @@ -209,19 +750,21 @@ def info(self, message: str) -> None: Args: message (str): The error message to log. """ - ... class Log4j: """Log4j logger class.""" - def __init__(self, spark: SparkSession) -> None: + def __init__(self, spark: SparkSession, level: str | None = None) -> None: """Log4j logger class. This class provides a wrapper around the Log4j logging system. Args: spark (SparkSession): The Spark session used to access Spark context and Log4j logging. + level (str | None): Logging level. Defaults to provided by spark """ log4j: Any = spark.sparkContext._jvm.org.apache.log4j # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + if level: + spark.sparkContext.setLogLevel(level) # Cast to our protocol type for type safety self.logger: JavaLogger = log4j.LogManager.getLogger(__name__) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index ca0c8fefe..45b1545d3 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -13,13 +13,15 @@ class SessionConfig: """Session configuration.""" + spark_uri: str = "local[*]" start_hail: bool = False write_mode: str = "errorifexists" - spark_uri: str = "local[*]" hail_home: str = os.path.dirname(hail_location) extended_spark_conf: dict[str, str] | None = field(default_factory=dict[str, str]) - use_enhanced_bgzip_codec: bool = False + extended_hail_conf: dict[str, str] | None = field(default_factory=dict[str, str]) output_partitions: int = 200 + use_enhanced_bgzip_codec: bool = False + dynamic_allocation: bool = True _target_: str = "gentropy.common.session.Session" @@ -129,6 +131,7 @@ class EqtlCatalogueConfig(StepConfig): eqtl_catalogue_paths_imported: str = MISSING eqtl_catalogue_study_index_out: str = MISSING eqtl_catalogue_credible_sets_out: str = MISSING + eqtl_catalogue_metadata_path: str = MISSING mqtl_quantification_methods_blacklist: list[str] = field(default_factory=lambda: []) eqtl_lead_pvalue_threshold: float = 1e-3 _target_: str = "gentropy.eqtl_catalogue.EqtlCatalogueStep" @@ -144,13 +147,13 @@ class FinngenStudiesConfig(StepConfig): } ) finngen_study_index_out: str = MISSING - finngen_phenotype_table_url: str = "https://r11.finngen.fi/api/phenos" + finngen_phenotype_table_url: str = MISSING finngen_release_prefix: str = "FINNGEN_R11_" finngen_summary_stats_url_prefix: str = ( "gs://finngen-public-data-r11/summary_stats/finngen_R11_" ) finngen_summary_stats_url_suffix: str = ".gz" - efo_curation_mapping_url: str = "https://raw.githubusercontent.com/opentargets/curation/24.09.1/mappings/disease/manual_string.tsv" + efo_curation_mapping_url: str = MISSING # https://www.finngen.fi/en/access_results#:~:text=Total%20sample%20size%3A%C2%A0453%2C733%C2%A0(254%2C618%C2%A0females%20and%C2%A0199%2C115%20males) sample_size: int = 453733 _target_: str = "gentropy.finngen_studies.FinnGenStudiesStep" diff --git a/src/gentropy/dataset/colocalisation.py b/src/gentropy/dataset/colocalisation.py index a6b4b364c..615762137 100644 --- a/src/gentropy/dataset/colocalisation.py +++ b/src/gentropy/dataset/colocalisation.py @@ -8,10 +8,8 @@ import pyspark.sql.functions as f from gentropy.common.schemas import parse_spark_schema -from gentropy.common.spark import get_record_with_maximum_value from gentropy.dataset.dataset import Dataset from gentropy.dataset.study_locus import StudyLocus -from gentropy.datasource.eqtl_catalogue.study_index import EqtlCatalogueStudyIndex if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -59,6 +57,10 @@ def extract_maximum_coloc_probability_per_region_and_gene( ValueError: if filter_by_qtl is not in the list of valid QTL types or is not in the list of valid colocalisation methods """ from gentropy.colocalisation import ColocalisationStep + from gentropy.common.spark import get_record_with_maximum_value + from gentropy.datasource.eqtl_catalogue.study_index import ( + EqtlCatalogueStudyIndex, + ) valid_qtls = list( set(EqtlCatalogueStudyIndex.method_to_qtl_type_mapping.values()) diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index 1b6dc66be..dedd3234b 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -15,6 +15,7 @@ from pyspark.sql.window import Window from gentropy.common.schemas import SchemaValidationError, compare_struct_schemas +from gentropy.common.session import Session if TYPE_CHECKING: from pyspark.sql import Column @@ -163,6 +164,31 @@ def get_QC_mappings(cls: type[Self]) -> dict[str, str]: """ return {} + @classmethod + def read(cls, path: str | list[str], fmt: str = "parquet", **kwargs: Any) -> Self: + """Reads dataset into a Dataset with a given schema. + + All kwargs are passed to the spark.read method, so they can be used to specify format, schema, etc. + + Args: + path (str | list[str]): Path to the dataset + fmt (str): Format of the dataset, default is "parquet" + **kwargs (Any): Additional arguments to pass to spark.read + Returns: + Self: Dataset with the file contents + """ + session = Session.find() + schema = cls.get_schema() + return cls( + _df=session.load_data( + path=path, + fmt=fmt, + schema=schema, + **kwargs, + ), + _schema=schema, + ) + @classmethod def from_parquet( cls: type[Self], @@ -188,7 +214,7 @@ def from_parquet( # Separate class params from spark params class_params, spark_params = cls._process_class_params(kwargs) - df = session.load_data(path, format="parquet", schema=schema, **spark_params) + df = session.load_data(path, fmt="parquet", schema=schema, **spark_params) if df.isEmpty(): raise ValueError(f"Parquet file is empty: {path}") return cls(_df=df, _schema=schema, **class_params) diff --git a/src/gentropy/datasource/eqtl_catalogue/__init__.py b/src/gentropy/datasource/eqtl_catalogue/__init__.py index 9632698b0..20ac4ede9 100644 --- a/src/gentropy/datasource/eqtl_catalogue/__init__.py +++ b/src/gentropy/datasource/eqtl_catalogue/__init__.py @@ -1,3 +1,27 @@ """eQTL Catalogue datasource classes.""" from __future__ import annotations + +from enum import StrEnum + + +class QuantificationMethod(StrEnum): + """QTL quantification methods.""" + + GE = "ge" + EXON = "exon" + TX = "tx" + MICROARRAY = "microarray" + LEAFCUTTER = "leafcutter" + APTAMER = "aptamer" + TXREV = "txrev" + MAJIQ = "majiq" + + +class StudyType(StrEnum): + """QTL study types.""" + + PQTL = "pqtl" + EQTL = "eqtl" + SQTL = "sqtl" + TUQTL = "tuqtl" diff --git a/src/gentropy/datasource/eqtl_catalogue/finemapping.py b/src/gentropy/datasource/eqtl_catalogue/finemapping.py index d2cab4a2b..7b076b562 100644 --- a/src/gentropy/datasource/eqtl_catalogue/finemapping.py +++ b/src/gentropy/datasource/eqtl_catalogue/finemapping.py @@ -15,14 +15,15 @@ StructType, ) -from gentropy.common.session import Session +from gentropy.common.processing import normalize_chromosome +from gentropy.common.session import NativeFileFormat, Session from gentropy.common.spark import clean_strings_from_symbols from gentropy.common.stats import split_pvalue_column from gentropy.dataset.study_locus import FinemappingMethod, StudyLocus from gentropy.datasource.eqtl_catalogue.study_index import EqtlCatalogueStudyIndex if TYPE_CHECKING: - from pyspark.sql import DataFrame + pass @dataclass @@ -124,6 +125,7 @@ def parse_susie_results( credible_sets: DataFrame, lbf: DataFrame, studies_metadata: DataFrame, + ss_ftp_path_template: str = "https://ftp.ebi.ac.uk/pub/databases/spot/eQTL/sumstats", ) -> DataFrame: """Parse the SuSIE results into a DataFrame containing the finemapping statistics and metadata about the studies. @@ -131,11 +133,11 @@ def parse_susie_results( credible_sets (DataFrame): DataFrame containing raw statistics of all variants in the credible sets. lbf (DataFrame): DataFrame containing the raw log Bayes Factors for all variants. studies_metadata (DataFrame): DataFrame containing the study metadata. + ss_ftp_path_template (str, optional): eQTL Catalogue FTP path template for summary statistics. Defaults to "https://ftp.ebi.ac.uk/pub/databases/spot/eQTL/sumstats". Returns: DataFrame: Processed SuSIE results to contain metadata about the studies and the finemapping statistics. """ - ss_ftp_path_template = "https://ftp.ebi.ac.uk/pub/databases/spot/eQTL/sumstats" return ( lbf.join( credible_sets.join(f.broadcast(studies_metadata), on="dataset_id"), @@ -158,7 +160,7 @@ def parse_susie_results( .select( f.regexp_replace(f.col("variant"), r"chr", "").alias("variantId"), f.col("region"), - f.col("chromosome"), + normalize_chromosome(f.col("chromosome")).alias("chromosome"), f.col("position"), f.col("pip").alias("posteriorProbability"), *split_pvalue_column(f.col("pvalue")), @@ -261,24 +263,24 @@ def from_susie_results( @classmethod def read_credible_set_from_source( cls: type[EqtlCatalogueFinemapping], - session: Session, credible_set_path: str | list[str], + session: Session | None = None, ) -> DataFrame: """Load raw credible sets from eQTL Catalogue. Args: - session (Session): Spark session. credible_set_path (str | list[str]): Path to raw table(s) containing finemapping results for any variant belonging to a credible set. + session (Session | None, optional): Session object. If not provided, the method will try to find an active session. Defaults to None. Returns: DataFrame: Credible sets DataFrame. """ + session = session or Session.find() return ( - session.spark.read.csv( + session.load_data( credible_set_path, - sep="\t", - header=True, schema=cls.raw_credible_set_schema, + fmt=NativeFileFormat.TSV.value, ) .withColumns( { @@ -298,24 +300,24 @@ def read_credible_set_from_source( @classmethod def read_lbf_from_source( cls: type[EqtlCatalogueFinemapping], - session: Session, lbf_path: str | list[str], + session: Session | None = None, ) -> DataFrame: """Load raw log Bayes Factors from eQTL Catalogue. Args: - session (Session): Spark session. lbf_path (str | list[str]): Path to raw table(s) containing Log Bayes Factors for each variant. + session (Session | None, optional): Session object. If not provided, the method will try to find an active session. Defaults to None. Returns: DataFrame: Log Bayes Factors DataFrame. """ + session = session or Session.find() return ( - session.spark.read.csv( + session.load_data( lbf_path, - sep="\t", - header=True, schema=cls.raw_lbf_schema, + fmt=NativeFileFormat.TSV.value, ) .withColumn( "dataset_id", diff --git a/src/gentropy/datasource/eqtl_catalogue/study_index.py b/src/gentropy/datasource/eqtl_catalogue/study_index.py index 00940d803..7fa5f58fe 100644 --- a/src/gentropy/datasource/eqtl_catalogue/study_index.py +++ b/src/gentropy/datasource/eqtl_catalogue/study_index.py @@ -5,12 +5,12 @@ from itertools import chain from typing import TYPE_CHECKING -import pandas as pd import pyspark.sql.functions as f from pyspark.sql.types import IntegerType, StringType, StructField, StructType from gentropy.common.session import Session from gentropy.dataset.study_index import StudyIndex +from gentropy.datasource.eqtl_catalogue import QuantificationMethod, StudyType if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -45,15 +45,15 @@ class EqtlCatalogueStudyIndex: StructField("study_type", StringType(), True), ] ) - raw_studies_metadata_path = "https://raw.githubusercontent.com/eQTL-Catalogue/eQTL-Catalogue-resources/fe3c4b4ed911b3a184271a6aadcd8c8769a66aba/data_tables/dataset_metadata.tsv" method_to_qtl_type_mapping = { - "ge": "eqtl", - "exon": "eqtl", - "tx": "eqtl", - "microarray": "eqtl", - "leafcutter": "sqtl", - "aptamer": "pqtl", - "txrev": "tuqtl", + QuantificationMethod.GE.value: StudyType.EQTL.value, + QuantificationMethod.EXON.value: StudyType.EQTL.value, + QuantificationMethod.TX.value: StudyType.EQTL.value, + QuantificationMethod.MICROARRAY.value: StudyType.EQTL.value, + QuantificationMethod.LEAFCUTTER.value: StudyType.SQTL.value, + QuantificationMethod.APTAMER.value: StudyType.PQTL.value, + QuantificationMethod.TXREV.value: StudyType.TUQTL.value, + QuantificationMethod.MAJIQ.value: StudyType.SQTL.value, } @classmethod @@ -131,20 +131,32 @@ def from_susie_results( @classmethod def read_studies_from_source( cls: type[EqtlCatalogueStudyIndex], - session: Session, + metadata_path: str, mqtl_quantification_methods_blacklist: list[str], + session: Session | None = None, ) -> DataFrame: """Read raw studies metadata from eQTL Catalogue. Args: - session (Session): Spark session. + metadata_path (str): Path to the studies metadata file. mqtl_quantification_methods_blacklist (list[str]): Molecular trait quantification methods that we don't want to ingest. Available options in https://github.com/eQTL-Catalogue/eQTL-Catalogue-resources/blob/master/data_tables/dataset_metadata.tsv + session (Session | None, optional): Session object. If not provided, the method will try to find an active session. Defaults to None. Returns: DataFrame: raw studies metadata. + + Raises: + ValueError: If an invalid quantification method is provided in the blacklist. + + Example metadata_path: "https://raw.githubusercontent.com/eQTL-Catalogue/eQTL-Catalogue-resources/fe3c4b4ed911b3a184271a6aadcd8c8769a66aba/data_tables/dataset_metadata.tsv" """ - pd.DataFrame.iteritems = pd.DataFrame.items - return session.spark.createDataFrame( - pd.read_csv(cls.raw_studies_metadata_path, sep="\t"), - schema=cls.raw_studies_metadata_schema, + session = session or Session.find() + for method in mqtl_quantification_methods_blacklist: + if method not in cls.method_to_qtl_type_mapping: + raise ValueError( + f"Quantification method '{method}' is not supported. " + + f"Available options are: {list(cls.method_to_qtl_type_mapping.keys())}" + ) + return session.load_data( + metadata_path, schema=cls.raw_studies_metadata_schema, fmt="tsv" ).filter(~(f.col("quant_method").isin(mqtl_quantification_methods_blacklist))) diff --git a/src/gentropy/datasource/finngen_meta/summary_statistics.py b/src/gentropy/datasource/finngen_meta/summary_statistics.py index 9f4299eda..a1e550452 100644 --- a/src/gentropy/datasource/finngen_meta/summary_statistics.py +++ b/src/gentropy/datasource/finngen_meta/summary_statistics.py @@ -207,7 +207,7 @@ def bgzip_to_parquet( "The use_enhanced_bgzip_codec is set to False. This will lead to inefficient reading of block gzipped files." ) raise KeyError( - "Please set `session.spark.use_enhanced_bgzip_codec` to True in the Session configuration." + "Please set `use_enhanced_bgzip_codec` to True in the Session configuration." ) # Handle n_threads limits and warnings @@ -365,13 +365,13 @@ def from_source( SummaryStatistics: Processed summary statistics dataset. """ if perform_min_allele_count_filter: - assert ( - min_allele_count_threshold > 0 - ), "Allele count threshold should be positive." + assert min_allele_count_threshold > 0, ( + "Allele count threshold should be positive." + ) if perform_min_allele_frequency_filter: - assert ( - 0.0 <= min_allele_frequency_threshold <= 0.5 - ), "MAF needs to be between 0 and 0.5." + assert 0.0 <= min_allele_frequency_threshold <= 0.5, ( + "MAF needs to be between 0 and 0.5." + ) if perform_min_allele_count_filter and perform_min_allele_frequency_filter: # NOTE - MAC filter would be more stringent at low allele frequencies, so no @@ -469,9 +469,9 @@ def from_source( # Filter out variants with low INFO score if perform_imputation_score_filter: - assert ( - imputation_score_threshold >= 0.0 - ), "Imputation score threshold should be positive." + assert imputation_score_threshold >= 0.0, ( + "Imputation score threshold should be positive." + ) sumstats = ( sumstats.withColumn( "hasLowImputationScore", diff --git a/src/gentropy/eqtl_catalogue.py b/src/gentropy/eqtl_catalogue.py index 3ad61ddea..2edf5d3bc 100644 --- a/src/gentropy/eqtl_catalogue.py +++ b/src/gentropy/eqtl_catalogue.py @@ -21,6 +21,7 @@ def __init__( eqtl_catalogue_paths_imported: str, eqtl_catalogue_study_index_out: str, eqtl_catalogue_credible_sets_out: str, + eqtl_catalogue_metadata_path: str = EqtlCatalogueConfig().eqtl_catalogue_metadata_path, eqtl_lead_pvalue_threshold: float = EqtlCatalogueConfig().eqtl_lead_pvalue_threshold, ) -> None: """Run eQTL Catalogue ingestion step. @@ -31,11 +32,14 @@ def __init__( eqtl_catalogue_paths_imported (str): Input eQTL Catalogue fine mapping results path. eqtl_catalogue_study_index_out (str): Output eQTL Catalogue study index path. eqtl_catalogue_credible_sets_out (str): Output eQTL Catalogue credible sets path. + eqtl_catalogue_metadata_path (str): Path to the data_table hosted on the eQTL Catalogue github. Defaults to EqtlCatalogueConfig().eqtl_catalogue_metadata_path eqtl_lead_pvalue_threshold (float, optional): Lead p-value threshold. Defaults to EqtlCatalogueConfig().eqtl_lead_pvalue_threshold. """ # Extract studies_metadata = EqtlCatalogueStudyIndex.read_studies_from_source( - session, list(mqtl_quantification_methods_blacklist) + eqtl_catalogue_metadata_path, + list(mqtl_quantification_methods_blacklist), + session=session, ) # Load raw data only for the studies we are interested in ingestion. This makes the proces much lighter. @@ -43,18 +47,18 @@ def __init__( studies_metadata ) credible_sets_df = EqtlCatalogueFinemapping.read_credible_set_from_source( - session, credible_set_path=[ f"{eqtl_catalogue_paths_imported}/{qtd_id}.credible_sets.tsv" for qtd_id in studies_to_ingest ], + session=session, ) lbf_df = EqtlCatalogueFinemapping.read_lbf_from_source( - session, lbf_path=[ f"{eqtl_catalogue_paths_imported}/{qtd_id}.lbf_variable.txt" for qtd_id in studies_to_ingest ], + session=session, ) # Transform @@ -65,6 +69,7 @@ def __init__( ( EqtlCatalogueStudyIndex.from_susie_results(processed_susie_df) # Writing the output: + .coalesce(1) .df.write.mode(session.write_mode) .parquet(eqtl_catalogue_study_index_out) ) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 46872a490..3ba7e559d 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -211,7 +211,7 @@ def __init__( session, credible_set_path, recursiveFileLookup=True ) self.feature_matrix = L2GFeatureMatrix( - _df=session.load_data(feature_matrix_path), + _df=session.load_data(feature_matrix_path, "parquet"), ) if run_mode == "predict": diff --git a/src/utils/spark.py b/src/utils/spark.py index b5099bca7..78c673920 100644 --- a/src/utils/spark.py +++ b/src/utils/spark.py @@ -1,4 +1,5 @@ """Spark utilities.""" + from __future__ import annotations from pathlib import Path diff --git a/tests/gentropy/common/test_session.py b/tests/gentropy/common/test_session.py index 6afc640ab..50123dd56 100644 --- a/tests/gentropy/common/test_session.py +++ b/tests/gentropy/common/test_session.py @@ -4,35 +4,124 @@ from typing import TYPE_CHECKING +import pandas as pd +import pytest +from pyspark.errors import PySparkException +from pyspark.sql import SparkSession + from gentropy.common.session import Log4j, Session if TYPE_CHECKING: - from pyspark.sql import SparkSession + from collections.abc import Generator + from pathlib import Path + from typing import Any -def test_session_creation() -> None: - """Test sessio creation with mock data.""" - assert isinstance(Session(spark_uri="local[1]"), Session) +def test_log4j_creation(spark: SparkSession) -> None: + """Test session log4j.""" + assert isinstance(Log4j(spark=spark), Log4j) -def test_hail_configuration(hail_home: str) -> None: - """Assert that Hail configuration is set when start_hail is True.""" - session = Session(spark_uri="local[1]", hail_home=hail_home, start_hail=True) +@pytest.fixture(scope="function") +def mock_data_files(tmp_path: Path) -> Generator[Path, None, None]: + """Create mock data files for testing.""" + data = pd.DataFrame( + { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + ) + tmp_files = [ + tmp_path / "test_path.parquet", + tmp_path / "test_path.csv", + tmp_path / "test_path.tsv", + tmp_path / "test_path.json", + ] + tmp_dir = tmp_path / "test_path" + tmp_dir.mkdir(exist_ok=True) - expected_hail_conf = { - "spark.jars": f"{hail_home}/backend/hail-all-spark.jar", - "spark.driver.extraClassPath": f"{hail_home}/backend/hail-all-spark.jar", - "spark.executor.extraClassPath": "./hail-all-spark.jar", - "spark.serializer": "org.apache.spark.serializer.KryoSerializer", - "spark.kryo.registrator": "is.hail.kryo.HailKryoRegistrator", - } + data.to_parquet(tmp_files[0]) + data.to_csv(tmp_files[1], index=False) + data.to_csv(tmp_files[2], index=False, sep="\t") + data.to_json(tmp_files[3], orient="records", lines=True) - observed_conf = dict(session.spark.sparkContext.getConf().getAll()) - # sourcery skip: no-loop-in-tests - for key, value in expected_hail_conf.items(): - assert observed_conf.get(key) == value, f"Expected {key} to be set to {value}" + tmp_files_nested = [tmp_dir / "part.1.parquet", tmp_dir / "part.2.parquet"] + for d in tmp_files_nested: + d.parent.mkdir(exist_ok=True) + data.to_parquet(d) + yield tmp_path -def test_log4j_creation(spark: SparkSession) -> None: - """Test session log4j.""" - assert isinstance(Log4j(spark=spark), Log4j) + for f in tmp_files + tmp_files_nested: + f.unlink(missing_ok=True) + tmp_dir.rmdir() + + +@pytest.mark.parametrize( + ["path", "fmt", "kwargs"], + [ + pytest.param("test_path.parquet", "parquet", {}, id="parquet"), + pytest.param("test_path.csv", "csv", {}, id="csv"), + pytest.param("test_path.tsv", "tsv", {}, id="tsv"), + pytest.param("test_path.json", "json", {}, id="json"), + pytest.param("test_path/", "parquet", {}, id="dataset fallback to parquet"), + ], +) +def test_load_data( + session: Session, + path: str, + fmt: str, + kwargs: dict[str, str], + mock_data_files: Path, +) -> None: + """Test Session.load_data method.""" + full_path = mock_data_files / path + try: + session.load_data(full_path.as_posix(), fmt=fmt, **kwargs) + except PySparkException as e: + pytest.fail(f"Session.load_data raised an exception: {e}") + + +@pytest.mark.parametrize( + ["url", "fmt", "error"], + [ + pytest.param( + "https://some_example.com/data.parquet", + "parquet", + "Only csv, tsv and json are URL supported formats", + id="unsupported format parquet", + ), + pytest.param( + "https://some_example.com/data.json", + "json", + None, + id="supported format json", + ), + pytest.param( + "http://some_example.com/data.csv", + "csv", + None, + id="supported format csv", + ), + ], +) +def test_load_from_url(url: str, fmt: str, error: str, session: Session) -> None: + """Test Session.load_data method with URL input.""" + + def mock_read(*args: Any, **kwargs: Any) -> pd.DataFrame: + return pd.DataFrame( + [ + ("val1", "val2"), + ], + columns=["col1", "col2"], + ) + + with pytest.MonkeyPatch.context() as m: + m.setattr("gentropy.common.session.pd.read_csv", mock_read) + m.setattr("gentropy.common.session.pd.read_json", mock_read) + if error: + with pytest.raises(ValueError, match=error): + session.load_data(url, fmt=fmt) + else: + df = session.load_data(url, fmt=fmt) + assert df.count() == 1 diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 0919059b6..85fcb00a3 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Generator from pathlib import Path import dbldatagen as dg @@ -33,26 +34,23 @@ from utils.spark import get_spark_testing_conf -@pytest.fixture(scope="session", autouse=True) -def spark(tmp_path_factory: pytest.TempPathFactory) -> SparkSession: - """Local spark session for testing purposes. - - Args: - tmp_path_factory (pytest.TempPathFactory): pytest fixture - - Returns: - SparkSession: local spark session - """ - return ( - SparkSession.builder.config(conf=get_spark_testing_conf()) +@pytest.fixture(scope="session") +def spark() -> Generator[SparkSession, None, None]: + """Local spark session for testing purposes.""" + spark = ( + SparkSession.Builder() + .config(conf=get_spark_testing_conf()) .master("local[1]") .appName("test") .getOrCreate() ) + yield spark + + spark.stop() @pytest.fixture() -def session() -> Session: +def session(spark: SparkSession) -> Session: """Return gentropy Session object.""" return Session() @@ -584,8 +582,8 @@ def sample_finngen_studies(spark: SparkSession) -> DataFrame: def sample_eqtl_catalogue_finemapping_credible_sets(session: Session) -> DataFrame: """Sample raw eQTL Catalogue credible sets outputted by SuSIE.""" return EqtlCatalogueFinemapping.read_credible_set_from_source( - session, credible_set_path=["tests/gentropy/data_samples/QTD000584.credible_sets.tsv"], + session=session, ) @@ -593,8 +591,8 @@ def sample_eqtl_catalogue_finemapping_credible_sets(session: Session) -> DataFra def sample_eqtl_catalogue_finemapping_lbf(session: Session) -> DataFrame: """Sample raw eQTL Catalogue table with logBayesFactors outputted by SuSIE.""" return EqtlCatalogueFinemapping.read_lbf_from_source( - session, lbf_path=["tests/gentropy/data_samples/QTD000584.lbf_variable.txt"], + session=session, ) @@ -621,7 +619,7 @@ def sample_ukbiobank_studies(spark: SparkSession) -> DataFrame: @pytest.fixture() -def study_locus_sample_for_colocalisation(spark: SparkSession) -> DataFrame: +def study_locus_sample_for_colocalisation(spark: SparkSession) -> StudyLocus: """Sample study locus data for colocalisation.""" return StudyLocus( _df=spark.read.parquet("tests/gentropy/data_samples/coloc_test.parquet"), diff --git a/tests/gentropy/datasource/eqtl_catalogue/test_eqtl_catalogue.py b/tests/gentropy/datasource/eqtl_catalogue/test_eqtl_catalogue.py index 6aa24662f..970730976 100644 --- a/tests/gentropy/datasource/eqtl_catalogue/test_eqtl_catalogue.py +++ b/tests/gentropy/datasource/eqtl_catalogue/test_eqtl_catalogue.py @@ -46,7 +46,7 @@ class TestEqtlCatalogueStudyLocus: """Test the correctness of the study locus dataset from eQTL Catalogue.""" @pytest.fixture(autouse=True) - def _setup(self, processed_finemapping_df: DataFrame) -> DataFrame: + def _setup(self, processed_finemapping_df: DataFrame) -> None: """Set up the test.""" self.study_locus = EqtlCatalogueFinemapping.from_susie_results( processed_finemapping_df diff --git a/tests/gentropy/datasource/finngen/test_finngen_finemapping.py b/tests/gentropy/datasource/finngen/test_finngen_finemapping.py index 4c7e12bf5..4d6deaa22 100644 --- a/tests/gentropy/datasource/finngen/test_finngen_finemapping.py +++ b/tests/gentropy/datasource/finngen/test_finngen_finemapping.py @@ -6,7 +6,6 @@ import hail as hl import pytest -from pyspark.sql import SparkSession from gentropy.common.session import Session from gentropy.dataset.study_locus import StudyLocus @@ -33,15 +32,15 @@ ], ) def test_finngen_finemapping_from_finngen_susie_finemapping( - spark: SparkSession, + session: Session, finngen_susie_finemapping_snp_files: str, finngen_susie_finemapping_cs_summary_files: str, ) -> None: """Test finemapping results (SuSie) from source.""" - hl.init(sc=spark.sparkContext, log="/dev/null", idempotent=True) + hl.init(sc=session.spark.sparkContext, log="/dev/null", idempotent=True) assert isinstance( FinnGenFinemapping.from_finngen_susie_finemapping( - spark=spark, + spark=session.spark, finngen_susie_finemapping_snp_files=finngen_susie_finemapping_snp_files, finngen_susie_finemapping_cs_summary_files=finngen_susie_finemapping_cs_summary_files, finngen_release_prefix="FINNGEN_R11", diff --git a/tests/gentropy/datasource/finngen_meta/test_finngen_meta_summary_statistics.py b/tests/gentropy/datasource/finngen_meta/test_finngen_meta_summary_statistics.py index 8c0e32e0b..a54317800 100644 --- a/tests/gentropy/datasource/finngen_meta/test_finngen_meta_summary_statistics.py +++ b/tests/gentropy/datasource/finngen_meta/test_finngen_meta_summary_statistics.py @@ -292,45 +292,7 @@ def test_bgzip_from_parquet(self, tmp_path: Path, session: Session) -> None: raw_summary_statistics_output_path=output_path.as_posix(), ) - assert "session.spark.use_enhanced_bgzip_codec" in str(e.value) - - @pytest.mark.long_test() - def test_bgzip_from_parquet_with_codec(self, tmp_path: Path) -> None: - """Test bgzip codec usage on multiple tsv.gz files with different schemas. - - Note: - This test requires access to the internet to download the necessary jar files - for the enhanced bgzip codec. It may take longer to run due to this setup. - Because of this, the test is marked as `long_test` and does not run by default. - """ - # Path to store the jar dependencies for spark enhanced bgzip codec - ivy_cache_path = tmp_path / "ivy_cache" - session = Session( - extended_spark_conf={"spark.jars.ivy": ivy_cache_path.as_posix()}, - use_enhanced_bgzip_codec=True, - ) - # Create inputs with different schemas - input_path_1 = "tests/gentropy/data_samples/bgzip_tests/A.tsv.gz" # contains only chr, pos, ref, alt, snp - input_path_2 = "tests/gentropy/data_samples/bgzip_tests/B.tsv.gz" # contains only chr, pos, ref, alt, fg_beta, - output_path = (tmp_path / "output").as_posix() - - # Assert that test files & tbi indices exist - for p in [input_path_1, input_path_2]: - assert Path(p).exists(), f"Test file {p} does not exist." - assert Path(p + ".tbi").exists(), f"Index file {p}.tbi does not exist." - FinnGenUkbMvpMetaSummaryStatistics.bgzip_to_parquet( - session, - summary_statistics_list=[input_path_1, input_path_2], - datasource=MetaAnalysisDataSource.FINNGEN_UKBB, - raw_summary_statistics_output_path=output_path, - ) - # Now read back the parquet files and check if schema is equal to raw schema - df = session.spark.read.parquet(output_path) - expected_schema = FinnGenUkbMvpMetaSummaryStatistics.raw_schema - expected_schema = expected_schema.add( - "studyId", t.StringType(), nullable=True - ) # studyId is added during bgzip_to_parquet - assert df.schema == expected_schema, "Schemas do not match after conversion." + assert "use_enhanced_bgzip_codec" in str(e.value) @pytest.mark.parametrize( ["params"], diff --git a/tests/gentropy/datasource/gnomad/test_gnomad_ld.py b/tests/gentropy/datasource/gnomad/test_gnomad_ld.py index 4784d747d..e28d09f4d 100644 --- a/tests/gentropy/datasource/gnomad/test_gnomad_ld.py +++ b/tests/gentropy/datasource/gnomad/test_gnomad_ld.py @@ -12,6 +12,7 @@ from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql import functions as f +from gentropy import Session from gentropy.datasource.gnomad.ld import GnomADLDMatrix @@ -94,10 +95,10 @@ def test_get_ld_variants__square( @pytest.fixture(autouse=True) def _setup( - self: TestGnomADLDMatrixVariants, spark: SparkSession, tmp_path: Path + self: TestGnomADLDMatrixVariants, session: Session, tmp_path: Path ) -> None: """Prepares fixtures for the test.""" - hl.init(sc=spark.sparkContext, log="/dev/null", idempotent=True) + hl.init(sc=session.spark.sparkContext, log="/dev/null", idempotent=True) ld_test_population = "test-pop" liftover_path = str(tmp_path / "mock_liftover.ht") @@ -189,9 +190,9 @@ def test_get_ld_matrix_slice__symmetry( ), "The matrix is not symmetric." @pytest.fixture(autouse=True) - def _setup(self: TestGnomADLDMatrixSlice, spark: SparkSession) -> None: + def _setup(self: TestGnomADLDMatrixSlice, session: Session) -> None: """Prepares fixtures for the test.""" - hl.init(sc=spark.sparkContext, log="/dev/null", idempotent=True) + hl.init(sc=session.spark.sparkContext, log="/dev/null", idempotent=True) gnomad_ld_matrix = GnomADLDMatrix( ld_matrix_template="tests/gentropy/data_samples/example_{POP}.bm" ) diff --git a/tests/gentropy/no_spark/test_no_spark.py b/tests/gentropy/no_spark/test_no_spark.py new file mode 100644 index 000000000..c6fcf5acb --- /dev/null +++ b/tests/gentropy/no_spark/test_no_spark.py @@ -0,0 +1,145 @@ +"""Tests that need to create their own Spark session.""" + +from collections.abc import Generator +from pathlib import Path + +import pytest +from pyspark.sql import SparkSession +from pyspark.sql import types as t + +from gentropy import Session +from gentropy.common.session import SparkWriteMode +from gentropy.datasource.finngen_meta import MetaAnalysisDataSource +from gentropy.datasource.finngen_meta.summary_statistics import ( + FinnGenUkbMvpMetaSummaryStatistics, +) +from utils.spark import get_spark_testing_conf + + +def _stop_active_spark() -> None: + """Stop any active Spark session and clear cached references.""" + spark = SparkSession.getActiveSession() + if spark is not None: + spark.stop() + + +@pytest.fixture(scope="function") +def _no_spark_session() -> Generator[None, None, None]: + """Clean up any active spark session.""" + _stop_active_spark() + yield + _stop_active_spark() + + +@pytest.mark.no_shared_spark +class TestNoSpark: + """Test functionalities that require the spark session stopped.""" + + ex_conf = dict(get_spark_testing_conf().getAll()) + + @pytest.mark.usefixtures("_no_spark_session") + def test_session_creation(self) -> None: + """Test session creation with mock data.""" + session = Session(spark_uri="local[1]", extended_spark_conf=self.ex_conf) + assert isinstance(session, Session) + + @pytest.mark.usefixtures("_no_spark_session") + def test_output_partition(self) -> None: + """Test output partition setting.""" + session = Session( + spark_uri="local[1]", + output_partitions=5, + write_mode=SparkWriteMode.OVERWRITE, + extended_spark_conf=self.ex_conf, + ) + assert session.output_partitions == 5 + assert session.write_mode == SparkWriteMode.OVERWRITE + + @pytest.mark.usefixtures("_no_spark_session") + def test_bgzip_configuration(self) -> None: + """Assert that Hail configuration is set when use_bgzip is True.""" + session = Session( + spark_uri="local[1]", + use_enhanced_bgzip_codec=True, + extended_spark_conf=self.ex_conf, + ) + + expected_bgzip_conf = { + "spark.jars.packages": "org.seqdoop:hadoop-bam:7.10.0", + "spark.hadoop.io.compression.codecs": "org.seqdoop.hadoop_bam.util.BGZFEnhancedGzipCodec", + "spark.gentropy.useEnhancedBgzipCodec": "true", + } + + observed_conf = dict(session.spark.sparkContext.getConf().getAll()) + for key, value in expected_bgzip_conf.items(): + assert observed_conf.get(key) == value, ( + f"Expected {key} to be set to {value}" + ) + + @pytest.mark.usefixtures("_no_spark_session") + def test_hail_configuration(self, hail_home: str) -> None: + """Assert that Hail configuration is set when start_hail is True.""" + session = Session( + spark_uri="local[1]", + hail_home=hail_home, + start_hail=True, + ) + + expected_hail_conf = { + "spark.jars": f"{hail_home}/backend/hail-all-spark.jar", + "spark.driver.extraClassPath": f"{hail_home}/backend/hail-all-spark.jar", + "spark.executor.extraClassPath": "./hail-all-spark.jar", + "spark.serializer": "org.apache.spark.serializer.KryoSerializer", + "spark.kryo.registrator": "is.hail.kryo.HailKryoRegistrator", + } + + observed_conf = dict(session.spark.sparkContext.getConf().getAll()) + for key, value in expected_hail_conf.items(): + assert observed_conf.get(key) == value, ( + f"Expected {key} to be set to {value}" + ) + + @pytest.mark.download_jars_from_web + @pytest.mark.usefixtures("_no_spark_session") + def test_bgzip_from_parquet_with_codec(self, tmp_path: Path) -> None: + """Test bgzip codec usage on multiple tsv.gz files with different schemas. + + Note: + This test downloads hadoop-bam and its dependencies from Maven Central + to a temporary ivy cache on first run. This test needs to be run in complete isolation + with access to the internet to download the dependencies. Hence this test is marked with 'webtest' and should be run separately from other tests. + """ + ivy_cache_dir = tmp_path / "ivy_cache" + ivy_cache_dir.mkdir(parents=True, exist_ok=True) + conf = self.ex_conf.copy() + conf["spark.jars.ivy"] = ivy_cache_dir.as_posix() + session = Session( + spark_uri="local[1]", + extended_spark_conf=conf, + use_enhanced_bgzip_codec=True, + dynamic_allocation=False, + ) + + # Create inputs with different schemas + input_path_1 = "tests/gentropy/data_samples/bgzip_tests/A.tsv.gz" # contains only chr, pos, ref, alt, snp + input_path_2 = "tests/gentropy/data_samples/bgzip_tests/B.tsv.gz" # contains only chr, pos, ref, alt, fg_beta, + output_path = (tmp_path / "output").as_posix() + + # Assert that test files & tbi indices exist + for p in [input_path_1, input_path_2]: + assert Path(p).exists(), f"Test file {p} does not exist." + assert Path(p + ".tbi").exists(), f"Index file {p}.tbi does not exist." + FinnGenUkbMvpMetaSummaryStatistics.bgzip_to_parquet( + session, + summary_statistics_list=[input_path_1, input_path_2], + datasource=MetaAnalysisDataSource.FINNGEN_UKBB, + raw_summary_statistics_output_path=output_path, + ) + # Now read back the parquet files and check if schema is equal to raw schema + df = session.spark.read.parquet(output_path) + expected_schema = FinnGenUkbMvpMetaSummaryStatistics.raw_schema + expected_schema = expected_schema.add( + "studyId", t.StringType(), nullable=True + ) # studyId is added during bgzip_to_parquet + assert df.schema == expected_schema, "Schemas do not match after conversion." + session.spark.stop() From c031373c9e6fcdf4057ba10f494af903a9f60a32 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 13 Feb 2026 15:12:16 +0000 Subject: [PATCH 05/16] build(deps-dev): update pre-commit requirement (#1171) Updates the requirements on [pre-commit](https://github.com/pre-commit/pre-commit) to permit the latest version. - [Release notes](https://github.com/pre-commit/pre-commit/releases) - [Changelog](https://github.com/pre-commit/pre-commit/blob/main/CHANGELOG.md) - [Commits](https://github.com/pre-commit/pre-commit/compare/v4.0.0...v4.5.0) --- updated-dependencies: - dependency-name: pre-commit dependency-version: 4.5.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 26b6b4e0f..5bfd04ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ dev = [ "prettier >=0.0.7, <0.1.0", "deptry >=0.22.0, <0.25.0", "yamllint >=1.33.0, <1.38.0", - "pre-commit >=4.0.0, <4.4.0", + "pre-commit >=4.0.0, <4.6.0", "mypy >=1.13, <1.19", "pep8-naming >=0.14.1, <0.16.0", "interrogate >=1.7.0, <1.8.0", From 22341b07fc7c3741829f12fab0d18e705bf5e8e4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 13 Feb 2026 15:27:41 +0000 Subject: [PATCH 06/16] build(deps-dev): update ipython requirement (#1169) Updates the requirements on [ipython](https://github.com/ipython/ipython) to permit the latest version. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/8.19.0...9.7.0) --- updated-dependencies: - dependency-name: ipython dependency-version: 9.7.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5bfd04ab5..263a07a82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ test = [ ] dev = [ - "ipython >=8.19.0, <8.38.0", + "ipython >=8.19.0, <9.9.0", "pydoclint >=0.3.8,<0.9.0", "ipykernel >=6.28.0, <6.31.0", "prettier >=0.0.7, <0.1.0", From a4935d7d95befa511879d3c29f2e2fc5da58665a Mon Sep 17 00:00:00 2001 From: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> Date: Thu, 26 Feb 2026 13:47:36 +0000 Subject: [PATCH 07/16] refactor: remove biofeature from interval schema (#1193) --- src/gentropy/assets/schemas/intervals.json | 6 ------ src/gentropy/datasource/intervals/e2g.py | 1 - src/gentropy/datasource/intervals/epiraction.py | 1 - tests/gentropy/conftest.py | 1 - tests/gentropy/dataset/test_l2g_feature.py | 4 ---- uv.lock | 8 +++----- 6 files changed, 3 insertions(+), 18 deletions(-) diff --git a/src/gentropy/assets/schemas/intervals.json b/src/gentropy/assets/schemas/intervals.json index 209171367..7959bd253 100644 --- a/src/gentropy/assets/schemas/intervals.json +++ b/src/gentropy/assets/schemas/intervals.json @@ -80,12 +80,6 @@ "nullable": true, "type": "string" }, - { - "metadata": {}, - "name": "biofeature", - "nullable": true, - "type": "string" - }, { "metadata": {}, "name": "biosampleName", diff --git a/src/gentropy/datasource/intervals/e2g.py b/src/gentropy/datasource/intervals/e2g.py index b2e0cc19d..ba46603b9 100644 --- a/src/gentropy/datasource/intervals/e2g.py +++ b/src/gentropy/datasource/intervals/e2g.py @@ -127,7 +127,6 @@ def parse( f.lit(IntervalDataSource.E2G.value).alias("datasourceId"), f.col("intervalType"), f.lit(cls.PMID).alias("pmid"), - f.lit(None).cast(t.StringType()).alias("biofeature"), f.col("biosampleName"), f.lit(None).cast(t.StringType()).alias("biosampleFromSourceId"), f.col("biosampleId"), diff --git a/src/gentropy/datasource/intervals/epiraction.py b/src/gentropy/datasource/intervals/epiraction.py index 920fd86f0..5e17fd583 100644 --- a/src/gentropy/datasource/intervals/epiraction.py +++ b/src/gentropy/datasource/intervals/epiraction.py @@ -134,7 +134,6 @@ def parse( f.lit(IntervalDataSource.EPIRACTION.value).alias("datasourceId"), f.col("intervalType"), f.lit(cls.PMID).alias("pmid"), - f.lit(None).cast(t.StringType()).alias("biofeature"), f.col("biosampleName"), f.lit(None).cast(t.StringType()).alias("biosampleFromSourceId"), f.lit(None).cast(t.StringType()).alias("biosampleId"), diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 85fcb00a3..3a53a6213 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -345,7 +345,6 @@ def mock_intervals(spark: SparkSession) -> Intervals: .withColumnSpec("pmid", percentNulls=0.1) .withColumnSpec("resourceScore", percentNulls=0.1) .withColumnSpec("score", percentNulls=0.1) - .withColumnSpec("biofeature", percentNulls=0.1) ) return Intervals(_df=data_spec.build(), _schema=interval_schema) diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 9f5247af3..c1c09ca53 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -753,7 +753,6 @@ def _setup(self, spark: SparkSession) -> None: "datasourceId": "dummy", "intervalType": "pchic", "pmid": None, - "biofeature": None, "biosampleName": None, "biosampleId": None, "studyId": None, @@ -770,7 +769,6 @@ def _setup(self, spark: SparkSession) -> None: "datasourceId": "dummy", "intervalType": "pchic", "pmid": None, - "biofeature": None, "biosampleName": None, "biosampleId": None, "studyId": None, @@ -788,7 +786,6 @@ def _setup(self, spark: SparkSession) -> None: "datasourceId": "dummy", "intervalType": "pchic", "pmid": None, - "biofeature": None, "biosampleName": None, "biosampleId": None, "studyId": None, @@ -806,7 +803,6 @@ def _setup(self, spark: SparkSession) -> None: "datasourceId": "dummy", "intervalType": "pchic", "pmid": None, - "biofeature": None, "biosampleName": None, "biosampleId": None, "studyId": None, diff --git a/uv.lock b/uv.lock index 5b824ff61..6787dfa53 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.14" resolution-markers = [ "(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and platform_machine == 'arm64' and sys_platform == 'linux')", @@ -1091,11 +1091,11 @@ dev = [ { name = "deptry", specifier = ">=0.22.0,<0.25.0" }, { name = "interrogate", specifier = ">=1.7.0,<1.8.0" }, { name = "ipykernel", specifier = ">=6.28.0,<6.31.0" }, - { name = "ipython", specifier = ">=8.19.0,<8.38.0" }, + { name = "ipython", specifier = ">=8.19.0,<9.9.0" }, { name = "isort", specifier = ">=5.13.2,<6.1.0" }, { name = "mypy", specifier = ">=1.13,<1.19" }, { name = "pep8-naming", specifier = ">=0.14.1,<0.16.0" }, - { name = "pre-commit", specifier = ">=4.0.0,<4.4.0" }, + { name = "pre-commit", specifier = ">=4.0.0,<4.6.0" }, { name = "prettier", specifier = ">=0.0.7,<0.1.0" }, { name = "pydoclint", specifier = ">=0.3.8,<0.9.0" }, { name = "ruff", specifier = ">=0.8.1,<0.15.0" }, @@ -4294,8 +4294,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/70/57ac2bc1078b497c562527dc89709db4f6d489676927c2331d9b47ab6bea/xgboost-3.1.1-py3-none-macosx_10_15_x86_64.whl", hash = "sha256:a51a2e488102a007b8c222d58bf855415002e8cdf06d104eea24b08dbf4eec4f", size = 2377615, upload-time = "2025-10-21T23:08:33.851Z" }, { url = "https://files.pythonhosted.org/packages/af/7f/c8bde020171c900fcc808a16fe643c16d5ef96fd1516a24478bc54a428b0/xgboost-3.1.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:fac06c989f2cf11af7aa546b3bb78e7fa87595891e5dfde28edf3e7492e5440a", size = 2210835, upload-time = "2025-10-21T23:08:54.084Z" }, { url = "https://files.pythonhosted.org/packages/22/87/731d92a92aa848d0bf47ac55925d93f2c34fc3843e80f2e952b34c209888/xgboost-3.1.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:4347671aa8a495595f17135171aeae5f6d9ab4b4e7b02f191864cf2202e3c902", size = 4952406, upload-time = "2025-10-21T23:10:39.009Z" }, - { url = "https://files.pythonhosted.org/packages/56/b0/e3efafd9c97ed931f6453bd71aa8feaffc9217e6121af65fda06cf32f608/xgboost-3.1.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:405e48a201495fe9474f7aa27419f937794726a1bc7d2c2f3208b351c816580a", size = 115884000, upload-time = "2025-10-21T23:11:59.974Z" }, - { url = "https://files.pythonhosted.org/packages/b8/90/f082b89dd74da8ca27f8a3c7b3e38fc8529a4a14eb2c5b0937c7d66aa922/xgboost-3.1.1-py3-none-win_amd64.whl", hash = "sha256:2e1067489688ad99a410e8f2acdfe9d21a299c2f3b4b25dc8f094eae709c7447", size = 71978587, upload-time = "2025-10-21T23:09:50.488Z" }, ] [[package]] From bfed77a93e19c4d1d7158ff8359be5b7646d634d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 13:59:42 +0000 Subject: [PATCH 08/16] build(deps-dev): update ipykernel requirement (#1192) Updates the requirements on [ipykernel](https://github.com/ipython/ipykernel) to permit the latest version. - [Release notes](https://github.com/ipython/ipykernel/releases) - [Changelog](https://github.com/ipython/ipykernel/blob/main/CHANGELOG.md) - [Commits](https://github.com/ipython/ipykernel/compare/v6.28.0...v7.2.0) --- updated-dependencies: - dependency-name: ipykernel dependency-version: 7.2.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 263a07a82..16af9eb1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ test = [ dev = [ "ipython >=8.19.0, <9.9.0", "pydoclint >=0.3.8,<0.9.0", - "ipykernel >=6.28.0, <6.31.0", + "ipykernel >=6.28.0, <7.3.0", "prettier >=0.0.7, <0.1.0", "deptry >=0.22.0, <0.25.0", "yamllint >=1.33.0, <1.38.0", From 209cc0c7b8e0c2544c11022858b9b6073df36715 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 14:47:27 +0000 Subject: [PATCH 09/16] build(deps-dev): update ruff requirement (#1190) Updates the requirements on [ruff](https://github.com/astral-sh/ruff) to permit the latest version. - [Release notes](https://github.com/astral-sh/ruff/releases) - [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md) - [Commits](https://github.com/astral-sh/ruff/compare/0.8.1...0.15.0) --- updated-dependencies: - dependency-name: ruff dependency-version: 0.15.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 16af9eb1f..56fba771c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ dev = [ "interrogate >=1.7.0, <1.8.0", "isort >=5.13.2, <6.1.0", "darglint >=1.8.1, <1.9.0", - "ruff >=0.8.1, <0.15.0", + "ruff >=0.8.1, <0.16.0", ] [tool.semantic_release] logging_use_named_masks = true From 43441081a2b1fe63eeb76d792a0e8176f2090309 Mon Sep 17 00:00:00 2001 From: Yakov Date: Thu, 26 Feb 2026 18:10:59 +0000 Subject: [PATCH 10/16] fix: add more tests using AI agent (#1184) * fix: add more tests using AI agent * fix: add more tests * fix: removing duplicates * fix: update AI-generated tests to reflect latest API changes and remove no-op tests (#1194) * Initial plan * fix: update tests to reflect latest changes and remove meaningless tests Co-authored-by: project-defiant <69353402+project-defiant@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: project-defiant <69353402+project-defiant@users.noreply.github.com> * chore: update tests --------- Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> --- tests/gentropy/common/test_genomic_region.py | 143 +++++++++ tests/gentropy/common/test_processing.py | 95 ++++++ tests/gentropy/common/test_stats.py | 81 +++++ tests/gentropy/common/test_types.py | 73 +++++ .../step/test_biosample_index_step.py | 58 ++++ tests/gentropy/step/test_interval_e2g_step.py | 66 ++++ tests/gentropy/test_config.py | 152 +++++++++ tests/gentropy/test_ld_based_clumping.py | 115 +++++++ tests/gentropy/test_susie_finemapper.py | 290 ++++++++++++++++++ tests/gentropy/test_variant_index.py | 148 +++++++++ uv.lock | 4 +- 11 files changed, 1223 insertions(+), 2 deletions(-) create mode 100644 tests/gentropy/common/test_genomic_region.py create mode 100644 tests/gentropy/common/test_processing.py create mode 100644 tests/gentropy/common/test_stats.py create mode 100644 tests/gentropy/common/test_types.py create mode 100644 tests/gentropy/step/test_biosample_index_step.py create mode 100644 tests/gentropy/step/test_interval_e2g_step.py create mode 100644 tests/gentropy/test_config.py create mode 100644 tests/gentropy/test_ld_based_clumping.py create mode 100644 tests/gentropy/test_susie_finemapper.py create mode 100644 tests/gentropy/test_variant_index.py diff --git a/tests/gentropy/common/test_genomic_region.py b/tests/gentropy/common/test_genomic_region.py new file mode 100644 index 000000000..dd265dc33 --- /dev/null +++ b/tests/gentropy/common/test_genomic_region.py @@ -0,0 +1,143 @@ +"""Tests for genomic region utilities.""" + +from __future__ import annotations + +import pytest + +from gentropy.common.genomic_region import GenomicRegion, KnownGenomicRegions + + +class TestGenomicRegion: + """Test GenomicRegion class.""" + + def test_genomic_region_initialization(self) -> None: + """Test GenomicRegion object creation.""" + region = GenomicRegion(chromosome="6", start=28510120, end=33480577) + assert region.chromosome == "6" + assert region.start == 28510120 + assert region.end == 33480577 + + def test_genomic_region_str_representation(self) -> None: + """Test string representation of GenomicRegion.""" + region = GenomicRegion(chromosome="6", start=28510120, end=33480577) + assert str(region) == "6:28510120-33480577" + + def test_genomic_region_from_string_simple(self) -> None: + """Test parsing simple genomic region string.""" + region = GenomicRegion.from_string("6:28510120-33480577") + assert region.chromosome == "6" + assert region.start == 28510120 + assert region.end == 33480577 + + def test_genomic_region_from_string_with_chr_prefix(self) -> None: + """Test parsing genomic region string with chr prefix.""" + region = GenomicRegion.from_string("chr6:28510120-33480577") + assert region.chromosome == "6" + assert region.start == 28510120 + assert region.end == 33480577 + + def test_genomic_region_from_string_with_commas(self) -> None: + """Test parsing genomic region string with commas in positions.""" + region = GenomicRegion.from_string("chr6:28,510,120-33,480,577") + assert region.chromosome == "6" + assert region.start == 28510120 + assert region.end == 33480577 + + def test_genomic_region_from_string_invalid_format(self) -> None: + """Test parsing invalid genomic region string format.""" + with pytest.raises( + ValueError, match="Genomic region should follow a ##:####-#### format" + ): + GenomicRegion.from_string("6:28510120") + + def test_genomic_region_from_string_invalid_positions(self) -> None: + """Test parsing genomic region string with non-integer positions.""" + with pytest.raises( + ValueError, + match="Start and the end position of the region has to be integer", + ): + GenomicRegion.from_string("6:28510120-abc") + + def test_genomic_region_from_string_malformed(self) -> None: + """Test parsing malformed genomic region string.""" + # Note: The implementation treats "invalid:region:format" -> "invalid-region-format" + # and then tries to parse positions, which fails with integer conversion error + with pytest.raises( + ValueError, + match="Start and the end position of the region has to be integer", + ): + GenomicRegion.from_string("invalid:region:format") + + def test_genomic_region_from_known_genomic_region(self) -> None: + """Test creating GenomicRegion from known genomic region.""" + region = GenomicRegion.from_known_genomic_region(KnownGenomicRegions.MHC) + assert region.chromosome == "6" + assert region.start == 25726063 + assert region.end == 33400556 + + def test_genomic_region_from_known_genomic_region_mhc_string(self) -> None: + """Test MHC region string representation.""" + region = GenomicRegion.from_known_genomic_region(KnownGenomicRegions.MHC) + assert str(region) == "6:25726063-33400556" + + def test_genomic_region_equality(self) -> None: + """Test equality of GenomicRegion objects.""" + region1 = GenomicRegion(chromosome="6", start=28510120, end=33480577) + region2 = GenomicRegion.from_string("6:28510120-33480577") + # Note: GenomicRegion doesn't define __eq__, so we compare attributes + assert region1.chromosome == region2.chromosome + assert region1.start == region2.start + assert region1.end == region2.end + + def test_genomic_region_different_chromosomes(self) -> None: + """Test GenomicRegion with different chromosomes.""" + region_chr1 = GenomicRegion.from_string("1:1000-2000") + region_chr22 = GenomicRegion.from_string("22:3000-4000") + assert region_chr1.chromosome == "1" + assert region_chr22.chromosome == "22" + + def test_genomic_region_large_positions(self) -> None: + """Test GenomicRegion with large position numbers.""" + region = GenomicRegion.from_string("chr21:1000000-249250621") + assert region.chromosome == "21" + assert region.start == 1000000 + assert region.end == 249250621 + + @pytest.mark.parametrize( + "region_string,expected_chr,expected_start,expected_end", + [ + ("1:100-200", "1", 100, 200), + ("chr2:1000-2000", "2", 1000, 2000), + ("3:10,000-20,000", "3", 10000, 20000), + ("chrX:100-200", "X", 100, 200), + ("chrY:1-100", "Y", 1, 100), + ("chrMT:1-16569", "MT", 1, 16569), + ], + ) + def test_genomic_region_various_formats( + self, + region_string: str, + expected_chr: str, + expected_start: int, + expected_end: int, + ) -> None: + """Test GenomicRegion parsing with various valid formats.""" + region = GenomicRegion.from_string(region_string) + assert region.chromosome == expected_chr + assert region.start == expected_start + assert region.end == expected_end + + +class TestKnownGenomicRegions: + """Test KnownGenomicRegions enum.""" + + def test_mhc_region_value(self) -> None: + """Test MHC region value.""" + assert KnownGenomicRegions.MHC.value == "chr6:25726063-33400556" + + def test_known_regions_enum_members(self) -> None: + """Test that KnownGenomicRegions has expected members.""" + # Verify MHC is available + assert hasattr(KnownGenomicRegions, "MHC") + # Verify it's an enum with string values + assert isinstance(KnownGenomicRegions.MHC.value, str) diff --git a/tests/gentropy/common/test_processing.py b/tests/gentropy/common/test_processing.py new file mode 100644 index 000000000..55b97d8ab --- /dev/null +++ b/tests/gentropy/common/test_processing.py @@ -0,0 +1,95 @@ +"""Tests for common processing functions.""" + +from __future__ import annotations + +import pytest +from pyspark.sql import functions as f + +from gentropy.common.processing import extract_chromosome, parse_efos +from gentropy.common.session import Session + + +@pytest.mark.usefixtures("session") +class TestProcessing: + """Test common processing functions.""" + + def test_parse_efos_single_efo(self, session: Session) -> None: + """Test parsing of a single EFO URI.""" + data = [("http://www.ebi.ac.uk/efo/EFO_0000001",)] + df = session.spark.createDataFrame(data, schema="efos STRING") + result = df.select(parse_efos(f.col("efos")).alias("parsed_efos")).collect() + + assert result[0]["parsed_efos"] == ["EFO_0000001"] + + def test_parse_efos_multiple_efos(self, session: Session) -> None: + """Test parsing of multiple EFO URIs.""" + data = [ + ( + "http://www.ebi.ac.uk/efo/EFO_0000001,http://purl.obolibrary.org/obo/OBA_VT0001253,http://www.orpha.net/ORDO/Orphanet_101953", + ) + ] + df = session.spark.createDataFrame(data, schema="efos STRING") + result = df.select(parse_efos(f.col("efos")).alias("parsed_efos")).collect() + + parsed = sorted(result[0]["parsed_efos"]) + expected = sorted(["EFO_0000001", "OBA_VT0001253", "Orphanet_101953"]) + assert parsed == expected + + def test_parse_efos_duplicate_handling(self, session: Session) -> None: + """Test that parse_efos removes duplicates.""" + data = [ + ( + "http://www.ebi.ac.uk/efo/EFO_0000001,http://www.ebi.ac.uk/efo/EFO_0000001,http://www.ebi.ac.uk/efo/EFO_0000002", + ) + ] + df = session.spark.createDataFrame(data, schema="efos STRING") + result = df.select(parse_efos(f.col("efos")).alias("parsed_efos")).collect() + + parsed = sorted(result[0]["parsed_efos"]) + # Should only have 2 unique EFOs + assert len(parsed) == 2 + assert sorted(["EFO_0000001", "EFO_0000002"]) == parsed + + def test_extract_chromosome_simple(self, session: Session) -> None: + """Test extraction of chromosome from simple variant IDs.""" + data = [ + ("1_12345_A_T",), + ("2_54321_G_C",), + ("X_999999_T_A",), + ] + df = session.spark.createDataFrame(data, schema="variantId STRING") + result = df.select( + extract_chromosome(f.col("variantId")).alias("chromosome") + ).collect() + + chromosomes = [row["chromosome"] for row in result] + assert chromosomes == ["1", "2", "X"] + + def test_extract_chromosome_with_prefix(self, session: Session) -> None: + """Test extraction of chromosome with 'chr' prefix.""" + data = [ + ("chr1_12345_A_T",), + ("chrX_54321_G_C",), + ] + df = session.spark.createDataFrame(data, schema="variantId STRING") + result = df.select( + extract_chromosome(f.col("variantId")).alias("chromosome") + ).collect() + + chromosomes = [row["chromosome"] for row in result] + assert chromosomes == ["chr1", "chrX"] + + def test_extract_chromosome_complex(self, session: Session) -> None: + """Test extraction of chromosome from complex variant IDs.""" + data = [ + ("15_KI270850v1_alt_48777_C_T",), + ("GL000220.1_13000_A_G",), + ] + df = session.spark.createDataFrame(data, schema="variantId STRING") + result = df.select( + extract_chromosome(f.col("variantId")).alias("chromosome") + ).collect() + + chromosomes = [row["chromosome"] for row in result] + assert chromosomes[0] == "15_KI270850v1_alt" + assert chromosomes[1] == "GL000220.1" diff --git a/tests/gentropy/common/test_stats.py b/tests/gentropy/common/test_stats.py new file mode 100644 index 000000000..15dba227c --- /dev/null +++ b/tests/gentropy/common/test_stats.py @@ -0,0 +1,81 @@ +"""Tests for common stats functions. + +Note: Basic functionality tests for split_pvalue and get_logsum are covered by doctests +in gentropy.common.stats module. This module tests additional edge cases and behaviors +not covered by doctests. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from gentropy.common.stats import get_logsum, split_pvalue + + +class TestStats: + """Test common statistics functions - edge cases and additional coverage.""" + + def test_get_logsum_large_values_overflow_protection(self) -> None: + """Test logsumexp calculation with large values (prevent overflow). + + This test validates the overflow protection mechanism used by get_logsum, + which is critical for numerical stability with very large exponents. + """ + arr = np.array([1000.0, 1000.1, 999.9]) + result = get_logsum(arr) + # Should not be infinity, indicating proper handling of overflow + assert np.isfinite(result) + assert result > 1000.0 # Should be close to max value plus log of sum + + def test_get_logsum_negative_values(self) -> None: + """Test logsumexp calculation with negative values. + + Edge case: function should handle negative log values correctly. + """ + arr = np.array([-1.0, -2.0, -3.0]) + result = get_logsum(arr) + assert np.isfinite(result) + + def test_split_pvalue_very_small_exponent(self) -> None: + """Test split_pvalue with extremely small p-value. + + Tests handling of p-values near the limits of floating-point representation. + """ + mantissa, exponent = split_pvalue(1e-300) + assert mantissa == 1.0 + assert exponent == -300 + + def test_split_pvalue_invalid_below_zero(self) -> None: + """Test split_pvalue with p-value < 0 (invalid input).""" + with pytest.raises(ValueError, match="P-value must be between 0 and 1"): + split_pvalue(-0.01) + + def test_split_pvalue_invalid_above_one(self) -> None: + """Test split_pvalue with p-value > 1 (invalid input).""" + with pytest.raises(ValueError, match="P-value must be between 0 and 1"): + split_pvalue(1.1) + + def test_split_pvalue_zero(self) -> None: + """Test split_pvalue with p-value = 0 (boundary case). + + Zero is a special case where log10(0) is undefined, so the function + should handle it with exponent = 0. + """ + mantissa, exponent = split_pvalue(0.0) + assert mantissa == 0.0 + assert exponent == 0 + + def test_split_pvalue_mantissa_range_comprehensive(self) -> None: + """Test that split_pvalue returns mantissa between 1 and 10 (except zero). + + This validates the invariant that (mantissa, exponent) always satisfies: + pvalue ≈ mantissa * 10^exponent where 1 <= mantissa < 10 (or mantissa = 0). + """ + test_pvalues = [0.001, 0.01, 0.05, 0.1, 0.5, 0.99, 1e-10, 1e-100] + for pval in test_pvalues: + mantissa, exponent = split_pvalue(pval) + if mantissa != 0.0: # Skip zero case + assert 1.0 <= mantissa <= 10.0, ( + f"Mantissa {mantissa} out of range for p-value {pval}" + ) diff --git a/tests/gentropy/common/test_types.py b/tests/gentropy/common/test_types.py new file mode 100644 index 000000000..b38d47b24 --- /dev/null +++ b/tests/gentropy/common/test_types.py @@ -0,0 +1,73 @@ +"""Tests for common types.""" + +from __future__ import annotations + +from gentropy.common.types import ( + GWASEffect, + PValComponents, +) + + +class TestTypes: + """Test type definitions and named tuples.""" + + def test_pval_components_creation(self) -> None: + """Test creation of PValComponents named tuple.""" + # Create mock column objects + mock_mantissa = "mantissa_col" + mock_exponent = "exponent_col" + + pval = PValComponents(mantissa=mock_mantissa, exponent=mock_exponent) + + assert pval.mantissa == mock_mantissa + assert pval.exponent == mock_exponent + + def test_pval_components_tuple_unpacking(self) -> None: + """Test unpacking of PValComponents.""" + mock_mantissa = "m" + mock_exponent = "e" + pval = PValComponents(mantissa=mock_mantissa, exponent=mock_exponent) + + m, e = pval + assert m == mock_mantissa + assert e == mock_exponent + + def test_gwas_effect_creation(self) -> None: + """Test creation of GWASEffect named tuple.""" + mock_beta = "beta_col" + mock_se = "se_col" + + effect = GWASEffect(beta=mock_beta, standard_error=mock_se) + + assert effect.beta == mock_beta + assert effect.standard_error == mock_se + + def test_gwas_effect_tuple_unpacking(self) -> None: + """Test unpacking of GWASEffect.""" + mock_beta = "b" + mock_se = "s" + effect = GWASEffect(beta=mock_beta, standard_error=mock_se) + + b, s = effect + assert b == mock_beta + assert s == mock_se + + def test_pval_components_named_access(self) -> None: + """Test named access to PValComponents fields.""" + pval = PValComponents(mantissa=0.5, exponent=-8) + + # Test both attribute and positional access + assert pval.mantissa == 0.5 + assert pval[0] == 0.5 + assert pval.exponent == -8 + assert pval[1] == -8 + + def test_gwas_effect_named_access(self) -> None: + """Test named access to GWASEffect fields.""" + effect = GWASEffect(beta=0.05, standard_error=0.01) + + # Test both attribute and positional access + assert effect.beta == 0.05 + assert effect[0] == 0.05 + assert effect.standard_error == 0.01 + assert effect[1] == 0.01 diff --git a/tests/gentropy/step/test_biosample_index_step.py b/tests/gentropy/step/test_biosample_index_step.py new file mode 100644 index 000000000..5427e3fc3 --- /dev/null +++ b/tests/gentropy/step/test_biosample_index_step.py @@ -0,0 +1,58 @@ +"""Tests for biosample index step.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from gentropy.biosample_index import BiosampleIndexStep +from gentropy.common.session import Session + + +@pytest.mark.step_test +class TestBiosampleIndexStep: + """Test biosample index step.""" + + def test_biosample_index_step_initialization( + self, session: Session, tmp_path: Path + ) -> None: + """Test that BiosampleIndexStep can be initialized.""" + # Create temporary paths + cell_ontology_input_path = str(tmp_path / "cell_ontology.json") + uberon_input_path = str(tmp_path / "uberon.json") + efo_input_path = str(tmp_path / "efo.json") + biosample_index_path = str(tmp_path / "biosample_index") + + # This test verifies that the step raises an exception when files don't exist + # (which is expected behavior) + with pytest.raises(Exception): + # Expected when data files don't exist - this is normal behavior + BiosampleIndexStep( + session=session, + cell_ontology_input_path=cell_ontology_input_path, + uberon_input_path=uberon_input_path, + efo_input_path=efo_input_path, + biosample_index_path=biosample_index_path, + ) + + def test_biosample_index_step_parameters(self) -> None: + """Test that BiosampleIndexStep has correct expected parameters.""" + import inspect + + from gentropy.biosample_index import BiosampleIndexStep + + sig = inspect.signature(BiosampleIndexStep.__init__) + params = list(sig.parameters.keys()) + + expected_params = [ + "self", + "session", + "cell_ontology_input_path", + "uberon_input_path", + "efo_input_path", + "biosample_index_path", + ] + + for param in expected_params: + assert param in params, f"Missing parameter: {param}" diff --git a/tests/gentropy/step/test_interval_e2g_step.py b/tests/gentropy/step/test_interval_e2g_step.py new file mode 100644 index 000000000..1836f1918 --- /dev/null +++ b/tests/gentropy/step/test_interval_e2g_step.py @@ -0,0 +1,66 @@ +"""Tests for interval steps.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from gentropy.common.session import Session +from gentropy.intervals import IntervalE2GStep + + +@pytest.mark.step_test +class TestIntervalE2GStep: + """Test interval E2G step.""" + + def test_interval_e2g_step_initialization( + self, session: Session, tmp_path: Path + ) -> None: + """Test that IntervalE2GStep raises an exception when files don't exist.""" + # Create temporary paths + target_index_path = str(tmp_path / "target_index") + biosample_mapping_path = str(tmp_path / "biosample_mapping.csv") + biosample_index_path = str(tmp_path / "biosample_index") + chromosome_contig_index_path = str(tmp_path / "chromosome_contig_index") + interval_source = str(tmp_path / "interval_source") + valid_output_path = str(tmp_path / "valid_output") + invalid_output_path = str(tmp_path / "invalid_output") + + # This test verifies that the step raises an exception when files don't exist + with pytest.raises(Exception): + # Expected when data files don't exist - this is normal behavior + IntervalE2GStep( + session=session, + target_index_path=target_index_path, + biosample_mapping_path=biosample_mapping_path, + biosample_index_path=biosample_index_path, + chromosome_contig_index_path=chromosome_contig_index_path, + interval_source=interval_source, + valid_output_path=valid_output_path, + invalid_output_path=invalid_output_path, + ) + + def test_interval_e2g_step_parameters(self) -> None: + """Test that IntervalE2GStep has correct expected parameters.""" + import inspect + + from gentropy.intervals import IntervalE2GStep + + sig = inspect.signature(IntervalE2GStep.__init__) + params = list(sig.parameters.keys()) + + expected_params = [ + "self", + "session", + "target_index_path", + "biosample_mapping_path", + "biosample_index_path", + "chromosome_contig_index_path", + "interval_source", + "valid_output_path", + "invalid_output_path", + ] + + for param in expected_params: + assert param in params, f"Missing parameter: {param}" diff --git a/tests/gentropy/test_config.py b/tests/gentropy/test_config.py new file mode 100644 index 000000000..11c09161b --- /dev/null +++ b/tests/gentropy/test_config.py @@ -0,0 +1,152 @@ +"""Tests for configuration module.""" + +from __future__ import annotations + +from dataclasses import fields, is_dataclass + +import pytest + +from gentropy.config import ( + BiosampleIndexConfig, + ColocalisationConfig, + Config, + SessionConfig, + StepConfig, + register_config, +) + + +class TestSessionConfig: + """Test SessionConfig dataclass.""" + + def test_session_config_creation(self) -> None: + """Test creating a SessionConfig.""" + config = SessionConfig() + + assert config.start_hail is False + assert config.write_mode == "errorifexists" + assert config.spark_uri == "local[*]" + assert config.output_partitions == 200 + + def test_session_config_is_dataclass(self) -> None: + """Test that SessionConfig is a dataclass.""" + assert is_dataclass(SessionConfig) + + def test_session_config_fields(self) -> None: + """Test that SessionConfig has expected fields.""" + config_fields = {f.name for f in fields(SessionConfig)} + + expected_fields = { + "start_hail", + "write_mode", + "spark_uri", + "hail_home", + "extended_spark_conf", + "use_enhanced_bgzip_codec", + "output_partitions", + "_target_", + } + + for expected in expected_fields: + assert expected in config_fields, f"Missing field: {expected}" + + def test_session_config_custom_values(self) -> None: + """Test creating SessionConfig with custom values.""" + config = SessionConfig( + start_hail=True, + write_mode="overwrite", + spark_uri="local[4]", + output_partitions=100, + ) + + assert config.start_hail is True + assert config.write_mode == "overwrite" + assert config.spark_uri == "local[4]" + assert config.output_partitions == 100 + + +class TestStepConfig: + """Test StepConfig base class.""" + + def test_step_config_is_dataclass(self) -> None: + """Test that StepConfig is a dataclass.""" + assert is_dataclass(StepConfig) + + def test_step_config_has_defaults(self) -> None: + """Test that StepConfig has default defaults.""" + config = StepConfig(session=SessionConfig()) + assert config.session is not None + + def test_step_config_fields(self) -> None: + """Test that StepConfig has expected fields.""" + config_fields = {f.name for f in fields(StepConfig)} + + expected_fields = {"session", "defaults"} + + for expected in expected_fields: + assert expected in config_fields, f"Missing field: {expected}" + + +class TestColocalisationConfig: + """Test ColocalisationConfig.""" + + def test_colocalisation_config_is_dataclass(self) -> None: + """Test that ColocalisationConfig is a dataclass.""" + assert is_dataclass(ColocalisationConfig) + + def test_colocalisation_config_inherits_from_step_config(self) -> None: + """Test that ColocalisationConfig inherits from StepConfig.""" + assert issubclass(ColocalisationConfig, StepConfig) + + def test_colocalisation_config_has_required_fields(self) -> None: + """Test that ColocalisationConfig has expected fields.""" + config_fields = {f.name for f in fields(ColocalisationConfig)} + + expected_fields = { + "credible_set_path", + "coloc_path", + "colocalisation_method", + } + + for expected in expected_fields: + assert expected in config_fields, f"Missing field: {expected}" + + +class TestBiosampleIndexConfig: + """Test BiosampleIndexConfig.""" + + def test_biosample_index_config_is_dataclass(self) -> None: + """Test that BiosampleIndexConfig is a dataclass.""" + assert is_dataclass(BiosampleIndexConfig) + + def test_biosample_index_config_inherits_from_step_config(self) -> None: + """Test that BiosampleIndexConfig inherits from StepConfig.""" + assert issubclass(BiosampleIndexConfig, StepConfig) + + def test_biosample_index_config_has_required_fields(self) -> None: + """Test that BiosampleIndexConfig has expected fields.""" + config_fields = {f.name for f in fields(BiosampleIndexConfig)} + + expected_fields = { + "cell_ontology_input_path", + "uberon_input_path", + "efo_input_path", + "biosample_index_path", + } + + for expected in expected_fields: + assert expected in config_fields, f"Missing field: {expected}" + + +def test_register_config() -> None: + """Test that register_config can be called without errors.""" + # This just verifies the function doesn't raise exceptions + try: + register_config() + except Exception as e: + pytest.fail(f"register_config raised unexpected exception: {e}") + + +def test_config_class_exists() -> None: + """Test that Config class exists and is a dataclass.""" + assert is_dataclass(Config) diff --git a/tests/gentropy/test_ld_based_clumping.py b/tests/gentropy/test_ld_based_clumping.py new file mode 100644 index 000000000..c26e6a55d --- /dev/null +++ b/tests/gentropy/test_ld_based_clumping.py @@ -0,0 +1,115 @@ +"""Tests for LD-based clumping step.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from gentropy.common.session import Session +from gentropy.ld_based_clumping import LDBasedClumpingStep + + +@pytest.mark.step_test +class TestLDBasedClumpingStep: + """Test LDBasedClumpingStep initialization and parameter validation.""" + + def test_ld_based_clumping_step_initialization( + self, session: Session, tmp_path: Path + ) -> None: + """Test that LDBasedClumpingStep initializes without errors.""" + study_locus_input_path = str(tmp_path / "study_locus") + study_index_path = str(tmp_path / "study_index") + ld_index_path = str(tmp_path / "ld_index") + clumped_output_path = str(tmp_path / "clumped") + + with ( + patch( + "gentropy.ld_based_clumping.StudyLocus.from_parquet" + ) as mock_study_locus, + patch("gentropy.ld_based_clumping.LDIndex.from_parquet") as mock_ld_index, + patch( + "gentropy.ld_based_clumping.StudyIndex.from_parquet" + ) as mock_study_index, + ): + # Mock the dataframe objects + mock_sl = MagicMock() + mock_sl.annotate_ld.return_value = mock_sl + mock_sl.clump.return_value = mock_sl + mock_sl.df = MagicMock() + + mock_study_locus.return_value = mock_sl + mock_ld_index.return_value = MagicMock() + mock_study_index.return_value = MagicMock() + + step = LDBasedClumpingStep( + session=session, + study_locus_input_path=study_locus_input_path, + study_index_path=study_index_path, + ld_index_path=ld_index_path, + clumped_study_locus_output_path=clumped_output_path, + ) + assert step is not None + + def test_ld_based_clumping_step_parameters(self) -> None: + """Test that LDBasedClumpingStep has correct expected parameters.""" + import inspect + + sig = inspect.signature(LDBasedClumpingStep.__init__) + params = list(sig.parameters.keys()) + + expected_params = [ + "self", + "session", + "study_locus_input_path", + "study_index_path", + "ld_index_path", + "clumped_study_locus_output_path", + ] + + for param in expected_params: + assert param in params, f"Missing parameter: {param}" + + def test_ld_based_clumping_step_methods_called( + self, session: Session, tmp_path: Path + ) -> None: + """Test that the step calls the expected methods in correct sequence.""" + study_locus_input_path = str(tmp_path / "study_locus") + study_index_path = str(tmp_path / "study_index") + ld_index_path = str(tmp_path / "ld_index") + clumped_output_path = str(tmp_path / "clumped") + + with ( + patch( + "gentropy.ld_based_clumping.StudyLocus.from_parquet" + ) as mock_study_locus, + patch("gentropy.ld_based_clumping.LDIndex.from_parquet") as mock_ld_index, + patch( + "gentropy.ld_based_clumping.StudyIndex.from_parquet" + ) as mock_study_index, + ): + # Setup mock chain + mock_sl = MagicMock() + mock_annotated = MagicMock() + mock_clumped = MagicMock() + mock_clumped.df = MagicMock() + + mock_sl.annotate_ld.return_value = mock_annotated + mock_annotated.clump.return_value = mock_clumped + + mock_study_locus.return_value = mock_sl + mock_ld_index.return_value = MagicMock() + mock_study_index.return_value = MagicMock() + + LDBasedClumpingStep( + session=session, + study_locus_input_path=study_locus_input_path, + study_index_path=study_index_path, + ld_index_path=ld_index_path, + clumped_study_locus_output_path=clumped_output_path, + ) + + # Verify methods were called in correct order + mock_sl.annotate_ld.assert_called_once() + mock_annotated.clump.assert_called_once() diff --git a/tests/gentropy/test_susie_finemapper.py b/tests/gentropy/test_susie_finemapper.py new file mode 100644 index 000000000..09e37bae5 --- /dev/null +++ b/tests/gentropy/test_susie_finemapper.py @@ -0,0 +1,290 @@ +"""Tests for SusieFineMapperStep.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from gentropy.common.session import Session +from gentropy.susie_finemapper import SusieFineMapperStep + + +@pytest.mark.step_test +class TestSusieFineMapperStep: + """Test SusieFineMapperStep initialization and helper methods.""" + + def test_empty_log_mg_creates_file(self, tmp_path: Path) -> None: + """Test that _empty_log_mg creates a CSV file with correct structure.""" + output_path = str(tmp_path / "test_log.tsv") + + SusieFineMapperStep._empty_log_mg( + studyId="STUDY001", + region="chr1:1000-2000", + error_mg="Test error message", + path_out=output_path, + ) + + # Verify file was created + assert Path(output_path).exists() + + # Read and verify content + df = pd.read_csv(output_path, sep="\t") + assert df.shape[0] == 1 + assert df.loc[0, "studyId"] == "STUDY001" + assert df.loc[0, "region"] == "chr1:1000-2000" + assert df.loc[0, "error"] == "Test error message" + + def test_empty_log_mg_column_structure(self, tmp_path: Path) -> None: + """Test that _empty_log_mg creates all expected columns.""" + output_path = str(tmp_path / "test_log_columns.tsv") + + SusieFineMapperStep._empty_log_mg( + studyId="TEST_STUDY", + region="chr10:5000-6000", + error_mg="Some error", + path_out=output_path, + ) + + df = pd.read_csv(output_path, sep="\t") + + expected_columns = { + "studyId", + "region", + "N_gwas_before_dedupl", + "N_gwas", + "N_ld", + "N_overlap", + "N_outliers", + "N_imputed", + "N_final_to_fm", + "elapsed_time", + "number_of_CS", + "error", + } + + assert set(df.columns) == expected_columns + + def test_empty_log_mg_default_numeric_values(self, tmp_path: Path) -> None: + """Test that _empty_log_mg sets all numeric fields to 0.""" + output_path = str(tmp_path / "test_log_values.tsv") + + SusieFineMapperStep._empty_log_mg( + studyId="STUDY_NUM", + region="chr5:1000-2000", + error_mg="Error", + path_out=output_path, + ) + + df = pd.read_csv(output_path, sep="\t") + + numeric_columns = { + "N_gwas_before_dedupl", + "N_gwas", + "N_ld", + "N_overlap", + "N_outliers", + "N_imputed", + "N_final_to_fm", + "elapsed_time", + "number_of_CS", + } + + for col in numeric_columns: + assert df.loc[0, col] == 0, f"Column {col} should be 0" + + def test_empty_log_mg_different_study_ids(self, tmp_path: Path) -> None: + """Test _empty_log_mg with various study IDs.""" + study_ids = ["STUDY_A", "STUDY_123", "ST_XYZ"] + + for study_id in study_ids: + output_path = str(tmp_path / f"{study_id}_log.tsv") + SusieFineMapperStep._empty_log_mg( + studyId=study_id, + region="chr1:1-100", + error_mg="Error", + path_out=output_path, + ) + + df = pd.read_csv(output_path, sep="\t") + assert df.loc[0, "studyId"] == study_id + + def test_empty_log_mg_special_characters_in_error(self, tmp_path: Path) -> None: + """Test _empty_log_mg with special characters in error message.""" + output_path = str(tmp_path / "test_special_chars.tsv") + error_msg = "Error: File not found (path=/data/test)" + + SusieFineMapperStep._empty_log_mg( + studyId="STUDY", + region="chr1:1-100", + error_mg=error_msg, + path_out=output_path, + ) + + df = pd.read_csv(output_path, sep="\t") + assert df.loc[0, "error"] == error_msg + + def test_susie_fine_mapper_step_initialization_fails_without_manifest( + self, session: Session, tmp_path: Path + ) -> None: + """Test that SusieFineMapperStep raises error when manifest doesn't exist.""" + missing_manifest = str(tmp_path / "missing_manifest.csv") + + with pytest.raises(FileNotFoundError): + SusieFineMapperStep( + session=session, + study_index_path=str(tmp_path / "study_index"), + study_locus_manifest_path=missing_manifest, + study_locus_index=0, + ld_matrix_paths={}, + ) + + def test_susie_fine_mapper_step_initialization_with_manifest( + self, session: Session, tmp_path: Path + ) -> None: + """Test that SusieFineMapperStep can be initialized with valid manifest.""" + # Create a minimal manifest file + manifest_data = pd.DataFrame( + { + "study_locus_input": [str(tmp_path / "input")], + "study_locus_output": [str(tmp_path / "output")], + } + ) + manifest_path = str(tmp_path / "manifest.csv") + manifest_data.to_csv(manifest_path, index=False) + + with ( + patch( + "gentropy.susie_finemapper.StudyLocus.from_parquet" + ) as mock_study_locus, + patch( + "gentropy.susie_finemapper.StudyIndex.from_parquet" + ) as mock_study_index, + ): + # Mock the study locus and index + mock_sl = MagicMock() + mock_sl.df.withColumn.return_value.collect.return_value = [MagicMock()] + mock_study_locus.return_value = mock_sl + + mock_study_index.return_value = MagicMock() + + with patch( + "gentropy.susie_finemapper.SusieFineMapperStep.susie_finemapper_one_sl_row_gathered_boundaries" + ) as mock_finemapper: + mock_finemapper.return_value = None + + step = SusieFineMapperStep( + session=session, + study_index_path=str(tmp_path / "study_index"), + study_locus_manifest_path=manifest_path, + study_locus_index=0, + ld_matrix_paths={}, + ) + + assert step is not None + + def test_susie_fine_mapper_step_invalid_index( + self, session: Session, tmp_path: Path + ) -> None: + """Test that SusieFineMapperStep raises error with invalid index.""" + manifest_data = pd.DataFrame( + { + "study_locus_input": [str(tmp_path / "input")], + "study_locus_output": [str(tmp_path / "output")], + } + ) + manifest_path = str(tmp_path / "manifest.csv") + manifest_data.to_csv(manifest_path, index=False) + + with pytest.raises(Exception): # IndexError or similar + SusieFineMapperStep( + session=session, + study_index_path=str(tmp_path / "study_index"), + study_locus_manifest_path=manifest_path, + study_locus_index=999, # Out of bounds + ld_matrix_paths={}, + ) + + def test_susie_fine_mapper_step_initialization_parameters(self) -> None: + """Test that SusieFineMapperStep has correct expected parameters.""" + import inspect + + sig = inspect.signature(SusieFineMapperStep.__init__) + params = list(sig.parameters.keys()) + + expected_params = [ + "self", + "session", + "study_index_path", + "study_locus_manifest_path", + "study_locus_index", + "ld_matrix_paths", + "max_causal_snps", + "lead_pval_threshold", + "purity_mean_r2_threshold", + "purity_min_r2_threshold", + "cs_lbf_thr", + "sum_pips", + "susie_est_tausq", + "run_carma", + "run_sumstat_imputation", + "carma_time_limit", + "carma_tau", + "imputed_r2_threshold", + "ld_score_threshold", + "ld_min_r2", + "ignore_qc", + ] + + for param in expected_params: + assert param in params, f"Missing parameter: {param}" + + def test_susie_fine_mapper_step_default_parameters(self) -> None: + """Test that SusieFineMapperStep has correct default parameter values.""" + import inspect + + sig = inspect.signature(SusieFineMapperStep.__init__) + + # Check default values + assert sig.parameters["max_causal_snps"].default == 10 + assert sig.parameters["lead_pval_threshold"].default == 1e-5 + assert sig.parameters["purity_mean_r2_threshold"].default == 0 + assert sig.parameters["purity_min_r2_threshold"].default == 0.25 + assert sig.parameters["cs_lbf_thr"].default == 2 + assert sig.parameters["sum_pips"].default == 0.99 + assert sig.parameters["susie_est_tausq"].default is False + assert sig.parameters["run_carma"].default is False + assert sig.parameters["run_sumstat_imputation"].default is False + assert sig.parameters["carma_time_limit"].default == 600 + assert sig.parameters["carma_tau"].default == 0.15 + assert sig.parameters["imputed_r2_threshold"].default == 0.9 + assert sig.parameters["ld_score_threshold"].default == 5 + assert sig.parameters["ld_min_r2"].default == 0.8 + assert sig.parameters["ignore_qc"].default is False + + @pytest.mark.parametrize( + "region,study_id", + [ + ("chr1:1000-2000", "STUDY_1"), + ("chr22:500000-600000", "STUDY_22"), + ("chrX:100-200", "STUDY_X"), + ], + ) + def test_empty_log_mg_parametrized( + self, tmp_path: Path, region: str, study_id: str + ) -> None: + """Test _empty_log_mg with various region and study ID combinations.""" + output_path = str(tmp_path / f"{study_id}_log.tsv") + + SusieFineMapperStep._empty_log_mg( + studyId=study_id, + region=region, + error_mg="Test error", + path_out=output_path, + ) + + df = pd.read_csv(output_path, sep="\t") + assert df.loc[0, "studyId"] == study_id + assert df.loc[0, "region"] == region diff --git a/tests/gentropy/test_variant_index.py b/tests/gentropy/test_variant_index.py new file mode 100644 index 000000000..fd7b3e728 --- /dev/null +++ b/tests/gentropy/test_variant_index.py @@ -0,0 +1,148 @@ +"""Tests for variant_index step.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from gentropy.common.session import Session +from gentropy.variant_index import ConvertToVcfStep, VariantIndexStep + + +@pytest.mark.step_test +class TestVariantIndexStep: + """Test VariantIndexStep initialization and parameter validation.""" + + def test_variant_index_step_initialization( + self, session: Session, tmp_path: Path + ) -> None: + """Test that VariantIndexStep initializes without errors when files don't exist.""" + vep_output_json_path = str(tmp_path / "vep_output.json") + variant_index_path = str(tmp_path / "variant_index") + + # Mock the VEP parser to avoid file I/O + with patch( + "gentropy.variant_index.VariantEffectPredictorParser.extract_variant_index_from_vep" + ) as mock_vep: + mock_vep.return_value = MagicMock() + mock_vep.return_value.df = MagicMock() + + step = VariantIndexStep( + session=session, + vep_output_json_path=vep_output_json_path, + variant_index_path=variant_index_path, + hash_threshold=20, + ) + assert step is not None + + def test_variant_index_step_initialization_with_annotations( + self, session: Session, tmp_path: Path + ) -> None: + """Test VariantIndexStep with variant annotations.""" + vep_output_json_path = str(tmp_path / "vep_output.json") + variant_index_path = str(tmp_path / "variant_index") + annotation_path = [str(tmp_path / "annotation1")] + + with ( + patch( + "gentropy.variant_index.VariantEffectPredictorParser.extract_variant_index_from_vep" + ) as mock_vep, + patch( + "gentropy.variant_index.VariantIndex.from_parquet" + ) as mock_from_parquet, + ): + mock_variant_index = MagicMock() + mock_variant_index.df = MagicMock() + mock_variant_index.add_annotation.return_value = mock_variant_index + mock_vep.return_value = mock_variant_index + mock_from_parquet.return_value = mock_variant_index + + step = VariantIndexStep( + session=session, + vep_output_json_path=vep_output_json_path, + variant_index_path=variant_index_path, + hash_threshold=20, + variant_annotations_path=annotation_path, + ) + assert step is not None + mock_variant_index.add_annotation.assert_called_once() + + def test_variant_index_step_initialization_with_amino_acids( + self, session: Session, tmp_path: Path + ) -> None: + """Test VariantIndexStep with amino acid annotations.""" + vep_output_json_path = str(tmp_path / "vep_output.json") + variant_index_path = str(tmp_path / "variant_index") + amino_acid_path = [str(tmp_path / "amino_acid")] + + with ( + patch( + "gentropy.variant_index.VariantEffectPredictorParser.extract_variant_index_from_vep" + ) as mock_vep, + patch( + "gentropy.variant_index.AminoAcidVariants.from_parquet" + ) as mock_amino_acids, + ): + mock_variant_index = MagicMock() + mock_variant_index.df = MagicMock() + mock_variant_index.annotate_with_amino_acid_consequences.return_value = ( + mock_variant_index + ) + mock_vep.return_value = mock_variant_index + mock_amino_acids.return_value = MagicMock() + + step = VariantIndexStep( + session=session, + vep_output_json_path=vep_output_json_path, + variant_index_path=variant_index_path, + hash_threshold=20, + amino_acid_change_annotations=amino_acid_path, + ) + assert step is not None + mock_variant_index.annotate_with_amino_acid_consequences.assert_called_once() + + def test_variant_index_step_parameters(self) -> None: + """Test that VariantIndexStep has correct expected parameters.""" + import inspect + + sig = inspect.signature(VariantIndexStep.__init__) + params = list(sig.parameters.keys()) + + expected_params = [ + "self", + "session", + "vep_output_json_path", + "variant_index_path", + "hash_threshold", + "variant_annotations_path", + "amino_acid_change_annotations", + ] + + for param in expected_params: + assert param in params, f"Missing parameter: {param}" + + +@pytest.mark.step_test +class TestConvertToVcfStep: + """Test ConvertToVcfStep initialization and parameter validation.""" + + def test_convert_to_vcf_step_parameters(self) -> None: + """Test that ConvertToVcfStep has correct expected parameters.""" + import inspect + + sig = inspect.signature(ConvertToVcfStep.__init__) + params = list(sig.parameters.keys()) + + expected_params = [ + "self", + "session", + "source_paths", + "source_formats", + "output_path", + "partition_size", + ] + + for param in expected_params: + assert param in params, f"Missing parameter: {param}" diff --git a/uv.lock b/uv.lock index 6787dfa53..09b8f6716 100644 --- a/uv.lock +++ b/uv.lock @@ -1090,7 +1090,7 @@ dev = [ { name = "darglint", specifier = ">=1.8.1,<1.9.0" }, { name = "deptry", specifier = ">=0.22.0,<0.25.0" }, { name = "interrogate", specifier = ">=1.7.0,<1.8.0" }, - { name = "ipykernel", specifier = ">=6.28.0,<6.31.0" }, + { name = "ipykernel", specifier = ">=6.28.0,<7.3.0" }, { name = "ipython", specifier = ">=8.19.0,<9.9.0" }, { name = "isort", specifier = ">=5.13.2,<6.1.0" }, { name = "mypy", specifier = ">=1.13,<1.19" }, @@ -1098,7 +1098,7 @@ dev = [ { name = "pre-commit", specifier = ">=4.0.0,<4.6.0" }, { name = "prettier", specifier = ">=0.0.7,<0.1.0" }, { name = "pydoclint", specifier = ">=0.3.8,<0.9.0" }, - { name = "ruff", specifier = ">=0.8.1,<0.15.0" }, + { name = "ruff", specifier = ">=0.8.1,<0.16.0" }, { name = "yamllint", specifier = ">=1.33.0,<1.38.0" }, ] docs = [ From 27b7e05a5d3aa3b76aea6b84abde91e74449ccd4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 09:49:20 +0000 Subject: [PATCH 11/16] chore(deps): bump artifact actions to v7 (#1181) * chore(deps): bump actions/upload-artifact from 5 to 6 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 5 to 6. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] * ci: update artifact actions to v7 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- .github/workflows/release.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index ce19ede08..c0db17a7b 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -46,7 +46,7 @@ jobs: with: github_token: ${{ steps.trigger-token.outputs.token }} tag: ${{ steps.semrelease.outputs.tag }} - - uses: actions/upload-artifact@v5 + - uses: actions/upload-artifact@v7 if: steps.semrelease.outputs.released == 'true' with: name: python-package-distributions @@ -63,7 +63,7 @@ jobs: permissions: id-token: write # IMPORTANT: mandatory for trusted publishing steps: - - uses: actions/download-artifact@v6 + - uses: actions/download-artifact@v7 with: name: python-package-distributions path: dist/ @@ -84,7 +84,7 @@ jobs: permissions: id-token: write # IMPORTANT: mandatory for trusted publishing steps: - - uses: actions/download-artifact@v6 + - uses: actions/download-artifact@v7 with: name: python-package-distributions path: dist/ From 7259d2465146f07ab9a22bf3759bd65f936fd4ea Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 10:50:48 +0000 Subject: [PATCH 12/16] build(deps-dev): update isort requirement (#1168) Updates the requirements on [isort](https://github.com/PyCQA/isort) to permit the latest version. - [Release notes](https://github.com/PyCQA/isort/releases) - [Changelog](https://github.com/PyCQA/isort/blob/main/CHANGELOG.md) - [Commits](https://github.com/PyCQA/isort/compare/5.13.2...7.0.0) --- updated-dependencies: - dependency-name: isort dependency-version: 7.0.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 56fba771c..ecdd9c048 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ dev = [ "mypy >=1.13, <1.19", "pep8-naming >=0.14.1, <0.16.0", "interrogate >=1.7.0, <1.8.0", - "isort >=5.13.2, <6.1.0", + "isort >=5.13.2, <7.1.0", "darglint >=1.8.1, <1.9.0", "ruff >=0.8.1, <0.16.0", ] From 36c60a278186443a5a81eb4a165651b4911e8b9b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:26:55 +0000 Subject: [PATCH 13/16] chore(deps): bump actions/checkout from 3 to 6 (#1166) Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 6. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v6) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- .github/workflows/artifact.yml | 2 +- .github/workflows/pr.yaml | 2 +- .github/workflows/pr_release_trigger.yaml | 2 +- .github/workflows/release.yaml | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/artifact.yml b/.github/workflows/artifact.yml index 46375ee1c..3679604dc 100644 --- a/.github/workflows/artifact.yml +++ b/.github/workflows/artifact.yml @@ -36,7 +36,7 @@ jobs: - id: checkout name: Check out repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 47f71da66..40fdf9eb2 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -16,7 +16,7 @@ jobs: python-version: ["3.11", "3.12", "3.13"] os: [ubuntu-latest, ubuntu-22.04-arm] # macos-latest is arm64 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 1 - name: Set up Python diff --git a/.github/workflows/pr_release_trigger.yaml b/.github/workflows/pr_release_trigger.yaml index d97828754..4f8d18cb1 100644 --- a/.github/workflows/pr_release_trigger.yaml +++ b/.github/workflows/pr_release_trigger.yaml @@ -8,7 +8,7 @@ jobs: pull-request: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: pull-request uses: diillson/auto-pull-request@v1.0.1 with: diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index c0db17a7b..0a2f37e6b 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -30,7 +30,7 @@ jobs: with: app-id: ${{ vars.TRIGGER_WORKFLOW_GH_APP_ID}} private-key: ${{ secrets.TRIGGER_WORKFLOW_GH_APP_KEY }} - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.ref_name }} @@ -95,7 +95,7 @@ jobs: runs-on: ubuntu-latest if: github.ref == 'refs/heads/main' && needs.release.outputs.released == 'true' steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} From 8cc7a9c83f31f58c28f22d4c24b04c305bf58f88 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:46:43 +0000 Subject: [PATCH 14/16] build(deps): update wandb requirement (#1189) Updates the requirements on [wandb](https://github.com/wandb/wandb) to permit the latest version. - [Release notes](https://github.com/wandb/wandb/releases) - [Changelog](https://github.com/wandb/wandb/blob/main/CHANGELOG.md) - [Commits](https://github.com/wandb/wandb/compare/v0.19.4...v0.25.0) --- updated-dependencies: - dependency-name: wandb dependency-version: 0.25.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ecdd9c048..ea27bff53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "xgboost-cpu>=3.0.4 ; (platform_machine == 'amd64' and sys_platform != 'darwin') or (platform_machine == 'x86_64' and sys_platform != 'darwin')", "xgboost>=3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'darwin'", "huggingface-hub>=0.27.1", - "wandb (>=0.19.4, <0.23.0)", + "wandb (>=0.19.4, <0.26.0)", ] classifiers = [ "Programming Language :: Python :: 3.11", From 62dd9b881011b6e32a8cf62d3cf2eea4037a7c78 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:18:18 +0000 Subject: [PATCH 15/16] chore: pre-commit autoupdate (#1165) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.14.4 → v0.15.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.14.4...v0.15.2) - [github.com/adrienverge/yamllint.git: v1.37.1 → v1.38.0](https://github.com/adrienverge/yamllint.git/compare/v1.37.1...v1.38.0) - [github.com/alessandrojcm/commitlint-pre-commit-hook: v9.23.0 → v9.24.0](https://github.com/alessandrojcm/commitlint-pre-commit-hook/compare/v9.23.0...v9.24.0) - [github.com/pre-commit/mirrors-mypy: v1.18.2 → v1.19.1](https://github.com/pre-commit/mirrors-mypy/compare/v1.18.2...v1.19.1) - [github.com/lovesegfault/beautysh: v6.4.1 → v6.4.2](https://github.com/lovesegfault/beautysh/compare/v6.4.1...v6.4.2) - [github.com/jsh9/pydoclint: 0.8.1 → 0.8.3](https://github.com/jsh9/pydoclint/compare/0.8.1...0.8.3) - [github.com/astral-sh/uv-pre-commit: 0.9.8 → 0.10.4](https://github.com/astral-sh/uv-pre-commit/compare/0.9.8...0.10.4) * chore: pre-commit auto fixes [...] * chore: update ruff rules --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> --- .pre-commit-config.yaml | 14 +++++++------- pyproject.toml | 4 +++- .../gwas_catalog/study_index_ot_curation.py | 2 +- .../datasource/gwas_catalog/summary_statistics.py | 2 +- src/gentropy/method/l2g/model.py | 2 +- uv.lock | 2 +- 6 files changed, 14 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0f0cfbf6c..f714c4b4c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ ci: skip: [uv-lock] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.4 + rev: v0.15.2 hooks: - id: ruff args: @@ -35,7 +35,7 @@ repos: - id: debug-statements - id: check-docstring-first - repo: https://github.com/adrienverge/yamllint.git - rev: v1.37.1 + rev: v1.38.0 hooks: - id: yamllint @@ -59,14 +59,14 @@ repos: exclude: "CHANGELOG.md" - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook - rev: v9.23.0 + rev: v9.24.0 hooks: - id: commitlint additional_dependencies: ["@commitlint/config-conventional@18.6.3"] stages: [commit-msg] - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.18.2" + rev: "v1.19.1" hooks: - id: mypy args: @@ -94,16 +94,16 @@ repos: args: [--convention=google, --add-ignore=D107] - repo: https://github.com/lovesegfault/beautysh - rev: v6.4.1 + rev: v6.4.2 hooks: - id: beautysh - repo: https://github.com/jsh9/pydoclint - rev: 0.8.1 + rev: 0.8.3 hooks: - id: pydoclint - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.9.8 + rev: 0.10.4 hooks: - id: uv-lock args: [--check] diff --git a/pyproject.toml b/pyproject.toml index ea27bff53..afd8ce757 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -251,9 +251,11 @@ ignore = [ "PLW2901", # Outer {outer_kind} variable {name} overwritten by inner {inner_kind} target "UP006", # keep type annotation style as is "UP007", # keep type annotation style as is + "UP042", # use of the str and Enum (double inheritance) for strEnum + "PLW0108", # Use of lambda expression assigned to variable. Use a def instead. # Ignored due to performance: https://github.com/charliermarsh/ruff/issues/2923 "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` - "G004", # f-string used in logging function + "G004" # f-string used in logging function ] diff --git a/src/gentropy/datasource/gwas_catalog/study_index_ot_curation.py b/src/gentropy/datasource/gwas_catalog/study_index_ot_curation.py index 8d75e4824..a1a7e3182 100644 --- a/src/gentropy/datasource/gwas_catalog/study_index_ot_curation.py +++ b/src/gentropy/datasource/gwas_catalog/study_index_ot_curation.py @@ -85,6 +85,6 @@ def from_url( return cls._parser( session.spark.read.csv( - SparkFiles.get(curation_url.split("/")[-1]), sep="\t", header=True + SparkFiles.get(curation_url.rsplit("/", maxsplit=1)[-1]), sep="\t", header=True ) ) diff --git a/src/gentropy/datasource/gwas_catalog/summary_statistics.py b/src/gentropy/datasource/gwas_catalog/summary_statistics.py index d50acea31..238cb8e81 100644 --- a/src/gentropy/datasource/gwas_catalog/summary_statistics.py +++ b/src/gentropy/datasource/gwas_catalog/summary_statistics.py @@ -42,7 +42,7 @@ def filename_to_study_identifier(path: str) -> str: ... ValueError: Path ("wrong/path") does not contain GWAS Catalog study identifier. """ - file_name = path.split("/")[-1] + file_name = path.rsplit("/", maxsplit=1)[-1] study_id_matches = re.search(r"(GCST\d+)", file_name) if not study_id_matches: diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index ac8447c26..f384128c6 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -241,7 +241,7 @@ def save(self: LocusToGeneModel, path: str) -> None: if not path.endswith(".skops"): raise ValueError("Path should end with .skops") if path.startswith("gs://"): - local_path = path.split("/")[-1] + local_path = path.rsplit("/", maxsplit=1)[-1] sio.dump(self.model, local_path) copy_to_gcs(local_path, path) else: diff --git a/uv.lock b/uv.lock index 09b8f6716..3aedf4cff 100644 --- a/uv.lock +++ b/uv.lock @@ -1092,7 +1092,7 @@ dev = [ { name = "interrogate", specifier = ">=1.7.0,<1.8.0" }, { name = "ipykernel", specifier = ">=6.28.0,<7.3.0" }, { name = "ipython", specifier = ">=8.19.0,<9.9.0" }, - { name = "isort", specifier = ">=5.13.2,<6.1.0" }, + { name = "isort", specifier = ">=5.13.2,<7.1.0" }, { name = "mypy", specifier = ">=1.13,<1.19" }, { name = "pep8-naming", specifier = ">=0.14.1,<0.16.0" }, { name = "pre-commit", specifier = ">=4.0.0,<4.6.0" }, From 95f2c7b93c703139e048a7e89b3ae9b602f159ed Mon Sep 17 00:00:00 2001 From: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:24:37 +0000 Subject: [PATCH 16/16] chore: update lock file --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 3aedf4cff..cd0fdd297 100644 --- a/uv.lock +++ b/uv.lock @@ -1079,7 +1079,7 @@ requires-dist = [ { name = "scipy", specifier = ">=1.11.4,<1.16.0" }, { name = "shap", specifier = ">=0.50.0" }, { name = "skops", specifier = ">=0.13.0,<0.14.0" }, - { name = "wandb", specifier = ">=0.19.4,<0.23.0" }, + { name = "wandb", specifier = ">=0.19.4,<0.26.0" }, { name = "xgboost", marker = "platform_machine == 'x86_64' and sys_platform == 'darwin'", specifier = ">=3.0.4" }, { name = "xgboost", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64'", specifier = ">=3.0.4" }, { name = "xgboost-cpu", marker = "(platform_machine == 'amd64' and sys_platform != 'darwin') or (platform_machine == 'x86_64' and sys_platform != 'darwin')", specifier = ">=3.0.4" },