diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3828057a..75f5dbcc 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -36,12 +36,17 @@ jobs: mamba install -y -q pip wheel pip install uv + - name: Install Postgres for testing + shell: bash -l {0} + run: | + mamba install -y -q postgresql + - name: Install dependencies shell: bash -l {0} run: | uv pip install -r requirements.txt + uv pip install testing.postgresql - # We have two cores so we can speed up the testing with xdist - name: Install pytest packages shell: bash -l {0} run: | diff --git a/.gitignore b/.gitignore index a83e3ca4..1487de3e 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,6 @@ pytest_session.txt # VS Code .vscode + +# Scratch directory +.scratch diff --git a/docker/Dockerfile.replication b/docker/Dockerfile.replication index 12e8b3a5..a8e218f1 100644 --- a/docker/Dockerfile.replication +++ b/docker/Dockerfile.replication @@ -3,11 +3,14 @@ FROM python:3.12-slim-bookworm ENV DEBIAN_FRONTEND=noninteractive # Update and install OS dependencies -RUN apt-get -y update && \ - apt-get -y upgrade && \ - apt-get -y install --no-install-recommends git && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + build-essential \ + python3-dev \ + pkg-config \ + git \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* # Install required python build dependencies RUN pip install --upgrade --no-cache-dir pip setuptools wheel uv diff --git a/pyproject.toml b/pyproject.toml index 5f082778..de2ff161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,12 @@ classifiers = [ keywords = ["lsst"] dependencies = [ "astropy", + "google-cloud-bigquery", "pyarrow", "pydantic >=2,<3", "pyyaml >= 5.1", "sqlalchemy", + "lsst-dax-ppdbx-gcp", "lsst-felis", "lsst-sdm-schemas", "lsst-utils", @@ -43,9 +45,6 @@ test = [ "pytest >= 3.2", "pytest-openfiles >= 0.5.0" ] -gcp = [ - "lsst-dax-ppdbx-gcp" -] [tool.setuptools.packages.find] where = ["python"] @@ -54,7 +53,7 @@ where = ["python"] zip-safe = true [tool.setuptools.package-data] -"lsst.dax.ppdb" = ["py.typed"] +"lsst.dax.ppdb" = ["py.typed", "config/schemas/*.yaml", "config/sql/*.sql"] [tool.setuptools.dynamic] version = { attr = "lsst_versions.get_lsst_version" } diff --git a/python/lsst/dax/ppdb/__init__.py b/python/lsst/dax/ppdb/__init__.py index d8aeb139..2f4dab94 100644 --- a/python/lsst/dax/ppdb/__init__.py +++ b/python/lsst/dax/ppdb/__init__.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from .config import * +from .ppdb_config import * from .ppdb import * from .replicator import * from .version import * # Generated by sconsUtils diff --git a/python/lsst/dax/ppdb/_factory.py b/python/lsst/dax/ppdb/_factory.py index aee2ee52..c3774778 100644 --- a/python/lsst/dax/ppdb/_factory.py +++ b/python/lsst/dax/ppdb/_factory.py @@ -26,8 +26,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .config import PpdbConfig from .ppdb import Ppdb + from .ppdb_config import PpdbConfig def config_type_for_name(type_name: str) -> type[PpdbConfig]: diff --git a/python/lsst/dax/ppdb/bigquery/__init__.py b/python/lsst/dax/ppdb/bigquery/__init__.py index 19d8c17d..e7b3071b 100644 --- a/python/lsst/dax/ppdb/bigquery/__init__.py +++ b/python/lsst/dax/ppdb/bigquery/__init__.py @@ -20,5 +20,6 @@ # along with this program. If not, see . from .manifest import Manifest +from .chunk_uploader import ChunkUploader from .ppdb_bigquery import PpdbBigQuery, PpdbBigQueryConfig from .ppdb_replica_chunk_extended import ChunkStatus, PpdbReplicaChunkExtended diff --git a/python/lsst/dax/ppdb/bigquery/chunk_uploader.py b/python/lsst/dax/ppdb/bigquery/chunk_uploader.py index 9efe0e22..727bcd09 100644 --- a/python/lsst/dax/ppdb/bigquery/chunk_uploader.py +++ b/python/lsst/dax/ppdb/bigquery/chunk_uploader.py @@ -42,7 +42,7 @@ ) from e -from ..config import PpdbConfig +from ..ppdb_config import PpdbConfig from .manifest import Manifest from .ppdb_bigquery import PpdbBigQuery, PpdbBigQueryConfig from .ppdb_replica_chunk_extended import ChunkStatus, PpdbReplicaChunkExtended @@ -237,15 +237,27 @@ def _process_chunk(self, replica_chunk: PpdbReplicaChunkExtended) -> None: ) # Make a list of local parquet files to upload. - parquet_files = list(chunk_dir.glob("*.parquet")) + upload_file_list = list(chunk_dir.glob("*.parquet")) + + # Include the update records file if the manifest indicates it should + # exist + if manifest.includes_update_records: + update_records_file = chunk_dir / "update_records.json" + if not update_records_file.exists(): + raise ChunkUploadError( + chunk_id, + f"Manifest indicates update records are included but file does not exist: " + f"{update_records_file}", + ) + upload_file_list.append(update_records_file) # Check if the chunk is expected to be empty. is_empty = manifest.is_empty_chunk() - if not parquet_files and not is_empty: + if not upload_file_list and not is_empty: # There is a mismatch between the manifest and the actual files. # Some processing error may have occurred when exporting. - raise ChunkUploadError(chunk_id, f"No parquet files found in {chunk_dir} for non-empty chunk") + raise ChunkUploadError(chunk_id, f"No files found to upload in {chunk_dir} for non-empty chunk") # Check that all expected parquet files from the manifest are present. for table_name, table_stats in manifest.table_data.items(): @@ -258,24 +270,21 @@ def _process_chunk(self, replica_chunk: PpdbReplicaChunkExtended) -> None: ) try: - # 1) Upload parquet files, which will happen only for non-empty - # chunks. - if parquet_files: - gcs_names = {path: posixpath.join(gcs_prefix, path.name) for path in parquet_files} + # 1) Upload the files to GCS for non-empty chunks + if upload_file_list: + gcs_names = {path: posixpath.join(gcs_prefix, path.name) for path in upload_file_list} try: - _LOG.info( - "Uploading %d parquet files to GCS under prefix: %s", len(gcs_names), gcs_prefix - ) + _LOG.info("Uploading %d files to GCS under prefix: %s", len(gcs_names), gcs_prefix) with Timer( "upload_files_time", _MON, tags={"prefix": str(gcs_prefix), "chunk_id": str(chunk_id)} ) as timer: self.storage.upload_files(gcs_names) - total_bytes = sum(p.stat().st_size for p in parquet_files) + total_bytes = sum(p.stat().st_size for p in upload_file_list) timer.add_values(file_count=len(gcs_names), total_bytes=total_bytes) except* UploadError as eg: raise ChunkUploadError(chunk_id, f"{len(eg.exceptions)} upload(s) failed") from eg - # 2) Upload manifest, even for empty chunks. + # 2) Upload manifest, even for empty chunks try: self.storage.upload_from_string( posixpath.join(gcs_prefix, replica_chunk.manifest_name), @@ -284,22 +293,29 @@ def _process_chunk(self, replica_chunk: PpdbReplicaChunkExtended) -> None: except UploadError as e: raise ChunkUploadError(chunk_id, "Manifest upload failed") from e - # 3) Update DB status, but not for empty chunks. + # Next two steps are inapplicable to empty chunks. if not is_empty: + # 3) Update status and GCS URI in the database + gcs_uri = posixpath.join(self.bucket_name, gcs_prefix) + updated_replica_chunk = replica_chunk.with_new_status(ChunkStatus.UPLOADED).with_new_gcs_uri( + f"gs://{gcs_uri}" + ) try: - self._bq.store_chunk(replica_chunk.with_new_status(ChunkStatus.UPLOADED), True) + self._bq.store_chunk(updated_replica_chunk, True) + _LOG.info( + "Updated replica chunk %d in database with status 'uploaded' and GCS URI: %s", + chunk_id, + gcs_uri, + ) except Exception as e: - raise ChunkUploadError( - chunk_id, "failed to update replica chunk status in database" - ) from e + raise ChunkUploadError(chunk_id, "Failed to update replica chunk in database") from e - # 4) Publish Pub/Sub staging message to trigger BigQuery load, but - # not for empty chunks. (Empty chunks cannot be staged.) - if not is_empty: + # 4) Publish Pub/Sub event to trigger staging of the chunk in + # BigQuery try: - self._post_to_stage_chunk_topic(self.bucket_name, gcs_prefix, chunk_id) + self._post_to_stage_chunk_topic(gcs_uri, chunk_id) except Exception as e: - raise ChunkUploadError(chunk_id, "failed to publish staging message") from e + raise ChunkUploadError(chunk_id, "Failed to publish staging message") from e except ChunkUploadError as err: try: @@ -310,17 +326,14 @@ def _process_chunk(self, replica_chunk: PpdbReplicaChunkExtended) -> None: except DeleteError as cleanup_err: # Note (Python 3.11+): annotate without masking the # original error. - err.add_note( - f"cleanup warning: failed to delete " - f"gs://{posixpath.join(self.bucket_name, gcs_prefix)}: {cleanup_err}" - ) + err.add_note(f"cleanup warning: failed to delete gs://{gcs_uri}: {cleanup_err}") raise - def _post_to_stage_chunk_topic(self, bucket_name: str, chunk_prefix: str, chunk_id: int) -> None: + def _post_to_stage_chunk_topic(self, gcs_uri: str, chunk_id: int) -> None: message = { "dataset": self.dataset_id, "chunk_id": str(chunk_id), - "folder": f"gs://{posixpath.join(bucket_name, chunk_prefix)}", + "folder": f"gs://{gcs_uri}", } self.publisher.publish(message).result(timeout=60) diff --git a/python/lsst/dax/ppdb/bigquery/manifest.py b/python/lsst/dax/ppdb/bigquery/manifest.py index b53c6f5a..da0fa456 100644 --- a/python/lsst/dax/ppdb/bigquery/manifest.py +++ b/python/lsst/dax/ppdb/bigquery/manifest.py @@ -79,6 +79,10 @@ class Manifest(BaseModel): """Name of the compression format used for artifacts (e.g., "gzip", "zstd", "snappy", etc.).""" + includes_update_records: bool = False + """Whether the exported data includes update records (e.g., in a separate + file) or not (`bool`).""" + @property def filename(self) -> str: """Generate the filename for this manifest based on the replica chunk @@ -118,12 +122,15 @@ def from_json_file(cls, file_path: Path) -> Manifest: def is_empty_chunk(self) -> bool: """Check if the manifest represents an empty replica chunk in which - all tables have zero rows. + all tables have zero rows and no update records are included. Returns ------- bool - `True` if all tables have zero rows, indicating an empty chunk, - `False` otherwise. + `True` if all tables have zero rows and no update records are + included, indicating an empty chunk, `False` otherwise. """ - return all(table.row_count == 0 for table in self.table_data.values()) + return ( + all(table.row_count == 0 for table in self.table_data.values()) + and not self.includes_update_records + ) diff --git a/python/lsst/dax/ppdb/bigquery/ppdb_bigquery.py b/python/lsst/dax/ppdb/bigquery/ppdb_bigquery.py index 877bb229..61cdc416 100644 --- a/python/lsst/dax/ppdb/bigquery/ppdb_bigquery.py +++ b/python/lsst/dax/ppdb/bigquery/ppdb_bigquery.py @@ -19,14 +19,19 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + import datetime import logging +import os import shutil from collections.abc import Collection, Iterable, Sequence from pathlib import Path +from typing import Any import felis import sqlalchemy +from google.cloud import secretmanager from lsst.dax.apdb import ( ApdbMetadata, @@ -41,11 +46,14 @@ from lsst.dax.apdb.timer import Timer from .._arrow import write_parquet -from ..config import PpdbConfig from ..ppdb import Ppdb, PpdbReplicaChunk -from ..sql import PpdbSqlBase, PpdbSqlBaseConfig +from ..ppdb_config import PpdbConfig +from ..sql import PasswordProvider, PpdbSqlBase, PpdbSqlBaseConfig from .manifest import Manifest, TableStats from .ppdb_replica_chunk_extended import ChunkStatus, PpdbReplicaChunkExtended +from .query_runner import QueryRunner +from .sql_resource import SqlResource +from .updates.update_records import UpdateRecords __all__ = ["ConfigValidationError", "PpdbBigQuery", "PpdbBigQueryConfig"] @@ -113,6 +121,30 @@ def fq_dataset_id(self) -> str: return f"{self.project_id}:{self.dataset_id}" +class _SecretManagerPasswordProvider(PasswordProvider): + """Retrieves a database password from Google Cloud Secret Manager. + + Parameters + ---------- + project_id : `str` + GCP project that owns the secret. + secret_name : `str`, optional + Name of the secret. Defaults to ``"ppdb-db-password"``. + """ + + def __init__(self, project_id: str, secret_name: str = "ppdb-db-password") -> None: + self._project_id = project_id + self._secret_name = secret_name + + def get_password(self) -> str: + """Return the password fetched from Secret Manager.""" + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{self._project_id}/secrets/{self._secret_name}/versions/latest" + _LOG.info("Retrieving database password from Secret Manager: %s", name) + response = client.access_secret_version(request={"name": name}) + return response.payload.data.decode("UTF-8") + + class ConfigValidationError(Exception): """Indicates an error validating the configuration.""" @@ -127,30 +159,73 @@ class PpdbBigQuery(Ppdb, PpdbSqlBase): """ def __init__(self, config: PpdbBigQueryConfig): - # Initialize the SQL interface for the PPDB. - PpdbSqlBase.__init__(self, config.sql) + # Build an optional password provider for GCP Secret Manager + password_provider: PasswordProvider | None = None + if os.getenv("PPDB_USE_SECRET_MANAGER", "false").lower() == "true": + _LOG.info("Using Secret Manager to retrieve database password") + password_provider = _SecretManagerPasswordProvider(config.project_id) - # Read parameters from config. + # Delegate SQL initialisation (schema load, engine, metadata, version + # checks) to the base class, passing the optional password provider + PpdbSqlBase.__init__(self, config.sql, password_provider=password_provider) + + # Read parameters from config if config.replication_dir is None: raise ValueError("Directory for chunk export is not set in configuration.") self.replication_path = config.replication_path self.parq_batch_size = config.parq_batch_size self.parq_compression = config.parq_compression self.delete_existing_dirs = config.delete_existing_dirs + self.project_id = config.project_id + + self._config = config + + self._query_runner: QueryRunner | None = None @property def metadata(self) -> ApdbMetadata: - """Implement `Ppdb` interface to return APDB metadata object. + """APDB metadata object from `Ppdb` interface (`ApdbMetadata`).""" + return self._metadata + + @property + def config(self) -> PpdbBigQueryConfig: + """PPDB config associated with this instance.""" + return self._config + + @property + def query_runner(self) -> QueryRunner: + """Query runner for executing SQL in BigQuery + (`~lsst.dax.ppdb.bigquery.QueryRunner`). + """ + if not self._query_runner: + self._query_runner = QueryRunner(self.config.project_id, self.config.dataset_id) + return self._query_runner + + @classmethod + def from_env(cls) -> PpdbBigQuery: + """Create an instance of this class from a config pointed to by an + environment variable. Returns ------- - metadata : `ApdbMetadata` - APDB metadata object. + ppdb: `PpdbBigQuery` + An instance of the PPDB BigQuery interface. """ - return self._metadata + ppdb_config_uri = os.environ.get("PPDB_CONFIG_URI", None) + if ppdb_config_uri: + logging.info("PPDB_CONFIG_URI: %s", ppdb_config_uri) + else: + raise OSError("PPDB_CONFIG_URI is not set in the environment") + ppdb = Ppdb.from_uri(ppdb_config_uri) + if not isinstance(ppdb, PpdbBigQuery): + raise ValueError(f"Ppdb from environment has wrong type: {type(ppdb)}") + return ppdb def _generate_manifest( - self, replica_chunk: ReplicaChunk, table_dict: dict[str, ApdbTableData] + self, + replica_chunk: ReplicaChunk, + table_dict: dict[str, ApdbTableData], + update_records: Collection[ApdbUpdateRecord], ) -> Manifest: """Generate the manifest data for the replica chunk.""" return Manifest( @@ -163,6 +238,7 @@ def _generate_manifest( table_name: TableStats(row_count=len(data.rows())) for table_name, data in table_dict.items() }, compression_format=self.parq_compression, + includes_update_records=bool(update_records), ) def store( @@ -178,22 +254,11 @@ def store( # Docstring is inherited. _LOG.info("Processing %s", replica_chunk.id) - # TODO: APDB does not generate ApdbUpdateRecords yet, but we will - # eventually have to add support for it. - if update_records: - raise NotImplementedError("PpdbBigQuery does not support record updates yet.") - try: - chunk_dir = self._get_chunk_path(replica_chunk) - - if chunk_dir.exists(): - if not self.delete_existing_dirs: - raise FileExistsError(f"Directory already exists for {replica_chunk.id}: {chunk_dir}") - _LOG.warning("Overwriting existing directory for %s: %s", replica_chunk.id, chunk_dir) - shutil.rmtree(chunk_dir) + chunk_dir = self._create_chunk_dir(replica_chunk) - chunk_dir.mkdir(parents=True) - _LOG.info("Created directory for %s: %s", replica_chunk.id, chunk_dir) + if update_records: + self._handle_updates(replica_chunk, update_records, chunk_dir) table_dict = { ApdbTables.DiaObject.value: objects, @@ -227,7 +292,7 @@ def store( # Create manifest for the replica chunk. try: - manifest = self._generate_manifest(replica_chunk, table_dict) + manifest = self._generate_manifest(replica_chunk, table_dict, update_records) _LOG.info("Generated manifest for %s: %s", replica_chunk.id, manifest.model_dump_json()) except Exception: _LOG.exception("Failed to generate manifest for %d", replica_chunk.id) @@ -261,15 +326,32 @@ def store( _LOG.info("Done processing %s", replica_chunk.id) - def _get_chunk_path(self, chunk: ReplicaChunk) -> Path: + def _create_chunk_dir(self, chunk: ReplicaChunk) -> Path: + """Create the directory for the replica chunk based on its last update + time and ID. + + Returns + ------- + chunk_dir + Path to the created directory for the replica chunk. + """ last_update_time = chunk.last_update_time.to_datetime() assert isinstance(last_update_time, datetime.datetime) - path = Path( + chunk_dir = Path( self.replication_path, chunk.last_update_time.strftime("%Y/%m/%d"), str(chunk.id), ) - return path + if chunk_dir.exists(): + if not self.delete_existing_dirs: + raise FileExistsError(f"Directory already exists for {chunk.id}: {chunk_dir}") + _LOG.warning("Overwriting existing directory for %s: %s", chunk.id, chunk_dir) + shutil.rmtree(chunk_dir) + + chunk_dir.mkdir(parents=True) + _LOG.info("Created directory for %s: %s", chunk.id, chunk_dir) + + return chunk_dir def get_replica_chunks(self, start_chunk_id: int | None = None) -> Sequence[PpdbReplicaChunk] | None: # Docstring is inherited. @@ -306,6 +388,7 @@ def get_replica_chunks_ext( table.columns["replica_time"], table.columns["status"], # Extended column table.columns["directory"], # Extended column + table.columns["gcs_uri"], # Extended column ).order_by(table.columns["last_update_time"]) if start_chunk_id is not None: query = query.where(table.columns["apdb_replica_chunk"] >= start_chunk_id) @@ -325,10 +408,61 @@ def get_replica_chunks_ext( replica_time=replica_time, status=row[4], directory=Path(row[5]), + gcs_uri=row[6], ) ) return ids + def get_replica_chunks_ext_by_ids(self, chunk_ids: Sequence[int]) -> Sequence[PpdbReplicaChunkExtended]: + """Find replica chunks for a list of chunk IDs. + + Parameters + ---------- + chunk_ids : `~collections.abc.Sequence` [ `int` ] + Replica chunk IDs to retrieve. + + Returns + ------- + chunks : `~collections.abc.Sequence` [ `PpdbReplicaChunkExtended` ] + List of matching chunks ordered by ``last_update_time``. + """ + if not chunk_ids: + return [] + + table = self.get_table("PpdbReplicaChunk") + query = ( + sqlalchemy.sql.select( + table.columns["apdb_replica_chunk"], + table.columns["last_update_time"], + table.columns["unique_id"], + table.columns["replica_time"], + table.columns["status"], + table.columns["directory"], + table.columns["gcs_uri"], + ) + .where(table.columns["apdb_replica_chunk"].in_(chunk_ids)) + .order_by(table.columns["apdb_replica_chunk"]) + ) + + chunks: list[PpdbReplicaChunkExtended] = [] + with self._engine.connect() as conn: + result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) + for row in result: + last_update_time = self.to_astropy_tai(row[1]) + replica_time = self.to_astropy_tai(row[3]) + chunks.append( + PpdbReplicaChunkExtended( + id=row[0], + last_update_time=last_update_time, + unique_id=row[2], + replica_time=replica_time, + status=row[4], + directory=Path(row[5]), + gcs_uri=row[6], + ) + ) + return chunks + def store_chunk(self, replica_chunk: PpdbReplicaChunkExtended, update: bool) -> None: """Insert or replace single record in PpdbReplicaChunk table, including the status and directory of the replica chunk. @@ -352,6 +486,7 @@ def store_chunk(self, replica_chunk: PpdbReplicaChunkExtended, update: bool) -> "replica_time": replica_chunk.replica_time_dt_utc, "status": replica_chunk.status, "directory": str(replica_chunk.directory), + "gcs_uri": replica_chunk.gcs_uri, } if update: self.upsert(connection, table, row, "apdb_replica_chunk") @@ -389,6 +524,12 @@ def create_replica_chunk_table(cls, table_name: str | None = None) -> schema_mod datatype=felis.datamodel.DataType.string, nullable=True, # We might want to allow NULL if an error occurs when exporting. ), + schema_model.Column( + name="gcs_uri", + id=f"#{table_name}.gcs_uri", + datatype=felis.datamodel.DataType.string, + nullable=True, + ), ] ) return replica_chunk_table @@ -466,7 +607,6 @@ def init_bigquery( sql_config = PpdbSqlBaseConfig( db_url=db_url, schema_name=db_schema, felis_path=felis_path, felis_schema=felis_schema ) - cls.make_database(sql_config, sa_metadata, schema_version, db_drop) # Build config parameters. bq_config = PpdbBigQueryConfig( @@ -485,6 +625,13 @@ def init_bigquery( if stage_chunk_topic is not None: bq_config.stage_chunk_topic = stage_chunk_topic + password_provider: PasswordProvider | None = None + if os.getenv("PPDB_USE_SECRET_MANAGER", "false").lower() == "true": + _LOG.info("Using Secret Manager to retrieve database password") + password_provider = _SecretManagerPasswordProvider(bq_config.project_id) + engine = cls.make_engine(bq_config.sql, password_provider=password_provider) + cls.make_database(engine, bq_config.sql, sa_metadata, schema_version, db_drop) + # Validate the config if requested. if validate_config: _LOG.info("validating BigQuery configuration") @@ -567,3 +714,139 @@ def validate_config(cls, config: PpdbBigQueryConfig) -> None: check_dataset_exists(config.project_id, config.dataset_id) except Exception as e: raise ConfigValidationError("Failed to validate BigQuery dataset") from e + + def _handle_updates( + self, replica_chunk: ReplicaChunk, apdb_update_records: Collection[ApdbUpdateRecord], chunk_dir: Path + ) -> None: + """Handle updates to existing records in the PPDB by writing a JSON + file with the update information for the replica chunk. + + Parameters + ---------- + replica_chunk : `ReplicaChunk` + The replica chunk associated with the updates. + update_records : `~collections.abc.Collection` [ `ApdbUpdateRecord` ] + Collection of update records to process. + + Notes + ----- + Serializes the ApdbUpdateRecord objects into a dictionary structure + for processing. + """ + update_records = UpdateRecords( + replica_chunk_id=replica_chunk.id, + records=list(apdb_update_records), + record_count=len(apdb_update_records), + ) + update_records.write_json_file(chunk_dir / "update_records.json") + + _LOG.info( + "Saved %d update records for %s to %s", + update_records.record_count, + replica_chunk.id, + chunk_dir / "update_records.json", + ) + + def get_promotable_chunks(self) -> list[int]: + """ + Return the first uninterrupted sequence of staged chunks such that all + prior chunks are promoted. + + Returns + ------- + chunk_ids : `list`[`int`] + A list of tuples containing the ``apdb_replica_chunk`` values of + the promotable chunks. + + Notes + ----- + This query finds the contiguous sequence of ``staged`` chunks beginning + with the earliest chunk that is not yet ``promoted``, and ending just + before the first chunk that is not ``staged``. If no such ending + exists, all ``staged`` chunks from that point onward are returned. If + no chunks are ``staged`` after the first non-``promoted`` chunk, an + empty list is returned. + """ + table = self.get_table("PpdbReplicaChunk") + if not table.schema: + raise ValueError("Table schema is not set, cannot construct query") + quoted_table_name = ( + self._engine.dialect.identifier_preparer.quote(table.schema) + + "." + + self._engine.dialect.identifier_preparer.quote(table.name) + ) + + sql = SqlResource("select_promotable_chunks", {"table_name": quoted_table_name}).sql + + with self._engine.connect() as conn: + result = conn.execute(sqlalchemy.text(sql)) + chunk_ids = [row[0] for row in result] + return chunk_ids + + def mark_chunks_promoted(self, promotable_chunks: list[int]) -> int: + """Set status='promoted' for the given chunk IDs. Returns number + updated. + + Parameters + ---------- + promotable_chunks : `list`[`int`] + List of integers containing the ``apdb_replica_chunk`` values of + the promotable chunks. + + Returns + ------- + count: `int` + The number of rows updated in the database, which should be equal + to the number of promotable chunks provided, if they were all found + and updated successfully. + """ + table = self.get_table("PpdbReplicaChunk") + stmt = ( + sqlalchemy.update(table) + .where(table.c.apdb_replica_chunk.in_(promotable_chunks), table.c.status != "promoted") + .values(status="promoted") + ) + + with self._engine.begin() as conn: + result: sqlalchemy.engine.CursorResult = conn.execute(stmt) + return result.rowcount or 0 + + def update(self, chunk_id: int, values: dict[str, Any]) -> int: + """Update an existing replica chunk in the database. + + Parameters + ---------- + chunk_id : `int` + The ID of the replica chunk to update. + values : `dict`[`str`, `Any`] + A dictionary of column names and their new values to update. + + Returns + ------- + count : `int` + The number of rows updated. This should be 1 if the update is + successful, or 0 if no rows were updated (e.g., if the chunk ID + does not exist or the status is already set to the new value). + """ + logging.info("Preparing to update replica chunk %d with values: %s", chunk_id, values) + table = self.get_table("PpdbReplicaChunk") + stmt = sqlalchemy.update(table).where(table.c.apdb_replica_chunk == chunk_id).values(values) + with self._engine.begin() as conn: + result = conn.execute(stmt) + affected_rows = result.rowcount + + new_status = values.get("status") + if affected_rows == 0: + logging.warning( + "No rows updated for replica chunk %s with status '%s'", + chunk_id, + new_status, + ) + else: + logging.info( + "Successfully updated %d row(s) for replica chunk %s to status '%s'", + affected_rows, + chunk_id, + new_status, + ) + return affected_rows diff --git a/python/lsst/dax/ppdb/bigquery/ppdb_replica_chunk_extended.py b/python/lsst/dax/ppdb/bigquery/ppdb_replica_chunk_extended.py index bd8d6422..7649c5c5 100644 --- a/python/lsst/dax/ppdb/bigquery/ppdb_replica_chunk_extended.py +++ b/python/lsst/dax/ppdb/bigquery/ppdb_replica_chunk_extended.py @@ -43,6 +43,10 @@ class ChunkStatus(StrEnum): """Chunk has been exported from the APDB to a local parquet file.""" UPLOADED = "uploaded" """Chunk has been uploaded to cloud storage.""" + STAGED = "staged" + """Chunk data has been copied into the staging tables.""" + PROMOTED = "promoted" + """Chunk data has been promoted from the staging to production tables.""" FAILED = "failed" """Chunk processing failed and an error occurred.""" SKIPPED = "skipped" @@ -59,6 +63,10 @@ class PpdbReplicaChunkExtended(PpdbReplicaChunk): directory: Path """Directory where the exported replica chunk data is stored locally.""" + gcs_uri: str | None = None + """GCS URI where the replica chunk data is stored, or `None` if not + uploaded yet.""" + @property def manifest_name(self) -> str: """Filename of the manifest file for this chunk.""" @@ -127,3 +135,19 @@ def with_new_status(self, new_status: ChunkStatus) -> PpdbReplicaChunkExtended: The new chunk with the updated status. """ return dataclasses.replace(self, status=new_status) + + def with_new_gcs_uri(self, new_gcs_uri: str) -> PpdbReplicaChunkExtended: + """Create a new `PpdbReplicaChunkExtended` with the same properties as + this one, but with a different GCS URI. + + Parameters + ---------- + new_gcs_uri : `str` + The new GCS URI to set. + + Returns + ------- + new_chunk : `PpdbReplicaChunkExtended` + The new chunk with the updated GCS URI. + """ + return dataclasses.replace(self, gcs_uri=new_gcs_uri) diff --git a/python/lsst/dax/ppdb/bigquery/query_runner.py b/python/lsst/dax/ppdb/bigquery/query_runner.py new file mode 100644 index 00000000..e9901658 --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/query_runner.py @@ -0,0 +1,139 @@ +# This file is part of dax_ppdbx_gcp +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = [ + "QueryRunner", +] + +import logging + +from google.cloud import bigquery + + +class QueryRunner: + """Class to run BigQuery queries with logging. + + Parameters + ---------- + project_id : `str` + Google Cloud project ID. + dataset_id : `str` + BigQuery dataset ID. + """ + + def __init__(self, project_id: str, dataset_id: str): + self._project_id = project_id + self._dataset_id = dataset_id + self._bq_client = bigquery.Client(project=project_id) + self._dataset = self._bq_client.get_dataset(f"{project_id}.{dataset_id}") + self._location = self._dataset.location + + @property + def project_id(self) -> str: + """Google Cloud project ID (`str`, read-only).""" + return self._project_id + + @property + def dataset(self) -> bigquery.Dataset: + """Dataset reference (`bigquery.Dataset`, read-only).""" + return self._dataset + + @property + def dataset_id(self) -> str: + """Dataset ID (`str`, read-only).""" + return self._dataset_id + + @property + def location(self) -> str: + """Dataset location, typically the region where it is hosted (`str`, + read-only). + """ + return self._location + + @classmethod + def log_job( + cls, + job: bigquery.job.QueryJob + | bigquery.job.LoadJob + | bigquery.job.CopyJob + | bigquery.job.ExtractJob + | bigquery.job.UnknownJob, + label: str, + level: int = logging.DEBUG, + ) -> None: + """Log details of a BigQuery job. + + Parameters + ---------- + job : `bigquery.job.QueryJob` + The BigQuery job to log. + label : `str` + A label for the job, typically indicating the type of operation + (e.g., "insert", "delete", "copy"). + level : `int`, optional + The logging level to use for the log message. Defaults to + `logging.DEBUG`. + """ + logging.log( + level, + "BQ %s: job_id=%s location=%s state=%s bytes_processed=%s bytes_billed=%s slot_millis=%s " + "dml_rows=%s reference_tables=%s", + label, + job.job_id, + job.location, + job.state, + getattr(job, "total_bytes_processed", None), + getattr(job, "total_bytes_billed", None), + getattr(job, "slot_millis", None), + getattr(job, "num_dml_affected_rows", None), + getattr(job, "referenced_tables", None), + ) + + def run_job( + self, label: str, sql: str, job_config: bigquery.QueryJobConfig | None = None + ) -> bigquery.job.QueryJob: + """Run a BigQuery job with the given SQL and configuration. + + Parameters + ---------- + label : `str` + A label for the job, typically indicating the type of operation + (e.g., "insert", "delete", "copy"). + sql : `str` + The SQL query to execute. + job_config : `bigquery.QueryJobConfig`, optional + Configuration for the job, such as query parameters or write + dispositions. If not provided, a default configuration will be + used. + + Returns + ------- + job: `bigquery.job.QueryJob` + The BigQuery job object representing the executed query. This can + be used to check the status of the job, retrieve results, or log + additional details. + """ + job = self._bq_client.query(sql, job_config=job_config, location=self.dataset.location) + job.result() # Wait for the job to complete + self.log_job(job, label) + return job diff --git a/python/lsst/dax/ppdb/bigquery/replica_chunk_promoter.py b/python/lsst/dax/ppdb/bigquery/replica_chunk_promoter.py new file mode 100644 index 00000000..424b8a27 --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/replica_chunk_promoter.py @@ -0,0 +1,242 @@ +# This file is part of dax_ppdbx_gcp +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = [ + "NoPromotableChunksError", + "ReplicaChunkPromoter", +] + +import logging + +from google.api_core.exceptions import NotFound +from google.cloud import bigquery + +from .ppdb_bigquery import PpdbBigQuery +from .query_runner import QueryRunner +from .updates.updates_manager import UpdatesManager + + +class NoPromotableChunksError(Exception): + """Exception raised when there are no promotable chunks available.""" + + pass + + +class ReplicaChunkPromoter: + """Class to promote replica chunks in BigQuery. + + Parameters + ---------- + ppdb : `PpdbBigQuery` + Interface to the PPDB in BigQuery. + table_names : `list`[`str`], optional + List of table names to promote or if None a default list will be used. + """ + + def __init__( + self, + ppdb: PpdbBigQuery, + table_names: list[str] | None = None, + ): + self._ppdb = ppdb + self._project_id = self._ppdb._config.project_id + self._dataset_id = self._ppdb._config.dataset_id + self._runner = ppdb.query_runner + # DM-52326: Hard-coded table names; these should be passed in from + # config. + self._table_names = table_names or ["DiaObject", "DiaSource", "DiaForcedSource"] + self._bq_client = bigquery.Client(project=self._runner.project_id) + self._phases = { + "get_promotable_chunks": self._get_promotable_chunks, + "build_tmp": self._copy_to_promoted_tmp, + "apply_record_updates": self._apply_record_updates, + "promote_prod": self._promote_tmp_to_prod, + "delete_staged_chunks": self._delete_staged_chunks, + "mark_promoted": self._mark_chunks_promoted, + } + + self._promotable_chunks: list[int] = [] + + @property + def promotable_chunks(self) -> list[int]: + """List of promotable chunks (`list` [ `int` ], read-only).""" + return self._promotable_chunks + + @promotable_chunks.setter + def promotable_chunks(self, chunks: list[int]) -> None: + if not chunks: + raise NoPromotableChunksError("No promotable chunks provided") + self._promotable_chunks = chunks + + @property + def table_prod_refs(self) -> list[str]: + """Fully-qualified production table references (`list`[`str`], + read-only). + """ + return [f"{self._project_id}.{self._dataset_id}.{table_name}" for table_name in self._table_names] + + @property + def table_staging_refs(self) -> list[str]: + """Fully-qualified staging table references (`list`[`str`], + read-only). + """ + return [ + f"{self._project_id}.{self._dataset_id}._{table_name}_staging" for table_name in self._table_names + ] + + @property + def table_promoted_tmp_refs(self) -> list[str]: + """Fully-qualified promoted temporary table references (`list`[`str`], + read-only). + """ + return [ + f"{self._project_id}.{self._dataset_id}._{table_name}_promoted_tmp" + for table_name in self._table_names + ] + + def _execute_phase(self, phase: str) -> None: + """Execute a specific promotion phase. + + Parameters + ---------- + phase : `str` + The name of the promotion phase to execute. This should be one of + the keys in the `phases` property. + """ + if phase not in self._phases: + raise ValueError(f"Unknown promotion phase: {phase}") + logging.debug("Executing promotion phase: %s", phase) + self._phases[phase]() + + def _get_promotable_chunks(self) -> None: + """Get list of promotable chunks from the database.""" + self._promotable_chunks = self._ppdb.get_promotable_chunks() + logging.info("Promotable chunk count: %s", len(self.promotable_chunks)) + + def _copy_to_promoted_tmp(self) -> None: + """ + Build ``_{table_name}_promoted_tmp`` efficiently by cloning prod and + inserting only staged rows for the given replica chunk IDs. + """ + job_cfg = bigquery.QueryJobConfig( + query_parameters=[bigquery.ArrayQueryParameter("ids", "INT64", self.promotable_chunks)] + ) + + for prod_ref, tmp_ref, stage_ref in zip( + self.table_prod_refs, self.table_promoted_tmp_refs, self.table_staging_refs, strict=False + ): + # Drop any existing tmp table (should not exist but just to be + # safe) + self._runner.run_job("drop_tmp", f"DROP TABLE IF EXISTS `{tmp_ref}`") + + # Clone prod table structure and data (zero-copy) + self._runner.run_job("clone_prod", f"CREATE TABLE `{tmp_ref}` CLONE `{prod_ref}`") + + # Build ordered target list from the cloned tmp schema + tmp_schema = self._bq_client.get_table(tmp_ref).schema + target_names = [f.name for f in tmp_schema if f.name != "apdb_replica_chunk"] + target_list_sql = ", ".join(f"`{n}`" for n in target_names) + + # Build source list, handling geo_point conversion + source_list_sql = ", ".join( + "ST_GEOGPOINT(s.`ra`, s.`dec`)" if n == "geo_point" else f"s.`{n}`" for n in target_names + ) + + # Insert staged rows into tmp, excluding apdb_replica_chunk column + sql = f""" + INSERT INTO `{tmp_ref}` ({target_list_sql}) + SELECT {source_list_sql} + FROM `{stage_ref}` AS s + WHERE s.apdb_replica_chunk IN UNNEST(@ids) + """ + logging.debug("SQL for inserting staged rows into %s: %s", tmp_ref, sql) + self._runner.run_job("insert_staged_to_tmp", sql, job_config=job_cfg) + + def _promote_tmp_to_prod(self) -> None: + """ + Swap each prod table with its corresponding *_promoted_tmp by replacing + prod contents in a single atomic copy job. This preserves schema, + partitioning, and clustering with zero-copy when in the same dataset. + """ + for prod_ref, tmp_ref in zip(self.table_prod_refs, self.table_promoted_tmp_refs, strict=False): + # Ensure tmp exists + try: + self._bq_client.get_table(tmp_ref) + except NotFound as e: + raise RuntimeError(f"Missing tmp table for promotion: {tmp_ref}") from e + + # Atomic zero-copy replacement of prod with tmp + copy_cfg = bigquery.CopyJobConfig(write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE) + job = self._bq_client.copy_table( + tmp_ref, prod_ref, job_config=copy_cfg, location=self._runner.location + ) + job.result() + QueryRunner.log_job(job, "promote_tmp_to_prod") + + def _cleanup(self) -> None: + """Drop the promotion temporary tables.""" + for tmp_ref in self.table_promoted_tmp_refs: + self._bq_client.delete_table(tmp_ref, not_found_ok=True) + logging.debug("Dropped %s (if it existed)", tmp_ref) + + def _delete_staged_chunks(self) -> None: + """Delete only rows for the promoted replica chunk IDs from each + staging table. + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[bigquery.ArrayQueryParameter("ids", "INT64", self.promotable_chunks)] + ) + + for staging_ref in self.table_staging_refs: + try: + sql = f"DELETE FROM `{staging_ref}` WHERE apdb_replica_chunk IN UNNEST(@ids)" + self._runner.run_job("delete_staged_chunks", sql, job_config=job_config) + logging.debug( + "Deleted %d chunk(s) from staging table %s", len(self.promotable_chunks), staging_ref + ) + except NotFound: + logging.warning("Staging table %s does not exist, skipping delete", staging_ref) + + def _apply_record_updates(self) -> None: + """Apply record updates to the promoted temporary tables.""" + updates_manager = UpdatesManager(self._ppdb, table_name_postfix="_promoted_tmp") + updates_manager.apply_updates(self.promotable_chunks) + + def _mark_chunks_promoted(self) -> None: + """Mark the replica chunks as promoted in the database.""" + self._ppdb.mark_chunks_promoted(self.promotable_chunks) + + def promote_chunks(self) -> None: + """Promote APDB replica chunks into production by executing a series of + phases. + """ + try: + for phase in self._phases.keys(): + self._execute_phase(phase) + finally: + try: + # Cleanup is always executed separately, not as an ordered + # phase. + self._cleanup() + except Exception: + logging.exception("Cleanup of chunk promotion failed") diff --git a/python/lsst/dax/ppdb/bigquery/sql_resource.py b/python/lsst/dax/ppdb/bigquery/sql_resource.py new file mode 100644 index 00000000..99d4f89d --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/sql_resource.py @@ -0,0 +1,56 @@ +# This file is part of dax_ppdb +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from lsst.resources import ResourcePath + + +class SqlResource: + """Class for loading SQL query text from a resource file and optionally + formatting it with provided arguments. + + Parameters + ---------- + sql_resource_name : `str` + Base name of the SQL file (without .sql extension) containing the + query. + The SQL file must be located in the `lsst.dax.ppdb.config.sql` package. + format_args : `dict` [ `str`, `str` ], optional + Optional dictionary of arguments for formatting the SQL text. + """ + + def __init__(self, sql_resource_name: str, format_args: dict[str, str] | None = None) -> None: + # FIXME: Move the config dir into a resources dir (similar to obs_lsst) + sql_resource_path = f"resource://lsst.dax.ppdb/config/sql/{sql_resource_name}.sql" + sql = ResourcePath(sql_resource_path).read().decode("utf-8") + if format_args is not None: + try: + sql = sql.format(**format_args) + except Exception as e: + raise RuntimeError( + f"Failed to format SQL resource at {sql_resource_path} with arguments {format_args}" + ) from e + self._sql = sql + + @property + def sql(self) -> str: + """SQL query string (`str`).""" + return self._sql diff --git a/python/lsst/dax/ppdb/bigquery/updates/__init__.py b/python/lsst/dax/ppdb/bigquery/updates/__init__.py new file mode 100644 index 00000000..1536958d --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/updates/__init__.py @@ -0,0 +1,31 @@ +# This file is part of dax_ppdb +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from .expanded_update_record import ExpandedUpdateRecord +from .updates_merger import ( + UpdatesMerger, + DiaObjectUpdatesMerger, + DiaSourceUpdatesMerger, + DiaForcedSourceUpdatesMerger, +) +from .update_records import UpdateRecords +from .update_record_expander import UpdateRecordExpander +from .updates_table import UpdatesTable diff --git a/python/lsst/dax/ppdb/bigquery/updates/expanded_update_record.py b/python/lsst/dax/ppdb/bigquery/updates/expanded_update_record.py new file mode 100644 index 00000000..d59c9785 --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/updates/expanded_update_record.py @@ -0,0 +1,80 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + + +class ExpandedUpdateRecord(BaseModel): + """ + A single normalized (expanded) update row. + + This model represents one field-level update after expanding an + original logical update event into one row per updated field. + It is the canonical shape loaded into the BigQuery updates table. + """ + + table_name: str = Field( + ..., + min_length=1, + description=("Logical target table for the update (e.g., 'DiaObject', 'DiaSource')."), + ) + + record_id: list[int] = Field( + ..., + description=( + "Identifier of the record being updated. For update types with a single record ID, this " + "will be a list of one element. For updates on records with a composite key " + "(e.g., DiaForcedSource), this will include all components of the key, in order." + ), + ) + + field_name: str = Field( + ..., + min_length=1, + description=("Name of the target column being updated."), + ) + + value_json: Any = Field( + ..., + description=("JSON-serializable new value for the field."), + ) + + replica_chunk_id: int = Field( + ..., + ge=0, + description=("Source replica chunk identifier associated with this update."), + ) + + update_order: int | None = Field( + default=None, + ge=0, + description=("Ordering value within the replica chunk or update batch."), + ) + + update_time_ns: int | None = Field( + default=None, + ge=0, + description=("Source event timestamp in nanoseconds since the epoch."), + ) diff --git a/python/lsst/dax/ppdb/bigquery/updates/update_record_expander.py b/python/lsst/dax/ppdb/bigquery/updates/update_record_expander.py new file mode 100644 index 00000000..dc7b57ad --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/updates/update_record_expander.py @@ -0,0 +1,243 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import hashlib +import logging + +from lsst.dax.apdb.apdbUpdateRecord import ApdbUpdateRecord + +from .expanded_update_record import ExpandedUpdateRecord +from .update_records import UpdateRecords + + +class UpdateRecordExpander: + """Expand APDB update records into individual field-level updates for + BigQuery. + """ + + _UPDATE_FIELD_MAPPING = { + "reassign_diasource_to_diaobject": ["diaObjectId"], + "reassign_diasource_to_ssobject": ["ssObjectId", "ssObjectReassocTimeMjdTai"], + "withdraw_diasource": ["timeWithdrawnMjdTai"], + "withdraw_diaforcedsource": ["timeWithdrawnMjdTai"], + "close_diaobject_validity": ["validityEndMjdTai", "nDiaSources"], + "update_n_dia_sources": ["nDiaSources"], + } + + _RECORD_ID_FIELD_MAPPING = { + "reassign_diasource_to_diaobject": ["diaSourceId"], + "reassign_diasource_to_ssobject": ["diaSourceId"], + "withdraw_diasource": ["diaSourceId"], + "withdraw_diaforcedsource": ["diaObjectId", "visit", "detector"], + "close_diaobject_validity": ["diaObjectId"], + "update_n_dia_sources": ["diaObjectId"], + } + + @classmethod + def get_update_fields(cls, update_type: str) -> list[str]: + """Get the names of fields to update for a given update type. + + Parameters + ---------- + update_type : `str` + The type of update record. + + Returns + ------- + field_names : `list` [ `str` ] + List of field names that should be updated for this update type. + + Raises + ------ + ValueError + If the update_type is not recognized. + """ + if update_type not in cls._UPDATE_FIELD_MAPPING: + raise ValueError(f"Unknown update_type: {update_type}") + + return cls._UPDATE_FIELD_MAPPING[update_type] + + @classmethod + def get_record_id_fields(cls, update_type: str) -> str | list[str]: + """Get the field name(s) that serve as the record ID for a given update + type. + + Parameters + ---------- + update_type : `str` + The type of update record. + + Returns + ------- + field_name : `str` or `list` [ `str` ] + Name of the field that contains the record ID for this update type, + or list of field names for composite keys. + + Raises + ------ + ValueError + If the update_type is not recognized. + """ + if update_type not in cls._RECORD_ID_FIELD_MAPPING: + raise ValueError(f"Unknown update_type: {update_type}") + + return cls._RECORD_ID_FIELD_MAPPING[update_type] + + @classmethod + def _compute_record_id_hash(cls, record_id: list[int]) -> str: + """Compute MD5 hash of a record_id list for deduplication. + + Parameters + ---------- + record_id : list[int] + The record ID as a list of integers. + + Returns + ------- + str + Full 64-character hexadecimal MD5 hash of the record_id list. + """ + record_id_str = ",".join(str(x) for x in record_id) + return hashlib.md5(record_id_str.encode()).hexdigest() + + @classmethod + def get_record_id_field(cls, update_type: str) -> str | list[str]: + """Get the field name(s) that serve as the record ID for a given update + type. + + Parameters + ---------- + update_type : `str` + The type of update record. + + Returns + ------- + field_name : `list` [ `str` ] + List of the fields that contain the record ID for this update type. + + Raises + ------ + ValueError + If the update_type is not recognized. + """ + return cls.get_record_id_fields(update_type) + + @classmethod + def expand_single_record( + cls, update_record: ApdbUpdateRecord, replica_chunk_id: int + ) -> list[ExpandedUpdateRecord]: + """Expand a single APDB update record into ExpandedUpdateRecord + objects. + + Parameters + ---------- + update_record : `ApdbUpdateRecord` + A single APDB update record to expand. + replica_chunk_id : `int` + The replica chunk ID associated with this update record. + + Returns + ------- + expanded_records : `list` [ `ExpandedUpdateRecord` ] + List of ExpandedUpdateRecord objects, one per field being updated. + """ + update_type = update_record.update_type + field_names = cls.get_update_fields(update_type) + + # Get the target table from the update record + table_name = update_record.apdb_table.name + + # Get the record ID + record_id = cls._get_record_id(update_record) + + expanded_records = [] + for field_name in field_names: + if not hasattr(update_record, field_name): + raise ValueError( + f"Update record of type {update_type} is missing expected field {field_name}" + ) + + value = getattr(update_record, field_name) + + expanded_record = ExpandedUpdateRecord( + table_name=table_name, + record_id=record_id, + field_name=field_name, + value_json=value, + replica_chunk_id=replica_chunk_id, + update_order=update_record.update_order, + update_time_ns=update_record.update_time_ns, + ) + expanded_records.append(expanded_record) + + return expanded_records + + @classmethod + def _get_record_id(cls, update_record: ApdbUpdateRecord) -> list[int]: + """Generate a record ID from an update record. + + Parameters + ---------- + update_record : `ApdbUpdateRecord` + The update record to generate an ID for. + + Returns + ------- + record_id : `list` [ `int` ] + The record ID as a list of integers. For simple keys, a + single-element list. For composite keys, a multi-element list. + """ + update_type = update_record.update_type + id_fields = cls.get_record_id_fields(update_type) + + record_id = [] + for field in id_fields: + if not hasattr(update_record, field): + raise ValueError(f"Update record of type {update_type} is missing expected ID field {field}") + record_id.append(int(getattr(update_record, field))) + return record_id + + @classmethod + def expand_updates(cls, update_records: UpdateRecords) -> list[ExpandedUpdateRecord]: + """Expand the APDB update records into a list of individual updates. + + Parameters + ---------- + update_records : `UpdateRecords` + The APDB update records to expand. + + Returns + ------- + expanded_updates : `list` [ `ExpandedUpdateRecord` ] + A list of individual updates derived from the input update records. + """ + expanded_updates = [] + + for update_record in update_records.records: + expanded_records = cls.expand_single_record(update_record, update_records.replica_chunk_id) + expanded_updates.extend(expanded_records) + + # DEBUG: Print number of expanded update records that were generated + logging.info("Created %d expanded update records", len(expanded_updates)) + + return expanded_updates diff --git a/python/lsst/dax/ppdb/bigquery/updates/update_records.py b/python/lsst/dax/ppdb/bigquery/updates/update_records.py new file mode 100644 index 00000000..a3107c85 --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/updates/update_records.py @@ -0,0 +1,126 @@ +# This file is part of dax_ppdb +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, cast + +from pydantic import BaseModel, field_serializer, field_validator + +from lsst.dax.apdb.apdbUpdateRecord import ApdbUpdateRecord + +DEFAULT_FILENAME = "update_records.json" +"""Default filename for the update records JSON file.""" + + +class UpdateRecords(BaseModel): + """Data model for APDB update records.""" + + replica_chunk_id: int + """Identifier of the replica chunk to which these update records belong.""" + + record_count: int + """Number of update records included in this object.""" + + records: list[ApdbUpdateRecord] + """List of APDB update records included in this object.""" + + @field_serializer("records") + def serialize_records( + self, + records: list[ApdbUpdateRecord], + ) -> list[dict[str, Any]]: + """Serialize the ``ApdbUpdateRecord`` objects to JSON. + + Parameters + ---------- + records : `list` [ `ApdbUpdateRecord` ] + The list of APDB update records to serialize. + + Returns + ------- + serialized_records : `list` [ `dict` [ `str`, `Any` ]] + The serialized JSON data. + """ + serialized_records: list[dict[str, Any]] = [] + for update_record in records: + record_dict: dict[str, Any] = json.loads(update_record.to_json()) + record_dict["update_time_ns"] = update_record.update_time_ns + record_dict["update_order"] = update_record.update_order + serialized_records.append(record_dict) + return serialized_records + + @field_validator("records", mode="before") + @classmethod + def deserialize_records( + cls, + records: list[dict[str, Any]] | list[ApdbUpdateRecord], + ) -> list[ApdbUpdateRecord]: + """Deserialize the JSON data to ``ApdbUpdateRecord`` objects. + + Parameters + ---------- + records : `list` [ `dict` [ `str`, `Any` ] | `ApdbUpdateRecord` ] + The list of serialized JSON data or already deserialized + ApdbUpdateRecord objects. + + Returns + ------- + update_records : `list` [ `ApdbUpdateRecord` ] + The list of APDB update records. + """ + if records and isinstance(records[0], ApdbUpdateRecord): + return cast(list[ApdbUpdateRecord], records) + deserialized_records: list[ApdbUpdateRecord] = [] + for record in records: + if isinstance(record, dict): + record_copy = record.copy() + update_time_ns = record_copy.pop("update_time_ns") + update_order = record_copy.pop("update_order") + json_str = json.dumps(record_copy) + update_record = ApdbUpdateRecord.from_json( + update_time_ns, + update_order, + json_str, + ) + deserialized_records.append(update_record) + elif isinstance(record, ApdbUpdateRecord): + deserialized_records.append(record) + else: + raise TypeError("Each record must be a dict or ApdbUpdateRecord") + return deserialized_records + + def write_json_file(self, path: Path) -> None: + with open(path, "w") as f: + json.dump(self.model_dump(), f, indent=2, default=str) + + @classmethod + def from_json_file(cls, path: Path) -> UpdateRecords: + with open(path) as f: + data = json.load(f) + return cls.model_validate(data) + + @classmethod + def from_json_string(cls, json_str: str) -> UpdateRecords: + data = json.loads(json_str) + return cls.model_validate(data) diff --git a/python/lsst/dax/ppdb/bigquery/updates/updates_manager.py b/python/lsst/dax/ppdb/bigquery/updates/updates_manager.py new file mode 100644 index 00000000..98e0e66a --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/updates/updates_manager.py @@ -0,0 +1,123 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import logging +import posixpath +import urllib +from collections.abc import Sequence + +from google.cloud import bigquery, storage + +from ..ppdb_bigquery import PpdbBigQuery +from .update_record_expander import UpdateRecordExpander +from .update_records import DEFAULT_FILENAME, UpdateRecords +from .updates_merger import ( + DiaForcedSourceUpdatesMerger, + DiaObjectUpdatesMerger, + DiaSourceUpdatesMerger, + UpdatesMerger, +) +from .updates_table import UpdatesTable + +DEFAULT_MERGERS = ( + DiaObjectUpdatesMerger, + DiaSourceUpdatesMerger, + DiaForcedSourceUpdatesMerger, +) + + +class UpdatesManager: + """Class responsible for managing the process of applying updates to the + PPDB database, including merging updates and inserting them into the + database. + """ + + def __init__( + self, + ppdb: PpdbBigQuery, + mergers: Sequence[type[UpdatesMerger]] = DEFAULT_MERGERS, + updates_table_name: str = "updates", + deduplicated_updates_table_name: str = "updates_deduplicated", + table_name_postfix: str | None = None, + ) -> None: + self._ppdb = ppdb + self._mergers = mergers + self._deduplicated_updates_table_name = deduplicated_updates_table_name + + self._bq_client = bigquery.Client() + + self._updates_table = UpdatesTable( + self._bq_client, + f"{self._ppdb._config.project_id}.{self._ppdb._config.dataset_id}.{updates_table_name}", + ) + + # TODO: Catch error if already exists + self._updates_table.create() + + self._gcs_client = storage.Client() + self._bucket = self._gcs_client.bucket(self._ppdb._config.bucket_name) + + self._table_name_postfix = table_name_postfix + + def apply_updates(self, replica_chunk_ids: Sequence[int]) -> None: + replica_chunks = self._ppdb.get_replica_chunks_ext_by_ids(replica_chunk_ids) + for replica_chunk in replica_chunks: + if replica_chunk.gcs_uri is None: + raise ValueError(f"Replica chunk {replica_chunk.id} does not have a GCS URI") + + # Parse the GCS URI to get the bucket name and object name + parsed_uri = urllib.parse.urlparse(replica_chunk.gcs_uri) + bucket_name = parsed_uri.netloc + object_name = posixpath.join(parsed_uri.path.lstrip("/"), DEFAULT_FILENAME) + + # Get the blob from the bucket + bucket = self._gcs_client.bucket(bucket_name) + blob = bucket.blob(object_name) + content = blob.download_as_text() + + # Expand the update records into the appropriate format for + # inserting into the updates table + update_records = UpdateRecords.from_json_string(content) + expanded_update_records = UpdateRecordExpander.expand_updates(update_records) + self._updates_table.insert(expanded_update_records) + + # Deduplicate the update records to a new table + deduplicated_updates_table_fqn = ( + f"{self._ppdb.project_id}.{self._ppdb._config.dataset_id}.{self._deduplicated_updates_table_name}" + ) + self._updates_table.deduplicate_to(deduplicated_updates_table_fqn) + + # Merge the deduplicated updates into the target tables + for merger in self._mergers: + merger_instance = merger(self._bq_client) + if self._table_name_postfix: + # Apply a postfix to the canonical target table name + merger_instance.target_table_name += f"{self._table_name_postfix}" + target_dataset_fqn = f"{self._ppdb._config.project_id}.{self._ppdb._config.dataset_id}" + + # DEBUG: Print message to log + logging.info( + "Merging updates into `%s.%s`", target_dataset_fqn, merger_instance.target_table_name + ) + + merger_instance.merge( + updates_table_fqn=deduplicated_updates_table_fqn, target_dataset_fqn=target_dataset_fqn + ) diff --git a/python/lsst/dax/ppdb/bigquery/updates/updates_merger.py b/python/lsst/dax/ppdb/bigquery/updates/updates_merger.py new file mode 100644 index 00000000..e74b68d6 --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/updates/updates_merger.py @@ -0,0 +1,117 @@ +# This file is part of dax_ppdb +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from abc import ABC + +from google.cloud import bigquery + +from ..sql_resource import SqlResource + + +class UpdatesMerger(ABC): + """Abstract base class for merging expanded update records into target + tables in BigQuery. + """ + + TABLE_NAME: str + """Logical name of the target table this merger applies to + (e.g., 'DiaObject').""" + + SQL_RESOURCE_NAME: str + """Base name of the SQL file (without .sql extension) containing the MERGE + statement for this merger. The SQL file must be located in the + `lsst.dax.ppdb.config.sql` package.""" + + def __init__(self, client: bigquery.Client, target_table_name: str | None = None) -> None: + """ + Parameters + ---------- + client + BigQuery client. + target_table_name + Optional name of the target table. If not provided, the class-level + TABLE_NAME will be used. + """ + self._client: bigquery.Client = client + self._target_table_name = target_table_name or self.TABLE_NAME + + @property + def target_table_name(self) -> str: + """Get the name of the target table this merger applies to.""" + return self._target_table_name + + @target_table_name.setter + def target_table_name(self, value: str) -> None: + """Set the name of the target table this merger applies to.""" + self._target_table_name = value + + def merge(self, *, updates_table_fqn: str, target_dataset_fqn: str) -> bigquery.QueryJob: + """ + Apply updates from the updates table specified by `updates_table_fqn` + to the target table in the `target_dataset_fqn` dataset. + + Parameters + ---------- + updates_table_fqn + Fully-qualified BigQuery table name containing updates. + target_dataset_fqn + Fully-qualified BigQuery dataset name containing the target table. + + Returns + ------- + google.cloud.bigquery.job.QueryJob + The completed BigQuery job. + """ + sql = SqlResource( + self.SQL_RESOURCE_NAME, + format_args={ + "updates_table": updates_table_fqn, + "target_dataset": target_dataset_fqn, + "target_table": self.target_table_name, + }, + ).sql + job = self._client.query(sql) + job.result() + + return job + + +class DiaObjectUpdatesMerger(UpdatesMerger): + """Merger for DiaObject updates.""" + + TABLE_NAME = "DiaObject" + SQL_RESOURCE_NAME = "merge_diaobject_updates" + + +class DiaSourceUpdatesMerger(UpdatesMerger): + """Merger for DiaSource updates.""" + + TABLE_NAME = "DiaSource" + SQL_RESOURCE_NAME = "merge_diasource_updates" + + +class DiaForcedSourceUpdatesMerger(UpdatesMerger): + """Merger for DiaForcedSource updates.""" + + TABLE_NAME = "DiaForcedSource" + SQL_RESOURCE_NAME = "merge_diaforcedsource_updates" diff --git a/python/lsst/dax/ppdb/bigquery/updates/updates_table.py b/python/lsst/dax/ppdb/bigquery/updates/updates_table.py new file mode 100644 index 00000000..b8054bbf --- /dev/null +++ b/python/lsst/dax/ppdb/bigquery/updates/updates_table.py @@ -0,0 +1,198 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import hashlib +from collections.abc import Iterable +from typing import Any + +from google.cloud import bigquery + +from .expanded_update_record import ExpandedUpdateRecord + + +class UpdatesTable: + """Manage the table in BigQuery used for inserting and deduplicating + expanded update records which contain one update per row. + """ + + def __init__(self, client: bigquery.Client, table_fqn: str) -> None: + """ + Parameters + ---------- + client + BigQuery client. + table_fqn + Fully-qualified table name in the form ``"project.dataset.table"``. + """ + self._client: bigquery.Client = client + self._table_fqn: str = table_fqn + + @staticmethod + def _compute_record_id_hash(record_id: list[int]) -> str: + """Compute MD5 hash of a record_id list for deduplication. + + Parameters + ---------- + record_id : list[int] + The record ID as a list of integers. + + Returns + ------- + str + Full 64-character hexadecimal MD5 hash of the record_id list. + """ + record_id_str = ",".join(str(x) for x in record_id) + return hashlib.md5(record_id_str.encode()).hexdigest() + + @property + def table_fqn(self) -> str: + """ + Fully-qualified BigQuery table name. + + Returns + ------- + str + Table name in the form ``"project.dataset.table"``. + """ + return self._table_fqn + + def create(self) -> bigquery.Table: + """ + Create the updates table. + + Returns + ------- + google.cloud.bigquery.Table + The created table. + + Raises + ------ + google.api_core.exceptions.Conflict + If the table already exists. + + Notes + ----- + Schema: + + - table_name: STRING (REQUIRED) + - record_id: ARRAY (REQUIRED) + - record_id_hash: STRING (REQUIRED) + - field_name: STRING (REQUIRED) + - value_json: JSON (REQUIRED) + - replica_chunk_id: INT64 (REQUIRED) + - update_order: INT64 (NULLABLE) + - update_time_ns: INT64 (NULLABLE) + """ + schema: list[bigquery.SchemaField] = [ + bigquery.SchemaField("table_name", "STRING", mode="REQUIRED"), + bigquery.SchemaField("record_id", "INT64", mode="REPEATED"), + bigquery.SchemaField("record_id_hash", "STRING", mode="REQUIRED"), + bigquery.SchemaField("field_name", "STRING", mode="REQUIRED"), + bigquery.SchemaField("value_json", "JSON", mode="REQUIRED"), + bigquery.SchemaField("replica_chunk_id", "INT64", mode="REQUIRED"), + bigquery.SchemaField("update_order", "INT64", mode="NULLABLE"), + bigquery.SchemaField("update_time_ns", "INT64", mode="NULLABLE"), + ] + + table = bigquery.Table(self._table_fqn, schema=schema) + return self._client.create_table(table) + + def insert(self, records: Iterable[ExpandedUpdateRecord]) -> bigquery.LoadJob: + """ + Insert `ExpandedUpdateRecord` rows into the updates table. + + Parameters + ---------- + records + Iterable of update records to insert. + + Returns + ------- + google.cloud.bigquery.LoadJob + Completed BigQuery load job. + + Raises + ------ + RuntimeError + If the BigQuery load job completes with errors. + + Notes + ----- + This uses a batch load via `Client.load_table_from_json` (not streaming + inserts). The table must already exist. + """ + rows: list[dict[str, Any]] = [ + { + "table_name": r.table_name, + "record_id": r.record_id, + "record_id_hash": self._compute_record_id_hash(r.record_id), + "field_name": r.field_name, + "value_json": r.value_json, + "replica_chunk_id": r.replica_chunk_id, + "update_order": r.update_order, + "update_time_ns": r.update_time_ns, + } + for r in records + ] + + print("Inserting rows into BigQuery:", rows) # Debug print to verify the data being loaded + + job = self._client.load_table_from_json( + rows, + self._table_fqn, + job_config=bigquery.LoadJobConfig( + write_disposition=bigquery.WriteDisposition.WRITE_APPEND, + ), + ) + job.result() + + if job.errors: + raise RuntimeError(f"BigQuery load failed: {job.errors}") + + return job + + def deduplicate_to(self, target_table_fqn: str) -> bigquery.QueryJob: + """ + Deduplicate this table's records to a target table. + + Keeps the record with the latest update_time_ns for each unique + combination of (table_name, record_id, field_name). + """ + query = f""" + CREATE OR REPLACE TABLE `{target_table_fqn}` + AS + SELECT * EXCEPT(row_num) + FROM ( + SELECT *, + ROW_NUMBER() OVER ( + PARTITION BY table_name, record_id_hash, field_name + ORDER BY update_time_ns DESC + ) as row_num + FROM `{self._table_fqn}` + ) + WHERE row_num = 1 + """ + + job = self._client.query(query) + job.result() + return job diff --git a/tests/config/schema.yaml b/python/lsst/dax/ppdb/config/schemas/test_apdb_schema.yaml similarity index 100% rename from tests/config/schema.yaml rename to python/lsst/dax/ppdb/config/schemas/test_apdb_schema.yaml diff --git a/python/lsst/dax/ppdb/config/sql/merge_diaforcedsource_updates.sql b/python/lsst/dax/ppdb/config/sql/merge_diaforcedsource_updates.sql new file mode 100644 index 00000000..8eef46b7 --- /dev/null +++ b/python/lsst/dax/ppdb/config/sql/merge_diaforcedsource_updates.sql @@ -0,0 +1,28 @@ +MERGE `{target_dataset}.{target_table}` T +USING ( + WITH patch AS ( + SELECT + record_id[OFFSET(0)] AS diaObjectId, + record_id[OFFSET(1)] AS visit, + record_id[OFFSET(2)] AS detector, + + ANY_VALUE( + CASE WHEN field_name = 'timeWithdrawnMjdTai' + THEN CAST(JSON_VALUE(value_json) AS FLOAT64) + END + ) AS timeWithdrawnMjdTai_value, + COUNTIF(field_name = 'timeWithdrawnMjdTai') > 0 AS timeWithdrawnMjdTai_present + + FROM `{updates_table}` + WHERE table_name = 'DiaForcedSource' + AND field_name IN ('timeWithdrawnMjdTai') + GROUP BY diaObjectId, visit, detector + ) + SELECT * FROM patch +) P +ON T.diaObjectId = P.diaObjectId + AND T.visit = P.visit + AND T.detector = P.detector +WHEN MATCHED THEN +UPDATE SET + timeWithdrawnMjdTai = IF(P.timeWithdrawnMjdTai_present, P.timeWithdrawnMjdTai_value, T.timeWithdrawnMjdTai); diff --git a/python/lsst/dax/ppdb/config/sql/merge_diaobject_updates.sql b/python/lsst/dax/ppdb/config/sql/merge_diaobject_updates.sql new file mode 100644 index 00000000..9c6c1827 --- /dev/null +++ b/python/lsst/dax/ppdb/config/sql/merge_diaobject_updates.sql @@ -0,0 +1,32 @@ +MERGE `{target_dataset}.{target_table}` T +USING ( + WITH patch AS ( + SELECT + record_id[OFFSET(0)] AS diaObjectId, + + ANY_VALUE( + CASE WHEN field_name = 'validityEndMjdTai' + THEN CAST(JSON_VALUE(value_json) AS FLOAT64) + END + ) AS validityEndMjdTai_value, + COUNTIF(field_name = 'validityEndMjdTai') > 0 AS validityEndMjdTai_present, + + ANY_VALUE( + CASE WHEN field_name = 'nDiaSources' + THEN CAST(JSON_VALUE(value_json) AS INT64) + END + ) AS nDiaSources_value, + COUNTIF(field_name = 'nDiaSources') > 0 AS nDiaSources_present + + FROM `{updates_table}` + WHERE table_name = 'DiaObject' + AND field_name IN ('validityEndMjdTai', 'nDiaSources') + GROUP BY diaObjectId + ) + SELECT * FROM patch +) P +ON T.diaObjectId = P.diaObjectId +WHEN MATCHED THEN +UPDATE SET + validityEndMjdTai = IF(P.validityEndMjdTai_present, P.validityEndMjdTai_value, T.validityEndMjdTai), + nDiaSources = IF(P.nDiaSources_present, P.nDiaSources_value, T.nDiaSources); diff --git a/python/lsst/dax/ppdb/config/sql/merge_diasource_updates.sql b/python/lsst/dax/ppdb/config/sql/merge_diasource_updates.sql new file mode 100644 index 00000000..5a2d5307 --- /dev/null +++ b/python/lsst/dax/ppdb/config/sql/merge_diasource_updates.sql @@ -0,0 +1,48 @@ +MERGE `{target_dataset}.{target_table}` T +USING ( + WITH patch AS ( + SELECT + record_id[OFFSET(0)] AS diaSourceId, + + ANY_VALUE( + CASE WHEN field_name = 'diaObjectId' + THEN CAST(JSON_VALUE(value_json) AS INT64) + END + ) AS diaObjectId_value, + COUNTIF(field_name = 'diaObjectId') > 0 AS diaObjectId_present, + + ANY_VALUE( + CASE WHEN field_name = 'ssObjectId' + THEN CAST(JSON_VALUE(value_json) AS INT64) + END + ) AS ssObjectId_value, + COUNTIF(field_name = 'ssObjectId') > 0 AS ssObjectId_present, + + ANY_VALUE( + CASE WHEN field_name = 'ssObjectReassocTimeMjdTai' + THEN CAST(JSON_VALUE(value_json) AS FLOAT64) + END + ) AS ssObjectReassocTimeMjdTai_value, + COUNTIF(field_name = 'ssObjectReassocTimeMjdTai') > 0 AS ssObjectReassocTimeMjdTai_present, + + ANY_VALUE( + CASE WHEN field_name = 'timeWithdrawnMjdTai' + THEN CAST(JSON_VALUE(value_json) AS FLOAT64) + END + ) AS timeWithdrawnMjdTai_value, + COUNTIF(field_name = 'timeWithdrawnMjdTai') > 0 AS timeWithdrawnMjdTai_present + + FROM `{updates_table}` + WHERE table_name = 'DiaSource' + AND field_name IN ('diaObjectId', 'ssObjectId', 'ssObjectReassocTimeMjdTai', 'timeWithdrawnMjdTai') + GROUP BY diaSourceId + ) + SELECT * FROM patch +) P +ON T.diaSourceId = P.diaSourceId +WHEN MATCHED THEN +UPDATE SET + diaObjectId = IF(P.diaObjectId_present, P.diaObjectId_value, T.diaObjectId), + ssObjectId = IF(P.ssObjectId_present, P.ssObjectId_value, T.ssObjectId), + ssObjectReassocTimeMjdTai = IF(P.ssObjectReassocTimeMjdTai_present, P.ssObjectReassocTimeMjdTai_value, T.ssObjectReassocTimeMjdTai), + timeWithdrawnMjdTai = IF(P.timeWithdrawnMjdTai_present, P.timeWithdrawnMjdTai_value, T.timeWithdrawnMjdTai) diff --git a/python/lsst/dax/ppdb/config/sql/select_promotable_chunks.sql b/python/lsst/dax/ppdb/config/sql/select_promotable_chunks.sql new file mode 100644 index 00000000..e776a1f0 --- /dev/null +++ b/python/lsst/dax/ppdb/config/sql/select_promotable_chunks.sql @@ -0,0 +1,24 @@ +WITH start AS ( +SELECT MIN(apdb_replica_chunk) AS s +FROM {table_name} +WHERE status <> 'promoted' + AND status <> 'skipped' +), +stop AS ( +SELECT MIN(p.apdb_replica_chunk) AS e +FROM {table_name} p +JOIN start ON TRUE +WHERE start.s IS NOT NULL + AND p.apdb_replica_chunk >= start.s + AND p.status <> 'staged' + AND status <> 'skipped' +) +SELECT p.apdb_replica_chunk +FROM {table_name} p +JOIN start ON TRUE +LEFT JOIN stop ON TRUE +WHERE start.s IS NOT NULL +AND p.status = 'staged' +AND p.apdb_replica_chunk >= start.s +AND (stop.e IS NULL OR p.apdb_replica_chunk < stop.e) +ORDER BY p.apdb_replica_chunk; diff --git a/python/lsst/dax/ppdb/ppdb.py b/python/lsst/dax/ppdb/ppdb.py index 31b6a315..e175bb7f 100644 --- a/python/lsst/dax/ppdb/ppdb.py +++ b/python/lsst/dax/ppdb/ppdb.py @@ -33,7 +33,7 @@ from lsst.resources import ResourcePathExpression from ._factory import ppdb_from_config -from .config import PpdbConfig +from .ppdb_config import PpdbConfig @dataclass(frozen=True) diff --git a/python/lsst/dax/ppdb/config.py b/python/lsst/dax/ppdb/ppdb_config.py similarity index 100% rename from python/lsst/dax/ppdb/config.py rename to python/lsst/dax/ppdb/ppdb_config.py diff --git a/python/lsst/dax/ppdb/sql/__init__.py b/python/lsst/dax/ppdb/sql/__init__.py index 92e21081..566853c3 100644 --- a/python/lsst/dax/ppdb/sql/__init__.py +++ b/python/lsst/dax/ppdb/sql/__init__.py @@ -20,4 +20,4 @@ # along with this program. If not, see . from ._ppdb_sql import PpdbSql, PpdbSqlConfig -from ._ppdb_sql_base import PpdbSqlBase, PpdbSqlBaseConfig +from ._ppdb_sql_base import PasswordProvider, PpdbSqlBase, PpdbSqlBaseConfig diff --git a/python/lsst/dax/ppdb/sql/_ppdb_sql.py b/python/lsst/dax/ppdb/sql/_ppdb_sql.py index 623cf651..320f579c 100644 --- a/python/lsst/dax/ppdb/sql/_ppdb_sql.py +++ b/python/lsst/dax/ppdb/sql/_ppdb_sql.py @@ -552,5 +552,6 @@ def init_database( isolation_level=isolation_level, connection_timeout=connection_timeout, ) - cls.make_database(config, sa_metadata, schema_version, drop) + engine = cls.make_engine(config) + cls.make_database(engine, config, sa_metadata, schema_version, drop) return config diff --git a/python/lsst/dax/ppdb/sql/_ppdb_sql_base.py b/python/lsst/dax/ppdb/sql/_ppdb_sql_base.py index 836cab53..542029e5 100644 --- a/python/lsst/dax/ppdb/sql/_ppdb_sql_base.py +++ b/python/lsst/dax/ppdb/sql/_ppdb_sql_base.py @@ -21,11 +21,12 @@ from __future__ import annotations -__all__ = ["PpdbSqlBase"] +__all__ = ["PasswordProvider", "PpdbSqlBase"] import logging import os import sqlite3 +from abc import ABC, abstractmethod from collections.abc import Iterable, MutableMapping from contextlib import closing from typing import Any @@ -49,6 +50,25 @@ _LOG = logging.getLogger(__name__) +class PasswordProvider(ABC): + """Abstract base class for objects that supply a database password. + + Implementations are free to retrieve the password from any source + (e.g. environment variables, a secrets manager, a local file) without + `PpdbSqlBase` needing to know about the mechanism. + """ + + @abstractmethod + def get_password(self) -> str: + """Return the database password. + + Returns + ------- + password : `str` + Plain-text password to embed in the database connection URL. + """ + + class MissingSchemaVersionError(RuntimeError): """Exception raised when schema version is not defined in the schema. @@ -121,12 +141,12 @@ class PpdbSqlBase: meta_schema_version_key = "version:schema" """Name of the metadata key to store Felis schema version number.""" - def __init__(self, config: PpdbSqlBaseConfig) -> None: + def __init__(self, config: PpdbSqlBaseConfig, password_provider: PasswordProvider | None = None) -> None: self._sa_metadata, self._schema_version = self.read_schema( config.felis_path, config.schema_name, config.felis_schema, config.db_url ) - self._engine = self.make_engine(config) + self._engine = self.make_engine(config, password_provider=password_provider) sa_metadata = sqlalchemy.MetaData(schema=config.schema_name) meta_table = sqlalchemy.schema.Table("metadata", sa_metadata, autoload_with=self._engine) @@ -137,14 +157,7 @@ def __init__(self, config: PpdbSqlBaseConfig) -> None: self._check_code_version() @classmethod - def make_engine(cls, config: PpdbSqlBaseConfig) -> sqlalchemy.engine.Engine: - """Make SQLALchemy engine based on configured parameters. - - Parameters - ---------- - config : `PpdbSqlBaseConfig` - Configuration object with SQL parameters. - """ + def _build_connect_args(cls, config: PpdbSqlBaseConfig) -> MutableMapping[str, Any]: kw: MutableMapping[str, Any] = {} conn_args: dict[str, Any] = {} if not config.use_connection_pool: @@ -159,9 +172,40 @@ def make_engine(cls, config: PpdbSqlBaseConfig) -> sqlalchemy.engine.Engine: conn_args.update(timeout=config.connection_timeout) elif config.db_url.startswith(("postgresql", "mysql")): conn_args.update(connect_timeout=config.connection_timeout) - kw = {"connect_args": conn_args} - engine = sqlalchemy.create_engine(config.db_url, **kw) + return {"connect_args": conn_args} + + @classmethod + def make_engine( + cls, + config: PpdbSqlBaseConfig, + *, + password_provider: PasswordProvider | None = None, + ) -> sqlalchemy.engine.Engine: + """Make SQLALchemy engine based on configured parameters. + Parameters + ---------- + config : `PpdbSqlBaseConfig` + Configuration object with SQL parameters. + password_provider : `PasswordProvider`, optional + If provided, the password returned by + ``password_provider.get_password()`` is injected into the + database URL. The URL must not already contain a password when + this argument is given. + + Raises + ------ + ValueError + Raised if ``password_provider`` is given but the URL already + contains a password. + """ + db_url = sqlalchemy.make_url(config.db_url) + if password_provider is not None: + if db_url.password is not None: + raise ValueError("Database URL must not contain a password when password_provider is used.") + db_url = db_url.set(password=password_provider.get_password()) + kw = cls._build_connect_args(config) + engine = sqlalchemy.create_engine(db_url, **kw) if engine.dialect.name == "sqlite": # Need to enable foreign keys on every new connection. sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect) @@ -171,6 +215,7 @@ def make_engine(cls, config: PpdbSqlBaseConfig) -> sqlalchemy.engine.Engine: @classmethod def make_database( cls, + engine: sqlalchemy.engine.Engine, config: PpdbSqlBaseConfig, sa_metadata: sqlalchemy.schema.MetaData, schema_version: VersionTuple, @@ -189,8 +234,6 @@ def make_database( drop : `bool` If `True` then drop existing tables before creating new ones. """ - engine = cls.make_engine(config) - if config.schema_name is not None: dialect = engine.dialect quoted_schema = dialect.preparer(dialect).quote_schema(config.schema_name) diff --git a/python/lsst/dax/ppdb/tests/_bigquery.py b/python/lsst/dax/ppdb/tests/_bigquery.py new file mode 100644 index 00000000..d801656d --- /dev/null +++ b/python/lsst/dax/ppdb/tests/_bigquery.py @@ -0,0 +1,224 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import gc +import io +import json +import shutil +import tempfile +import uuid +from typing import Any + +import google.auth +from google.auth.exceptions import DefaultCredentialsError +from google.auth.transport.requests import Request +from google.cloud import storage + +from lsst.dax.apdb import ( + ApdbConfig, +) +from lsst.dax.apdb.sql import ApdbSql +from lsst.dax.ppdb import PpdbConfig +from lsst.dax.ppdb.bigquery import ChunkUploader, PpdbBigQuery +from lsst.dax.ppdb.tests._ppdb import TEST_SCHEMA_RESOURCE_PATH + +try: + import testing.postgresql +except ImportError: + testing = None + + +TEST_CONFIG = { + "db_drop": True, + "validate_config": False, + "delete_existing_dirs": True, + "bucket_name": "ppdb-test", + "object_prefix": "data/test", + "dataset_id": "test_dataset", + "project_id": "test_project", +} + + +def json_rows_to_buf(rows: list[dict]) -> io.StringIO: + """Convert a list of dict rows to a newline-delimited JSON StringIO + buffer. + """ + buf = io.StringIO() + for row in rows: + buf.write(json.dumps(row) + "\n") + buf.seek(0) + return buf + + +def generate_test_bucket_name(test_prefix: str = "ppdb-test") -> str: + """Generate a unique bucket name for testing.""" + test_id = uuid.uuid4().hex[:16] + return f"{test_prefix}-{test_id}" + + +def delete_test_bucket(bucket_or_bucket_name: str | storage.Bucket) -> None: + """Delete a cloud storage bucket that was created for testing. + + Parameters + ---------- + bucket_or_bucket_name: `str` or `storage.Bucket` + The name of the bucket or the actual bucket to delete. + """ + storage_client = storage.Client() + try: + if isinstance(bucket_or_bucket_name, str): + bucket = storage_client.bucket(bucket_or_bucket_name) + else: + bucket = bucket_or_bucket_name + blobs = list(bucket.list_blobs()) + for blob in blobs: + blob.delete() + bucket.delete() + except Exception as e: + print(f"Failed to delete test GCS bucket: {e}") + + +class ChunkUploaderWithoutPubSub(ChunkUploader): + """A dummy implementation of the ChunkUploader that does not actually + post messages to Pub/Sub. + """ + + def _post_to_stage_chunk_topic(self, gcs_uri: str, chunk_id: int) -> None: + message = { + "dataset": None, + "chunk_id": str(chunk_id), + "folder": f"gs://{gcs_uri}", + } + print(f"Dummy publish to Pub/Sub topic: {message}") + + +class SqliteMixin: + """Mixin class to provide Sqlite-specific setup/teardown and instance + creation. + """ + + def setUp(self) -> None: + self.tempdir = tempfile.mkdtemp() + self.apdb_url = f"sqlite:///{self.tempdir}/apdb.sqlite3" + self.ppdb_url = f"sqlite:///{self.tempdir}/ppdb.sqlite3" + + def tearDown(self) -> None: + shutil.rmtree(self.tempdir, ignore_errors=True) + + def make_instance(self, **kwargs: Any) -> PpdbConfig: + """Make config class instance used in all tests.""" + kw = { + **TEST_CONFIG, + "db_url": self.ppdb_url, + "felis_path": TEST_SCHEMA_RESOURCE_PATH, + "replication_dir": self.tempdir, + } + bq_config = PpdbBigQuery.init_bigquery(**kw) # type: ignore[arg-type] + return bq_config + + def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig: + """Make APDB instance for tests.""" + kw = { + "schema_file": TEST_SCHEMA_RESOURCE_PATH, + "ss_schema_file": "", + "db_url": self.apdb_url, + "enable_replica": True, + } + kw.update(kwargs) + return ApdbSql.init_database(**kw) # type: ignore[arg-type] + + +class PostgresMixin: + """Mixin class to provide Postgres-specific setup/teardown and instance + creation. + """ + + postgresql: Any + + @classmethod + def setUpClass(cls) -> None: + # Create the postgres test server. + cls.postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True) + + @classmethod + def tearDownClass(cls) -> None: + # Clean up any lingering SQLAlchemy engines/connections + # so they're closed before we shut down the server. + gc.collect() + cls.postgresql.clear_cache() + + def setUp(self) -> None: + self.server = self.postgresql() + self.tempdir = tempfile.mkdtemp() + + def tearDown(self) -> None: + self.server = self.postgresql() + shutil.rmtree(self.tempdir, ignore_errors=True) + + def make_instance(self, config_dict: dict[str, Any] = TEST_CONFIG, **kwargs: Any) -> PpdbConfig: + """Make config class instance used in all tests.""" + kw = { + **config_dict, + "db_url": self.server.url(), + "db_schema": "ppdb_test", + "felis_path": TEST_SCHEMA_RESOURCE_PATH, + "replication_dir": self.tempdir, + } + bq_config = PpdbBigQuery.init_bigquery(**kw) + return bq_config + + def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig: + kw = { + "schema_file": TEST_SCHEMA_RESOURCE_PATH, + "ss_schema_file": "", + "db_url": self.server.url(), + "namespace": "apdb", + "enable_replica": True, + } + kw.update(kwargs) + return ApdbSql.init_database(**kw) + + +def have_valid_google_credentials() -> bool: + """Check that valid Google credentials are available for testing. + + Returns + ------- + credentials_valid: `bool` + True if valid Google credentials are available, False if not. + + Raises + ------ + google.auth.exceptions.RefreshError + Raised if the credentials cannot be refreshed. + Exception + Raised for other transport or configuration failures. + """ + try: + credentials, _ = google.auth.default() + except DefaultCredentialsError: + return False + + # This will validate the default credentials that were found in the + # environment. + credentials.refresh(Request()) + + return True diff --git a/python/lsst/dax/ppdb/tests/_ppdb.py b/python/lsst/dax/ppdb/tests/_ppdb.py index 06a84a37..790d4d76 100644 --- a/python/lsst/dax/ppdb/tests/_ppdb.py +++ b/python/lsst/dax/ppdb/tests/_ppdb.py @@ -21,7 +21,7 @@ from __future__ import annotations -__all__ = ["PpdbTest"] +__all__ = ["TEST_SCHEMA_RESOURCE_PATH", "PpdbTest", "fill_apdb"] import unittest from abc import ABC, abstractmethod @@ -44,20 +44,15 @@ from lsst.dax.apdb.tests.data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog from lsst.sphgeom import Angle, Circle, Region, UnitVector3d -from ..config import PpdbConfig from ..ppdb import Ppdb, PpdbReplicaChunk +from ..ppdb_config import PpdbConfig from ..replicator import Replicator if TYPE_CHECKING: import pandas - class TestCaseMixin(unittest.TestCase): - """Base class for mixin test classes that use TestCase methods.""" -else: - - class TestCaseMixin: - """Do-nothing definition of mixin base class for regular execution.""" +TEST_SCHEMA_RESOURCE_PATH = "resource://lsst.dax.ppdb/config/schemas/test_apdb_schema.yaml" def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: @@ -68,12 +63,107 @@ def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: return region -class PpdbTest(TestCaseMixin, ABC): +def _make_update_records( + sources: pandas.DataFrame, fsources: pandas.DataFrame, update_time: astropy.time.Time +) -> list[ApdbUpdateRecord]: + """Create update records from source catalogs for testing.""" + update_time_ns = int(update_time.unix_tai * 1e9) + records: list[ApdbUpdateRecord] = [] + + # Reassign one DIASource to SSObject. + dia_source = sources.iloc[0] + records.append( + ApdbReassignDiaSourceToSSObjectRecord( + update_time_ns=update_time_ns, + update_order=0, + diaSourceId=int(dia_source["diaSourceId"]), + ssObjectId=1, + ssObjectReassocTimeMjdTai=float(update_time.tai.mjd), + ra=float(dia_source["ra"]), + dec=float(dia_source["dec"]), + midpointMjdTai=60000.0, + ) + ) + + # Close validity interval for matching DIAObject. + records.append( + ApdbCloseDiaObjectValidityRecord( + update_time_ns=update_time_ns, + update_order=1, + diaObjectId=int(dia_source["diaObjectId"]), + validityEndMjdTai=update_time.tai.mjd, + nDiaSources=None, + ra=float(dia_source["ra"]), + dec=float(dia_source["dec"]), + ) + ) + + # Withdraw one DIAForcedSource. + dia_fsource = fsources.iloc[0] + records.append( + ApdbWithdrawDiaForcedSourceRecord( + update_time_ns=update_time_ns, + update_order=2, + diaObjectId=int(dia_fsource["diaObjectId"]), + visit=int(dia_fsource["visit"]), + detector=int(dia_fsource["detector"]), + timeWithdrawnMjdTai=update_time.tai.mjd, + ra=float(dia_source["ra"]), + dec=float(dia_source["dec"]), + midpointMjdTai=60000.0, + ) + ) + + return records + + +def fill_apdb(apdb: Apdb, include_update_records: bool = False) -> None: + """Populate APDB with some data to replicate.""" + region1 = _make_region((1.0, 1.0, -1.0)) + region2 = _make_region((-1.0, -1.0, -1.0)) + nobj = 100 + objects1 = makeObjectCatalog(region1, nobj) + objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2) + + # With the default 10 minutes replica chunk window we should have 4 + # records. All timestamps are far in the past, means that replication + # of the last chunk can run without waiting. + visits = [ + (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), + (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), + (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1), + (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2), + (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1), + (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2), + (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), + (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), + ] + + # Time when updates are applied. + update_time = astropy.time.Time("2021-03-01T12:00:00") + + update_records = [] + start_id = 0 + for visit, (visit_time, objects) in enumerate(visits): + sources = makeSourceCatalog(objects, visit_time, visit=visit, start_id=start_id) + fsources = makeForcedSourceCatalog(objects, visit_time, visit=visit) + apdb.store(visit_time, objects, sources, fsources) + start_id += nobj + + if include_update_records and visit == (len(visits) - 1): + # Generate a few update records. + update_records = _make_update_records(sources, fsources, update_time) + + if include_update_records: + chunk = ReplicaChunk.make_replica_chunk(update_time, apdb.getConfig().replica_chunk_seconds) + # All our tests use SQL APDB. + assert isinstance(apdb, ApdbSql), "Expecting ApdbSql instance" + apdb._storeUpdateRecords(update_records, chunk, store_chunk=True) + + +class PpdbTest(unittest.TestCase, ABC): """Base class for Ppdb tests that can be specialized for concrete implementation. - - This can only be used as a mixin class for a unittest.TestCase and it - calls various assert methods. """ include_update_records = False @@ -102,112 +192,6 @@ def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig: """ raise NotImplementedError() - def test_empty_db(self) -> None: - """Test for instantiation a database and making queries on empty - database. - """ - config = self.make_instance() - ppdb = Ppdb.from_config(config) - chunks = ppdb.get_replica_chunks() - if chunks is not None: - self.assertEqual(len(chunks), 0) - - def _fill_apdb(self, apdb: Apdb) -> None: - """Populate APDB with some data to replicate.""" - visit_time = astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai") - region1 = _make_region((1.0, 1.0, -1.0)) - region2 = _make_region((-1.0, -1.0, -1.0)) - nobj = 100 - objects1 = makeObjectCatalog(region1, nobj) - objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2) - - # With the default 10 minutes replica chunk window we should have 4 - # records. All timestamps are far in the past, means that replication - # of the last chunk can run without waiting. - visits = [ - (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), - (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), - (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1), - (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2), - (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1), - (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2), - (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), - (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), - ] - - # Time when apdates are applied. - update_time = astropy.time.Time("2021-03-01T12:00:00") - - update_records = [] - start_id = 0 - for visit, (visit_time, objects) in enumerate(visits): - sources = makeSourceCatalog(objects, visit_time, visit=visit, start_id=start_id) - fsources = makeForcedSourceCatalog(objects, visit_time, visit=visit) - apdb.store(visit_time, objects, sources, fsources) - start_id += nobj - - if self.include_update_records and visit == (len(visits) - 1): - # Generate few update records. - update_records = self._make_update_records(sources, fsources, update_time) - - if self.include_update_records: - chunk = ReplicaChunk.make_replica_chunk(update_time, apdb.getConfig().replica_chunk_seconds) - # All our tests use SQL APDB. - assert isinstance(apdb, ApdbSql), "Expecting ApdbSql instance" - apdb._storeUpdateRecords(update_records, chunk, store_chunk=True) - - def _make_update_records( - self, sources: pandas.DataFrame, fsources: pandas.DataFrame, update_time: astropy.time.Time - ) -> list[ApdbUpdateRecord]: - update_time_ns = int(update_time.unix_tai * 1e9) - records: list[ApdbUpdateRecord] = [] - - # Reassign one DIASource to SSObject. - dia_source = sources.iloc[0] - records.append( - ApdbReassignDiaSourceToSSObjectRecord( - update_time_ns=update_time_ns, - update_order=0, - diaSourceId=int(dia_source["diaSourceId"]), - ssObjectId=1, - ssObjectReassocTimeMjdTai=float(update_time.tai.mjd), - ra=float(dia_source["ra"]), - dec=float(dia_source["dec"]), - midpointMjdTai=60000.0, - ) - ) - - # Close validity interval for matching DIAObject. - records.append( - ApdbCloseDiaObjectValidityRecord( - update_time_ns=update_time_ns, - update_order=1, - diaObjectId=int(dia_source["diaObjectId"]), - validityEndMjdTai=update_time.tai.mjd, - nDiaSources=None, - ra=float(dia_source["ra"]), - dec=float(dia_source["dec"]), - ) - ) - - # Withdraw one DIAForcedSource. - dia_fsource = fsources.iloc[0] - records.append( - ApdbWithdrawDiaForcedSourceRecord( - update_time_ns=update_time_ns, - update_order=2, - diaObjectId=int(dia_fsource["diaObjectId"]), - visit=int(dia_fsource["visit"]), - detector=int(dia_fsource["detector"]), - timeWithdrawnMjdTai=update_time.tai.mjd, - ra=float(dia_source["ra"]), - dec=float(dia_source["dec"]), - midpointMjdTai=60000.0, - ) - ) - - return records - def _check_chunks( self, apdb_chunks: Sequence[ReplicaChunk], ppdb_chunks: Sequence[PpdbReplicaChunk] ) -> None: @@ -218,12 +202,22 @@ def _check_chunks( self.assertEqual(ppdb_chunks[i].last_update_time, apdb_chunks[i].last_update_time) self.assertEqual(ppdb_chunks[i].unique_id, apdb_chunks[i].unique_id) + def test_empty_db(self) -> None: + """Test for instantiation a database and making queries on empty + database. + """ + config = self.make_instance() + ppdb = Ppdb.from_config(config) + chunks = ppdb.get_replica_chunks() + if chunks is not None: + self.assertEqual(len(chunks), 0) + def test_replication_single(self) -> None: """Test replication from APDB to PPDB using a single chunk option.""" apdb_config = self.make_apdb_instance() apdb = Apdb.from_config(apdb_config) - self._fill_apdb(apdb) + fill_apdb(apdb, self.include_update_records) expected_chunks = 5 if self.include_update_records else 4 @@ -286,7 +280,7 @@ def test_replication_all(self) -> None: apdb_config = self.make_apdb_instance() apdb = Apdb.from_config(apdb_config) - self._fill_apdb(apdb) + fill_apdb(apdb, self.include_update_records) expected_chunks = 5 if self.include_update_records else 4 diff --git a/python/lsst/dax/ppdb/tests/_updates.py b/python/lsst/dax/ppdb/tests/_updates.py new file mode 100644 index 00000000..45d3fb1b --- /dev/null +++ b/python/lsst/dax/ppdb/tests/_updates.py @@ -0,0 +1,156 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from lsst.dax.apdb import ( + ApdbCloseDiaObjectValidityRecord, + ApdbReassignDiaSourceToDiaObjectRecord, + ApdbReassignDiaSourceToSSObjectRecord, + ApdbUpdateNDiaSourcesRecord, + ApdbUpdateRecord, + ApdbWithdrawDiaForcedSourceRecord, + ApdbWithdrawDiaSourceRecord, +) + +from ..bigquery.updates import UpdateRecords + + +def _create_test_update_records() -> UpdateRecords: + """Create test UpdateRecords with sample ApdbUpdateRecord instances.""" + records: list[ApdbUpdateRecord] = [] + + # Hardcoded test values + test_update_time_ns = 1640995200000000000 # 2022-01-01 00:00:00 UTC in nanoseconds + test_mjd_tai = 59580.0 # Corresponding MJD TAI for 2022-01-01 + test_replica_chunk_id = 12345 + + # Reassign DIASource to different DIAObject + records.append( + ApdbReassignDiaSourceToDiaObjectRecord( + update_time_ns=test_update_time_ns, + update_order=0, + diaSourceId=100001, + diaObjectId=300001, + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + ) + + # Reassign DIASource to SSObject + records.append( + ApdbReassignDiaSourceToSSObjectRecord( + update_time_ns=test_update_time_ns, + update_order=1, + diaSourceId=100002, + ssObjectId=2001, + ssObjectReassocTimeMjdTai=test_mjd_tai, + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + ) + + # Withdraw DIASource + records.append( + ApdbWithdrawDiaSourceRecord( + update_time_ns=test_update_time_ns, + update_order=2, + diaSourceId=100003, + timeWithdrawnMjdTai=test_mjd_tai, + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + ) + + # Withdraw DIAForcedSource + records.append( + ApdbWithdrawDiaForcedSourceRecord( + update_time_ns=test_update_time_ns, + update_order=3, + diaObjectId=200001, + visit=12345, + detector=42, + timeWithdrawnMjdTai=test_mjd_tai, + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + ) + + # Close DIAObject validity interval + records.append( + ApdbCloseDiaObjectValidityRecord( + update_time_ns=test_update_time_ns, + update_order=4, + diaObjectId=200001, + validityEndMjdTai=test_mjd_tai, + nDiaSources=5, + ra=45.0, + dec=-30.0, + ) + ) + + # Update DIAObject nDiaSources count + records.append( + ApdbUpdateNDiaSourcesRecord( + update_time_ns=test_update_time_ns, + update_order=5, + diaObjectId=200002, + nDiaSources=10, + ra=45.0, + dec=-30.0, + ) + ) + + # Add duplicate records for testing deduplication + # Duplicate of the first record but with later timestamp (should be kept) + records.append( + ApdbReassignDiaSourceToDiaObjectRecord( + update_time_ns=test_update_time_ns + 1000000000, # 1 second later + update_order=0, + diaSourceId=100001, + diaObjectId=400001, # Different target object + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + ) + + # Duplicate of the nDiaSources update but with earlier timestamp (should be + # discarded) + records.append( + ApdbUpdateNDiaSourcesRecord( + update_time_ns=test_update_time_ns - 1000000000, # 1 second earlier + update_order=5, + diaObjectId=200002, + nDiaSources=8, # Different value but older timestamp + ra=45.0, + dec=-30.0, + ) + ) + + return UpdateRecords( + replica_chunk_id=test_replica_chunk_id, + record_count=len(records), + records=records, + ) diff --git a/python/lsst/dax/ppdb/tests/config/__init__.py b/python/lsst/dax/ppdb/tests/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/requirements.txt b/requirements.txt index 4f213f08..c0cc7069 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,13 @@ astropy +google-cloud-bigquery pyarrow pydantic >=2,<3 pyyaml >= 5.1 sqlalchemy + lsst-dax-apdb @ git+https://github.com/lsst/dax_apdb@main -lsst-utils @ git+https://github.com/lsst/utils@main -lsst-resources[s3] @ git+https://github.com/lsst/resources@main -lsst-felis @ git+https://github.com/lsst/felis@main lsst-dax-ppdbx-gcp @ git+https://github.com/lsst-dm/dax_ppdbx_gcp@main +lsst-felis @ git+https://github.com/lsst/felis@main lsst-sdm-schemas @ git+https://github.com/lsst/sdm_schemas@main +lsst-utils @ git+https://github.com/lsst/utils@main +lsst-resources[s3] @ git+https://github.com/lsst/resources@main diff --git a/tests/test_ppdbBigQuery.py b/tests/test_ppdbBigQuery.py deleted file mode 100644 index 198f0bca..00000000 --- a/tests/test_ppdbBigQuery.py +++ /dev/null @@ -1,138 +0,0 @@ -# This file is part of dax_ppdb. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -import gc -import os -import shutil -import tempfile -import unittest -from typing import Any - -from lsst.dax.apdb import ApdbConfig -from lsst.dax.apdb.sql import ApdbSql -from lsst.dax.ppdb import PpdbConfig -from lsst.dax.ppdb.bigquery import PpdbBigQuery -from lsst.dax.ppdb.tests import PpdbTest - -try: - import testing.postgresql -except ImportError: - testing = None - -TEST_SCHEMA = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config/schema.yaml") - -TEST_CONFIG = { - "db_drop": True, - "validate_config": False, - "delete_existing_dirs": True, - "bucket_name": "test_bucket", - "object_prefix": "test_prefix", - "dataset_id": "test_dataset", - "project_id": "test_project", -} - - -class SqliteTestCase(PpdbTest, unittest.TestCase): - """A test case for the PpdbBigQuery class using a SQLite backend.""" - - def setUp(self) -> None: - self.tempdir = tempfile.mkdtemp() - self.apdb_url = f"sqlite:///{self.tempdir}/apdb.sqlite3" - self.ppdb_url = f"sqlite:///{self.tempdir}/ppdb.sqlite3" - - def tearDown(self) -> None: - shutil.rmtree(self.tempdir, ignore_errors=True) - - def make_instance(self, **kwargs: Any) -> PpdbConfig: - """Make config class instance used in all tests.""" - kw = { - **TEST_CONFIG, - "db_url": self.ppdb_url, - "felis_path": TEST_SCHEMA, - "replication_dir": self.tempdir, - } - bq_config = PpdbBigQuery.init_bigquery( - **kw, - ) # type: ignore[arg-type] - return bq_config - - def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig: - """Make APDB instance for tests.""" - kw = { - "schema_file": TEST_SCHEMA, - "ss_schema_file": "", - "db_url": self.apdb_url, - "enable_replica": True, - } - kw.update(kwargs) - return ApdbSql.init_database(**kw) # type: ignore[arg-type] - - -@unittest.skipUnless(testing is not None, "testing.postgresql module not found") -class PostgresTestCase(PpdbTest, unittest.TestCase): - """A test case for the PpdbBigQuery class using a Postgres backend.""" - - postgresql: Any - - @classmethod - def setUpClass(cls) -> None: - # Create the postgres test server. - cls.postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True) - super().setUpClass() - - @classmethod - def tearDownClass(cls) -> None: - # Clean up any lingering SQLAlchemy engines/connections - # so they're closed before we shut down the server. - gc.collect() - cls.postgresql.clear_cache() - super().tearDownClass() - - def setUp(self) -> None: - self.server = self.postgresql() - self.tempdir = tempfile.mkdtemp() - - def tearDown(self) -> None: - self.server = self.postgresql() - shutil.rmtree(self.tempdir, ignore_errors=True) - - def make_instance(self, **kwargs: Any) -> PpdbConfig: - """Make config class instance used in all tests.""" - kw = { - **TEST_CONFIG, - "db_url": self.server.url(), - "db_schema": None, - "felis_path": TEST_SCHEMA, - "replication_dir": self.tempdir, - } - bq_config = PpdbBigQuery.init_bigquery(**kw) # type: ignore[arg-type] - return bq_config - - def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig: - kw = { - "schema_file": TEST_SCHEMA, - "ss_schema_file": "", - "db_url": self.server.url(), - "namespace": "apdb", - "enable_replica": True, - } - kw.update(kwargs) - return ApdbSql.init_database(**kw) # type: ignore[arg-type] diff --git a/tests/test_ppdb_bigquery.py b/tests/test_ppdb_bigquery.py new file mode 100644 index 00000000..43a870d9 --- /dev/null +++ b/tests/test_ppdb_bigquery.py @@ -0,0 +1,39 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +from lsst.dax.ppdb.tests import PpdbTest +from lsst.dax.ppdb.tests._bigquery import PostgresMixin, SqliteMixin + +try: + import testing.postgresql +except ImportError: + testing = None + + +class SqliteTestCase(SqliteMixin, PpdbTest, unittest.TestCase): + """A test case for the PpdbBigQuery class using a SQLite backend.""" + + +@unittest.skipUnless(testing is not None, "testing.postgresql module not found") +class PostgresTestCase(PostgresMixin, PpdbTest, unittest.TestCase): + """A test case for the PpdbBigQuery class using a Postgres backend.""" diff --git a/tests/test_ppdbSql.py b/tests/test_ppdb_sql.py similarity index 91% rename from tests/test_ppdbSql.py rename to tests/test_ppdb_sql.py index f8675079..b6a6a1ab 100644 --- a/tests/test_ppdbSql.py +++ b/tests/test_ppdb_sql.py @@ -20,7 +20,6 @@ # along with this program. If not, see . import gc -import os import shutil import tempfile import unittest @@ -30,15 +29,13 @@ from lsst.dax.apdb.sql import ApdbSql from lsst.dax.ppdb import PpdbConfig from lsst.dax.ppdb.sql import PpdbSql -from lsst.dax.ppdb.tests import PpdbTest +from lsst.dax.ppdb.tests import TEST_SCHEMA_RESOURCE_PATH, PpdbTest try: import testing.postgresql except ImportError: testing = None -TEST_SCHEMA = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config/schema.yaml") - class ApdbSQLiteTestCase(PpdbTest, unittest.TestCase): """A test case for PpdbSql class using SQLite backend.""" @@ -55,11 +52,11 @@ def tearDown(self) -> None: def make_instance(self, **kwargs: Any) -> PpdbConfig: """Make config class instance used in all tests.""" - return PpdbSql.init_database(db_url=self.ppdb_url, schema_file=TEST_SCHEMA, **kwargs) + return PpdbSql.init_database(db_url=self.ppdb_url, schema_file=TEST_SCHEMA_RESOURCE_PATH, **kwargs) def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig: kw = { - "schema_file": TEST_SCHEMA, + "schema_file": TEST_SCHEMA_RESOURCE_PATH, "ss_schema_file": "", "db_url": self.apdb_url, "enable_replica": True, @@ -98,11 +95,13 @@ def tearDown(self) -> None: def make_instance(self, **kwargs: Any) -> PpdbConfig: """Make config class instance used in all tests.""" - return PpdbSql.init_database(db_url=self.server.url(), schema_file=TEST_SCHEMA, **kwargs) + return PpdbSql.init_database( + db_url=self.server.url(), schema_file=TEST_SCHEMA_RESOURCE_PATH, **kwargs + ) def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig: kw = { - "schema_file": TEST_SCHEMA, + "schema_file": TEST_SCHEMA_RESOURCE_PATH, "ss_schema_file": "", "db_url": self.server.url(), "namespace": "apdb", diff --git a/tests/test_update_record_expander.py b/tests/test_update_record_expander.py new file mode 100644 index 00000000..5038bfb5 --- /dev/null +++ b/tests/test_update_record_expander.py @@ -0,0 +1,322 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import datetime +import unittest + +import astropy.time + +from lsst.dax.apdb import ( + ApdbCloseDiaObjectValidityRecord, + ApdbReassignDiaSourceToDiaObjectRecord, + ApdbReassignDiaSourceToSSObjectRecord, + ApdbUpdateNDiaSourcesRecord, + ApdbWithdrawDiaForcedSourceRecord, + ApdbWithdrawDiaSourceRecord, +) +from lsst.dax.ppdb.bigquery.updates import ExpandedUpdateRecord, UpdateRecordExpander, UpdateRecords +from lsst.dax.ppdb.tests._updates import _create_test_update_records + + +class UpdateRecordExpanderTestCase(unittest.TestCase): + """Test UpdateRecordExpander functionality.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + # Test time for consistent timestamps + self.update_time = astropy.time.Time("2021-03-01T12:00:00", format="isot", scale="tai") + self.update_time_ns = int(self.update_time.unix_tai * 1e9) + + # Test replica chunk ID + self.replica_chunk_id = 12345 + + def test_get_update_fields(self) -> None: + """Test get_update_fields class method.""" + # Test known update types + self.assertEqual( + UpdateRecordExpander.get_update_fields("reassign_diasource_to_diaobject"), ["diaObjectId"] + ) + self.assertEqual( + UpdateRecordExpander.get_update_fields("reassign_diasource_to_ssobject"), + ["ssObjectId", "ssObjectReassocTimeMjdTai"], + ) + self.assertEqual( + UpdateRecordExpander.get_update_fields("withdraw_diasource"), ["timeWithdrawnMjdTai"] + ) + self.assertEqual( + UpdateRecordExpander.get_update_fields("withdraw_diaforcedsource"), ["timeWithdrawnMjdTai"] + ) + self.assertEqual( + UpdateRecordExpander.get_update_fields("close_diaobject_validity"), + ["validityEndMjdTai", "nDiaSources"], + ) + self.assertEqual(UpdateRecordExpander.get_update_fields("update_n_dia_sources"), ["nDiaSources"]) + + # Test unknown update type + with self.assertRaises(ValueError) as cm: + UpdateRecordExpander.get_update_fields("unknown_update_type") + self.assertIn("Unknown update_type: unknown_update_type", str(cm.exception)) + + def test_get_record_id_field_names(self) -> None: + """Test get_record_id_field class method.""" + from lsst.dax.ppdb.bigquery.updates import UpdateRecordExpander + + self.assertEqual( + UpdateRecordExpander.get_record_id_fields("reassign_diasource_to_diaobject"), ["diaSourceId"] + ) + self.assertEqual( + UpdateRecordExpander.get_record_id_fields("reassign_diasource_to_ssobject"), ["diaSourceId"] + ) + self.assertEqual(UpdateRecordExpander.get_record_id_fields("withdraw_diasource"), ["diaSourceId"]) + self.assertEqual( + UpdateRecordExpander.get_record_id_fields("withdraw_diaforcedsource"), + ["diaObjectId", "visit", "detector"], + ) + self.assertEqual( + UpdateRecordExpander.get_record_id_fields("close_diaobject_validity"), ["diaObjectId"] + ) + self.assertEqual(UpdateRecordExpander.get_record_id_fields("update_n_dia_sources"), ["diaObjectId"]) + + # Test unknown update type + with self.assertRaises(ValueError) as cm: + UpdateRecordExpander.get_record_id_fields("unknown_update_type") + self.assertIn("Unknown update_type: unknown_update_type", str(cm.exception)) + + def test_reassign_diasource_to_diaobject(self) -> None: + """Test expand_single_record with + ApdbReassignDiaSourceToDiaObjectRecord. + """ + from lsst.dax.ppdb.bigquery.updates import ExpandedUpdateRecord, UpdateRecordExpander + + record = ApdbReassignDiaSourceToDiaObjectRecord( + update_time_ns=self.update_time_ns, + update_order=0, + diaSourceId=100001, + diaObjectId=300001, + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + + expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id) + + # Should expand to 1 record (diaObjectId) + self.assertEqual(len(expanded), 1) + + expanded_record = expanded[0] + self.assertIsInstance(expanded_record, ExpandedUpdateRecord) + self.assertEqual(expanded_record.table_name, "DiaSource") + self.assertEqual(expanded_record.record_id, [100001]) + self.assertEqual(expanded_record.field_name, "diaObjectId") + self.assertEqual(expanded_record.value_json, 300001) + self.assertEqual(expanded_record.replica_chunk_id, self.replica_chunk_id) + self.assertEqual(expanded_record.update_order, 0) + self.assertEqual(expanded_record.update_time_ns, self.update_time_ns) + + def test_reassign_diasource_to_ssobject(self) -> None: + """Test expand_single_record with + ApdbReassignDiaSourceToSSObjectRecord. + """ + record = ApdbReassignDiaSourceToSSObjectRecord( + update_time_ns=self.update_time_ns, + update_order=0, + diaSourceId=100001, + ssObjectId=2001, + ssObjectReassocTimeMjdTai=float(self.update_time.tai.mjd), + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + + expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id) + + # Should expand to 2 records (ssObjectId and ssObjectReassocTimeMjdTai) + self.assertEqual(len(expanded), 2) + + # Check first expanded record (ssObjectId) + first_record = expanded[0] + self.assertIsInstance(first_record, ExpandedUpdateRecord) + self.assertEqual(first_record.table_name, "DiaSource") + self.assertEqual(first_record.record_id, [100001]) + self.assertEqual(first_record.field_name, "ssObjectId") + self.assertEqual(first_record.value_json, 2001) + self.assertEqual(first_record.replica_chunk_id, self.replica_chunk_id) + self.assertEqual(first_record.update_order, 0) + self.assertEqual(first_record.update_time_ns, self.update_time_ns) + + # Check second expanded record (ssObjectReassocTimeMjdTai) + second_record = expanded[1] + self.assertEqual(second_record.table_name, "DiaSource") + self.assertEqual(second_record.record_id, [100001]) + self.assertEqual(second_record.field_name, "ssObjectReassocTimeMjdTai") + self.assertEqual(second_record.value_json, float(self.update_time.tai.mjd)) + + def test_withdraw_diasource(self) -> None: + """Test expand_single_record with ApdbWithdrawDiaSourceRecord.""" + record = ApdbWithdrawDiaSourceRecord( + update_time_ns=self.update_time_ns, + update_order=2, + diaSourceId=100003, + timeWithdrawnMjdTai=self.update_time.tai.mjd, + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + + expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id) + + # Should expand to 1 record (timeWithdrawnMjdTai) + self.assertEqual(len(expanded), 1) + + expanded_record = expanded[0] + self.assertEqual(expanded_record.table_name, "DiaSource") + self.assertEqual(expanded_record.record_id, [100003]) + self.assertEqual(expanded_record.field_name, "timeWithdrawnMjdTai") + self.assertEqual(expanded_record.value_json, self.update_time.tai.mjd) + + def test_update_n_dia_sources(self) -> None: + """Test expand_single_record with ApdbUpdateNDiaSourcesRecord.""" + record = ApdbUpdateNDiaSourcesRecord( + update_time_ns=self.update_time_ns, + update_order=5, + diaObjectId=200002, + nDiaSources=10, + ra=45.0, + dec=-30.0, + ) + + expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id) + + # Should expand to 1 record (nDiaSources) + self.assertEqual(len(expanded), 1) + + expanded_record = expanded[0] + self.assertEqual(expanded_record.table_name, "DiaObject") + self.assertEqual(expanded_record.record_id, [200002]) + self.assertEqual(expanded_record.field_name, "nDiaSources") + self.assertEqual(expanded_record.value_json, 10) + + def test_close_diaobject_validity(self) -> None: + """Test expand_single_record with ApdbCloseDiaObjectValidityRecord.""" + record = ApdbCloseDiaObjectValidityRecord( + update_time_ns=self.update_time_ns, + update_order=4, + diaObjectId=200001, + validityEndMjdTai=self.update_time.tai.mjd, + nDiaSources=5, + ra=45.0, + dec=-30.0, + ) + + expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id) + + # Should expand to 2 records (validityEndMjdTai and nDiaSources) + self.assertEqual(len(expanded), 2) + + # Check first expanded record (validityEndMjdTai) + first_record = expanded[0] + self.assertIsInstance(first_record, ExpandedUpdateRecord) + self.assertEqual(first_record.table_name, "DiaObject") + self.assertEqual(first_record.record_id, [200001]) + self.assertEqual(first_record.field_name, "validityEndMjdTai") + self.assertEqual(first_record.value_json, self.update_time.tai.mjd) + + # Check second expanded record (nDiaSources) + second_record = expanded[1] + self.assertEqual(second_record.table_name, "DiaObject") + self.assertEqual(second_record.record_id, [200001]) + self.assertEqual(second_record.field_name, "nDiaSources") + self.assertEqual(second_record.value_json, 5) + + def test_withdraw_diaforcedsource(self) -> None: + """Test expand_single_record with ApdbWithdrawDiaForcedSourceRecord.""" + record = ApdbWithdrawDiaForcedSourceRecord( + update_time_ns=self.update_time_ns, + update_order=2, + diaObjectId=200001, + visit=12345, + detector=42, + timeWithdrawnMjdTai=self.update_time.tai.mjd, + ra=45.0, + dec=-30.0, + midpointMjdTai=60000.0, + ) + + expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id) + + # Should expand to 1 record (timeWithdrawnMjdTai) + self.assertEqual(len(expanded), 1) + + expanded_record = expanded[0] + self.assertEqual(expanded_record.table_name, "DiaForcedSource") + # The record ID should be a list of the composite key components + # [diaObjectId, visit, detector] for BigQuery compatibility + expected_record_id = [200001, 12345, 42] + self.assertEqual(expanded_record.record_id, expected_record_id) + self.assertEqual(expanded_record.field_name, "timeWithdrawnMjdTai") + self.assertEqual(expanded_record.value_json, self.update_time.tai.mjd) + + def test_update_records_all(self) -> None: + """Test the full expand_updates method with multiple record types.""" + update_records = _create_test_update_records() + + expanded = UpdateRecordExpander.expand_updates(update_records) + + self.assertEqual(len(expanded), 10) + + # Verify all expanded records have correct replica_chunk_id + for record in expanded: + self.assertEqual(record.replica_chunk_id, self.replica_chunk_id) + self.assertIsInstance(record.update_time_ns, int) + self.assertIsInstance(record.update_order, int) + + # Check that we have the expected table names + table_names = {record.table_name for record in expanded} + expected_tables = {"DiaSource", "DiaObject", "DiaForcedSource"} + self.assertEqual(table_names, expected_tables) + + # Check that we have the expected field names + field_names = {record.field_name for record in expanded} + expected_fields = { + "diaObjectId", # from reassign to diaobject + "ssObjectId", + "ssObjectReassocTimeMjdTai", # from reassign to ssobject + "timeWithdrawnMjdTai", # from withdraw diasource and withdraw forced source + "validityEndMjdTai", + "nDiaSources", # from close validity and update n dia sources + } + self.assertEqual(field_names, expected_fields) + + def test_empty_records(self) -> None: + """Test expand_updates with empty records list.""" + empty_update_records = UpdateRecords( + replica_chunk_id=self.replica_chunk_id, + record_count=0, + records=[], + file_created_at=datetime.datetime.now(datetime.UTC), + ) + + expanded = UpdateRecordExpander.expand_updates(empty_update_records) + self.assertEqual(len(expanded), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_update_records.py b/tests/test_update_records.py new file mode 100644 index 00000000..87c3dc8d --- /dev/null +++ b/tests/test_update_records.py @@ -0,0 +1,329 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import json +import unittest + +import pytest +from google.cloud import storage + +from lsst.dax.apdb import ( + Apdb, + ApdbReplica, + apdbUpdateRecord, +) +from lsst.dax.ppdb import Ppdb +from lsst.dax.ppdb.bigquery import PpdbBigQuery +from lsst.dax.ppdb.bigquery.updates import UpdateRecords +from lsst.dax.ppdb.replicator import Replicator +from lsst.dax.ppdb.tests import fill_apdb +from lsst.dax.ppdb.tests._bigquery import ( + ChunkUploaderWithoutPubSub, + PostgresMixin, + delete_test_bucket, + generate_test_bucket_name, + have_valid_google_credentials, +) + + +@unittest.skipIf(not have_valid_google_credentials(), "Missing valid Google credentials") +class UpdateRecordsTestCase(PostgresMixin, unittest.TestCase): + """A test case for the handling of APDB record updates by PpdbBigQuery and + related classes including the ChunkUploader. + """ + + def setUp(self): + super().setUp() + + # Make APDB instance and fill it with test data. + apdb_config = self.make_apdb_instance() + apdb = Apdb.from_config(apdb_config) + fill_apdb(apdb, include_update_records=True) + apdb_replica = ApdbReplica.from_config(apdb_config) + + # Make PPDB instance. + self.ppdb_config = self.make_instance() + self.ppdb = Ppdb.from_config(self.ppdb_config) + assert isinstance(self.ppdb, PpdbBigQuery) + + # Replicate APDB replica chunks to the PPDB. + replicator = Replicator( + apdb_replica, self.ppdb, update=False, min_wait_time=0, max_wait_time=0, check_interval=0 + ) + replicator.run(exit_on_empty=True) + + def test_json_serialization(self) -> None: + """Test that the APDB update records are correctly saved to a JSON file + in the replication output and can be read back as valid UpdateRecords + objects. + """ + update_records_path = self.ppdb.replication_path / "2021/03/01/1614600000" / "update_records.json" + self.assertTrue(update_records_path.exists(), "Update records file not found in replication output") + + update_records = UpdateRecords.from_json_file(update_records_path) + print("\n" + str(update_records)) + + self.assertEqual( + update_records.replica_chunk_id, + 1614600000, + "Unexpected replica chunk ID in deserialized update records", + ) + + self.assertEqual(update_records.record_count, 3, "Unexpected number of update records deserialized") + + self.assertEqual( + len(update_records.records), 3, "Unexpected number of update records in the deserialized object" + ) + + for record in update_records.records: + self.assertIsInstance( + record, + apdbUpdateRecord.ApdbUpdateRecord, + "Deserialized record is not an instance of ApdbUpdateRecord", + ) + + update_record = update_records.records[0] + self.assertIsInstance( + update_record, + apdbUpdateRecord.ApdbReassignDiaSourceToSSObjectRecord, + "Deserialized record is not an instance of ApdbReassignDiaSourceToSSObjectRecord", + ) + assert isinstance(update_record, apdbUpdateRecord.ApdbReassignDiaSourceToSSObjectRecord) + self.assertEqual( + update_record.diaSourceId, + 700, + "Unexpected diaSourceId in deserialized ApdbReassignDiaSourceToSSObjectRecord", + ) + self.assertEqual( + update_record.ssObjectId, + 1, + "Unexpected ssObjectId in deserialized ApdbReassignDiaSourceToSSObjectRecord", + ) + self.assertEqual( + update_record.update_time_ns, + 1614600037000000000, + "Unexpected update_time_ns in deserialized ApdbReassignDiaSourceToSSObjectRecord", + ) + self.assertEqual( + update_record.update_order, + 0, + "Unexpected update_order in deserialized ApdbReassignDiaSourceToSSObjectRecord", + ) + self.assertEqual( + update_record.midpointMjdTai, + 60000.0, + "Unexpected midpointMjdTai in deserialized ApdbReassignDiaSourceToSSObjectRecord", + ) + self.assertEqual( + update_record.ssObjectReassocTimeMjdTai, + 59274.50042824074, + "Unexpected ssObjectReassocTimeMjdTai in deserialized ApdbReassignDiaSourceToSSObjectRecord", + ) + self.assertNotEqual( + update_record.ra, + 0.0, + "Unexpected ra in deserialized ApdbReassignDiaSourceToSSObjectRecord, should not be 0.0", + ) + self.assertNotEqual( + update_record.dec, + 0.0, + "Unexpected dec in deserialized ApdbReassignDiaSourceToSSObjectRecord, should not be 0.0", + ) + + update_record = update_records.records[1] + self.assertIsInstance( + update_record, + apdbUpdateRecord.ApdbCloseDiaObjectValidityRecord, + "Deserialized record is not an instance of ApdbCloseDiaObjectValidityRecord", + ) + self.assertEqual( + update_record.diaObjectId, + 200, + "Unexpected diaObjectId in deserialized ApdbCloseDiaObjectValidityRecord", + ) + self.assertNotEqual( + update_record.ra, + 0.0, + "Unexpected ra in deserialized ApdbCloseDiaObjectValidityRecord, should not be 0.0", + ) + self.assertNotEqual( + update_record.dec, + 0.0, + "Unexpected dec in deserialized ApdbCloseDiaObjectValidityRecord, should not be 0.0", + ) + self.assertEqual( + update_record.update_time_ns, + 1614600037000000000, + "Unexpected update_time_ns in deserialized ApdbCloseDiaObjectValidityRecord", + ) + self.assertEqual( + update_record.update_order, + 1, + "Unexpected update_order in deserialized ApdbCloseDiaObjectValidityRecord", + ) + self.assertEqual( + update_record.validityEndMjdTai, + 59274.50042824074, + "Unexpected validityEndMjdTai in deserialized ApdbCloseDiaObjectValidityRecord", + ) + self.assertIsNone( + update_record.nDiaSources, + "Unexpected nDiaSources in deserialized ApdbCloseDiaObjectValidityRecord, expected None", + ) + + update_record = update_records.records[2] + self.assertIsInstance( + update_record, + apdbUpdateRecord.ApdbWithdrawDiaForcedSourceRecord, + "Deserialized record is not an instance of ApdbWithdrawDiaForcedSourceRecord", + ) + self.assertEqual( + update_record.diaObjectId, + 200, + "Unexpected diaObjectId in deserialized ApdbWithdrawDiaForcedSourceRecord", + ) + self.assertEqual( + update_record.visit, + 7, + "Unexpected visit in deserialized ApdbWithdrawDiaForcedSourceRecord", + ) + self.assertEqual( + update_record.detector, + 1, + "Unexpected detector in deserialized ApdbWithdrawDiaForcedSourceRecord", + ) + self.assertNotEqual( + update_record.ra, + 0.0, + "Unexpected ra in deserialized ApdbWithdrawDiaForcedSourceRecord, should not be 0.0", + ) + self.assertNotEqual( + update_record.dec, + 0.0, + "Unexpected dec in deserialized ApdbWithdrawDiaForcedSourceRecord, should not be 0.0", + ) + self.assertEqual( + update_record.midpointMjdTai, + 60000.0, + "Unexpected midpointMjdTai in deserialized ApdbWithdrawDiaForcedSourceRecord", + ) + self.assertEqual( + update_record.update_time_ns, + 1614600037000000000, + "Unexpected update_time_ns in deserialized ApdbWithdrawDiaForcedSourceRecord", + ) + self.assertEqual( + update_record.update_order, + 2, + "Unexpected update_order in deserialized ApdbWithdrawDiaForcedSourceRecord", + ) + self.assertEqual( + update_record.timeWithdrawnMjdTai, + 59274.50042824074, + "Unexpected timeWithdrawnMjdTai in deserialized ApdbWithdrawDiaForcedSourceRecord", + ) + self.assertNotEqual( + update_record.ra, + 0.0, + "Unexpected ra in deserialized ApdbWithdrawDiaForcedSourceRecord, should not be 0.0", + ) + self.assertNotEqual( + update_record.dec, + 0.0, + "Unexpected dec in deserialized ApdbWithdrawDiaForcedSourceRecord, should not be 0.0", + ) + + # FIXME: This should be in a separate test case and probably a separate + # module as well. + @pytest.mark.skipif( + pytest.importorskip("lsst.dax.ppdbx.gcp", reason="dax_ppdbx_gcp is not installed") is None, + reason="", + ) + def test_chunk_uploader(self) -> None: + """Test that the update records are correctly uploaded to Google Cloud + Storage after replication. + """ + # Change the configuration to use a unique test bucket name to avoid + # conflicts + self.ppdb_config.bucket_name = generate_test_bucket_name("ppdb-test-gcs-upload") + + # Create the test GCS bucket + storage_client = storage.Client() + try: + bucket = storage_client.bucket(self.ppdb_config.bucket_name) + bucket.create(location="US") + except Exception as e: + self.fail(f"Failed to create test GCS bucket: {e}") + + # Configure and run the uploader + uploader = ChunkUploaderWithoutPubSub( + self.ppdb_config, + wait_interval=0, + exit_on_empty=True, + exit_on_error=True, + ) + print(f"Uploader will copy files to {uploader.bucket_name}/{uploader.prefix}/") + uploader.run() + + # Retrieve the update records file[] + blobs = list(bucket.list_blobs(match_glob="**/update_records.json")) + update_records_files = [b.name for b in blobs] + self.assertEqual( + len(update_records_files), + 1, + f"Expected exactly one update_records.json file in GCS, found " + f"{len(update_records_files)}: {update_records_files}", + ) + + # Download the contents of the update records file as a string + update_records_str = blobs[0].download_as_text() + + # Print the contents of the update records file for debugging + update_records_json = json.loads(update_records_str) + print(f"Contents of update_records.json in GCS:\n{json.dumps(update_records_json, indent=2)}") + + # Load the update records into the data model and perform a few basic + # checks (test_json_serialization already tests this in detail, so we + # just check a few key fields here). + update_records = UpdateRecords.model_validate(update_records_json) + self.assertEqual( + update_records.replica_chunk_id, + 1614600000, + "Unexpected replica chunk ID in update records file from GCS", + ) + self.assertEqual( + update_records.record_count, + 3, + f"Expected record_count of 3 in update records file from GCS, found " + f"{update_records.record_count}", + ) + self.assertEqual( + len(update_records.records), + 3, + f"Expected 3 update records in the file from GCS, found {len(update_records.records)}", + ) + + # FIXME: This should be in a tearDown() method. + # Delete the test GCS bucket + try: + delete_test_bucket(bucket) + except Exception as e: + raise RuntimeError(f"Failed to delete test GCS bucket: {e}") from e diff --git a/tests/test_updates_manager.py b/tests/test_updates_manager.py new file mode 100644 index 00000000..b9721fe0 --- /dev/null +++ b/tests/test_updates_manager.py @@ -0,0 +1,250 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest +import uuid +from collections.abc import Collection, Sequence + +import astropy +import felis +from google.cloud import bigquery, storage + +from lsst.dax.apdb import ( + ApdbTableData, + ReplicaChunk, +) +from lsst.dax.ppdb import Ppdb +from lsst.dax.ppdb.bigquery import PpdbBigQuery +from lsst.dax.ppdb.bigquery.updates.updates_manager import UpdatesManager +from lsst.dax.ppdb.tests._bigquery import ( + ChunkUploaderWithoutPubSub, + PostgresMixin, + generate_test_bucket_name, + have_valid_google_credentials, + json_rows_to_buf, +) +from lsst.dax.ppdb.tests._updates import _create_test_update_records + + +@unittest.skipIf(not have_valid_google_credentials(), "Missing valid Google credentials") +class UpdatesManagerTestCase(PostgresMixin, unittest.TestCase): + """A test case for the handling of APDB record updates by PpdbBigQuery and + related classes including the ChunkUploader. + """ + + def setUp(self): + super().setUp() + + # Set up BigQuery client and test dataset + self.bq_client = bigquery.Client() + + bucket_name = generate_test_bucket_name("ppdb-updates-manager-test") + dataset_id = f"test_updates_manager_{uuid.uuid4().hex[:8]}" + project_id = self.bq_client.project + config = { + "db_drop": True, + "validate_config": False, + "delete_existing_dirs": True, + "bucket_name": bucket_name, + "object_prefix": "data/test", + "dataset_id": dataset_id, + "project_id": project_id, + } + + # Setup the Postgres database and create the config instance + self.ppdb_config = self.make_instance(config) + + # Create the test dataset and tables in BigQuery + self.target_dataset_fqn = f"{project_id}.{dataset_id}" + self._create_test_dataset(self.bq_client, dataset_id) + + # Create the test GCS bucket + storage_client = storage.Client() + try: + bucket = storage_client.bucket(self.ppdb_config.bucket_name) + bucket.create(location="US") + except Exception as e: + self.fail(f"Failed to create test GCS bucket: {e}") + + # Create the PPDB instance + self.ppdb = Ppdb.from_config(self.ppdb_config) + assert isinstance(self.ppdb, PpdbBigQuery) + + def tearDown(self): + # Delete the test dataset + try: + self.bq_client.delete_dataset( + self.ppdb_config.dataset_id, delete_contents=True, not_found_ok=True + ) + except Exception as e: + print(f"Failed to delete test dataset: {e}") + + # Delete the test GCS bucket + storage_client = storage.Client() + try: + bucket = storage_client.bucket(self.ppdb_config.bucket_name) + blobs = list(bucket.list_blobs()) + for blob in blobs: + blob.delete() + bucket.delete() + except Exception as e: + print(f"Failed to delete test GCS bucket: {e}") + + super().tearDown() + + def _create_test_dataset(self, client: bigquery.Client, dataset_id: str) -> None: + dataset = bigquery.Dataset(f"{client.project}.{dataset_id}") + client.create_dataset(dataset, exists_ok=False) + + # Create DiaObject table + schema = [ + bigquery.SchemaField("diaObjectId", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("validityEndMjdTai", "FLOAT", mode="NULLABLE"), + bigquery.SchemaField("nDiaSources", "INTEGER", mode="NULLABLE"), + ] + table_fqn = f"{self.target_dataset_fqn}.DiaObject" + table = bigquery.Table(table_fqn, schema=schema) + client.create_table(table) + rows = [ + {"diaObjectId": 200001, "validityEndMjdTai": None, "nDiaSources": 3}, + {"diaObjectId": 200002, "validityEndMjdTai": None, "nDiaSources": 7}, + {"diaObjectId": 200003, "validityEndMjdTai": 59000.0, "nDiaSources": 2}, + ] + buf = json_rows_to_buf(rows) + job = client.load_table_from_file( + buf, + table_fqn, + job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON), + ) + job.result() + + # Create test DiaSource table + schema = [ + bigquery.SchemaField("diaSourceId", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("diaObjectId", "INTEGER", mode="NULLABLE"), + bigquery.SchemaField("ssObjectId", "INTEGER", mode="NULLABLE"), + bigquery.SchemaField("ssObjectReassocTimeMjdTai", "FLOAT", mode="NULLABLE"), + bigquery.SchemaField("timeWithdrawnMjdTai", "FLOAT", mode="NULLABLE"), + ] + table_fqn = f"{self.target_dataset_fqn}.DiaSource" + table = bigquery.Table(table_fqn, schema=schema) + self.bq_client.create_table(table) + rows = [ + { + "diaSourceId": 100001, + "diaObjectId": 200001, + "ssObjectId": None, + "ssObjectReassocTimeMjdTai": None, + "timeWithdrawnMjdTai": None, + }, + { + "diaSourceId": 100002, + "diaObjectId": 200002, + "ssObjectId": None, + "ssObjectReassocTimeMjdTai": None, + "timeWithdrawnMjdTai": None, + }, + { + "diaSourceId": 100003, + "diaObjectId": 200003, + "ssObjectId": None, + "ssObjectReassocTimeMjdTai": None, + "timeWithdrawnMjdTai": None, + }, + { + "diaSourceId": 100004, + "diaObjectId": 200004, + "ssObjectId": None, + "ssObjectReassocTimeMjdTai": None, + "timeWithdrawnMjdTai": None, + }, + ] + job = client.load_table_from_file( + json_rows_to_buf(rows), + table_fqn, + job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON), + ) + job.result() + + # Create test DiaForcedSource table + schema = [ + bigquery.SchemaField("diaObjectId", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("visit", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("detector", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("timeWithdrawnMjdTai", "FLOAT", mode="NULLABLE"), + ] + table_fqn = f"{self.target_dataset_fqn}.DiaForcedSource" + table = bigquery.Table(table_fqn, schema=schema) + self.bq_client.create_table(table) + rows = [ + {"diaObjectId": 200001, "visit": 12345, "detector": 42, "timeWithdrawnMjdTai": None}, + {"diaObjectId": 200001, "visit": 12346, "detector": 42, "timeWithdrawnMjdTai": None}, + ] + job = self.bq_client.load_table_from_file( + json_rows_to_buf(rows), + table_fqn, + job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON), + ) + job.result() + + def test_apply_updates(self): + """Test that the update records are correctly uploaded to Google Cloud + Storage after replication. + """ + + class DummyApdbTableData(ApdbTableData): + def column_names(self) -> Sequence[str]: + return [] + + def column_defs(self) -> Sequence[tuple[str, felis.datamodel.DataType]]: + return [] + + def rows(self) -> Collection[tuple]: + return [] + + # Create and store the test update records + update_records = _create_test_update_records() + self.ppdb.store( + ReplicaChunk( + id=update_records.replica_chunk_id, + last_update_time=astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), + unique_id=uuid.uuid4(), + ), + objects=DummyApdbTableData(), + sources=DummyApdbTableData(), + forced_sources=DummyApdbTableData(), + update_records=update_records.records, + update=True, + ) + + # Configure and run the uploader + uploader = ChunkUploaderWithoutPubSub( + self.ppdb_config, + wait_interval=0, + exit_on_empty=True, + exit_on_error=True, + ) + print(f"Uploader will copy files to {uploader.bucket_name}/{uploader.prefix}") + uploader.run() + + # Apply the updates to the target tables + updates_manager = UpdatesManager(self.ppdb) + updates_manager.apply_updates([update_records.replica_chunk_id]) diff --git a/tests/test_updates_merger.py b/tests/test_updates_merger.py new file mode 100644 index 00000000..df8c144d --- /dev/null +++ b/tests/test_updates_merger.py @@ -0,0 +1,235 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest +import uuid + +try: + from google.cloud import bigquery +except ImportError: + bigquery = None + +from lsst.dax.ppdb.bigquery.updates import ( + DiaForcedSourceUpdatesMerger, + DiaObjectUpdatesMerger, + DiaSourceUpdatesMerger, + UpdateRecordExpander, + UpdatesTable, +) +from lsst.dax.ppdb.tests._bigquery import have_valid_google_credentials, json_rows_to_buf +from lsst.dax.ppdb.tests._updates import _create_test_update_records + + +@unittest.skipIf(not have_valid_google_credentials(), "Missing valid Google credentials") +class TestUpdatesMerger(unittest.TestCase): + """Test UpdatesMerger functionality.""" + + def setUp(self): + self.client = bigquery.Client() + self.dataset_id = f"test_merger_{uuid.uuid4().hex[:8]}" + self.project_id = self.client.project + self.updates_table_fqn = f"{self.project_id}.{self.dataset_id}.updates" + self.target_dataset_fqn = f"{self.project_id}.{self.dataset_id}" + dataset = bigquery.Dataset(f"{self.project_id}.{self.dataset_id}") + dataset.default_table_expiration_ms = 3600000 + self.client.create_dataset(dataset) + + def tearDown(self): + try: + self.client.delete_dataset(self.dataset_id, delete_contents=True, not_found_ok=True) + except Exception: + pass + + def _create_target_table(self): + schema = [ + bigquery.SchemaField("diaObjectId", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("validityEndMjdTai", "FLOAT", mode="NULLABLE"), + bigquery.SchemaField("nDiaSources", "INTEGER", mode="NULLABLE"), + ] + table_fqn = f"{self.target_dataset_fqn}.DiaObject" + table = bigquery.Table(table_fqn, schema=schema) + self.client.create_table(table) + rows = [ + {"diaObjectId": 200001, "validityEndMjdTai": None, "nDiaSources": 3}, + {"diaObjectId": 200002, "validityEndMjdTai": None, "nDiaSources": 7}, + {"diaObjectId": 200003, "validityEndMjdTai": 59000.0, "nDiaSources": 2}, + ] + buf = json_rows_to_buf(rows) + job = self.client.load_table_from_file( + buf, + table_fqn, + job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON), + ) + job.result() + + def test_merge_diaobject(self): + self._create_target_table() + updates_table = UpdatesTable(self.client, self.updates_table_fqn) + updates_table.create() + update_records = _create_test_update_records() + expanded = UpdateRecordExpander.expand_updates(update_records) + updates_table.insert(expanded) + dedup_fqn = f"{self.updates_table_fqn}_dedup" + updates_table.deduplicate_to(dedup_fqn) + table_fqn = f"{self.target_dataset_fqn}.DiaObject" + query = f"SELECT * FROM `{table_fqn}` ORDER BY diaObjectId" + before = {r.diaObjectId: r for r in self.client.query(query).result()} + print("Before merge:", before) + merger = DiaObjectUpdatesMerger(self.client) + merger.merge(updates_table_fqn=dedup_fqn, target_dataset_fqn=self.target_dataset_fqn) + after = {r.diaObjectId: r for r in self.client.query(query).result()} + print("After merge:", after) + self.assertEqual(after[200001].validityEndMjdTai, 59580.0) + self.assertEqual(after[200001].nDiaSources, 5) + self.assertIsNone(after[200002].validityEndMjdTai) + self.assertEqual(after[200002].nDiaSources, 10) + self.assertEqual(after[200003].validityEndMjdTai, before[200003].validityEndMjdTai) + self.assertEqual(after[200003].nDiaSources, before[200003].nDiaSources) + + def test_merge_diasource(self): + schema = [ + bigquery.SchemaField("diaSourceId", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("diaObjectId", "INTEGER", mode="NULLABLE"), + bigquery.SchemaField("ssObjectId", "INTEGER", mode="NULLABLE"), + bigquery.SchemaField("ssObjectReassocTimeMjdTai", "FLOAT", mode="NULLABLE"), + bigquery.SchemaField("timeWithdrawnMjdTai", "FLOAT", mode="NULLABLE"), + ] + table_fqn = f"{self.target_dataset_fqn}.DiaSource" + table = bigquery.Table(table_fqn, schema=schema) + self.client.create_table(table) + rows = [ + { + "diaSourceId": 100001, + "diaObjectId": 200001, + "ssObjectId": None, + "ssObjectReassocTimeMjdTai": None, + "timeWithdrawnMjdTai": None, + }, + { + "diaSourceId": 100002, + "diaObjectId": 200002, + "ssObjectId": None, + "ssObjectReassocTimeMjdTai": None, + "timeWithdrawnMjdTai": None, + }, + { + "diaSourceId": 100003, + "diaObjectId": 200003, + "ssObjectId": None, + "ssObjectReassocTimeMjdTai": None, + "timeWithdrawnMjdTai": None, + }, + { + "diaSourceId": 100004, + "diaObjectId": 200004, + "ssObjectId": None, + "ssObjectReassocTimeMjdTai": None, + "timeWithdrawnMjdTai": None, + }, + ] + job = self.client.load_table_from_file( + json_rows_to_buf(rows), + table_fqn, + job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON), + ) + job.result() + + updates_table = UpdatesTable(self.client, self.updates_table_fqn) + updates_table.create() + update_records = _create_test_update_records() + expanded = UpdateRecordExpander.expand_updates(update_records) + updates_table.insert(expanded) + dedup_fqn = f"{self.updates_table_fqn}_dedup" + updates_table.deduplicate_to(dedup_fqn) + + query = f"SELECT * FROM `{table_fqn}` ORDER BY diaSourceId" + before = {r.diaSourceId: r for r in self.client.query(query).result()} + merger = DiaSourceUpdatesMerger(self.client) + merger.merge(updates_table_fqn=dedup_fqn, target_dataset_fqn=self.target_dataset_fqn) + after = {r.diaSourceId: r for r in self.client.query(query).result()} + + self.assertEqual(after[100001].diaObjectId, 400001) + self.assertEqual(after[100002].ssObjectId, 2001) + self.assertEqual(after[100002].ssObjectReassocTimeMjdTai, 59580.0) + self.assertEqual(after[100003].timeWithdrawnMjdTai, 59580.0) + self.assertEqual(after[100004].diaObjectId, before[100004].diaObjectId) + self.assertEqual(after[100004].ssObjectId, before[100004].ssObjectId) + self.assertEqual(after[100004].timeWithdrawnMjdTai, before[100004].timeWithdrawnMjdTai) + + def test_merge_diaforcedsource(self): + schema = [ + bigquery.SchemaField("diaObjectId", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("visit", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("detector", "INTEGER", mode="REQUIRED"), + bigquery.SchemaField("timeWithdrawnMjdTai", "FLOAT", mode="NULLABLE"), + ] + table_fqn = f"{self.target_dataset_fqn}.DiaForcedSource" + table = bigquery.Table(table_fqn, schema=schema) + self.client.create_table(table) + rows = [ + {"diaObjectId": 200001, "visit": 12345, "detector": 42, "timeWithdrawnMjdTai": None}, + {"diaObjectId": 200001, "visit": 12346, "detector": 42, "timeWithdrawnMjdTai": None}, + ] + job = self.client.load_table_from_file( + json_rows_to_buf(rows), + table_fqn, + job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON), + ) + job.result() + + updates_table = UpdatesTable(self.client, self.updates_table_fqn) + updates_table.create() + update_records = _create_test_update_records() + expanded = UpdateRecordExpander.expand_updates(update_records) + updates_table.insert(expanded) + dedup_fqn = f"{self.updates_table_fqn}_dedup" + updates_table.deduplicate_to(dedup_fqn) + + query = f"SELECT * FROM `{table_fqn}` ORDER BY diaObjectId, visit, detector" + before = {(r.diaObjectId, r.visit, r.detector): r for r in self.client.query(query).result()} + merger = DiaForcedSourceUpdatesMerger(self.client) + merger.merge(updates_table_fqn=dedup_fqn, target_dataset_fqn=self.target_dataset_fqn) + after = {(r.diaObjectId, r.visit, r.detector): r for r in self.client.query(query).result()} + + self.assertEqual(after[(200001, 12345, 42)].timeWithdrawnMjdTai, 59580.0) + self.assertEqual( + after[(200001, 12346, 42)].timeWithdrawnMjdTai, + before[(200001, 12346, 42)].timeWithdrawnMjdTai, + ) + + def test_merge_no_updates(self): + self._create_target_table() + updates_table = UpdatesTable(self.client, self.updates_table_fqn) + updates_table.create() + dedup_fqn = f"{self.updates_table_fqn}_dedup" + updates_table.deduplicate_to(dedup_fqn) + table_fqn = f"{self.target_dataset_fqn}.DiaObject" + before = {r.diaObjectId: r for r in self.client.query(f"SELECT * FROM `{table_fqn}`").result()} + merger = DiaObjectUpdatesMerger(self.client) + merger.merge(updates_table_fqn=dedup_fqn, target_dataset_fqn=self.target_dataset_fqn) + after = {r.diaObjectId: r for r in self.client.query(f"SELECT * FROM `{table_fqn}`").result()} + for obj_id in before: + self.assertEqual(before[obj_id].validityEndMjdTai, after[obj_id].validityEndMjdTai) + self.assertEqual(before[obj_id].nDiaSources, after[obj_id].nDiaSources) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_updates_table.py b/tests/test_updates_table.py new file mode 100644 index 00000000..07fc1202 --- /dev/null +++ b/tests/test_updates_table.py @@ -0,0 +1,202 @@ +# This file is part of dax_ppdb. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest +import uuid + +from google.cloud import bigquery + +from lsst.dax.ppdb.bigquery.updates import UpdateRecordExpander, UpdatesTable +from lsst.dax.ppdb.tests._bigquery import have_valid_google_credentials +from lsst.dax.ppdb.tests._updates import _create_test_update_records + + +@unittest.skipIf(not have_valid_google_credentials(), "Missing valid Google credentials") +class TestUpdatesTable(unittest.TestCase): + """Test UpdatesTable functionality.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + # Create BigQuery client + self.client = bigquery.Client() + + # Create unique dataset name for this test run + self.dataset_id = f"test_updates_{uuid.uuid4().hex[:8]}" + self.project_id = self.client.project + self.table_name = "updates" + self.table_fqn = f"{self.project_id}.{self.dataset_id}.{self.table_name}" + + # Create the test dataset + dataset = bigquery.Dataset(f"{self.project_id}.{self.dataset_id}") + # Set a short expiration for cleanup safety (1 hour) + dataset.default_table_expiration_ms = 3600000 # 1 hour + self.dataset = self.client.create_dataset(dataset) + + # Create UpdatesTable instance + self.updates_table = UpdatesTable(self.client, self.table_fqn) + + def tearDown(self) -> None: + """Clean up test fixtures.""" + # Always clean up the test dataset, whether test passed or failed + try: + self.client.delete_dataset(self.dataset_id, delete_contents=True, not_found_ok=True) + except Exception: + # If deletion fails, at least the expiration will clean it up + pass + + def test_table_fqn_property(self) -> None: + """Test the table_fqn property.""" + self.assertEqual(self.updates_table.table_fqn, self.table_fqn) + + def test_create_table(self) -> None: + """Test creating the updates table.""" + table = self.updates_table.create() + + # Verify table was created successfully + self.assertEqual(table.table_id, self.table_name) + self.assertEqual(table.dataset_id, self.dataset_id) + + # Verify schema is correct + expected_fields = { + "table_name": ("STRING", "REQUIRED"), + "record_id": ("INTEGER", "REPEATED"), + "record_id_hash": ("STRING", "REQUIRED"), + "field_name": ("STRING", "REQUIRED"), + "value_json": ("JSON", "REQUIRED"), + "replica_chunk_id": ("INTEGER", "REQUIRED"), + "update_order": ("INTEGER", "NULLABLE"), + "update_time_ns": ("INTEGER", "NULLABLE"), + } + + actual_fields = {field.name: (field.field_type, field.mode) for field in table.schema} + self.assertEqual(actual_fields, expected_fields) + + def test_create_table_already_exists(self) -> None: + """Test creating a table that already exists raises an error.""" + # Create table first time - should succeed + self.updates_table.create() + + # Try to create again - should raise Conflict + with self.assertRaises(Exception) as cm: + self.updates_table.create() + + # Check that it's a conflict-type error + self.assertIn("already exists", str(cm.exception).lower()) + + def test_insert_records(self) -> None: + """Test insertion of expanded records into the table.""" + # Create the table first + self.updates_table.create() + + # Get test update records and expand them + update_records = _create_test_update_records() + expanded_records = UpdateRecordExpander.expand_updates(update_records) + + # Insert the records + job = self.updates_table.insert(expanded_records) + + # Verify the job completed successfully + self.assertIsNone(job.errors) + + # Verify records were inserted by querying the table + query = f"SELECT COUNT(*) as count FROM `{self.table_fqn}`" + result = list(self.client.query(query).result()) + record_count = result[0].count + + # Should have 10 total expanded records based on the test data + # (1 + 2 + 1 + 1 + 2 + 1 from original records + 2 duplicates) + self.assertEqual(record_count, 10) + + # Verify some specific data was inserted correctly + query = f""" + SELECT table_name, record_id, field_name, replica_chunk_id + FROM `{self.table_fqn}` + """ + # WHERE table_name = 'DiaForcedSource' + results = list(self.client.query(query).result()) + + print(results) # Debug print to see what was inserted + # Should have one DiaForcedSource record + # self.assertEqual(len(results), 1) + # row = results[0] + # self.assertEqual(row.table_name, "DiaForcedSource") + # self.assertEqual(row.record_id, [200001, 12345, 42]) + # self.assertEqual(row.field_name, "timeWithdrawnMjdTai") + # self.assertEqual(row.replica_chunk_id, self.replica_chunk_id) + + def test_insert_empty_records(self) -> None: + """Test insertion of empty record list.""" + # Create the table first + self.updates_table.create() + + # Insert empty list + job = self.updates_table.insert([]) + + # Verify the job completed successfully + self.assertIsNone(job.errors) + + # Verify no records were inserted + query = f"SELECT COUNT(*) as count FROM `{self.table_fqn}`" + result = list(self.client.query(query).result()) + record_count = result[0].count + self.assertEqual(record_count, 0) + + def test_deduplicate_records(self) -> None: + """Test deduplication functionality.""" + # Create the source table + self.updates_table.create() + + # Get test records (which now include duplicates) and expand them + update_records = _create_test_update_records() + expanded_records = UpdateRecordExpander.expand_updates(update_records) + + # Insert all records (including duplicates) + self.updates_table.insert(expanded_records) + + # Count original records + query = f"SELECT COUNT(*) as count FROM `{self.table_fqn}`" + original_count = list(self.client.query(query).result())[0].count + + # Create deduplicated table + dedup_table_fqn = f"{self.table_fqn}_dedup" + self.updates_table.deduplicate_to(dedup_table_fqn) + + # Count deduplicated records + query = f"SELECT COUNT(*) as count FROM `{dedup_table_fqn}`" + dedup_count = list(self.client.query(query).result())[0].count + + # Should have fewer records after deduplication + self.assertLess(dedup_count, original_count) + + # Verify specific deduplication behavior + record_id_hash = UpdatesTable._compute_record_id_hash([100001]) + query = f""" + SELECT value_json + FROM `{dedup_table_fqn}` + WHERE record_id_hash = '{record_id_hash}' AND field_name = 'diaObjectId' + """ + result = list(self.client.query(query).result()) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].value_json, 400001) # Should be the later update + + +if __name__ == "__main__": + unittest.main()