diff --git a/impc_etl/jobs/load/impc_kg/human_gene_mapper.py b/impc_etl/jobs/load/impc_kg/human_gene_mapper.py index 17211043..97be424e 100644 --- a/impc_etl/jobs/load/impc_kg/human_gene_mapper.py +++ b/impc_etl/jobs/load/impc_kg/human_gene_mapper.py @@ -1,86 +1,78 @@ -import luigi -from impc_etl.jobs.extract.gene_ref_extractor import ExtractGeneRef -from impc_etl.jobs.load.impc_bulk_api.impc_api_mapper import to_camel_case -from luigi.contrib.spark import PySparkTask -from pyspark import SparkContext -from pyspark.sql import SparkSession -from pyspark.sql.functions import arrays_zip, explode +""" +Module to generate the human gene data as JSON for the KG. +""" +import logging +import textwrap -from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id -from impc_etl.workflow.config import ImpcConfig +from airflow.sdk import Variable, asset +from impc_etl.utils.airflow import create_input_asset, create_output_asset +from impc_etl.utils.spark import with_spark_session -class ImpcKgHumanGeneMapper(PySparkTask): - """ - PySpark Task class to parse GenTar Product report data. - """ +task_logger = logging.getLogger("airflow.task") +dr_tag = Variable.get("data_release_tag") - #: Name of the Spark task - name: str = "ImpcKgHumanGeneMapper" +gene_ref_parquet_path_asset = create_input_asset("output/gene_ref_parquet") - #: Path of the output directory where the new parquet file will be generated. - output_path: luigi.Parameter = luigi.Parameter() +human_gene_output_asset = create_output_asset("/impc_kg/human_gene_json") - def requires(self): - return [ExtractGeneRef()] - - def output(self): +@asset.multi( + schedule=[gene_ref_parquet_path_asset], + outlets=[human_gene_output_asset], + dag_id=f"{dr_tag}_impc_kg_human_gene_mapper", + description=textwrap.dedent( """ - Returns the full parquet path as an output for the Luigi Task - (e.g. impc/dr15.2/parquet/product_report_parquet) + PySpark task to create the human gene Knowledge Graph JSON files + from the HGNC data in the reference database. """ - return ImpcConfig().get_target(f"{self.output_path}/impc_kg/human_gene_json") + ), + tags=["impc_kg"], +) +@with_spark_session +def impc_kg_human_gene_mapper(): - def app_options(self): - """ - Generates the options pass to the PySpark job - """ - return [ - self.input()[0].path, - self.output().path, - ] + from impc_etl.jobs.load.impc_web_api.impc_web_api_helper import to_camel_case + from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id + + from pyspark.sql import SparkSession + from pyspark.sql.functions import ( + explode, + arrays_zip, + ) - def main(self, sc: SparkContext, *args): - """ - Takes in a SparkContext and the list of arguments generated by `app_options` and executes the PySpark job. - """ - spark = SparkSession(sc) - - # Parsing app options - gene_ref_parquet_path = args[0] - output_path = args[1] + spark = SparkSession.builder.getOrCreate() - gene_ref_df = spark.read.parquet(gene_ref_parquet_path) - gene_ref_df = gene_ref_df.withColumn( - "human_info", explode(arrays_zip("human_gene_symbol", "human_gene_acc_id")) - ).select("human_info.*") + gene_ref_df = spark.read.parquet(gene_ref_parquet_path_asset.uri) + gene_ref_df = gene_ref_df.withColumn( + "human_info", explode(arrays_zip("human_gene_symbol", "human_gene_acc_id")) + ).select("human_info.*") - gene_ref_df = add_unique_id( - gene_ref_df, - "human_gene_id", - ["human_gene_acc_id"], - ) + gene_ref_df = add_unique_id( + gene_ref_df, + "human_gene_id", + ["human_gene_acc_id"], + ) - mouse_gene_col_map = { - "human_gene_symbol": "symbol", - } + mouse_gene_col_map = { + "human_gene_symbol": "symbol", + } - output_cols = [ - "human_gene_id", - "human_gene_acc_id", - "human_gene_symbol", - ] - output_df = gene_ref_df.select(*output_cols).distinct() - for col_name in output_df.columns: - output_df = output_df.withColumnRenamed( - col_name, - ( - to_camel_case(col_name) - if col_name not in mouse_gene_col_map - else to_camel_case(mouse_gene_col_map[col_name]) - ), - ) - output_df.distinct().coalesce(1).write.json( - output_path, mode="overwrite", compression="gzip" + output_cols = [ + "human_gene_id", + "human_gene_acc_id", + "human_gene_symbol", + ] + output_df = gene_ref_df.select(*output_cols).distinct() + for col_name in output_df.columns: + output_df = output_df.withColumnRenamed( + col_name, + ( + to_camel_case(col_name) + if col_name not in mouse_gene_col_map + else to_camel_case(mouse_gene_col_map[col_name]) + ), ) + output_df.distinct().coalesce(1).write.json( + human_gene_output_asset.uri, mode="overwrite", compression="gzip" + )