diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index 3828057a..75f5dbcc 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -36,12 +36,17 @@ jobs:
mamba install -y -q pip wheel
pip install uv
+ - name: Install Postgres for testing
+ shell: bash -l {0}
+ run: |
+ mamba install -y -q postgresql
+
- name: Install dependencies
shell: bash -l {0}
run: |
uv pip install -r requirements.txt
+ uv pip install testing.postgresql
- # We have two cores so we can speed up the testing with xdist
- name: Install pytest packages
shell: bash -l {0}
run: |
diff --git a/.gitignore b/.gitignore
index a83e3ca4..1487de3e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -27,3 +27,6 @@ pytest_session.txt
# VS Code
.vscode
+
+# Scratch directory
+.scratch
diff --git a/docker/Dockerfile.replication b/docker/Dockerfile.replication
index 12e8b3a5..a8e218f1 100644
--- a/docker/Dockerfile.replication
+++ b/docker/Dockerfile.replication
@@ -3,11 +3,14 @@ FROM python:3.12-slim-bookworm
ENV DEBIAN_FRONTEND=noninteractive
# Update and install OS dependencies
-RUN apt-get -y update && \
- apt-get -y upgrade && \
- apt-get -y install --no-install-recommends git && \
- apt-get clean && \
- rm -rf /var/lib/apt/lists/*
+RUN apt-get update \
+ && apt-get install -y --no-install-recommends \
+ build-essential \
+ python3-dev \
+ pkg-config \
+ git \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
# Install required python build dependencies
RUN pip install --upgrade --no-cache-dir pip setuptools wheel uv
diff --git a/pyproject.toml b/pyproject.toml
index 5f082778..de2ff161 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,10 +23,12 @@ classifiers = [
keywords = ["lsst"]
dependencies = [
"astropy",
+ "google-cloud-bigquery",
"pyarrow",
"pydantic >=2,<3",
"pyyaml >= 5.1",
"sqlalchemy",
+ "lsst-dax-ppdbx-gcp",
"lsst-felis",
"lsst-sdm-schemas",
"lsst-utils",
@@ -43,9 +45,6 @@ test = [
"pytest >= 3.2",
"pytest-openfiles >= 0.5.0"
]
-gcp = [
- "lsst-dax-ppdbx-gcp"
-]
[tool.setuptools.packages.find]
where = ["python"]
@@ -54,7 +53,7 @@ where = ["python"]
zip-safe = true
[tool.setuptools.package-data]
-"lsst.dax.ppdb" = ["py.typed"]
+"lsst.dax.ppdb" = ["py.typed", "config/schemas/*.yaml", "config/sql/*.sql"]
[tool.setuptools.dynamic]
version = { attr = "lsst_versions.get_lsst_version" }
diff --git a/python/lsst/dax/ppdb/__init__.py b/python/lsst/dax/ppdb/__init__.py
index d8aeb139..2f4dab94 100644
--- a/python/lsst/dax/ppdb/__init__.py
+++ b/python/lsst/dax/ppdb/__init__.py
@@ -19,7 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-from .config import *
+from .ppdb_config import *
from .ppdb import *
from .replicator import *
from .version import * # Generated by sconsUtils
diff --git a/python/lsst/dax/ppdb/_factory.py b/python/lsst/dax/ppdb/_factory.py
index aee2ee52..c3774778 100644
--- a/python/lsst/dax/ppdb/_factory.py
+++ b/python/lsst/dax/ppdb/_factory.py
@@ -26,8 +26,8 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from .config import PpdbConfig
from .ppdb import Ppdb
+ from .ppdb_config import PpdbConfig
def config_type_for_name(type_name: str) -> type[PpdbConfig]:
diff --git a/python/lsst/dax/ppdb/bigquery/__init__.py b/python/lsst/dax/ppdb/bigquery/__init__.py
index 19d8c17d..e7b3071b 100644
--- a/python/lsst/dax/ppdb/bigquery/__init__.py
+++ b/python/lsst/dax/ppdb/bigquery/__init__.py
@@ -20,5 +20,6 @@
# along with this program. If not, see .
from .manifest import Manifest
+from .chunk_uploader import ChunkUploader
from .ppdb_bigquery import PpdbBigQuery, PpdbBigQueryConfig
from .ppdb_replica_chunk_extended import ChunkStatus, PpdbReplicaChunkExtended
diff --git a/python/lsst/dax/ppdb/bigquery/chunk_uploader.py b/python/lsst/dax/ppdb/bigquery/chunk_uploader.py
index 9efe0e22..727bcd09 100644
--- a/python/lsst/dax/ppdb/bigquery/chunk_uploader.py
+++ b/python/lsst/dax/ppdb/bigquery/chunk_uploader.py
@@ -42,7 +42,7 @@
) from e
-from ..config import PpdbConfig
+from ..ppdb_config import PpdbConfig
from .manifest import Manifest
from .ppdb_bigquery import PpdbBigQuery, PpdbBigQueryConfig
from .ppdb_replica_chunk_extended import ChunkStatus, PpdbReplicaChunkExtended
@@ -237,15 +237,27 @@ def _process_chunk(self, replica_chunk: PpdbReplicaChunkExtended) -> None:
)
# Make a list of local parquet files to upload.
- parquet_files = list(chunk_dir.glob("*.parquet"))
+ upload_file_list = list(chunk_dir.glob("*.parquet"))
+
+ # Include the update records file if the manifest indicates it should
+ # exist
+ if manifest.includes_update_records:
+ update_records_file = chunk_dir / "update_records.json"
+ if not update_records_file.exists():
+ raise ChunkUploadError(
+ chunk_id,
+ f"Manifest indicates update records are included but file does not exist: "
+ f"{update_records_file}",
+ )
+ upload_file_list.append(update_records_file)
# Check if the chunk is expected to be empty.
is_empty = manifest.is_empty_chunk()
- if not parquet_files and not is_empty:
+ if not upload_file_list and not is_empty:
# There is a mismatch between the manifest and the actual files.
# Some processing error may have occurred when exporting.
- raise ChunkUploadError(chunk_id, f"No parquet files found in {chunk_dir} for non-empty chunk")
+ raise ChunkUploadError(chunk_id, f"No files found to upload in {chunk_dir} for non-empty chunk")
# Check that all expected parquet files from the manifest are present.
for table_name, table_stats in manifest.table_data.items():
@@ -258,24 +270,21 @@ def _process_chunk(self, replica_chunk: PpdbReplicaChunkExtended) -> None:
)
try:
- # 1) Upload parquet files, which will happen only for non-empty
- # chunks.
- if parquet_files:
- gcs_names = {path: posixpath.join(gcs_prefix, path.name) for path in parquet_files}
+ # 1) Upload the files to GCS for non-empty chunks
+ if upload_file_list:
+ gcs_names = {path: posixpath.join(gcs_prefix, path.name) for path in upload_file_list}
try:
- _LOG.info(
- "Uploading %d parquet files to GCS under prefix: %s", len(gcs_names), gcs_prefix
- )
+ _LOG.info("Uploading %d files to GCS under prefix: %s", len(gcs_names), gcs_prefix)
with Timer(
"upload_files_time", _MON, tags={"prefix": str(gcs_prefix), "chunk_id": str(chunk_id)}
) as timer:
self.storage.upload_files(gcs_names)
- total_bytes = sum(p.stat().st_size for p in parquet_files)
+ total_bytes = sum(p.stat().st_size for p in upload_file_list)
timer.add_values(file_count=len(gcs_names), total_bytes=total_bytes)
except* UploadError as eg:
raise ChunkUploadError(chunk_id, f"{len(eg.exceptions)} upload(s) failed") from eg
- # 2) Upload manifest, even for empty chunks.
+ # 2) Upload manifest, even for empty chunks
try:
self.storage.upload_from_string(
posixpath.join(gcs_prefix, replica_chunk.manifest_name),
@@ -284,22 +293,29 @@ def _process_chunk(self, replica_chunk: PpdbReplicaChunkExtended) -> None:
except UploadError as e:
raise ChunkUploadError(chunk_id, "Manifest upload failed") from e
- # 3) Update DB status, but not for empty chunks.
+ # Next two steps are inapplicable to empty chunks.
if not is_empty:
+ # 3) Update status and GCS URI in the database
+ gcs_uri = posixpath.join(self.bucket_name, gcs_prefix)
+ updated_replica_chunk = replica_chunk.with_new_status(ChunkStatus.UPLOADED).with_new_gcs_uri(
+ f"gs://{gcs_uri}"
+ )
try:
- self._bq.store_chunk(replica_chunk.with_new_status(ChunkStatus.UPLOADED), True)
+ self._bq.store_chunk(updated_replica_chunk, True)
+ _LOG.info(
+ "Updated replica chunk %d in database with status 'uploaded' and GCS URI: %s",
+ chunk_id,
+ gcs_uri,
+ )
except Exception as e:
- raise ChunkUploadError(
- chunk_id, "failed to update replica chunk status in database"
- ) from e
+ raise ChunkUploadError(chunk_id, "Failed to update replica chunk in database") from e
- # 4) Publish Pub/Sub staging message to trigger BigQuery load, but
- # not for empty chunks. (Empty chunks cannot be staged.)
- if not is_empty:
+ # 4) Publish Pub/Sub event to trigger staging of the chunk in
+ # BigQuery
try:
- self._post_to_stage_chunk_topic(self.bucket_name, gcs_prefix, chunk_id)
+ self._post_to_stage_chunk_topic(gcs_uri, chunk_id)
except Exception as e:
- raise ChunkUploadError(chunk_id, "failed to publish staging message") from e
+ raise ChunkUploadError(chunk_id, "Failed to publish staging message") from e
except ChunkUploadError as err:
try:
@@ -310,17 +326,14 @@ def _process_chunk(self, replica_chunk: PpdbReplicaChunkExtended) -> None:
except DeleteError as cleanup_err:
# Note (Python 3.11+): annotate without masking the
# original error.
- err.add_note(
- f"cleanup warning: failed to delete "
- f"gs://{posixpath.join(self.bucket_name, gcs_prefix)}: {cleanup_err}"
- )
+ err.add_note(f"cleanup warning: failed to delete gs://{gcs_uri}: {cleanup_err}")
raise
- def _post_to_stage_chunk_topic(self, bucket_name: str, chunk_prefix: str, chunk_id: int) -> None:
+ def _post_to_stage_chunk_topic(self, gcs_uri: str, chunk_id: int) -> None:
message = {
"dataset": self.dataset_id,
"chunk_id": str(chunk_id),
- "folder": f"gs://{posixpath.join(bucket_name, chunk_prefix)}",
+ "folder": f"gs://{gcs_uri}",
}
self.publisher.publish(message).result(timeout=60)
diff --git a/python/lsst/dax/ppdb/bigquery/manifest.py b/python/lsst/dax/ppdb/bigquery/manifest.py
index b53c6f5a..da0fa456 100644
--- a/python/lsst/dax/ppdb/bigquery/manifest.py
+++ b/python/lsst/dax/ppdb/bigquery/manifest.py
@@ -79,6 +79,10 @@ class Manifest(BaseModel):
"""Name of the compression format used for artifacts (e.g., "gzip",
"zstd", "snappy", etc.)."""
+ includes_update_records: bool = False
+ """Whether the exported data includes update records (e.g., in a separate
+ file) or not (`bool`)."""
+
@property
def filename(self) -> str:
"""Generate the filename for this manifest based on the replica chunk
@@ -118,12 +122,15 @@ def from_json_file(cls, file_path: Path) -> Manifest:
def is_empty_chunk(self) -> bool:
"""Check if the manifest represents an empty replica chunk in which
- all tables have zero rows.
+ all tables have zero rows and no update records are included.
Returns
-------
bool
- `True` if all tables have zero rows, indicating an empty chunk,
- `False` otherwise.
+ `True` if all tables have zero rows and no update records are
+ included, indicating an empty chunk, `False` otherwise.
"""
- return all(table.row_count == 0 for table in self.table_data.values())
+ return (
+ all(table.row_count == 0 for table in self.table_data.values())
+ and not self.includes_update_records
+ )
diff --git a/python/lsst/dax/ppdb/bigquery/ppdb_bigquery.py b/python/lsst/dax/ppdb/bigquery/ppdb_bigquery.py
index 877bb229..61cdc416 100644
--- a/python/lsst/dax/ppdb/bigquery/ppdb_bigquery.py
+++ b/python/lsst/dax/ppdb/bigquery/ppdb_bigquery.py
@@ -19,14 +19,19 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+from __future__ import annotations
+
import datetime
import logging
+import os
import shutil
from collections.abc import Collection, Iterable, Sequence
from pathlib import Path
+from typing import Any
import felis
import sqlalchemy
+from google.cloud import secretmanager
from lsst.dax.apdb import (
ApdbMetadata,
@@ -41,11 +46,14 @@
from lsst.dax.apdb.timer import Timer
from .._arrow import write_parquet
-from ..config import PpdbConfig
from ..ppdb import Ppdb, PpdbReplicaChunk
-from ..sql import PpdbSqlBase, PpdbSqlBaseConfig
+from ..ppdb_config import PpdbConfig
+from ..sql import PasswordProvider, PpdbSqlBase, PpdbSqlBaseConfig
from .manifest import Manifest, TableStats
from .ppdb_replica_chunk_extended import ChunkStatus, PpdbReplicaChunkExtended
+from .query_runner import QueryRunner
+from .sql_resource import SqlResource
+from .updates.update_records import UpdateRecords
__all__ = ["ConfigValidationError", "PpdbBigQuery", "PpdbBigQueryConfig"]
@@ -113,6 +121,30 @@ def fq_dataset_id(self) -> str:
return f"{self.project_id}:{self.dataset_id}"
+class _SecretManagerPasswordProvider(PasswordProvider):
+ """Retrieves a database password from Google Cloud Secret Manager.
+
+ Parameters
+ ----------
+ project_id : `str`
+ GCP project that owns the secret.
+ secret_name : `str`, optional
+ Name of the secret. Defaults to ``"ppdb-db-password"``.
+ """
+
+ def __init__(self, project_id: str, secret_name: str = "ppdb-db-password") -> None:
+ self._project_id = project_id
+ self._secret_name = secret_name
+
+ def get_password(self) -> str:
+ """Return the password fetched from Secret Manager."""
+ client = secretmanager.SecretManagerServiceClient()
+ name = f"projects/{self._project_id}/secrets/{self._secret_name}/versions/latest"
+ _LOG.info("Retrieving database password from Secret Manager: %s", name)
+ response = client.access_secret_version(request={"name": name})
+ return response.payload.data.decode("UTF-8")
+
+
class ConfigValidationError(Exception):
"""Indicates an error validating the configuration."""
@@ -127,30 +159,73 @@ class PpdbBigQuery(Ppdb, PpdbSqlBase):
"""
def __init__(self, config: PpdbBigQueryConfig):
- # Initialize the SQL interface for the PPDB.
- PpdbSqlBase.__init__(self, config.sql)
+ # Build an optional password provider for GCP Secret Manager
+ password_provider: PasswordProvider | None = None
+ if os.getenv("PPDB_USE_SECRET_MANAGER", "false").lower() == "true":
+ _LOG.info("Using Secret Manager to retrieve database password")
+ password_provider = _SecretManagerPasswordProvider(config.project_id)
- # Read parameters from config.
+ # Delegate SQL initialisation (schema load, engine, metadata, version
+ # checks) to the base class, passing the optional password provider
+ PpdbSqlBase.__init__(self, config.sql, password_provider=password_provider)
+
+ # Read parameters from config
if config.replication_dir is None:
raise ValueError("Directory for chunk export is not set in configuration.")
self.replication_path = config.replication_path
self.parq_batch_size = config.parq_batch_size
self.parq_compression = config.parq_compression
self.delete_existing_dirs = config.delete_existing_dirs
+ self.project_id = config.project_id
+
+ self._config = config
+
+ self._query_runner: QueryRunner | None = None
@property
def metadata(self) -> ApdbMetadata:
- """Implement `Ppdb` interface to return APDB metadata object.
+ """APDB metadata object from `Ppdb` interface (`ApdbMetadata`)."""
+ return self._metadata
+
+ @property
+ def config(self) -> PpdbBigQueryConfig:
+ """PPDB config associated with this instance."""
+ return self._config
+
+ @property
+ def query_runner(self) -> QueryRunner:
+ """Query runner for executing SQL in BigQuery
+ (`~lsst.dax.ppdb.bigquery.QueryRunner`).
+ """
+ if not self._query_runner:
+ self._query_runner = QueryRunner(self.config.project_id, self.config.dataset_id)
+ return self._query_runner
+
+ @classmethod
+ def from_env(cls) -> PpdbBigQuery:
+ """Create an instance of this class from a config pointed to by an
+ environment variable.
Returns
-------
- metadata : `ApdbMetadata`
- APDB metadata object.
+ ppdb: `PpdbBigQuery`
+ An instance of the PPDB BigQuery interface.
"""
- return self._metadata
+ ppdb_config_uri = os.environ.get("PPDB_CONFIG_URI", None)
+ if ppdb_config_uri:
+ logging.info("PPDB_CONFIG_URI: %s", ppdb_config_uri)
+ else:
+ raise OSError("PPDB_CONFIG_URI is not set in the environment")
+ ppdb = Ppdb.from_uri(ppdb_config_uri)
+ if not isinstance(ppdb, PpdbBigQuery):
+ raise ValueError(f"Ppdb from environment has wrong type: {type(ppdb)}")
+ return ppdb
def _generate_manifest(
- self, replica_chunk: ReplicaChunk, table_dict: dict[str, ApdbTableData]
+ self,
+ replica_chunk: ReplicaChunk,
+ table_dict: dict[str, ApdbTableData],
+ update_records: Collection[ApdbUpdateRecord],
) -> Manifest:
"""Generate the manifest data for the replica chunk."""
return Manifest(
@@ -163,6 +238,7 @@ def _generate_manifest(
table_name: TableStats(row_count=len(data.rows())) for table_name, data in table_dict.items()
},
compression_format=self.parq_compression,
+ includes_update_records=bool(update_records),
)
def store(
@@ -178,22 +254,11 @@ def store(
# Docstring is inherited.
_LOG.info("Processing %s", replica_chunk.id)
- # TODO: APDB does not generate ApdbUpdateRecords yet, but we will
- # eventually have to add support for it.
- if update_records:
- raise NotImplementedError("PpdbBigQuery does not support record updates yet.")
-
try:
- chunk_dir = self._get_chunk_path(replica_chunk)
-
- if chunk_dir.exists():
- if not self.delete_existing_dirs:
- raise FileExistsError(f"Directory already exists for {replica_chunk.id}: {chunk_dir}")
- _LOG.warning("Overwriting existing directory for %s: %s", replica_chunk.id, chunk_dir)
- shutil.rmtree(chunk_dir)
+ chunk_dir = self._create_chunk_dir(replica_chunk)
- chunk_dir.mkdir(parents=True)
- _LOG.info("Created directory for %s: %s", replica_chunk.id, chunk_dir)
+ if update_records:
+ self._handle_updates(replica_chunk, update_records, chunk_dir)
table_dict = {
ApdbTables.DiaObject.value: objects,
@@ -227,7 +292,7 @@ def store(
# Create manifest for the replica chunk.
try:
- manifest = self._generate_manifest(replica_chunk, table_dict)
+ manifest = self._generate_manifest(replica_chunk, table_dict, update_records)
_LOG.info("Generated manifest for %s: %s", replica_chunk.id, manifest.model_dump_json())
except Exception:
_LOG.exception("Failed to generate manifest for %d", replica_chunk.id)
@@ -261,15 +326,32 @@ def store(
_LOG.info("Done processing %s", replica_chunk.id)
- def _get_chunk_path(self, chunk: ReplicaChunk) -> Path:
+ def _create_chunk_dir(self, chunk: ReplicaChunk) -> Path:
+ """Create the directory for the replica chunk based on its last update
+ time and ID.
+
+ Returns
+ -------
+ chunk_dir
+ Path to the created directory for the replica chunk.
+ """
last_update_time = chunk.last_update_time.to_datetime()
assert isinstance(last_update_time, datetime.datetime)
- path = Path(
+ chunk_dir = Path(
self.replication_path,
chunk.last_update_time.strftime("%Y/%m/%d"),
str(chunk.id),
)
- return path
+ if chunk_dir.exists():
+ if not self.delete_existing_dirs:
+ raise FileExistsError(f"Directory already exists for {chunk.id}: {chunk_dir}")
+ _LOG.warning("Overwriting existing directory for %s: %s", chunk.id, chunk_dir)
+ shutil.rmtree(chunk_dir)
+
+ chunk_dir.mkdir(parents=True)
+ _LOG.info("Created directory for %s: %s", chunk.id, chunk_dir)
+
+ return chunk_dir
def get_replica_chunks(self, start_chunk_id: int | None = None) -> Sequence[PpdbReplicaChunk] | None:
# Docstring is inherited.
@@ -306,6 +388,7 @@ def get_replica_chunks_ext(
table.columns["replica_time"],
table.columns["status"], # Extended column
table.columns["directory"], # Extended column
+ table.columns["gcs_uri"], # Extended column
).order_by(table.columns["last_update_time"])
if start_chunk_id is not None:
query = query.where(table.columns["apdb_replica_chunk"] >= start_chunk_id)
@@ -325,10 +408,61 @@ def get_replica_chunks_ext(
replica_time=replica_time,
status=row[4],
directory=Path(row[5]),
+ gcs_uri=row[6],
)
)
return ids
+ def get_replica_chunks_ext_by_ids(self, chunk_ids: Sequence[int]) -> Sequence[PpdbReplicaChunkExtended]:
+ """Find replica chunks for a list of chunk IDs.
+
+ Parameters
+ ----------
+ chunk_ids : `~collections.abc.Sequence` [ `int` ]
+ Replica chunk IDs to retrieve.
+
+ Returns
+ -------
+ chunks : `~collections.abc.Sequence` [ `PpdbReplicaChunkExtended` ]
+ List of matching chunks ordered by ``last_update_time``.
+ """
+ if not chunk_ids:
+ return []
+
+ table = self.get_table("PpdbReplicaChunk")
+ query = (
+ sqlalchemy.sql.select(
+ table.columns["apdb_replica_chunk"],
+ table.columns["last_update_time"],
+ table.columns["unique_id"],
+ table.columns["replica_time"],
+ table.columns["status"],
+ table.columns["directory"],
+ table.columns["gcs_uri"],
+ )
+ .where(table.columns["apdb_replica_chunk"].in_(chunk_ids))
+ .order_by(table.columns["apdb_replica_chunk"])
+ )
+
+ chunks: list[PpdbReplicaChunkExtended] = []
+ with self._engine.connect() as conn:
+ result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
+ for row in result:
+ last_update_time = self.to_astropy_tai(row[1])
+ replica_time = self.to_astropy_tai(row[3])
+ chunks.append(
+ PpdbReplicaChunkExtended(
+ id=row[0],
+ last_update_time=last_update_time,
+ unique_id=row[2],
+ replica_time=replica_time,
+ status=row[4],
+ directory=Path(row[5]),
+ gcs_uri=row[6],
+ )
+ )
+ return chunks
+
def store_chunk(self, replica_chunk: PpdbReplicaChunkExtended, update: bool) -> None:
"""Insert or replace single record in PpdbReplicaChunk table, including
the status and directory of the replica chunk.
@@ -352,6 +486,7 @@ def store_chunk(self, replica_chunk: PpdbReplicaChunkExtended, update: bool) ->
"replica_time": replica_chunk.replica_time_dt_utc,
"status": replica_chunk.status,
"directory": str(replica_chunk.directory),
+ "gcs_uri": replica_chunk.gcs_uri,
}
if update:
self.upsert(connection, table, row, "apdb_replica_chunk")
@@ -389,6 +524,12 @@ def create_replica_chunk_table(cls, table_name: str | None = None) -> schema_mod
datatype=felis.datamodel.DataType.string,
nullable=True, # We might want to allow NULL if an error occurs when exporting.
),
+ schema_model.Column(
+ name="gcs_uri",
+ id=f"#{table_name}.gcs_uri",
+ datatype=felis.datamodel.DataType.string,
+ nullable=True,
+ ),
]
)
return replica_chunk_table
@@ -466,7 +607,6 @@ def init_bigquery(
sql_config = PpdbSqlBaseConfig(
db_url=db_url, schema_name=db_schema, felis_path=felis_path, felis_schema=felis_schema
)
- cls.make_database(sql_config, sa_metadata, schema_version, db_drop)
# Build config parameters.
bq_config = PpdbBigQueryConfig(
@@ -485,6 +625,13 @@ def init_bigquery(
if stage_chunk_topic is not None:
bq_config.stage_chunk_topic = stage_chunk_topic
+ password_provider: PasswordProvider | None = None
+ if os.getenv("PPDB_USE_SECRET_MANAGER", "false").lower() == "true":
+ _LOG.info("Using Secret Manager to retrieve database password")
+ password_provider = _SecretManagerPasswordProvider(bq_config.project_id)
+ engine = cls.make_engine(bq_config.sql, password_provider=password_provider)
+ cls.make_database(engine, bq_config.sql, sa_metadata, schema_version, db_drop)
+
# Validate the config if requested.
if validate_config:
_LOG.info("validating BigQuery configuration")
@@ -567,3 +714,139 @@ def validate_config(cls, config: PpdbBigQueryConfig) -> None:
check_dataset_exists(config.project_id, config.dataset_id)
except Exception as e:
raise ConfigValidationError("Failed to validate BigQuery dataset") from e
+
+ def _handle_updates(
+ self, replica_chunk: ReplicaChunk, apdb_update_records: Collection[ApdbUpdateRecord], chunk_dir: Path
+ ) -> None:
+ """Handle updates to existing records in the PPDB by writing a JSON
+ file with the update information for the replica chunk.
+
+ Parameters
+ ----------
+ replica_chunk : `ReplicaChunk`
+ The replica chunk associated with the updates.
+ update_records : `~collections.abc.Collection` [ `ApdbUpdateRecord` ]
+ Collection of update records to process.
+
+ Notes
+ -----
+ Serializes the ApdbUpdateRecord objects into a dictionary structure
+ for processing.
+ """
+ update_records = UpdateRecords(
+ replica_chunk_id=replica_chunk.id,
+ records=list(apdb_update_records),
+ record_count=len(apdb_update_records),
+ )
+ update_records.write_json_file(chunk_dir / "update_records.json")
+
+ _LOG.info(
+ "Saved %d update records for %s to %s",
+ update_records.record_count,
+ replica_chunk.id,
+ chunk_dir / "update_records.json",
+ )
+
+ def get_promotable_chunks(self) -> list[int]:
+ """
+ Return the first uninterrupted sequence of staged chunks such that all
+ prior chunks are promoted.
+
+ Returns
+ -------
+ chunk_ids : `list`[`int`]
+ A list of tuples containing the ``apdb_replica_chunk`` values of
+ the promotable chunks.
+
+ Notes
+ -----
+ This query finds the contiguous sequence of ``staged`` chunks beginning
+ with the earliest chunk that is not yet ``promoted``, and ending just
+ before the first chunk that is not ``staged``. If no such ending
+ exists, all ``staged`` chunks from that point onward are returned. If
+ no chunks are ``staged`` after the first non-``promoted`` chunk, an
+ empty list is returned.
+ """
+ table = self.get_table("PpdbReplicaChunk")
+ if not table.schema:
+ raise ValueError("Table schema is not set, cannot construct query")
+ quoted_table_name = (
+ self._engine.dialect.identifier_preparer.quote(table.schema)
+ + "."
+ + self._engine.dialect.identifier_preparer.quote(table.name)
+ )
+
+ sql = SqlResource("select_promotable_chunks", {"table_name": quoted_table_name}).sql
+
+ with self._engine.connect() as conn:
+ result = conn.execute(sqlalchemy.text(sql))
+ chunk_ids = [row[0] for row in result]
+ return chunk_ids
+
+ def mark_chunks_promoted(self, promotable_chunks: list[int]) -> int:
+ """Set status='promoted' for the given chunk IDs. Returns number
+ updated.
+
+ Parameters
+ ----------
+ promotable_chunks : `list`[`int`]
+ List of integers containing the ``apdb_replica_chunk`` values of
+ the promotable chunks.
+
+ Returns
+ -------
+ count: `int`
+ The number of rows updated in the database, which should be equal
+ to the number of promotable chunks provided, if they were all found
+ and updated successfully.
+ """
+ table = self.get_table("PpdbReplicaChunk")
+ stmt = (
+ sqlalchemy.update(table)
+ .where(table.c.apdb_replica_chunk.in_(promotable_chunks), table.c.status != "promoted")
+ .values(status="promoted")
+ )
+
+ with self._engine.begin() as conn:
+ result: sqlalchemy.engine.CursorResult = conn.execute(stmt)
+ return result.rowcount or 0
+
+ def update(self, chunk_id: int, values: dict[str, Any]) -> int:
+ """Update an existing replica chunk in the database.
+
+ Parameters
+ ----------
+ chunk_id : `int`
+ The ID of the replica chunk to update.
+ values : `dict`[`str`, `Any`]
+ A dictionary of column names and their new values to update.
+
+ Returns
+ -------
+ count : `int`
+ The number of rows updated. This should be 1 if the update is
+ successful, or 0 if no rows were updated (e.g., if the chunk ID
+ does not exist or the status is already set to the new value).
+ """
+ logging.info("Preparing to update replica chunk %d with values: %s", chunk_id, values)
+ table = self.get_table("PpdbReplicaChunk")
+ stmt = sqlalchemy.update(table).where(table.c.apdb_replica_chunk == chunk_id).values(values)
+ with self._engine.begin() as conn:
+ result = conn.execute(stmt)
+ affected_rows = result.rowcount
+
+ new_status = values.get("status")
+ if affected_rows == 0:
+ logging.warning(
+ "No rows updated for replica chunk %s with status '%s'",
+ chunk_id,
+ new_status,
+ )
+ else:
+ logging.info(
+ "Successfully updated %d row(s) for replica chunk %s to status '%s'",
+ affected_rows,
+ chunk_id,
+ new_status,
+ )
+ return affected_rows
diff --git a/python/lsst/dax/ppdb/bigquery/ppdb_replica_chunk_extended.py b/python/lsst/dax/ppdb/bigquery/ppdb_replica_chunk_extended.py
index bd8d6422..7649c5c5 100644
--- a/python/lsst/dax/ppdb/bigquery/ppdb_replica_chunk_extended.py
+++ b/python/lsst/dax/ppdb/bigquery/ppdb_replica_chunk_extended.py
@@ -43,6 +43,10 @@ class ChunkStatus(StrEnum):
"""Chunk has been exported from the APDB to a local parquet file."""
UPLOADED = "uploaded"
"""Chunk has been uploaded to cloud storage."""
+ STAGED = "staged"
+ """Chunk data has been copied into the staging tables."""
+ PROMOTED = "promoted"
+ """Chunk data has been promoted from the staging to production tables."""
FAILED = "failed"
"""Chunk processing failed and an error occurred."""
SKIPPED = "skipped"
@@ -59,6 +63,10 @@ class PpdbReplicaChunkExtended(PpdbReplicaChunk):
directory: Path
"""Directory where the exported replica chunk data is stored locally."""
+ gcs_uri: str | None = None
+ """GCS URI where the replica chunk data is stored, or `None` if not
+ uploaded yet."""
+
@property
def manifest_name(self) -> str:
"""Filename of the manifest file for this chunk."""
@@ -127,3 +135,19 @@ def with_new_status(self, new_status: ChunkStatus) -> PpdbReplicaChunkExtended:
The new chunk with the updated status.
"""
return dataclasses.replace(self, status=new_status)
+
+ def with_new_gcs_uri(self, new_gcs_uri: str) -> PpdbReplicaChunkExtended:
+ """Create a new `PpdbReplicaChunkExtended` with the same properties as
+ this one, but with a different GCS URI.
+
+ Parameters
+ ----------
+ new_gcs_uri : `str`
+ The new GCS URI to set.
+
+ Returns
+ -------
+ new_chunk : `PpdbReplicaChunkExtended`
+ The new chunk with the updated GCS URI.
+ """
+ return dataclasses.replace(self, gcs_uri=new_gcs_uri)
diff --git a/python/lsst/dax/ppdb/bigquery/query_runner.py b/python/lsst/dax/ppdb/bigquery/query_runner.py
new file mode 100644
index 00000000..e9901658
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/query_runner.py
@@ -0,0 +1,139 @@
+# This file is part of dax_ppdbx_gcp
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = [
+ "QueryRunner",
+]
+
+import logging
+
+from google.cloud import bigquery
+
+
+class QueryRunner:
+ """Class to run BigQuery queries with logging.
+
+ Parameters
+ ----------
+ project_id : `str`
+ Google Cloud project ID.
+ dataset_id : `str`
+ BigQuery dataset ID.
+ """
+
+ def __init__(self, project_id: str, dataset_id: str):
+ self._project_id = project_id
+ self._dataset_id = dataset_id
+ self._bq_client = bigquery.Client(project=project_id)
+ self._dataset = self._bq_client.get_dataset(f"{project_id}.{dataset_id}")
+ self._location = self._dataset.location
+
+ @property
+ def project_id(self) -> str:
+ """Google Cloud project ID (`str`, read-only)."""
+ return self._project_id
+
+ @property
+ def dataset(self) -> bigquery.Dataset:
+ """Dataset reference (`bigquery.Dataset`, read-only)."""
+ return self._dataset
+
+ @property
+ def dataset_id(self) -> str:
+ """Dataset ID (`str`, read-only)."""
+ return self._dataset_id
+
+ @property
+ def location(self) -> str:
+ """Dataset location, typically the region where it is hosted (`str`,
+ read-only).
+ """
+ return self._location
+
+ @classmethod
+ def log_job(
+ cls,
+ job: bigquery.job.QueryJob
+ | bigquery.job.LoadJob
+ | bigquery.job.CopyJob
+ | bigquery.job.ExtractJob
+ | bigquery.job.UnknownJob,
+ label: str,
+ level: int = logging.DEBUG,
+ ) -> None:
+ """Log details of a BigQuery job.
+
+ Parameters
+ ----------
+ job : `bigquery.job.QueryJob`
+ The BigQuery job to log.
+ label : `str`
+ A label for the job, typically indicating the type of operation
+ (e.g., "insert", "delete", "copy").
+ level : `int`, optional
+ The logging level to use for the log message. Defaults to
+ `logging.DEBUG`.
+ """
+ logging.log(
+ level,
+ "BQ %s: job_id=%s location=%s state=%s bytes_processed=%s bytes_billed=%s slot_millis=%s "
+ "dml_rows=%s reference_tables=%s",
+ label,
+ job.job_id,
+ job.location,
+ job.state,
+ getattr(job, "total_bytes_processed", None),
+ getattr(job, "total_bytes_billed", None),
+ getattr(job, "slot_millis", None),
+ getattr(job, "num_dml_affected_rows", None),
+ getattr(job, "referenced_tables", None),
+ )
+
+ def run_job(
+ self, label: str, sql: str, job_config: bigquery.QueryJobConfig | None = None
+ ) -> bigquery.job.QueryJob:
+ """Run a BigQuery job with the given SQL and configuration.
+
+ Parameters
+ ----------
+ label : `str`
+ A label for the job, typically indicating the type of operation
+ (e.g., "insert", "delete", "copy").
+ sql : `str`
+ The SQL query to execute.
+ job_config : `bigquery.QueryJobConfig`, optional
+ Configuration for the job, such as query parameters or write
+ dispositions. If not provided, a default configuration will be
+ used.
+
+ Returns
+ -------
+ job: `bigquery.job.QueryJob`
+ The BigQuery job object representing the executed query. This can
+ be used to check the status of the job, retrieve results, or log
+ additional details.
+ """
+ job = self._bq_client.query(sql, job_config=job_config, location=self.dataset.location)
+ job.result() # Wait for the job to complete
+ self.log_job(job, label)
+ return job
diff --git a/python/lsst/dax/ppdb/bigquery/replica_chunk_promoter.py b/python/lsst/dax/ppdb/bigquery/replica_chunk_promoter.py
new file mode 100644
index 00000000..424b8a27
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/replica_chunk_promoter.py
@@ -0,0 +1,242 @@
+# This file is part of dax_ppdbx_gcp
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = [
+ "NoPromotableChunksError",
+ "ReplicaChunkPromoter",
+]
+
+import logging
+
+from google.api_core.exceptions import NotFound
+from google.cloud import bigquery
+
+from .ppdb_bigquery import PpdbBigQuery
+from .query_runner import QueryRunner
+from .updates.updates_manager import UpdatesManager
+
+
+class NoPromotableChunksError(Exception):
+ """Exception raised when there are no promotable chunks available."""
+
+ pass
+
+
+class ReplicaChunkPromoter:
+ """Class to promote replica chunks in BigQuery.
+
+ Parameters
+ ----------
+ ppdb : `PpdbBigQuery`
+ Interface to the PPDB in BigQuery.
+ table_names : `list`[`str`], optional
+ List of table names to promote or if None a default list will be used.
+ """
+
+ def __init__(
+ self,
+ ppdb: PpdbBigQuery,
+ table_names: list[str] | None = None,
+ ):
+ self._ppdb = ppdb
+ self._project_id = self._ppdb._config.project_id
+ self._dataset_id = self._ppdb._config.dataset_id
+ self._runner = ppdb.query_runner
+ # DM-52326: Hard-coded table names; these should be passed in from
+ # config.
+ self._table_names = table_names or ["DiaObject", "DiaSource", "DiaForcedSource"]
+ self._bq_client = bigquery.Client(project=self._runner.project_id)
+ self._phases = {
+ "get_promotable_chunks": self._get_promotable_chunks,
+ "build_tmp": self._copy_to_promoted_tmp,
+ "apply_record_updates": self._apply_record_updates,
+ "promote_prod": self._promote_tmp_to_prod,
+ "delete_staged_chunks": self._delete_staged_chunks,
+ "mark_promoted": self._mark_chunks_promoted,
+ }
+
+ self._promotable_chunks: list[int] = []
+
+ @property
+ def promotable_chunks(self) -> list[int]:
+ """List of promotable chunks (`list` [ `int` ], read-only)."""
+ return self._promotable_chunks
+
+ @promotable_chunks.setter
+ def promotable_chunks(self, chunks: list[int]) -> None:
+ if not chunks:
+ raise NoPromotableChunksError("No promotable chunks provided")
+ self._promotable_chunks = chunks
+
+ @property
+ def table_prod_refs(self) -> list[str]:
+ """Fully-qualified production table references (`list`[`str`],
+ read-only).
+ """
+ return [f"{self._project_id}.{self._dataset_id}.{table_name}" for table_name in self._table_names]
+
+ @property
+ def table_staging_refs(self) -> list[str]:
+ """Fully-qualified staging table references (`list`[`str`],
+ read-only).
+ """
+ return [
+ f"{self._project_id}.{self._dataset_id}._{table_name}_staging" for table_name in self._table_names
+ ]
+
+ @property
+ def table_promoted_tmp_refs(self) -> list[str]:
+ """Fully-qualified promoted temporary table references (`list`[`str`],
+ read-only).
+ """
+ return [
+ f"{self._project_id}.{self._dataset_id}._{table_name}_promoted_tmp"
+ for table_name in self._table_names
+ ]
+
+ def _execute_phase(self, phase: str) -> None:
+ """Execute a specific promotion phase.
+
+ Parameters
+ ----------
+ phase : `str`
+ The name of the promotion phase to execute. This should be one of
+ the keys in the `phases` property.
+ """
+ if phase not in self._phases:
+ raise ValueError(f"Unknown promotion phase: {phase}")
+ logging.debug("Executing promotion phase: %s", phase)
+ self._phases[phase]()
+
+ def _get_promotable_chunks(self) -> None:
+ """Get list of promotable chunks from the database."""
+ self._promotable_chunks = self._ppdb.get_promotable_chunks()
+ logging.info("Promotable chunk count: %s", len(self.promotable_chunks))
+
+ def _copy_to_promoted_tmp(self) -> None:
+ """
+ Build ``_{table_name}_promoted_tmp`` efficiently by cloning prod and
+ inserting only staged rows for the given replica chunk IDs.
+ """
+ job_cfg = bigquery.QueryJobConfig(
+ query_parameters=[bigquery.ArrayQueryParameter("ids", "INT64", self.promotable_chunks)]
+ )
+
+ for prod_ref, tmp_ref, stage_ref in zip(
+ self.table_prod_refs, self.table_promoted_tmp_refs, self.table_staging_refs, strict=False
+ ):
+ # Drop any existing tmp table (should not exist but just to be
+ # safe)
+ self._runner.run_job("drop_tmp", f"DROP TABLE IF EXISTS `{tmp_ref}`")
+
+ # Clone prod table structure and data (zero-copy)
+ self._runner.run_job("clone_prod", f"CREATE TABLE `{tmp_ref}` CLONE `{prod_ref}`")
+
+ # Build ordered target list from the cloned tmp schema
+ tmp_schema = self._bq_client.get_table(tmp_ref).schema
+ target_names = [f.name for f in tmp_schema if f.name != "apdb_replica_chunk"]
+ target_list_sql = ", ".join(f"`{n}`" for n in target_names)
+
+ # Build source list, handling geo_point conversion
+ source_list_sql = ", ".join(
+ "ST_GEOGPOINT(s.`ra`, s.`dec`)" if n == "geo_point" else f"s.`{n}`" for n in target_names
+ )
+
+ # Insert staged rows into tmp, excluding apdb_replica_chunk column
+ sql = f"""
+ INSERT INTO `{tmp_ref}` ({target_list_sql})
+ SELECT {source_list_sql}
+ FROM `{stage_ref}` AS s
+ WHERE s.apdb_replica_chunk IN UNNEST(@ids)
+ """
+ logging.debug("SQL for inserting staged rows into %s: %s", tmp_ref, sql)
+ self._runner.run_job("insert_staged_to_tmp", sql, job_config=job_cfg)
+
+ def _promote_tmp_to_prod(self) -> None:
+ """
+ Swap each prod table with its corresponding *_promoted_tmp by replacing
+ prod contents in a single atomic copy job. This preserves schema,
+ partitioning, and clustering with zero-copy when in the same dataset.
+ """
+ for prod_ref, tmp_ref in zip(self.table_prod_refs, self.table_promoted_tmp_refs, strict=False):
+ # Ensure tmp exists
+ try:
+ self._bq_client.get_table(tmp_ref)
+ except NotFound as e:
+ raise RuntimeError(f"Missing tmp table for promotion: {tmp_ref}") from e
+
+ # Atomic zero-copy replacement of prod with tmp
+ copy_cfg = bigquery.CopyJobConfig(write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE)
+ job = self._bq_client.copy_table(
+ tmp_ref, prod_ref, job_config=copy_cfg, location=self._runner.location
+ )
+ job.result()
+ QueryRunner.log_job(job, "promote_tmp_to_prod")
+
+ def _cleanup(self) -> None:
+ """Drop the promotion temporary tables."""
+ for tmp_ref in self.table_promoted_tmp_refs:
+ self._bq_client.delete_table(tmp_ref, not_found_ok=True)
+ logging.debug("Dropped %s (if it existed)", tmp_ref)
+
+ def _delete_staged_chunks(self) -> None:
+ """Delete only rows for the promoted replica chunk IDs from each
+ staging table.
+ """
+ job_config = bigquery.QueryJobConfig(
+ query_parameters=[bigquery.ArrayQueryParameter("ids", "INT64", self.promotable_chunks)]
+ )
+
+ for staging_ref in self.table_staging_refs:
+ try:
+ sql = f"DELETE FROM `{staging_ref}` WHERE apdb_replica_chunk IN UNNEST(@ids)"
+ self._runner.run_job("delete_staged_chunks", sql, job_config=job_config)
+ logging.debug(
+ "Deleted %d chunk(s) from staging table %s", len(self.promotable_chunks), staging_ref
+ )
+ except NotFound:
+ logging.warning("Staging table %s does not exist, skipping delete", staging_ref)
+
+ def _apply_record_updates(self) -> None:
+ """Apply record updates to the promoted temporary tables."""
+ updates_manager = UpdatesManager(self._ppdb, table_name_postfix="_promoted_tmp")
+ updates_manager.apply_updates(self.promotable_chunks)
+
+ def _mark_chunks_promoted(self) -> None:
+ """Mark the replica chunks as promoted in the database."""
+ self._ppdb.mark_chunks_promoted(self.promotable_chunks)
+
+ def promote_chunks(self) -> None:
+ """Promote APDB replica chunks into production by executing a series of
+ phases.
+ """
+ try:
+ for phase in self._phases.keys():
+ self._execute_phase(phase)
+ finally:
+ try:
+ # Cleanup is always executed separately, not as an ordered
+ # phase.
+ self._cleanup()
+ except Exception:
+ logging.exception("Cleanup of chunk promotion failed")
diff --git a/python/lsst/dax/ppdb/bigquery/sql_resource.py b/python/lsst/dax/ppdb/bigquery/sql_resource.py
new file mode 100644
index 00000000..99d4f89d
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/sql_resource.py
@@ -0,0 +1,56 @@
+# This file is part of dax_ppdb
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from lsst.resources import ResourcePath
+
+
+class SqlResource:
+ """Class for loading SQL query text from a resource file and optionally
+ formatting it with provided arguments.
+
+ Parameters
+ ----------
+ sql_resource_name : `str`
+ Base name of the SQL file (without .sql extension) containing the
+ query.
+ The SQL file must be located in the `lsst.dax.ppdb.config.sql` package.
+ format_args : `dict` [ `str`, `str` ], optional
+ Optional dictionary of arguments for formatting the SQL text.
+ """
+
+ def __init__(self, sql_resource_name: str, format_args: dict[str, str] | None = None) -> None:
+ # FIXME: Move the config dir into a resources dir (similar to obs_lsst)
+ sql_resource_path = f"resource://lsst.dax.ppdb/config/sql/{sql_resource_name}.sql"
+ sql = ResourcePath(sql_resource_path).read().decode("utf-8")
+ if format_args is not None:
+ try:
+ sql = sql.format(**format_args)
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to format SQL resource at {sql_resource_path} with arguments {format_args}"
+ ) from e
+ self._sql = sql
+
+ @property
+ def sql(self) -> str:
+ """SQL query string (`str`)."""
+ return self._sql
diff --git a/python/lsst/dax/ppdb/bigquery/updates/__init__.py b/python/lsst/dax/ppdb/bigquery/updates/__init__.py
new file mode 100644
index 00000000..1536958d
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/updates/__init__.py
@@ -0,0 +1,31 @@
+# This file is part of dax_ppdb
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from .expanded_update_record import ExpandedUpdateRecord
+from .updates_merger import (
+ UpdatesMerger,
+ DiaObjectUpdatesMerger,
+ DiaSourceUpdatesMerger,
+ DiaForcedSourceUpdatesMerger,
+)
+from .update_records import UpdateRecords
+from .update_record_expander import UpdateRecordExpander
+from .updates_table import UpdatesTable
diff --git a/python/lsst/dax/ppdb/bigquery/updates/expanded_update_record.py b/python/lsst/dax/ppdb/bigquery/updates/expanded_update_record.py
new file mode 100644
index 00000000..d59c9785
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/updates/expanded_update_record.py
@@ -0,0 +1,80 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+
+class ExpandedUpdateRecord(BaseModel):
+ """
+ A single normalized (expanded) update row.
+
+ This model represents one field-level update after expanding an
+ original logical update event into one row per updated field.
+ It is the canonical shape loaded into the BigQuery updates table.
+ """
+
+ table_name: str = Field(
+ ...,
+ min_length=1,
+ description=("Logical target table for the update (e.g., 'DiaObject', 'DiaSource')."),
+ )
+
+ record_id: list[int] = Field(
+ ...,
+ description=(
+ "Identifier of the record being updated. For update types with a single record ID, this "
+ "will be a list of one element. For updates on records with a composite key "
+ "(e.g., DiaForcedSource), this will include all components of the key, in order."
+ ),
+ )
+
+ field_name: str = Field(
+ ...,
+ min_length=1,
+ description=("Name of the target column being updated."),
+ )
+
+ value_json: Any = Field(
+ ...,
+ description=("JSON-serializable new value for the field."),
+ )
+
+ replica_chunk_id: int = Field(
+ ...,
+ ge=0,
+ description=("Source replica chunk identifier associated with this update."),
+ )
+
+ update_order: int | None = Field(
+ default=None,
+ ge=0,
+ description=("Ordering value within the replica chunk or update batch."),
+ )
+
+ update_time_ns: int | None = Field(
+ default=None,
+ ge=0,
+ description=("Source event timestamp in nanoseconds since the epoch."),
+ )
diff --git a/python/lsst/dax/ppdb/bigquery/updates/update_record_expander.py b/python/lsst/dax/ppdb/bigquery/updates/update_record_expander.py
new file mode 100644
index 00000000..dc7b57ad
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/updates/update_record_expander.py
@@ -0,0 +1,243 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import hashlib
+import logging
+
+from lsst.dax.apdb.apdbUpdateRecord import ApdbUpdateRecord
+
+from .expanded_update_record import ExpandedUpdateRecord
+from .update_records import UpdateRecords
+
+
+class UpdateRecordExpander:
+ """Expand APDB update records into individual field-level updates for
+ BigQuery.
+ """
+
+ _UPDATE_FIELD_MAPPING = {
+ "reassign_diasource_to_diaobject": ["diaObjectId"],
+ "reassign_diasource_to_ssobject": ["ssObjectId", "ssObjectReassocTimeMjdTai"],
+ "withdraw_diasource": ["timeWithdrawnMjdTai"],
+ "withdraw_diaforcedsource": ["timeWithdrawnMjdTai"],
+ "close_diaobject_validity": ["validityEndMjdTai", "nDiaSources"],
+ "update_n_dia_sources": ["nDiaSources"],
+ }
+
+ _RECORD_ID_FIELD_MAPPING = {
+ "reassign_diasource_to_diaobject": ["diaSourceId"],
+ "reassign_diasource_to_ssobject": ["diaSourceId"],
+ "withdraw_diasource": ["diaSourceId"],
+ "withdraw_diaforcedsource": ["diaObjectId", "visit", "detector"],
+ "close_diaobject_validity": ["diaObjectId"],
+ "update_n_dia_sources": ["diaObjectId"],
+ }
+
+ @classmethod
+ def get_update_fields(cls, update_type: str) -> list[str]:
+ """Get the names of fields to update for a given update type.
+
+ Parameters
+ ----------
+ update_type : `str`
+ The type of update record.
+
+ Returns
+ -------
+ field_names : `list` [ `str` ]
+ List of field names that should be updated for this update type.
+
+ Raises
+ ------
+ ValueError
+ If the update_type is not recognized.
+ """
+ if update_type not in cls._UPDATE_FIELD_MAPPING:
+ raise ValueError(f"Unknown update_type: {update_type}")
+
+ return cls._UPDATE_FIELD_MAPPING[update_type]
+
+ @classmethod
+ def get_record_id_fields(cls, update_type: str) -> str | list[str]:
+ """Get the field name(s) that serve as the record ID for a given update
+ type.
+
+ Parameters
+ ----------
+ update_type : `str`
+ The type of update record.
+
+ Returns
+ -------
+ field_name : `str` or `list` [ `str` ]
+ Name of the field that contains the record ID for this update type,
+ or list of field names for composite keys.
+
+ Raises
+ ------
+ ValueError
+ If the update_type is not recognized.
+ """
+ if update_type not in cls._RECORD_ID_FIELD_MAPPING:
+ raise ValueError(f"Unknown update_type: {update_type}")
+
+ return cls._RECORD_ID_FIELD_MAPPING[update_type]
+
+ @classmethod
+ def _compute_record_id_hash(cls, record_id: list[int]) -> str:
+ """Compute MD5 hash of a record_id list for deduplication.
+
+ Parameters
+ ----------
+ record_id : list[int]
+ The record ID as a list of integers.
+
+ Returns
+ -------
+ str
+ Full 64-character hexadecimal MD5 hash of the record_id list.
+ """
+ record_id_str = ",".join(str(x) for x in record_id)
+ return hashlib.md5(record_id_str.encode()).hexdigest()
+
+ @classmethod
+ def get_record_id_field(cls, update_type: str) -> str | list[str]:
+ """Get the field name(s) that serve as the record ID for a given update
+ type.
+
+ Parameters
+ ----------
+ update_type : `str`
+ The type of update record.
+
+ Returns
+ -------
+ field_name : `list` [ `str` ]
+ List of the fields that contain the record ID for this update type.
+
+ Raises
+ ------
+ ValueError
+ If the update_type is not recognized.
+ """
+ return cls.get_record_id_fields(update_type)
+
+ @classmethod
+ def expand_single_record(
+ cls, update_record: ApdbUpdateRecord, replica_chunk_id: int
+ ) -> list[ExpandedUpdateRecord]:
+ """Expand a single APDB update record into ExpandedUpdateRecord
+ objects.
+
+ Parameters
+ ----------
+ update_record : `ApdbUpdateRecord`
+ A single APDB update record to expand.
+ replica_chunk_id : `int`
+ The replica chunk ID associated with this update record.
+
+ Returns
+ -------
+ expanded_records : `list` [ `ExpandedUpdateRecord` ]
+ List of ExpandedUpdateRecord objects, one per field being updated.
+ """
+ update_type = update_record.update_type
+ field_names = cls.get_update_fields(update_type)
+
+ # Get the target table from the update record
+ table_name = update_record.apdb_table.name
+
+ # Get the record ID
+ record_id = cls._get_record_id(update_record)
+
+ expanded_records = []
+ for field_name in field_names:
+ if not hasattr(update_record, field_name):
+ raise ValueError(
+ f"Update record of type {update_type} is missing expected field {field_name}"
+ )
+
+ value = getattr(update_record, field_name)
+
+ expanded_record = ExpandedUpdateRecord(
+ table_name=table_name,
+ record_id=record_id,
+ field_name=field_name,
+ value_json=value,
+ replica_chunk_id=replica_chunk_id,
+ update_order=update_record.update_order,
+ update_time_ns=update_record.update_time_ns,
+ )
+ expanded_records.append(expanded_record)
+
+ return expanded_records
+
+ @classmethod
+ def _get_record_id(cls, update_record: ApdbUpdateRecord) -> list[int]:
+ """Generate a record ID from an update record.
+
+ Parameters
+ ----------
+ update_record : `ApdbUpdateRecord`
+ The update record to generate an ID for.
+
+ Returns
+ -------
+ record_id : `list` [ `int` ]
+ The record ID as a list of integers. For simple keys, a
+ single-element list. For composite keys, a multi-element list.
+ """
+ update_type = update_record.update_type
+ id_fields = cls.get_record_id_fields(update_type)
+
+ record_id = []
+ for field in id_fields:
+ if not hasattr(update_record, field):
+ raise ValueError(f"Update record of type {update_type} is missing expected ID field {field}")
+ record_id.append(int(getattr(update_record, field)))
+ return record_id
+
+ @classmethod
+ def expand_updates(cls, update_records: UpdateRecords) -> list[ExpandedUpdateRecord]:
+ """Expand the APDB update records into a list of individual updates.
+
+ Parameters
+ ----------
+ update_records : `UpdateRecords`
+ The APDB update records to expand.
+
+ Returns
+ -------
+ expanded_updates : `list` [ `ExpandedUpdateRecord` ]
+ A list of individual updates derived from the input update records.
+ """
+ expanded_updates = []
+
+ for update_record in update_records.records:
+ expanded_records = cls.expand_single_record(update_record, update_records.replica_chunk_id)
+ expanded_updates.extend(expanded_records)
+
+ # DEBUG: Print number of expanded update records that were generated
+ logging.info("Created %d expanded update records", len(expanded_updates))
+
+ return expanded_updates
diff --git a/python/lsst/dax/ppdb/bigquery/updates/update_records.py b/python/lsst/dax/ppdb/bigquery/updates/update_records.py
new file mode 100644
index 00000000..a3107c85
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/updates/update_records.py
@@ -0,0 +1,126 @@
+# This file is part of dax_ppdb
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Any, cast
+
+from pydantic import BaseModel, field_serializer, field_validator
+
+from lsst.dax.apdb.apdbUpdateRecord import ApdbUpdateRecord
+
+DEFAULT_FILENAME = "update_records.json"
+"""Default filename for the update records JSON file."""
+
+
+class UpdateRecords(BaseModel):
+ """Data model for APDB update records."""
+
+ replica_chunk_id: int
+ """Identifier of the replica chunk to which these update records belong."""
+
+ record_count: int
+ """Number of update records included in this object."""
+
+ records: list[ApdbUpdateRecord]
+ """List of APDB update records included in this object."""
+
+ @field_serializer("records")
+ def serialize_records(
+ self,
+ records: list[ApdbUpdateRecord],
+ ) -> list[dict[str, Any]]:
+ """Serialize the ``ApdbUpdateRecord`` objects to JSON.
+
+ Parameters
+ ----------
+ records : `list` [ `ApdbUpdateRecord` ]
+ The list of APDB update records to serialize.
+
+ Returns
+ -------
+ serialized_records : `list` [ `dict` [ `str`, `Any` ]]
+ The serialized JSON data.
+ """
+ serialized_records: list[dict[str, Any]] = []
+ for update_record in records:
+ record_dict: dict[str, Any] = json.loads(update_record.to_json())
+ record_dict["update_time_ns"] = update_record.update_time_ns
+ record_dict["update_order"] = update_record.update_order
+ serialized_records.append(record_dict)
+ return serialized_records
+
+ @field_validator("records", mode="before")
+ @classmethod
+ def deserialize_records(
+ cls,
+ records: list[dict[str, Any]] | list[ApdbUpdateRecord],
+ ) -> list[ApdbUpdateRecord]:
+ """Deserialize the JSON data to ``ApdbUpdateRecord`` objects.
+
+ Parameters
+ ----------
+ records : `list` [ `dict` [ `str`, `Any` ] | `ApdbUpdateRecord` ]
+ The list of serialized JSON data or already deserialized
+ ApdbUpdateRecord objects.
+
+ Returns
+ -------
+ update_records : `list` [ `ApdbUpdateRecord` ]
+ The list of APDB update records.
+ """
+ if records and isinstance(records[0], ApdbUpdateRecord):
+ return cast(list[ApdbUpdateRecord], records)
+ deserialized_records: list[ApdbUpdateRecord] = []
+ for record in records:
+ if isinstance(record, dict):
+ record_copy = record.copy()
+ update_time_ns = record_copy.pop("update_time_ns")
+ update_order = record_copy.pop("update_order")
+ json_str = json.dumps(record_copy)
+ update_record = ApdbUpdateRecord.from_json(
+ update_time_ns,
+ update_order,
+ json_str,
+ )
+ deserialized_records.append(update_record)
+ elif isinstance(record, ApdbUpdateRecord):
+ deserialized_records.append(record)
+ else:
+ raise TypeError("Each record must be a dict or ApdbUpdateRecord")
+ return deserialized_records
+
+ def write_json_file(self, path: Path) -> None:
+ with open(path, "w") as f:
+ json.dump(self.model_dump(), f, indent=2, default=str)
+
+ @classmethod
+ def from_json_file(cls, path: Path) -> UpdateRecords:
+ with open(path) as f:
+ data = json.load(f)
+ return cls.model_validate(data)
+
+ @classmethod
+ def from_json_string(cls, json_str: str) -> UpdateRecords:
+ data = json.loads(json_str)
+ return cls.model_validate(data)
diff --git a/python/lsst/dax/ppdb/bigquery/updates/updates_manager.py b/python/lsst/dax/ppdb/bigquery/updates/updates_manager.py
new file mode 100644
index 00000000..98e0e66a
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/updates/updates_manager.py
@@ -0,0 +1,123 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import logging
+import posixpath
+import urllib
+from collections.abc import Sequence
+
+from google.cloud import bigquery, storage
+
+from ..ppdb_bigquery import PpdbBigQuery
+from .update_record_expander import UpdateRecordExpander
+from .update_records import DEFAULT_FILENAME, UpdateRecords
+from .updates_merger import (
+ DiaForcedSourceUpdatesMerger,
+ DiaObjectUpdatesMerger,
+ DiaSourceUpdatesMerger,
+ UpdatesMerger,
+)
+from .updates_table import UpdatesTable
+
+DEFAULT_MERGERS = (
+ DiaObjectUpdatesMerger,
+ DiaSourceUpdatesMerger,
+ DiaForcedSourceUpdatesMerger,
+)
+
+
+class UpdatesManager:
+ """Class responsible for managing the process of applying updates to the
+ PPDB database, including merging updates and inserting them into the
+ database.
+ """
+
+ def __init__(
+ self,
+ ppdb: PpdbBigQuery,
+ mergers: Sequence[type[UpdatesMerger]] = DEFAULT_MERGERS,
+ updates_table_name: str = "updates",
+ deduplicated_updates_table_name: str = "updates_deduplicated",
+ table_name_postfix: str | None = None,
+ ) -> None:
+ self._ppdb = ppdb
+ self._mergers = mergers
+ self._deduplicated_updates_table_name = deduplicated_updates_table_name
+
+ self._bq_client = bigquery.Client()
+
+ self._updates_table = UpdatesTable(
+ self._bq_client,
+ f"{self._ppdb._config.project_id}.{self._ppdb._config.dataset_id}.{updates_table_name}",
+ )
+
+ # TODO: Catch error if already exists
+ self._updates_table.create()
+
+ self._gcs_client = storage.Client()
+ self._bucket = self._gcs_client.bucket(self._ppdb._config.bucket_name)
+
+ self._table_name_postfix = table_name_postfix
+
+ def apply_updates(self, replica_chunk_ids: Sequence[int]) -> None:
+ replica_chunks = self._ppdb.get_replica_chunks_ext_by_ids(replica_chunk_ids)
+ for replica_chunk in replica_chunks:
+ if replica_chunk.gcs_uri is None:
+ raise ValueError(f"Replica chunk {replica_chunk.id} does not have a GCS URI")
+
+ # Parse the GCS URI to get the bucket name and object name
+ parsed_uri = urllib.parse.urlparse(replica_chunk.gcs_uri)
+ bucket_name = parsed_uri.netloc
+ object_name = posixpath.join(parsed_uri.path.lstrip("/"), DEFAULT_FILENAME)
+
+ # Get the blob from the bucket
+ bucket = self._gcs_client.bucket(bucket_name)
+ blob = bucket.blob(object_name)
+ content = blob.download_as_text()
+
+ # Expand the update records into the appropriate format for
+ # inserting into the updates table
+ update_records = UpdateRecords.from_json_string(content)
+ expanded_update_records = UpdateRecordExpander.expand_updates(update_records)
+ self._updates_table.insert(expanded_update_records)
+
+ # Deduplicate the update records to a new table
+ deduplicated_updates_table_fqn = (
+ f"{self._ppdb.project_id}.{self._ppdb._config.dataset_id}.{self._deduplicated_updates_table_name}"
+ )
+ self._updates_table.deduplicate_to(deduplicated_updates_table_fqn)
+
+ # Merge the deduplicated updates into the target tables
+ for merger in self._mergers:
+ merger_instance = merger(self._bq_client)
+ if self._table_name_postfix:
+ # Apply a postfix to the canonical target table name
+ merger_instance.target_table_name += f"{self._table_name_postfix}"
+ target_dataset_fqn = f"{self._ppdb._config.project_id}.{self._ppdb._config.dataset_id}"
+
+ # DEBUG: Print message to log
+ logging.info(
+ "Merging updates into `%s.%s`", target_dataset_fqn, merger_instance.target_table_name
+ )
+
+ merger_instance.merge(
+ updates_table_fqn=deduplicated_updates_table_fqn, target_dataset_fqn=target_dataset_fqn
+ )
diff --git a/python/lsst/dax/ppdb/bigquery/updates/updates_merger.py b/python/lsst/dax/ppdb/bigquery/updates/updates_merger.py
new file mode 100644
index 00000000..e74b68d6
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/updates/updates_merger.py
@@ -0,0 +1,117 @@
+# This file is part of dax_ppdb
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from abc import ABC
+
+from google.cloud import bigquery
+
+from ..sql_resource import SqlResource
+
+
+class UpdatesMerger(ABC):
+ """Abstract base class for merging expanded update records into target
+ tables in BigQuery.
+ """
+
+ TABLE_NAME: str
+ """Logical name of the target table this merger applies to
+ (e.g., 'DiaObject')."""
+
+ SQL_RESOURCE_NAME: str
+ """Base name of the SQL file (without .sql extension) containing the MERGE
+ statement for this merger. The SQL file must be located in the
+ `lsst.dax.ppdb.config.sql` package."""
+
+ def __init__(self, client: bigquery.Client, target_table_name: str | None = None) -> None:
+ """
+ Parameters
+ ----------
+ client
+ BigQuery client.
+ target_table_name
+ Optional name of the target table. If not provided, the class-level
+ TABLE_NAME will be used.
+ """
+ self._client: bigquery.Client = client
+ self._target_table_name = target_table_name or self.TABLE_NAME
+
+ @property
+ def target_table_name(self) -> str:
+ """Get the name of the target table this merger applies to."""
+ return self._target_table_name
+
+ @target_table_name.setter
+ def target_table_name(self, value: str) -> None:
+ """Set the name of the target table this merger applies to."""
+ self._target_table_name = value
+
+ def merge(self, *, updates_table_fqn: str, target_dataset_fqn: str) -> bigquery.QueryJob:
+ """
+ Apply updates from the updates table specified by `updates_table_fqn`
+ to the target table in the `target_dataset_fqn` dataset.
+
+ Parameters
+ ----------
+ updates_table_fqn
+ Fully-qualified BigQuery table name containing updates.
+ target_dataset_fqn
+ Fully-qualified BigQuery dataset name containing the target table.
+
+ Returns
+ -------
+ google.cloud.bigquery.job.QueryJob
+ The completed BigQuery job.
+ """
+ sql = SqlResource(
+ self.SQL_RESOURCE_NAME,
+ format_args={
+ "updates_table": updates_table_fqn,
+ "target_dataset": target_dataset_fqn,
+ "target_table": self.target_table_name,
+ },
+ ).sql
+ job = self._client.query(sql)
+ job.result()
+
+ return job
+
+
+class DiaObjectUpdatesMerger(UpdatesMerger):
+ """Merger for DiaObject updates."""
+
+ TABLE_NAME = "DiaObject"
+ SQL_RESOURCE_NAME = "merge_diaobject_updates"
+
+
+class DiaSourceUpdatesMerger(UpdatesMerger):
+ """Merger for DiaSource updates."""
+
+ TABLE_NAME = "DiaSource"
+ SQL_RESOURCE_NAME = "merge_diasource_updates"
+
+
+class DiaForcedSourceUpdatesMerger(UpdatesMerger):
+ """Merger for DiaForcedSource updates."""
+
+ TABLE_NAME = "DiaForcedSource"
+ SQL_RESOURCE_NAME = "merge_diaforcedsource_updates"
diff --git a/python/lsst/dax/ppdb/bigquery/updates/updates_table.py b/python/lsst/dax/ppdb/bigquery/updates/updates_table.py
new file mode 100644
index 00000000..b8054bbf
--- /dev/null
+++ b/python/lsst/dax/ppdb/bigquery/updates/updates_table.py
@@ -0,0 +1,198 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import hashlib
+from collections.abc import Iterable
+from typing import Any
+
+from google.cloud import bigquery
+
+from .expanded_update_record import ExpandedUpdateRecord
+
+
+class UpdatesTable:
+ """Manage the table in BigQuery used for inserting and deduplicating
+ expanded update records which contain one update per row.
+ """
+
+ def __init__(self, client: bigquery.Client, table_fqn: str) -> None:
+ """
+ Parameters
+ ----------
+ client
+ BigQuery client.
+ table_fqn
+ Fully-qualified table name in the form ``"project.dataset.table"``.
+ """
+ self._client: bigquery.Client = client
+ self._table_fqn: str = table_fqn
+
+ @staticmethod
+ def _compute_record_id_hash(record_id: list[int]) -> str:
+ """Compute MD5 hash of a record_id list for deduplication.
+
+ Parameters
+ ----------
+ record_id : list[int]
+ The record ID as a list of integers.
+
+ Returns
+ -------
+ str
+ Full 64-character hexadecimal MD5 hash of the record_id list.
+ """
+ record_id_str = ",".join(str(x) for x in record_id)
+ return hashlib.md5(record_id_str.encode()).hexdigest()
+
+ @property
+ def table_fqn(self) -> str:
+ """
+ Fully-qualified BigQuery table name.
+
+ Returns
+ -------
+ str
+ Table name in the form ``"project.dataset.table"``.
+ """
+ return self._table_fqn
+
+ def create(self) -> bigquery.Table:
+ """
+ Create the updates table.
+
+ Returns
+ -------
+ google.cloud.bigquery.Table
+ The created table.
+
+ Raises
+ ------
+ google.api_core.exceptions.Conflict
+ If the table already exists.
+
+ Notes
+ -----
+ Schema:
+
+ - table_name: STRING (REQUIRED)
+ - record_id: ARRAY (REQUIRED)
+ - record_id_hash: STRING (REQUIRED)
+ - field_name: STRING (REQUIRED)
+ - value_json: JSON (REQUIRED)
+ - replica_chunk_id: INT64 (REQUIRED)
+ - update_order: INT64 (NULLABLE)
+ - update_time_ns: INT64 (NULLABLE)
+ """
+ schema: list[bigquery.SchemaField] = [
+ bigquery.SchemaField("table_name", "STRING", mode="REQUIRED"),
+ bigquery.SchemaField("record_id", "INT64", mode="REPEATED"),
+ bigquery.SchemaField("record_id_hash", "STRING", mode="REQUIRED"),
+ bigquery.SchemaField("field_name", "STRING", mode="REQUIRED"),
+ bigquery.SchemaField("value_json", "JSON", mode="REQUIRED"),
+ bigquery.SchemaField("replica_chunk_id", "INT64", mode="REQUIRED"),
+ bigquery.SchemaField("update_order", "INT64", mode="NULLABLE"),
+ bigquery.SchemaField("update_time_ns", "INT64", mode="NULLABLE"),
+ ]
+
+ table = bigquery.Table(self._table_fqn, schema=schema)
+ return self._client.create_table(table)
+
+ def insert(self, records: Iterable[ExpandedUpdateRecord]) -> bigquery.LoadJob:
+ """
+ Insert `ExpandedUpdateRecord` rows into the updates table.
+
+ Parameters
+ ----------
+ records
+ Iterable of update records to insert.
+
+ Returns
+ -------
+ google.cloud.bigquery.LoadJob
+ Completed BigQuery load job.
+
+ Raises
+ ------
+ RuntimeError
+ If the BigQuery load job completes with errors.
+
+ Notes
+ -----
+ This uses a batch load via `Client.load_table_from_json` (not streaming
+ inserts). The table must already exist.
+ """
+ rows: list[dict[str, Any]] = [
+ {
+ "table_name": r.table_name,
+ "record_id": r.record_id,
+ "record_id_hash": self._compute_record_id_hash(r.record_id),
+ "field_name": r.field_name,
+ "value_json": r.value_json,
+ "replica_chunk_id": r.replica_chunk_id,
+ "update_order": r.update_order,
+ "update_time_ns": r.update_time_ns,
+ }
+ for r in records
+ ]
+
+ print("Inserting rows into BigQuery:", rows) # Debug print to verify the data being loaded
+
+ job = self._client.load_table_from_json(
+ rows,
+ self._table_fqn,
+ job_config=bigquery.LoadJobConfig(
+ write_disposition=bigquery.WriteDisposition.WRITE_APPEND,
+ ),
+ )
+ job.result()
+
+ if job.errors:
+ raise RuntimeError(f"BigQuery load failed: {job.errors}")
+
+ return job
+
+ def deduplicate_to(self, target_table_fqn: str) -> bigquery.QueryJob:
+ """
+ Deduplicate this table's records to a target table.
+
+ Keeps the record with the latest update_time_ns for each unique
+ combination of (table_name, record_id, field_name).
+ """
+ query = f"""
+ CREATE OR REPLACE TABLE `{target_table_fqn}`
+ AS
+ SELECT * EXCEPT(row_num)
+ FROM (
+ SELECT *,
+ ROW_NUMBER() OVER (
+ PARTITION BY table_name, record_id_hash, field_name
+ ORDER BY update_time_ns DESC
+ ) as row_num
+ FROM `{self._table_fqn}`
+ )
+ WHERE row_num = 1
+ """
+
+ job = self._client.query(query)
+ job.result()
+ return job
diff --git a/tests/config/schema.yaml b/python/lsst/dax/ppdb/config/schemas/test_apdb_schema.yaml
similarity index 100%
rename from tests/config/schema.yaml
rename to python/lsst/dax/ppdb/config/schemas/test_apdb_schema.yaml
diff --git a/python/lsst/dax/ppdb/config/sql/merge_diaforcedsource_updates.sql b/python/lsst/dax/ppdb/config/sql/merge_diaforcedsource_updates.sql
new file mode 100644
index 00000000..8eef46b7
--- /dev/null
+++ b/python/lsst/dax/ppdb/config/sql/merge_diaforcedsource_updates.sql
@@ -0,0 +1,28 @@
+MERGE `{target_dataset}.{target_table}` T
+USING (
+ WITH patch AS (
+ SELECT
+ record_id[OFFSET(0)] AS diaObjectId,
+ record_id[OFFSET(1)] AS visit,
+ record_id[OFFSET(2)] AS detector,
+
+ ANY_VALUE(
+ CASE WHEN field_name = 'timeWithdrawnMjdTai'
+ THEN CAST(JSON_VALUE(value_json) AS FLOAT64)
+ END
+ ) AS timeWithdrawnMjdTai_value,
+ COUNTIF(field_name = 'timeWithdrawnMjdTai') > 0 AS timeWithdrawnMjdTai_present
+
+ FROM `{updates_table}`
+ WHERE table_name = 'DiaForcedSource'
+ AND field_name IN ('timeWithdrawnMjdTai')
+ GROUP BY diaObjectId, visit, detector
+ )
+ SELECT * FROM patch
+) P
+ON T.diaObjectId = P.diaObjectId
+ AND T.visit = P.visit
+ AND T.detector = P.detector
+WHEN MATCHED THEN
+UPDATE SET
+ timeWithdrawnMjdTai = IF(P.timeWithdrawnMjdTai_present, P.timeWithdrawnMjdTai_value, T.timeWithdrawnMjdTai);
diff --git a/python/lsst/dax/ppdb/config/sql/merge_diaobject_updates.sql b/python/lsst/dax/ppdb/config/sql/merge_diaobject_updates.sql
new file mode 100644
index 00000000..9c6c1827
--- /dev/null
+++ b/python/lsst/dax/ppdb/config/sql/merge_diaobject_updates.sql
@@ -0,0 +1,32 @@
+MERGE `{target_dataset}.{target_table}` T
+USING (
+ WITH patch AS (
+ SELECT
+ record_id[OFFSET(0)] AS diaObjectId,
+
+ ANY_VALUE(
+ CASE WHEN field_name = 'validityEndMjdTai'
+ THEN CAST(JSON_VALUE(value_json) AS FLOAT64)
+ END
+ ) AS validityEndMjdTai_value,
+ COUNTIF(field_name = 'validityEndMjdTai') > 0 AS validityEndMjdTai_present,
+
+ ANY_VALUE(
+ CASE WHEN field_name = 'nDiaSources'
+ THEN CAST(JSON_VALUE(value_json) AS INT64)
+ END
+ ) AS nDiaSources_value,
+ COUNTIF(field_name = 'nDiaSources') > 0 AS nDiaSources_present
+
+ FROM `{updates_table}`
+ WHERE table_name = 'DiaObject'
+ AND field_name IN ('validityEndMjdTai', 'nDiaSources')
+ GROUP BY diaObjectId
+ )
+ SELECT * FROM patch
+) P
+ON T.diaObjectId = P.diaObjectId
+WHEN MATCHED THEN
+UPDATE SET
+ validityEndMjdTai = IF(P.validityEndMjdTai_present, P.validityEndMjdTai_value, T.validityEndMjdTai),
+ nDiaSources = IF(P.nDiaSources_present, P.nDiaSources_value, T.nDiaSources);
diff --git a/python/lsst/dax/ppdb/config/sql/merge_diasource_updates.sql b/python/lsst/dax/ppdb/config/sql/merge_diasource_updates.sql
new file mode 100644
index 00000000..5a2d5307
--- /dev/null
+++ b/python/lsst/dax/ppdb/config/sql/merge_diasource_updates.sql
@@ -0,0 +1,48 @@
+MERGE `{target_dataset}.{target_table}` T
+USING (
+ WITH patch AS (
+ SELECT
+ record_id[OFFSET(0)] AS diaSourceId,
+
+ ANY_VALUE(
+ CASE WHEN field_name = 'diaObjectId'
+ THEN CAST(JSON_VALUE(value_json) AS INT64)
+ END
+ ) AS diaObjectId_value,
+ COUNTIF(field_name = 'diaObjectId') > 0 AS diaObjectId_present,
+
+ ANY_VALUE(
+ CASE WHEN field_name = 'ssObjectId'
+ THEN CAST(JSON_VALUE(value_json) AS INT64)
+ END
+ ) AS ssObjectId_value,
+ COUNTIF(field_name = 'ssObjectId') > 0 AS ssObjectId_present,
+
+ ANY_VALUE(
+ CASE WHEN field_name = 'ssObjectReassocTimeMjdTai'
+ THEN CAST(JSON_VALUE(value_json) AS FLOAT64)
+ END
+ ) AS ssObjectReassocTimeMjdTai_value,
+ COUNTIF(field_name = 'ssObjectReassocTimeMjdTai') > 0 AS ssObjectReassocTimeMjdTai_present,
+
+ ANY_VALUE(
+ CASE WHEN field_name = 'timeWithdrawnMjdTai'
+ THEN CAST(JSON_VALUE(value_json) AS FLOAT64)
+ END
+ ) AS timeWithdrawnMjdTai_value,
+ COUNTIF(field_name = 'timeWithdrawnMjdTai') > 0 AS timeWithdrawnMjdTai_present
+
+ FROM `{updates_table}`
+ WHERE table_name = 'DiaSource'
+ AND field_name IN ('diaObjectId', 'ssObjectId', 'ssObjectReassocTimeMjdTai', 'timeWithdrawnMjdTai')
+ GROUP BY diaSourceId
+ )
+ SELECT * FROM patch
+) P
+ON T.diaSourceId = P.diaSourceId
+WHEN MATCHED THEN
+UPDATE SET
+ diaObjectId = IF(P.diaObjectId_present, P.diaObjectId_value, T.diaObjectId),
+ ssObjectId = IF(P.ssObjectId_present, P.ssObjectId_value, T.ssObjectId),
+ ssObjectReassocTimeMjdTai = IF(P.ssObjectReassocTimeMjdTai_present, P.ssObjectReassocTimeMjdTai_value, T.ssObjectReassocTimeMjdTai),
+ timeWithdrawnMjdTai = IF(P.timeWithdrawnMjdTai_present, P.timeWithdrawnMjdTai_value, T.timeWithdrawnMjdTai)
diff --git a/python/lsst/dax/ppdb/config/sql/select_promotable_chunks.sql b/python/lsst/dax/ppdb/config/sql/select_promotable_chunks.sql
new file mode 100644
index 00000000..e776a1f0
--- /dev/null
+++ b/python/lsst/dax/ppdb/config/sql/select_promotable_chunks.sql
@@ -0,0 +1,24 @@
+WITH start AS (
+SELECT MIN(apdb_replica_chunk) AS s
+FROM {table_name}
+WHERE status <> 'promoted'
+ AND status <> 'skipped'
+),
+stop AS (
+SELECT MIN(p.apdb_replica_chunk) AS e
+FROM {table_name} p
+JOIN start ON TRUE
+WHERE start.s IS NOT NULL
+ AND p.apdb_replica_chunk >= start.s
+ AND p.status <> 'staged'
+ AND status <> 'skipped'
+)
+SELECT p.apdb_replica_chunk
+FROM {table_name} p
+JOIN start ON TRUE
+LEFT JOIN stop ON TRUE
+WHERE start.s IS NOT NULL
+AND p.status = 'staged'
+AND p.apdb_replica_chunk >= start.s
+AND (stop.e IS NULL OR p.apdb_replica_chunk < stop.e)
+ORDER BY p.apdb_replica_chunk;
diff --git a/python/lsst/dax/ppdb/ppdb.py b/python/lsst/dax/ppdb/ppdb.py
index 31b6a315..e175bb7f 100644
--- a/python/lsst/dax/ppdb/ppdb.py
+++ b/python/lsst/dax/ppdb/ppdb.py
@@ -33,7 +33,7 @@
from lsst.resources import ResourcePathExpression
from ._factory import ppdb_from_config
-from .config import PpdbConfig
+from .ppdb_config import PpdbConfig
@dataclass(frozen=True)
diff --git a/python/lsst/dax/ppdb/config.py b/python/lsst/dax/ppdb/ppdb_config.py
similarity index 100%
rename from python/lsst/dax/ppdb/config.py
rename to python/lsst/dax/ppdb/ppdb_config.py
diff --git a/python/lsst/dax/ppdb/sql/__init__.py b/python/lsst/dax/ppdb/sql/__init__.py
index 92e21081..566853c3 100644
--- a/python/lsst/dax/ppdb/sql/__init__.py
+++ b/python/lsst/dax/ppdb/sql/__init__.py
@@ -20,4 +20,4 @@
# along with this program. If not, see .
from ._ppdb_sql import PpdbSql, PpdbSqlConfig
-from ._ppdb_sql_base import PpdbSqlBase, PpdbSqlBaseConfig
+from ._ppdb_sql_base import PasswordProvider, PpdbSqlBase, PpdbSqlBaseConfig
diff --git a/python/lsst/dax/ppdb/sql/_ppdb_sql.py b/python/lsst/dax/ppdb/sql/_ppdb_sql.py
index 623cf651..320f579c 100644
--- a/python/lsst/dax/ppdb/sql/_ppdb_sql.py
+++ b/python/lsst/dax/ppdb/sql/_ppdb_sql.py
@@ -552,5 +552,6 @@ def init_database(
isolation_level=isolation_level,
connection_timeout=connection_timeout,
)
- cls.make_database(config, sa_metadata, schema_version, drop)
+ engine = cls.make_engine(config)
+ cls.make_database(engine, config, sa_metadata, schema_version, drop)
return config
diff --git a/python/lsst/dax/ppdb/sql/_ppdb_sql_base.py b/python/lsst/dax/ppdb/sql/_ppdb_sql_base.py
index 836cab53..542029e5 100644
--- a/python/lsst/dax/ppdb/sql/_ppdb_sql_base.py
+++ b/python/lsst/dax/ppdb/sql/_ppdb_sql_base.py
@@ -21,11 +21,12 @@
from __future__ import annotations
-__all__ = ["PpdbSqlBase"]
+__all__ = ["PasswordProvider", "PpdbSqlBase"]
import logging
import os
import sqlite3
+from abc import ABC, abstractmethod
from collections.abc import Iterable, MutableMapping
from contextlib import closing
from typing import Any
@@ -49,6 +50,25 @@
_LOG = logging.getLogger(__name__)
+class PasswordProvider(ABC):
+ """Abstract base class for objects that supply a database password.
+
+ Implementations are free to retrieve the password from any source
+ (e.g. environment variables, a secrets manager, a local file) without
+ `PpdbSqlBase` needing to know about the mechanism.
+ """
+
+ @abstractmethod
+ def get_password(self) -> str:
+ """Return the database password.
+
+ Returns
+ -------
+ password : `str`
+ Plain-text password to embed in the database connection URL.
+ """
+
+
class MissingSchemaVersionError(RuntimeError):
"""Exception raised when schema version is not defined in the schema.
@@ -121,12 +141,12 @@ class PpdbSqlBase:
meta_schema_version_key = "version:schema"
"""Name of the metadata key to store Felis schema version number."""
- def __init__(self, config: PpdbSqlBaseConfig) -> None:
+ def __init__(self, config: PpdbSqlBaseConfig, password_provider: PasswordProvider | None = None) -> None:
self._sa_metadata, self._schema_version = self.read_schema(
config.felis_path, config.schema_name, config.felis_schema, config.db_url
)
- self._engine = self.make_engine(config)
+ self._engine = self.make_engine(config, password_provider=password_provider)
sa_metadata = sqlalchemy.MetaData(schema=config.schema_name)
meta_table = sqlalchemy.schema.Table("metadata", sa_metadata, autoload_with=self._engine)
@@ -137,14 +157,7 @@ def __init__(self, config: PpdbSqlBaseConfig) -> None:
self._check_code_version()
@classmethod
- def make_engine(cls, config: PpdbSqlBaseConfig) -> sqlalchemy.engine.Engine:
- """Make SQLALchemy engine based on configured parameters.
-
- Parameters
- ----------
- config : `PpdbSqlBaseConfig`
- Configuration object with SQL parameters.
- """
+ def _build_connect_args(cls, config: PpdbSqlBaseConfig) -> MutableMapping[str, Any]:
kw: MutableMapping[str, Any] = {}
conn_args: dict[str, Any] = {}
if not config.use_connection_pool:
@@ -159,9 +172,40 @@ def make_engine(cls, config: PpdbSqlBaseConfig) -> sqlalchemy.engine.Engine:
conn_args.update(timeout=config.connection_timeout)
elif config.db_url.startswith(("postgresql", "mysql")):
conn_args.update(connect_timeout=config.connection_timeout)
- kw = {"connect_args": conn_args}
- engine = sqlalchemy.create_engine(config.db_url, **kw)
+ return {"connect_args": conn_args}
+
+ @classmethod
+ def make_engine(
+ cls,
+ config: PpdbSqlBaseConfig,
+ *,
+ password_provider: PasswordProvider | None = None,
+ ) -> sqlalchemy.engine.Engine:
+ """Make SQLALchemy engine based on configured parameters.
+ Parameters
+ ----------
+ config : `PpdbSqlBaseConfig`
+ Configuration object with SQL parameters.
+ password_provider : `PasswordProvider`, optional
+ If provided, the password returned by
+ ``password_provider.get_password()`` is injected into the
+ database URL. The URL must not already contain a password when
+ this argument is given.
+
+ Raises
+ ------
+ ValueError
+ Raised if ``password_provider`` is given but the URL already
+ contains a password.
+ """
+ db_url = sqlalchemy.make_url(config.db_url)
+ if password_provider is not None:
+ if db_url.password is not None:
+ raise ValueError("Database URL must not contain a password when password_provider is used.")
+ db_url = db_url.set(password=password_provider.get_password())
+ kw = cls._build_connect_args(config)
+ engine = sqlalchemy.create_engine(db_url, **kw)
if engine.dialect.name == "sqlite":
# Need to enable foreign keys on every new connection.
sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect)
@@ -171,6 +215,7 @@ def make_engine(cls, config: PpdbSqlBaseConfig) -> sqlalchemy.engine.Engine:
@classmethod
def make_database(
cls,
+ engine: sqlalchemy.engine.Engine,
config: PpdbSqlBaseConfig,
sa_metadata: sqlalchemy.schema.MetaData,
schema_version: VersionTuple,
@@ -189,8 +234,6 @@ def make_database(
drop : `bool`
If `True` then drop existing tables before creating new ones.
"""
- engine = cls.make_engine(config)
-
if config.schema_name is not None:
dialect = engine.dialect
quoted_schema = dialect.preparer(dialect).quote_schema(config.schema_name)
diff --git a/python/lsst/dax/ppdb/tests/_bigquery.py b/python/lsst/dax/ppdb/tests/_bigquery.py
new file mode 100644
index 00000000..d801656d
--- /dev/null
+++ b/python/lsst/dax/ppdb/tests/_bigquery.py
@@ -0,0 +1,224 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import gc
+import io
+import json
+import shutil
+import tempfile
+import uuid
+from typing import Any
+
+import google.auth
+from google.auth.exceptions import DefaultCredentialsError
+from google.auth.transport.requests import Request
+from google.cloud import storage
+
+from lsst.dax.apdb import (
+ ApdbConfig,
+)
+from lsst.dax.apdb.sql import ApdbSql
+from lsst.dax.ppdb import PpdbConfig
+from lsst.dax.ppdb.bigquery import ChunkUploader, PpdbBigQuery
+from lsst.dax.ppdb.tests._ppdb import TEST_SCHEMA_RESOURCE_PATH
+
+try:
+ import testing.postgresql
+except ImportError:
+ testing = None
+
+
+TEST_CONFIG = {
+ "db_drop": True,
+ "validate_config": False,
+ "delete_existing_dirs": True,
+ "bucket_name": "ppdb-test",
+ "object_prefix": "data/test",
+ "dataset_id": "test_dataset",
+ "project_id": "test_project",
+}
+
+
+def json_rows_to_buf(rows: list[dict]) -> io.StringIO:
+ """Convert a list of dict rows to a newline-delimited JSON StringIO
+ buffer.
+ """
+ buf = io.StringIO()
+ for row in rows:
+ buf.write(json.dumps(row) + "\n")
+ buf.seek(0)
+ return buf
+
+
+def generate_test_bucket_name(test_prefix: str = "ppdb-test") -> str:
+ """Generate a unique bucket name for testing."""
+ test_id = uuid.uuid4().hex[:16]
+ return f"{test_prefix}-{test_id}"
+
+
+def delete_test_bucket(bucket_or_bucket_name: str | storage.Bucket) -> None:
+ """Delete a cloud storage bucket that was created for testing.
+
+ Parameters
+ ----------
+ bucket_or_bucket_name: `str` or `storage.Bucket`
+ The name of the bucket or the actual bucket to delete.
+ """
+ storage_client = storage.Client()
+ try:
+ if isinstance(bucket_or_bucket_name, str):
+ bucket = storage_client.bucket(bucket_or_bucket_name)
+ else:
+ bucket = bucket_or_bucket_name
+ blobs = list(bucket.list_blobs())
+ for blob in blobs:
+ blob.delete()
+ bucket.delete()
+ except Exception as e:
+ print(f"Failed to delete test GCS bucket: {e}")
+
+
+class ChunkUploaderWithoutPubSub(ChunkUploader):
+ """A dummy implementation of the ChunkUploader that does not actually
+ post messages to Pub/Sub.
+ """
+
+ def _post_to_stage_chunk_topic(self, gcs_uri: str, chunk_id: int) -> None:
+ message = {
+ "dataset": None,
+ "chunk_id": str(chunk_id),
+ "folder": f"gs://{gcs_uri}",
+ }
+ print(f"Dummy publish to Pub/Sub topic: {message}")
+
+
+class SqliteMixin:
+ """Mixin class to provide Sqlite-specific setup/teardown and instance
+ creation.
+ """
+
+ def setUp(self) -> None:
+ self.tempdir = tempfile.mkdtemp()
+ self.apdb_url = f"sqlite:///{self.tempdir}/apdb.sqlite3"
+ self.ppdb_url = f"sqlite:///{self.tempdir}/ppdb.sqlite3"
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.tempdir, ignore_errors=True)
+
+ def make_instance(self, **kwargs: Any) -> PpdbConfig:
+ """Make config class instance used in all tests."""
+ kw = {
+ **TEST_CONFIG,
+ "db_url": self.ppdb_url,
+ "felis_path": TEST_SCHEMA_RESOURCE_PATH,
+ "replication_dir": self.tempdir,
+ }
+ bq_config = PpdbBigQuery.init_bigquery(**kw) # type: ignore[arg-type]
+ return bq_config
+
+ def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig:
+ """Make APDB instance for tests."""
+ kw = {
+ "schema_file": TEST_SCHEMA_RESOURCE_PATH,
+ "ss_schema_file": "",
+ "db_url": self.apdb_url,
+ "enable_replica": True,
+ }
+ kw.update(kwargs)
+ return ApdbSql.init_database(**kw) # type: ignore[arg-type]
+
+
+class PostgresMixin:
+ """Mixin class to provide Postgres-specific setup/teardown and instance
+ creation.
+ """
+
+ postgresql: Any
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ # Create the postgres test server.
+ cls.postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True)
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ # Clean up any lingering SQLAlchemy engines/connections
+ # so they're closed before we shut down the server.
+ gc.collect()
+ cls.postgresql.clear_cache()
+
+ def setUp(self) -> None:
+ self.server = self.postgresql()
+ self.tempdir = tempfile.mkdtemp()
+
+ def tearDown(self) -> None:
+ self.server = self.postgresql()
+ shutil.rmtree(self.tempdir, ignore_errors=True)
+
+ def make_instance(self, config_dict: dict[str, Any] = TEST_CONFIG, **kwargs: Any) -> PpdbConfig:
+ """Make config class instance used in all tests."""
+ kw = {
+ **config_dict,
+ "db_url": self.server.url(),
+ "db_schema": "ppdb_test",
+ "felis_path": TEST_SCHEMA_RESOURCE_PATH,
+ "replication_dir": self.tempdir,
+ }
+ bq_config = PpdbBigQuery.init_bigquery(**kw)
+ return bq_config
+
+ def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig:
+ kw = {
+ "schema_file": TEST_SCHEMA_RESOURCE_PATH,
+ "ss_schema_file": "",
+ "db_url": self.server.url(),
+ "namespace": "apdb",
+ "enable_replica": True,
+ }
+ kw.update(kwargs)
+ return ApdbSql.init_database(**kw)
+
+
+def have_valid_google_credentials() -> bool:
+ """Check that valid Google credentials are available for testing.
+
+ Returns
+ -------
+ credentials_valid: `bool`
+ True if valid Google credentials are available, False if not.
+
+ Raises
+ ------
+ google.auth.exceptions.RefreshError
+ Raised if the credentials cannot be refreshed.
+ Exception
+ Raised for other transport or configuration failures.
+ """
+ try:
+ credentials, _ = google.auth.default()
+ except DefaultCredentialsError:
+ return False
+
+ # This will validate the default credentials that were found in the
+ # environment.
+ credentials.refresh(Request())
+
+ return True
diff --git a/python/lsst/dax/ppdb/tests/_ppdb.py b/python/lsst/dax/ppdb/tests/_ppdb.py
index 06a84a37..790d4d76 100644
--- a/python/lsst/dax/ppdb/tests/_ppdb.py
+++ b/python/lsst/dax/ppdb/tests/_ppdb.py
@@ -21,7 +21,7 @@
from __future__ import annotations
-__all__ = ["PpdbTest"]
+__all__ = ["TEST_SCHEMA_RESOURCE_PATH", "PpdbTest", "fill_apdb"]
import unittest
from abc import ABC, abstractmethod
@@ -44,20 +44,15 @@
from lsst.dax.apdb.tests.data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog
from lsst.sphgeom import Angle, Circle, Region, UnitVector3d
-from ..config import PpdbConfig
from ..ppdb import Ppdb, PpdbReplicaChunk
+from ..ppdb_config import PpdbConfig
from ..replicator import Replicator
if TYPE_CHECKING:
import pandas
- class TestCaseMixin(unittest.TestCase):
- """Base class for mixin test classes that use TestCase methods."""
-else:
-
- class TestCaseMixin:
- """Do-nothing definition of mixin base class for regular execution."""
+TEST_SCHEMA_RESOURCE_PATH = "resource://lsst.dax.ppdb/config/schemas/test_apdb_schema.yaml"
def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
@@ -68,12 +63,107 @@ def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
return region
-class PpdbTest(TestCaseMixin, ABC):
+def _make_update_records(
+ sources: pandas.DataFrame, fsources: pandas.DataFrame, update_time: astropy.time.Time
+) -> list[ApdbUpdateRecord]:
+ """Create update records from source catalogs for testing."""
+ update_time_ns = int(update_time.unix_tai * 1e9)
+ records: list[ApdbUpdateRecord] = []
+
+ # Reassign one DIASource to SSObject.
+ dia_source = sources.iloc[0]
+ records.append(
+ ApdbReassignDiaSourceToSSObjectRecord(
+ update_time_ns=update_time_ns,
+ update_order=0,
+ diaSourceId=int(dia_source["diaSourceId"]),
+ ssObjectId=1,
+ ssObjectReassocTimeMjdTai=float(update_time.tai.mjd),
+ ra=float(dia_source["ra"]),
+ dec=float(dia_source["dec"]),
+ midpointMjdTai=60000.0,
+ )
+ )
+
+ # Close validity interval for matching DIAObject.
+ records.append(
+ ApdbCloseDiaObjectValidityRecord(
+ update_time_ns=update_time_ns,
+ update_order=1,
+ diaObjectId=int(dia_source["diaObjectId"]),
+ validityEndMjdTai=update_time.tai.mjd,
+ nDiaSources=None,
+ ra=float(dia_source["ra"]),
+ dec=float(dia_source["dec"]),
+ )
+ )
+
+ # Withdraw one DIAForcedSource.
+ dia_fsource = fsources.iloc[0]
+ records.append(
+ ApdbWithdrawDiaForcedSourceRecord(
+ update_time_ns=update_time_ns,
+ update_order=2,
+ diaObjectId=int(dia_fsource["diaObjectId"]),
+ visit=int(dia_fsource["visit"]),
+ detector=int(dia_fsource["detector"]),
+ timeWithdrawnMjdTai=update_time.tai.mjd,
+ ra=float(dia_source["ra"]),
+ dec=float(dia_source["dec"]),
+ midpointMjdTai=60000.0,
+ )
+ )
+
+ return records
+
+
+def fill_apdb(apdb: Apdb, include_update_records: bool = False) -> None:
+ """Populate APDB with some data to replicate."""
+ region1 = _make_region((1.0, 1.0, -1.0))
+ region2 = _make_region((-1.0, -1.0, -1.0))
+ nobj = 100
+ objects1 = makeObjectCatalog(region1, nobj)
+ objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2)
+
+ # With the default 10 minutes replica chunk window we should have 4
+ # records. All timestamps are far in the past, means that replication
+ # of the last chunk can run without waiting.
+ visits = [
+ (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1),
+ (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2),
+ (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1),
+ (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2),
+ (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1),
+ (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2),
+ (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1),
+ (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2),
+ ]
+
+ # Time when updates are applied.
+ update_time = astropy.time.Time("2021-03-01T12:00:00")
+
+ update_records = []
+ start_id = 0
+ for visit, (visit_time, objects) in enumerate(visits):
+ sources = makeSourceCatalog(objects, visit_time, visit=visit, start_id=start_id)
+ fsources = makeForcedSourceCatalog(objects, visit_time, visit=visit)
+ apdb.store(visit_time, objects, sources, fsources)
+ start_id += nobj
+
+ if include_update_records and visit == (len(visits) - 1):
+ # Generate a few update records.
+ update_records = _make_update_records(sources, fsources, update_time)
+
+ if include_update_records:
+ chunk = ReplicaChunk.make_replica_chunk(update_time, apdb.getConfig().replica_chunk_seconds)
+ # All our tests use SQL APDB.
+ assert isinstance(apdb, ApdbSql), "Expecting ApdbSql instance"
+ apdb._storeUpdateRecords(update_records, chunk, store_chunk=True)
+
+
+class PpdbTest(unittest.TestCase, ABC):
"""Base class for Ppdb tests that can be specialized for concrete
implementation.
-
- This can only be used as a mixin class for a unittest.TestCase and it
- calls various assert methods.
"""
include_update_records = False
@@ -102,112 +192,6 @@ def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig:
"""
raise NotImplementedError()
- def test_empty_db(self) -> None:
- """Test for instantiation a database and making queries on empty
- database.
- """
- config = self.make_instance()
- ppdb = Ppdb.from_config(config)
- chunks = ppdb.get_replica_chunks()
- if chunks is not None:
- self.assertEqual(len(chunks), 0)
-
- def _fill_apdb(self, apdb: Apdb) -> None:
- """Populate APDB with some data to replicate."""
- visit_time = astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai")
- region1 = _make_region((1.0, 1.0, -1.0))
- region2 = _make_region((-1.0, -1.0, -1.0))
- nobj = 100
- objects1 = makeObjectCatalog(region1, nobj)
- objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2)
-
- # With the default 10 minutes replica chunk window we should have 4
- # records. All timestamps are far in the past, means that replication
- # of the last chunk can run without waiting.
- visits = [
- (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1),
- (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2),
- (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1),
- (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2),
- (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1),
- (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2),
- (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1),
- (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2),
- ]
-
- # Time when apdates are applied.
- update_time = astropy.time.Time("2021-03-01T12:00:00")
-
- update_records = []
- start_id = 0
- for visit, (visit_time, objects) in enumerate(visits):
- sources = makeSourceCatalog(objects, visit_time, visit=visit, start_id=start_id)
- fsources = makeForcedSourceCatalog(objects, visit_time, visit=visit)
- apdb.store(visit_time, objects, sources, fsources)
- start_id += nobj
-
- if self.include_update_records and visit == (len(visits) - 1):
- # Generate few update records.
- update_records = self._make_update_records(sources, fsources, update_time)
-
- if self.include_update_records:
- chunk = ReplicaChunk.make_replica_chunk(update_time, apdb.getConfig().replica_chunk_seconds)
- # All our tests use SQL APDB.
- assert isinstance(apdb, ApdbSql), "Expecting ApdbSql instance"
- apdb._storeUpdateRecords(update_records, chunk, store_chunk=True)
-
- def _make_update_records(
- self, sources: pandas.DataFrame, fsources: pandas.DataFrame, update_time: astropy.time.Time
- ) -> list[ApdbUpdateRecord]:
- update_time_ns = int(update_time.unix_tai * 1e9)
- records: list[ApdbUpdateRecord] = []
-
- # Reassign one DIASource to SSObject.
- dia_source = sources.iloc[0]
- records.append(
- ApdbReassignDiaSourceToSSObjectRecord(
- update_time_ns=update_time_ns,
- update_order=0,
- diaSourceId=int(dia_source["diaSourceId"]),
- ssObjectId=1,
- ssObjectReassocTimeMjdTai=float(update_time.tai.mjd),
- ra=float(dia_source["ra"]),
- dec=float(dia_source["dec"]),
- midpointMjdTai=60000.0,
- )
- )
-
- # Close validity interval for matching DIAObject.
- records.append(
- ApdbCloseDiaObjectValidityRecord(
- update_time_ns=update_time_ns,
- update_order=1,
- diaObjectId=int(dia_source["diaObjectId"]),
- validityEndMjdTai=update_time.tai.mjd,
- nDiaSources=None,
- ra=float(dia_source["ra"]),
- dec=float(dia_source["dec"]),
- )
- )
-
- # Withdraw one DIAForcedSource.
- dia_fsource = fsources.iloc[0]
- records.append(
- ApdbWithdrawDiaForcedSourceRecord(
- update_time_ns=update_time_ns,
- update_order=2,
- diaObjectId=int(dia_fsource["diaObjectId"]),
- visit=int(dia_fsource["visit"]),
- detector=int(dia_fsource["detector"]),
- timeWithdrawnMjdTai=update_time.tai.mjd,
- ra=float(dia_source["ra"]),
- dec=float(dia_source["dec"]),
- midpointMjdTai=60000.0,
- )
- )
-
- return records
-
def _check_chunks(
self, apdb_chunks: Sequence[ReplicaChunk], ppdb_chunks: Sequence[PpdbReplicaChunk]
) -> None:
@@ -218,12 +202,22 @@ def _check_chunks(
self.assertEqual(ppdb_chunks[i].last_update_time, apdb_chunks[i].last_update_time)
self.assertEqual(ppdb_chunks[i].unique_id, apdb_chunks[i].unique_id)
+ def test_empty_db(self) -> None:
+ """Test for instantiation a database and making queries on empty
+ database.
+ """
+ config = self.make_instance()
+ ppdb = Ppdb.from_config(config)
+ chunks = ppdb.get_replica_chunks()
+ if chunks is not None:
+ self.assertEqual(len(chunks), 0)
+
def test_replication_single(self) -> None:
"""Test replication from APDB to PPDB using a single chunk option."""
apdb_config = self.make_apdb_instance()
apdb = Apdb.from_config(apdb_config)
- self._fill_apdb(apdb)
+ fill_apdb(apdb, self.include_update_records)
expected_chunks = 5 if self.include_update_records else 4
@@ -286,7 +280,7 @@ def test_replication_all(self) -> None:
apdb_config = self.make_apdb_instance()
apdb = Apdb.from_config(apdb_config)
- self._fill_apdb(apdb)
+ fill_apdb(apdb, self.include_update_records)
expected_chunks = 5 if self.include_update_records else 4
diff --git a/python/lsst/dax/ppdb/tests/_updates.py b/python/lsst/dax/ppdb/tests/_updates.py
new file mode 100644
index 00000000..45d3fb1b
--- /dev/null
+++ b/python/lsst/dax/ppdb/tests/_updates.py
@@ -0,0 +1,156 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from lsst.dax.apdb import (
+ ApdbCloseDiaObjectValidityRecord,
+ ApdbReassignDiaSourceToDiaObjectRecord,
+ ApdbReassignDiaSourceToSSObjectRecord,
+ ApdbUpdateNDiaSourcesRecord,
+ ApdbUpdateRecord,
+ ApdbWithdrawDiaForcedSourceRecord,
+ ApdbWithdrawDiaSourceRecord,
+)
+
+from ..bigquery.updates import UpdateRecords
+
+
+def _create_test_update_records() -> UpdateRecords:
+ """Create test UpdateRecords with sample ApdbUpdateRecord instances."""
+ records: list[ApdbUpdateRecord] = []
+
+ # Hardcoded test values
+ test_update_time_ns = 1640995200000000000 # 2022-01-01 00:00:00 UTC in nanoseconds
+ test_mjd_tai = 59580.0 # Corresponding MJD TAI for 2022-01-01
+ test_replica_chunk_id = 12345
+
+ # Reassign DIASource to different DIAObject
+ records.append(
+ ApdbReassignDiaSourceToDiaObjectRecord(
+ update_time_ns=test_update_time_ns,
+ update_order=0,
+ diaSourceId=100001,
+ diaObjectId=300001,
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+ )
+
+ # Reassign DIASource to SSObject
+ records.append(
+ ApdbReassignDiaSourceToSSObjectRecord(
+ update_time_ns=test_update_time_ns,
+ update_order=1,
+ diaSourceId=100002,
+ ssObjectId=2001,
+ ssObjectReassocTimeMjdTai=test_mjd_tai,
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+ )
+
+ # Withdraw DIASource
+ records.append(
+ ApdbWithdrawDiaSourceRecord(
+ update_time_ns=test_update_time_ns,
+ update_order=2,
+ diaSourceId=100003,
+ timeWithdrawnMjdTai=test_mjd_tai,
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+ )
+
+ # Withdraw DIAForcedSource
+ records.append(
+ ApdbWithdrawDiaForcedSourceRecord(
+ update_time_ns=test_update_time_ns,
+ update_order=3,
+ diaObjectId=200001,
+ visit=12345,
+ detector=42,
+ timeWithdrawnMjdTai=test_mjd_tai,
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+ )
+
+ # Close DIAObject validity interval
+ records.append(
+ ApdbCloseDiaObjectValidityRecord(
+ update_time_ns=test_update_time_ns,
+ update_order=4,
+ diaObjectId=200001,
+ validityEndMjdTai=test_mjd_tai,
+ nDiaSources=5,
+ ra=45.0,
+ dec=-30.0,
+ )
+ )
+
+ # Update DIAObject nDiaSources count
+ records.append(
+ ApdbUpdateNDiaSourcesRecord(
+ update_time_ns=test_update_time_ns,
+ update_order=5,
+ diaObjectId=200002,
+ nDiaSources=10,
+ ra=45.0,
+ dec=-30.0,
+ )
+ )
+
+ # Add duplicate records for testing deduplication
+ # Duplicate of the first record but with later timestamp (should be kept)
+ records.append(
+ ApdbReassignDiaSourceToDiaObjectRecord(
+ update_time_ns=test_update_time_ns + 1000000000, # 1 second later
+ update_order=0,
+ diaSourceId=100001,
+ diaObjectId=400001, # Different target object
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+ )
+
+ # Duplicate of the nDiaSources update but with earlier timestamp (should be
+ # discarded)
+ records.append(
+ ApdbUpdateNDiaSourcesRecord(
+ update_time_ns=test_update_time_ns - 1000000000, # 1 second earlier
+ update_order=5,
+ diaObjectId=200002,
+ nDiaSources=8, # Different value but older timestamp
+ ra=45.0,
+ dec=-30.0,
+ )
+ )
+
+ return UpdateRecords(
+ replica_chunk_id=test_replica_chunk_id,
+ record_count=len(records),
+ records=records,
+ )
diff --git a/python/lsst/dax/ppdb/tests/config/__init__.py b/python/lsst/dax/ppdb/tests/config/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/requirements.txt b/requirements.txt
index 4f213f08..c0cc7069 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,11 +1,13 @@
astropy
+google-cloud-bigquery
pyarrow
pydantic >=2,<3
pyyaml >= 5.1
sqlalchemy
+
lsst-dax-apdb @ git+https://github.com/lsst/dax_apdb@main
-lsst-utils @ git+https://github.com/lsst/utils@main
-lsst-resources[s3] @ git+https://github.com/lsst/resources@main
-lsst-felis @ git+https://github.com/lsst/felis@main
lsst-dax-ppdbx-gcp @ git+https://github.com/lsst-dm/dax_ppdbx_gcp@main
+lsst-felis @ git+https://github.com/lsst/felis@main
lsst-sdm-schemas @ git+https://github.com/lsst/sdm_schemas@main
+lsst-utils @ git+https://github.com/lsst/utils@main
+lsst-resources[s3] @ git+https://github.com/lsst/resources@main
diff --git a/tests/test_ppdbBigQuery.py b/tests/test_ppdbBigQuery.py
deleted file mode 100644
index 198f0bca..00000000
--- a/tests/test_ppdbBigQuery.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# This file is part of dax_ppdb.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-import gc
-import os
-import shutil
-import tempfile
-import unittest
-from typing import Any
-
-from lsst.dax.apdb import ApdbConfig
-from lsst.dax.apdb.sql import ApdbSql
-from lsst.dax.ppdb import PpdbConfig
-from lsst.dax.ppdb.bigquery import PpdbBigQuery
-from lsst.dax.ppdb.tests import PpdbTest
-
-try:
- import testing.postgresql
-except ImportError:
- testing = None
-
-TEST_SCHEMA = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config/schema.yaml")
-
-TEST_CONFIG = {
- "db_drop": True,
- "validate_config": False,
- "delete_existing_dirs": True,
- "bucket_name": "test_bucket",
- "object_prefix": "test_prefix",
- "dataset_id": "test_dataset",
- "project_id": "test_project",
-}
-
-
-class SqliteTestCase(PpdbTest, unittest.TestCase):
- """A test case for the PpdbBigQuery class using a SQLite backend."""
-
- def setUp(self) -> None:
- self.tempdir = tempfile.mkdtemp()
- self.apdb_url = f"sqlite:///{self.tempdir}/apdb.sqlite3"
- self.ppdb_url = f"sqlite:///{self.tempdir}/ppdb.sqlite3"
-
- def tearDown(self) -> None:
- shutil.rmtree(self.tempdir, ignore_errors=True)
-
- def make_instance(self, **kwargs: Any) -> PpdbConfig:
- """Make config class instance used in all tests."""
- kw = {
- **TEST_CONFIG,
- "db_url": self.ppdb_url,
- "felis_path": TEST_SCHEMA,
- "replication_dir": self.tempdir,
- }
- bq_config = PpdbBigQuery.init_bigquery(
- **kw,
- ) # type: ignore[arg-type]
- return bq_config
-
- def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig:
- """Make APDB instance for tests."""
- kw = {
- "schema_file": TEST_SCHEMA,
- "ss_schema_file": "",
- "db_url": self.apdb_url,
- "enable_replica": True,
- }
- kw.update(kwargs)
- return ApdbSql.init_database(**kw) # type: ignore[arg-type]
-
-
-@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
-class PostgresTestCase(PpdbTest, unittest.TestCase):
- """A test case for the PpdbBigQuery class using a Postgres backend."""
-
- postgresql: Any
-
- @classmethod
- def setUpClass(cls) -> None:
- # Create the postgres test server.
- cls.postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True)
- super().setUpClass()
-
- @classmethod
- def tearDownClass(cls) -> None:
- # Clean up any lingering SQLAlchemy engines/connections
- # so they're closed before we shut down the server.
- gc.collect()
- cls.postgresql.clear_cache()
- super().tearDownClass()
-
- def setUp(self) -> None:
- self.server = self.postgresql()
- self.tempdir = tempfile.mkdtemp()
-
- def tearDown(self) -> None:
- self.server = self.postgresql()
- shutil.rmtree(self.tempdir, ignore_errors=True)
-
- def make_instance(self, **kwargs: Any) -> PpdbConfig:
- """Make config class instance used in all tests."""
- kw = {
- **TEST_CONFIG,
- "db_url": self.server.url(),
- "db_schema": None,
- "felis_path": TEST_SCHEMA,
- "replication_dir": self.tempdir,
- }
- bq_config = PpdbBigQuery.init_bigquery(**kw) # type: ignore[arg-type]
- return bq_config
-
- def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig:
- kw = {
- "schema_file": TEST_SCHEMA,
- "ss_schema_file": "",
- "db_url": self.server.url(),
- "namespace": "apdb",
- "enable_replica": True,
- }
- kw.update(kwargs)
- return ApdbSql.init_database(**kw) # type: ignore[arg-type]
diff --git a/tests/test_ppdb_bigquery.py b/tests/test_ppdb_bigquery.py
new file mode 100644
index 00000000..43a870d9
--- /dev/null
+++ b/tests/test_ppdb_bigquery.py
@@ -0,0 +1,39 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+
+from lsst.dax.ppdb.tests import PpdbTest
+from lsst.dax.ppdb.tests._bigquery import PostgresMixin, SqliteMixin
+
+try:
+ import testing.postgresql
+except ImportError:
+ testing = None
+
+
+class SqliteTestCase(SqliteMixin, PpdbTest, unittest.TestCase):
+ """A test case for the PpdbBigQuery class using a SQLite backend."""
+
+
+@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
+class PostgresTestCase(PostgresMixin, PpdbTest, unittest.TestCase):
+ """A test case for the PpdbBigQuery class using a Postgres backend."""
diff --git a/tests/test_ppdbSql.py b/tests/test_ppdb_sql.py
similarity index 91%
rename from tests/test_ppdbSql.py
rename to tests/test_ppdb_sql.py
index f8675079..b6a6a1ab 100644
--- a/tests/test_ppdbSql.py
+++ b/tests/test_ppdb_sql.py
@@ -20,7 +20,6 @@
# along with this program. If not, see .
import gc
-import os
import shutil
import tempfile
import unittest
@@ -30,15 +29,13 @@
from lsst.dax.apdb.sql import ApdbSql
from lsst.dax.ppdb import PpdbConfig
from lsst.dax.ppdb.sql import PpdbSql
-from lsst.dax.ppdb.tests import PpdbTest
+from lsst.dax.ppdb.tests import TEST_SCHEMA_RESOURCE_PATH, PpdbTest
try:
import testing.postgresql
except ImportError:
testing = None
-TEST_SCHEMA = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config/schema.yaml")
-
class ApdbSQLiteTestCase(PpdbTest, unittest.TestCase):
"""A test case for PpdbSql class using SQLite backend."""
@@ -55,11 +52,11 @@ def tearDown(self) -> None:
def make_instance(self, **kwargs: Any) -> PpdbConfig:
"""Make config class instance used in all tests."""
- return PpdbSql.init_database(db_url=self.ppdb_url, schema_file=TEST_SCHEMA, **kwargs)
+ return PpdbSql.init_database(db_url=self.ppdb_url, schema_file=TEST_SCHEMA_RESOURCE_PATH, **kwargs)
def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig:
kw = {
- "schema_file": TEST_SCHEMA,
+ "schema_file": TEST_SCHEMA_RESOURCE_PATH,
"ss_schema_file": "",
"db_url": self.apdb_url,
"enable_replica": True,
@@ -98,11 +95,13 @@ def tearDown(self) -> None:
def make_instance(self, **kwargs: Any) -> PpdbConfig:
"""Make config class instance used in all tests."""
- return PpdbSql.init_database(db_url=self.server.url(), schema_file=TEST_SCHEMA, **kwargs)
+ return PpdbSql.init_database(
+ db_url=self.server.url(), schema_file=TEST_SCHEMA_RESOURCE_PATH, **kwargs
+ )
def make_apdb_instance(self, **kwargs: Any) -> ApdbConfig:
kw = {
- "schema_file": TEST_SCHEMA,
+ "schema_file": TEST_SCHEMA_RESOURCE_PATH,
"ss_schema_file": "",
"db_url": self.server.url(),
"namespace": "apdb",
diff --git a/tests/test_update_record_expander.py b/tests/test_update_record_expander.py
new file mode 100644
index 00000000..5038bfb5
--- /dev/null
+++ b/tests/test_update_record_expander.py
@@ -0,0 +1,322 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import datetime
+import unittest
+
+import astropy.time
+
+from lsst.dax.apdb import (
+ ApdbCloseDiaObjectValidityRecord,
+ ApdbReassignDiaSourceToDiaObjectRecord,
+ ApdbReassignDiaSourceToSSObjectRecord,
+ ApdbUpdateNDiaSourcesRecord,
+ ApdbWithdrawDiaForcedSourceRecord,
+ ApdbWithdrawDiaSourceRecord,
+)
+from lsst.dax.ppdb.bigquery.updates import ExpandedUpdateRecord, UpdateRecordExpander, UpdateRecords
+from lsst.dax.ppdb.tests._updates import _create_test_update_records
+
+
+class UpdateRecordExpanderTestCase(unittest.TestCase):
+ """Test UpdateRecordExpander functionality."""
+
+ def setUp(self) -> None:
+ """Set up test fixtures."""
+ # Test time for consistent timestamps
+ self.update_time = astropy.time.Time("2021-03-01T12:00:00", format="isot", scale="tai")
+ self.update_time_ns = int(self.update_time.unix_tai * 1e9)
+
+ # Test replica chunk ID
+ self.replica_chunk_id = 12345
+
+ def test_get_update_fields(self) -> None:
+ """Test get_update_fields class method."""
+ # Test known update types
+ self.assertEqual(
+ UpdateRecordExpander.get_update_fields("reassign_diasource_to_diaobject"), ["diaObjectId"]
+ )
+ self.assertEqual(
+ UpdateRecordExpander.get_update_fields("reassign_diasource_to_ssobject"),
+ ["ssObjectId", "ssObjectReassocTimeMjdTai"],
+ )
+ self.assertEqual(
+ UpdateRecordExpander.get_update_fields("withdraw_diasource"), ["timeWithdrawnMjdTai"]
+ )
+ self.assertEqual(
+ UpdateRecordExpander.get_update_fields("withdraw_diaforcedsource"), ["timeWithdrawnMjdTai"]
+ )
+ self.assertEqual(
+ UpdateRecordExpander.get_update_fields("close_diaobject_validity"),
+ ["validityEndMjdTai", "nDiaSources"],
+ )
+ self.assertEqual(UpdateRecordExpander.get_update_fields("update_n_dia_sources"), ["nDiaSources"])
+
+ # Test unknown update type
+ with self.assertRaises(ValueError) as cm:
+ UpdateRecordExpander.get_update_fields("unknown_update_type")
+ self.assertIn("Unknown update_type: unknown_update_type", str(cm.exception))
+
+ def test_get_record_id_field_names(self) -> None:
+ """Test get_record_id_field class method."""
+ from lsst.dax.ppdb.bigquery.updates import UpdateRecordExpander
+
+ self.assertEqual(
+ UpdateRecordExpander.get_record_id_fields("reassign_diasource_to_diaobject"), ["diaSourceId"]
+ )
+ self.assertEqual(
+ UpdateRecordExpander.get_record_id_fields("reassign_diasource_to_ssobject"), ["diaSourceId"]
+ )
+ self.assertEqual(UpdateRecordExpander.get_record_id_fields("withdraw_diasource"), ["diaSourceId"])
+ self.assertEqual(
+ UpdateRecordExpander.get_record_id_fields("withdraw_diaforcedsource"),
+ ["diaObjectId", "visit", "detector"],
+ )
+ self.assertEqual(
+ UpdateRecordExpander.get_record_id_fields("close_diaobject_validity"), ["diaObjectId"]
+ )
+ self.assertEqual(UpdateRecordExpander.get_record_id_fields("update_n_dia_sources"), ["diaObjectId"])
+
+ # Test unknown update type
+ with self.assertRaises(ValueError) as cm:
+ UpdateRecordExpander.get_record_id_fields("unknown_update_type")
+ self.assertIn("Unknown update_type: unknown_update_type", str(cm.exception))
+
+ def test_reassign_diasource_to_diaobject(self) -> None:
+ """Test expand_single_record with
+ ApdbReassignDiaSourceToDiaObjectRecord.
+ """
+ from lsst.dax.ppdb.bigquery.updates import ExpandedUpdateRecord, UpdateRecordExpander
+
+ record = ApdbReassignDiaSourceToDiaObjectRecord(
+ update_time_ns=self.update_time_ns,
+ update_order=0,
+ diaSourceId=100001,
+ diaObjectId=300001,
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+
+ expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id)
+
+ # Should expand to 1 record (diaObjectId)
+ self.assertEqual(len(expanded), 1)
+
+ expanded_record = expanded[0]
+ self.assertIsInstance(expanded_record, ExpandedUpdateRecord)
+ self.assertEqual(expanded_record.table_name, "DiaSource")
+ self.assertEqual(expanded_record.record_id, [100001])
+ self.assertEqual(expanded_record.field_name, "diaObjectId")
+ self.assertEqual(expanded_record.value_json, 300001)
+ self.assertEqual(expanded_record.replica_chunk_id, self.replica_chunk_id)
+ self.assertEqual(expanded_record.update_order, 0)
+ self.assertEqual(expanded_record.update_time_ns, self.update_time_ns)
+
+ def test_reassign_diasource_to_ssobject(self) -> None:
+ """Test expand_single_record with
+ ApdbReassignDiaSourceToSSObjectRecord.
+ """
+ record = ApdbReassignDiaSourceToSSObjectRecord(
+ update_time_ns=self.update_time_ns,
+ update_order=0,
+ diaSourceId=100001,
+ ssObjectId=2001,
+ ssObjectReassocTimeMjdTai=float(self.update_time.tai.mjd),
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+
+ expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id)
+
+ # Should expand to 2 records (ssObjectId and ssObjectReassocTimeMjdTai)
+ self.assertEqual(len(expanded), 2)
+
+ # Check first expanded record (ssObjectId)
+ first_record = expanded[0]
+ self.assertIsInstance(first_record, ExpandedUpdateRecord)
+ self.assertEqual(first_record.table_name, "DiaSource")
+ self.assertEqual(first_record.record_id, [100001])
+ self.assertEqual(first_record.field_name, "ssObjectId")
+ self.assertEqual(first_record.value_json, 2001)
+ self.assertEqual(first_record.replica_chunk_id, self.replica_chunk_id)
+ self.assertEqual(first_record.update_order, 0)
+ self.assertEqual(first_record.update_time_ns, self.update_time_ns)
+
+ # Check second expanded record (ssObjectReassocTimeMjdTai)
+ second_record = expanded[1]
+ self.assertEqual(second_record.table_name, "DiaSource")
+ self.assertEqual(second_record.record_id, [100001])
+ self.assertEqual(second_record.field_name, "ssObjectReassocTimeMjdTai")
+ self.assertEqual(second_record.value_json, float(self.update_time.tai.mjd))
+
+ def test_withdraw_diasource(self) -> None:
+ """Test expand_single_record with ApdbWithdrawDiaSourceRecord."""
+ record = ApdbWithdrawDiaSourceRecord(
+ update_time_ns=self.update_time_ns,
+ update_order=2,
+ diaSourceId=100003,
+ timeWithdrawnMjdTai=self.update_time.tai.mjd,
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+
+ expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id)
+
+ # Should expand to 1 record (timeWithdrawnMjdTai)
+ self.assertEqual(len(expanded), 1)
+
+ expanded_record = expanded[0]
+ self.assertEqual(expanded_record.table_name, "DiaSource")
+ self.assertEqual(expanded_record.record_id, [100003])
+ self.assertEqual(expanded_record.field_name, "timeWithdrawnMjdTai")
+ self.assertEqual(expanded_record.value_json, self.update_time.tai.mjd)
+
+ def test_update_n_dia_sources(self) -> None:
+ """Test expand_single_record with ApdbUpdateNDiaSourcesRecord."""
+ record = ApdbUpdateNDiaSourcesRecord(
+ update_time_ns=self.update_time_ns,
+ update_order=5,
+ diaObjectId=200002,
+ nDiaSources=10,
+ ra=45.0,
+ dec=-30.0,
+ )
+
+ expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id)
+
+ # Should expand to 1 record (nDiaSources)
+ self.assertEqual(len(expanded), 1)
+
+ expanded_record = expanded[0]
+ self.assertEqual(expanded_record.table_name, "DiaObject")
+ self.assertEqual(expanded_record.record_id, [200002])
+ self.assertEqual(expanded_record.field_name, "nDiaSources")
+ self.assertEqual(expanded_record.value_json, 10)
+
+ def test_close_diaobject_validity(self) -> None:
+ """Test expand_single_record with ApdbCloseDiaObjectValidityRecord."""
+ record = ApdbCloseDiaObjectValidityRecord(
+ update_time_ns=self.update_time_ns,
+ update_order=4,
+ diaObjectId=200001,
+ validityEndMjdTai=self.update_time.tai.mjd,
+ nDiaSources=5,
+ ra=45.0,
+ dec=-30.0,
+ )
+
+ expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id)
+
+ # Should expand to 2 records (validityEndMjdTai and nDiaSources)
+ self.assertEqual(len(expanded), 2)
+
+ # Check first expanded record (validityEndMjdTai)
+ first_record = expanded[0]
+ self.assertIsInstance(first_record, ExpandedUpdateRecord)
+ self.assertEqual(first_record.table_name, "DiaObject")
+ self.assertEqual(first_record.record_id, [200001])
+ self.assertEqual(first_record.field_name, "validityEndMjdTai")
+ self.assertEqual(first_record.value_json, self.update_time.tai.mjd)
+
+ # Check second expanded record (nDiaSources)
+ second_record = expanded[1]
+ self.assertEqual(second_record.table_name, "DiaObject")
+ self.assertEqual(second_record.record_id, [200001])
+ self.assertEqual(second_record.field_name, "nDiaSources")
+ self.assertEqual(second_record.value_json, 5)
+
+ def test_withdraw_diaforcedsource(self) -> None:
+ """Test expand_single_record with ApdbWithdrawDiaForcedSourceRecord."""
+ record = ApdbWithdrawDiaForcedSourceRecord(
+ update_time_ns=self.update_time_ns,
+ update_order=2,
+ diaObjectId=200001,
+ visit=12345,
+ detector=42,
+ timeWithdrawnMjdTai=self.update_time.tai.mjd,
+ ra=45.0,
+ dec=-30.0,
+ midpointMjdTai=60000.0,
+ )
+
+ expanded = UpdateRecordExpander.expand_single_record(record, self.replica_chunk_id)
+
+ # Should expand to 1 record (timeWithdrawnMjdTai)
+ self.assertEqual(len(expanded), 1)
+
+ expanded_record = expanded[0]
+ self.assertEqual(expanded_record.table_name, "DiaForcedSource")
+ # The record ID should be a list of the composite key components
+ # [diaObjectId, visit, detector] for BigQuery compatibility
+ expected_record_id = [200001, 12345, 42]
+ self.assertEqual(expanded_record.record_id, expected_record_id)
+ self.assertEqual(expanded_record.field_name, "timeWithdrawnMjdTai")
+ self.assertEqual(expanded_record.value_json, self.update_time.tai.mjd)
+
+ def test_update_records_all(self) -> None:
+ """Test the full expand_updates method with multiple record types."""
+ update_records = _create_test_update_records()
+
+ expanded = UpdateRecordExpander.expand_updates(update_records)
+
+ self.assertEqual(len(expanded), 10)
+
+ # Verify all expanded records have correct replica_chunk_id
+ for record in expanded:
+ self.assertEqual(record.replica_chunk_id, self.replica_chunk_id)
+ self.assertIsInstance(record.update_time_ns, int)
+ self.assertIsInstance(record.update_order, int)
+
+ # Check that we have the expected table names
+ table_names = {record.table_name for record in expanded}
+ expected_tables = {"DiaSource", "DiaObject", "DiaForcedSource"}
+ self.assertEqual(table_names, expected_tables)
+
+ # Check that we have the expected field names
+ field_names = {record.field_name for record in expanded}
+ expected_fields = {
+ "diaObjectId", # from reassign to diaobject
+ "ssObjectId",
+ "ssObjectReassocTimeMjdTai", # from reassign to ssobject
+ "timeWithdrawnMjdTai", # from withdraw diasource and withdraw forced source
+ "validityEndMjdTai",
+ "nDiaSources", # from close validity and update n dia sources
+ }
+ self.assertEqual(field_names, expected_fields)
+
+ def test_empty_records(self) -> None:
+ """Test expand_updates with empty records list."""
+ empty_update_records = UpdateRecords(
+ replica_chunk_id=self.replica_chunk_id,
+ record_count=0,
+ records=[],
+ file_created_at=datetime.datetime.now(datetime.UTC),
+ )
+
+ expanded = UpdateRecordExpander.expand_updates(empty_update_records)
+ self.assertEqual(len(expanded), 0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_update_records.py b/tests/test_update_records.py
new file mode 100644
index 00000000..87c3dc8d
--- /dev/null
+++ b/tests/test_update_records.py
@@ -0,0 +1,329 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import json
+import unittest
+
+import pytest
+from google.cloud import storage
+
+from lsst.dax.apdb import (
+ Apdb,
+ ApdbReplica,
+ apdbUpdateRecord,
+)
+from lsst.dax.ppdb import Ppdb
+from lsst.dax.ppdb.bigquery import PpdbBigQuery
+from lsst.dax.ppdb.bigquery.updates import UpdateRecords
+from lsst.dax.ppdb.replicator import Replicator
+from lsst.dax.ppdb.tests import fill_apdb
+from lsst.dax.ppdb.tests._bigquery import (
+ ChunkUploaderWithoutPubSub,
+ PostgresMixin,
+ delete_test_bucket,
+ generate_test_bucket_name,
+ have_valid_google_credentials,
+)
+
+
+@unittest.skipIf(not have_valid_google_credentials(), "Missing valid Google credentials")
+class UpdateRecordsTestCase(PostgresMixin, unittest.TestCase):
+ """A test case for the handling of APDB record updates by PpdbBigQuery and
+ related classes including the ChunkUploader.
+ """
+
+ def setUp(self):
+ super().setUp()
+
+ # Make APDB instance and fill it with test data.
+ apdb_config = self.make_apdb_instance()
+ apdb = Apdb.from_config(apdb_config)
+ fill_apdb(apdb, include_update_records=True)
+ apdb_replica = ApdbReplica.from_config(apdb_config)
+
+ # Make PPDB instance.
+ self.ppdb_config = self.make_instance()
+ self.ppdb = Ppdb.from_config(self.ppdb_config)
+ assert isinstance(self.ppdb, PpdbBigQuery)
+
+ # Replicate APDB replica chunks to the PPDB.
+ replicator = Replicator(
+ apdb_replica, self.ppdb, update=False, min_wait_time=0, max_wait_time=0, check_interval=0
+ )
+ replicator.run(exit_on_empty=True)
+
+ def test_json_serialization(self) -> None:
+ """Test that the APDB update records are correctly saved to a JSON file
+ in the replication output and can be read back as valid UpdateRecords
+ objects.
+ """
+ update_records_path = self.ppdb.replication_path / "2021/03/01/1614600000" / "update_records.json"
+ self.assertTrue(update_records_path.exists(), "Update records file not found in replication output")
+
+ update_records = UpdateRecords.from_json_file(update_records_path)
+ print("\n" + str(update_records))
+
+ self.assertEqual(
+ update_records.replica_chunk_id,
+ 1614600000,
+ "Unexpected replica chunk ID in deserialized update records",
+ )
+
+ self.assertEqual(update_records.record_count, 3, "Unexpected number of update records deserialized")
+
+ self.assertEqual(
+ len(update_records.records), 3, "Unexpected number of update records in the deserialized object"
+ )
+
+ for record in update_records.records:
+ self.assertIsInstance(
+ record,
+ apdbUpdateRecord.ApdbUpdateRecord,
+ "Deserialized record is not an instance of ApdbUpdateRecord",
+ )
+
+ update_record = update_records.records[0]
+ self.assertIsInstance(
+ update_record,
+ apdbUpdateRecord.ApdbReassignDiaSourceToSSObjectRecord,
+ "Deserialized record is not an instance of ApdbReassignDiaSourceToSSObjectRecord",
+ )
+ assert isinstance(update_record, apdbUpdateRecord.ApdbReassignDiaSourceToSSObjectRecord)
+ self.assertEqual(
+ update_record.diaSourceId,
+ 700,
+ "Unexpected diaSourceId in deserialized ApdbReassignDiaSourceToSSObjectRecord",
+ )
+ self.assertEqual(
+ update_record.ssObjectId,
+ 1,
+ "Unexpected ssObjectId in deserialized ApdbReassignDiaSourceToSSObjectRecord",
+ )
+ self.assertEqual(
+ update_record.update_time_ns,
+ 1614600037000000000,
+ "Unexpected update_time_ns in deserialized ApdbReassignDiaSourceToSSObjectRecord",
+ )
+ self.assertEqual(
+ update_record.update_order,
+ 0,
+ "Unexpected update_order in deserialized ApdbReassignDiaSourceToSSObjectRecord",
+ )
+ self.assertEqual(
+ update_record.midpointMjdTai,
+ 60000.0,
+ "Unexpected midpointMjdTai in deserialized ApdbReassignDiaSourceToSSObjectRecord",
+ )
+ self.assertEqual(
+ update_record.ssObjectReassocTimeMjdTai,
+ 59274.50042824074,
+ "Unexpected ssObjectReassocTimeMjdTai in deserialized ApdbReassignDiaSourceToSSObjectRecord",
+ )
+ self.assertNotEqual(
+ update_record.ra,
+ 0.0,
+ "Unexpected ra in deserialized ApdbReassignDiaSourceToSSObjectRecord, should not be 0.0",
+ )
+ self.assertNotEqual(
+ update_record.dec,
+ 0.0,
+ "Unexpected dec in deserialized ApdbReassignDiaSourceToSSObjectRecord, should not be 0.0",
+ )
+
+ update_record = update_records.records[1]
+ self.assertIsInstance(
+ update_record,
+ apdbUpdateRecord.ApdbCloseDiaObjectValidityRecord,
+ "Deserialized record is not an instance of ApdbCloseDiaObjectValidityRecord",
+ )
+ self.assertEqual(
+ update_record.diaObjectId,
+ 200,
+ "Unexpected diaObjectId in deserialized ApdbCloseDiaObjectValidityRecord",
+ )
+ self.assertNotEqual(
+ update_record.ra,
+ 0.0,
+ "Unexpected ra in deserialized ApdbCloseDiaObjectValidityRecord, should not be 0.0",
+ )
+ self.assertNotEqual(
+ update_record.dec,
+ 0.0,
+ "Unexpected dec in deserialized ApdbCloseDiaObjectValidityRecord, should not be 0.0",
+ )
+ self.assertEqual(
+ update_record.update_time_ns,
+ 1614600037000000000,
+ "Unexpected update_time_ns in deserialized ApdbCloseDiaObjectValidityRecord",
+ )
+ self.assertEqual(
+ update_record.update_order,
+ 1,
+ "Unexpected update_order in deserialized ApdbCloseDiaObjectValidityRecord",
+ )
+ self.assertEqual(
+ update_record.validityEndMjdTai,
+ 59274.50042824074,
+ "Unexpected validityEndMjdTai in deserialized ApdbCloseDiaObjectValidityRecord",
+ )
+ self.assertIsNone(
+ update_record.nDiaSources,
+ "Unexpected nDiaSources in deserialized ApdbCloseDiaObjectValidityRecord, expected None",
+ )
+
+ update_record = update_records.records[2]
+ self.assertIsInstance(
+ update_record,
+ apdbUpdateRecord.ApdbWithdrawDiaForcedSourceRecord,
+ "Deserialized record is not an instance of ApdbWithdrawDiaForcedSourceRecord",
+ )
+ self.assertEqual(
+ update_record.diaObjectId,
+ 200,
+ "Unexpected diaObjectId in deserialized ApdbWithdrawDiaForcedSourceRecord",
+ )
+ self.assertEqual(
+ update_record.visit,
+ 7,
+ "Unexpected visit in deserialized ApdbWithdrawDiaForcedSourceRecord",
+ )
+ self.assertEqual(
+ update_record.detector,
+ 1,
+ "Unexpected detector in deserialized ApdbWithdrawDiaForcedSourceRecord",
+ )
+ self.assertNotEqual(
+ update_record.ra,
+ 0.0,
+ "Unexpected ra in deserialized ApdbWithdrawDiaForcedSourceRecord, should not be 0.0",
+ )
+ self.assertNotEqual(
+ update_record.dec,
+ 0.0,
+ "Unexpected dec in deserialized ApdbWithdrawDiaForcedSourceRecord, should not be 0.0",
+ )
+ self.assertEqual(
+ update_record.midpointMjdTai,
+ 60000.0,
+ "Unexpected midpointMjdTai in deserialized ApdbWithdrawDiaForcedSourceRecord",
+ )
+ self.assertEqual(
+ update_record.update_time_ns,
+ 1614600037000000000,
+ "Unexpected update_time_ns in deserialized ApdbWithdrawDiaForcedSourceRecord",
+ )
+ self.assertEqual(
+ update_record.update_order,
+ 2,
+ "Unexpected update_order in deserialized ApdbWithdrawDiaForcedSourceRecord",
+ )
+ self.assertEqual(
+ update_record.timeWithdrawnMjdTai,
+ 59274.50042824074,
+ "Unexpected timeWithdrawnMjdTai in deserialized ApdbWithdrawDiaForcedSourceRecord",
+ )
+ self.assertNotEqual(
+ update_record.ra,
+ 0.0,
+ "Unexpected ra in deserialized ApdbWithdrawDiaForcedSourceRecord, should not be 0.0",
+ )
+ self.assertNotEqual(
+ update_record.dec,
+ 0.0,
+ "Unexpected dec in deserialized ApdbWithdrawDiaForcedSourceRecord, should not be 0.0",
+ )
+
+ # FIXME: This should be in a separate test case and probably a separate
+ # module as well.
+ @pytest.mark.skipif(
+ pytest.importorskip("lsst.dax.ppdbx.gcp", reason="dax_ppdbx_gcp is not installed") is None,
+ reason="",
+ )
+ def test_chunk_uploader(self) -> None:
+ """Test that the update records are correctly uploaded to Google Cloud
+ Storage after replication.
+ """
+ # Change the configuration to use a unique test bucket name to avoid
+ # conflicts
+ self.ppdb_config.bucket_name = generate_test_bucket_name("ppdb-test-gcs-upload")
+
+ # Create the test GCS bucket
+ storage_client = storage.Client()
+ try:
+ bucket = storage_client.bucket(self.ppdb_config.bucket_name)
+ bucket.create(location="US")
+ except Exception as e:
+ self.fail(f"Failed to create test GCS bucket: {e}")
+
+ # Configure and run the uploader
+ uploader = ChunkUploaderWithoutPubSub(
+ self.ppdb_config,
+ wait_interval=0,
+ exit_on_empty=True,
+ exit_on_error=True,
+ )
+ print(f"Uploader will copy files to {uploader.bucket_name}/{uploader.prefix}/")
+ uploader.run()
+
+ # Retrieve the update records file[]
+ blobs = list(bucket.list_blobs(match_glob="**/update_records.json"))
+ update_records_files = [b.name for b in blobs]
+ self.assertEqual(
+ len(update_records_files),
+ 1,
+ f"Expected exactly one update_records.json file in GCS, found "
+ f"{len(update_records_files)}: {update_records_files}",
+ )
+
+ # Download the contents of the update records file as a string
+ update_records_str = blobs[0].download_as_text()
+
+ # Print the contents of the update records file for debugging
+ update_records_json = json.loads(update_records_str)
+ print(f"Contents of update_records.json in GCS:\n{json.dumps(update_records_json, indent=2)}")
+
+ # Load the update records into the data model and perform a few basic
+ # checks (test_json_serialization already tests this in detail, so we
+ # just check a few key fields here).
+ update_records = UpdateRecords.model_validate(update_records_json)
+ self.assertEqual(
+ update_records.replica_chunk_id,
+ 1614600000,
+ "Unexpected replica chunk ID in update records file from GCS",
+ )
+ self.assertEqual(
+ update_records.record_count,
+ 3,
+ f"Expected record_count of 3 in update records file from GCS, found "
+ f"{update_records.record_count}",
+ )
+ self.assertEqual(
+ len(update_records.records),
+ 3,
+ f"Expected 3 update records in the file from GCS, found {len(update_records.records)}",
+ )
+
+ # FIXME: This should be in a tearDown() method.
+ # Delete the test GCS bucket
+ try:
+ delete_test_bucket(bucket)
+ except Exception as e:
+ raise RuntimeError(f"Failed to delete test GCS bucket: {e}") from e
diff --git a/tests/test_updates_manager.py b/tests/test_updates_manager.py
new file mode 100644
index 00000000..b9721fe0
--- /dev/null
+++ b/tests/test_updates_manager.py
@@ -0,0 +1,250 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+import uuid
+from collections.abc import Collection, Sequence
+
+import astropy
+import felis
+from google.cloud import bigquery, storage
+
+from lsst.dax.apdb import (
+ ApdbTableData,
+ ReplicaChunk,
+)
+from lsst.dax.ppdb import Ppdb
+from lsst.dax.ppdb.bigquery import PpdbBigQuery
+from lsst.dax.ppdb.bigquery.updates.updates_manager import UpdatesManager
+from lsst.dax.ppdb.tests._bigquery import (
+ ChunkUploaderWithoutPubSub,
+ PostgresMixin,
+ generate_test_bucket_name,
+ have_valid_google_credentials,
+ json_rows_to_buf,
+)
+from lsst.dax.ppdb.tests._updates import _create_test_update_records
+
+
+@unittest.skipIf(not have_valid_google_credentials(), "Missing valid Google credentials")
+class UpdatesManagerTestCase(PostgresMixin, unittest.TestCase):
+ """A test case for the handling of APDB record updates by PpdbBigQuery and
+ related classes including the ChunkUploader.
+ """
+
+ def setUp(self):
+ super().setUp()
+
+ # Set up BigQuery client and test dataset
+ self.bq_client = bigquery.Client()
+
+ bucket_name = generate_test_bucket_name("ppdb-updates-manager-test")
+ dataset_id = f"test_updates_manager_{uuid.uuid4().hex[:8]}"
+ project_id = self.bq_client.project
+ config = {
+ "db_drop": True,
+ "validate_config": False,
+ "delete_existing_dirs": True,
+ "bucket_name": bucket_name,
+ "object_prefix": "data/test",
+ "dataset_id": dataset_id,
+ "project_id": project_id,
+ }
+
+ # Setup the Postgres database and create the config instance
+ self.ppdb_config = self.make_instance(config)
+
+ # Create the test dataset and tables in BigQuery
+ self.target_dataset_fqn = f"{project_id}.{dataset_id}"
+ self._create_test_dataset(self.bq_client, dataset_id)
+
+ # Create the test GCS bucket
+ storage_client = storage.Client()
+ try:
+ bucket = storage_client.bucket(self.ppdb_config.bucket_name)
+ bucket.create(location="US")
+ except Exception as e:
+ self.fail(f"Failed to create test GCS bucket: {e}")
+
+ # Create the PPDB instance
+ self.ppdb = Ppdb.from_config(self.ppdb_config)
+ assert isinstance(self.ppdb, PpdbBigQuery)
+
+ def tearDown(self):
+ # Delete the test dataset
+ try:
+ self.bq_client.delete_dataset(
+ self.ppdb_config.dataset_id, delete_contents=True, not_found_ok=True
+ )
+ except Exception as e:
+ print(f"Failed to delete test dataset: {e}")
+
+ # Delete the test GCS bucket
+ storage_client = storage.Client()
+ try:
+ bucket = storage_client.bucket(self.ppdb_config.bucket_name)
+ blobs = list(bucket.list_blobs())
+ for blob in blobs:
+ blob.delete()
+ bucket.delete()
+ except Exception as e:
+ print(f"Failed to delete test GCS bucket: {e}")
+
+ super().tearDown()
+
+ def _create_test_dataset(self, client: bigquery.Client, dataset_id: str) -> None:
+ dataset = bigquery.Dataset(f"{client.project}.{dataset_id}")
+ client.create_dataset(dataset, exists_ok=False)
+
+ # Create DiaObject table
+ schema = [
+ bigquery.SchemaField("diaObjectId", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("validityEndMjdTai", "FLOAT", mode="NULLABLE"),
+ bigquery.SchemaField("nDiaSources", "INTEGER", mode="NULLABLE"),
+ ]
+ table_fqn = f"{self.target_dataset_fqn}.DiaObject"
+ table = bigquery.Table(table_fqn, schema=schema)
+ client.create_table(table)
+ rows = [
+ {"diaObjectId": 200001, "validityEndMjdTai": None, "nDiaSources": 3},
+ {"diaObjectId": 200002, "validityEndMjdTai": None, "nDiaSources": 7},
+ {"diaObjectId": 200003, "validityEndMjdTai": 59000.0, "nDiaSources": 2},
+ ]
+ buf = json_rows_to_buf(rows)
+ job = client.load_table_from_file(
+ buf,
+ table_fqn,
+ job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON),
+ )
+ job.result()
+
+ # Create test DiaSource table
+ schema = [
+ bigquery.SchemaField("diaSourceId", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("diaObjectId", "INTEGER", mode="NULLABLE"),
+ bigquery.SchemaField("ssObjectId", "INTEGER", mode="NULLABLE"),
+ bigquery.SchemaField("ssObjectReassocTimeMjdTai", "FLOAT", mode="NULLABLE"),
+ bigquery.SchemaField("timeWithdrawnMjdTai", "FLOAT", mode="NULLABLE"),
+ ]
+ table_fqn = f"{self.target_dataset_fqn}.DiaSource"
+ table = bigquery.Table(table_fqn, schema=schema)
+ self.bq_client.create_table(table)
+ rows = [
+ {
+ "diaSourceId": 100001,
+ "diaObjectId": 200001,
+ "ssObjectId": None,
+ "ssObjectReassocTimeMjdTai": None,
+ "timeWithdrawnMjdTai": None,
+ },
+ {
+ "diaSourceId": 100002,
+ "diaObjectId": 200002,
+ "ssObjectId": None,
+ "ssObjectReassocTimeMjdTai": None,
+ "timeWithdrawnMjdTai": None,
+ },
+ {
+ "diaSourceId": 100003,
+ "diaObjectId": 200003,
+ "ssObjectId": None,
+ "ssObjectReassocTimeMjdTai": None,
+ "timeWithdrawnMjdTai": None,
+ },
+ {
+ "diaSourceId": 100004,
+ "diaObjectId": 200004,
+ "ssObjectId": None,
+ "ssObjectReassocTimeMjdTai": None,
+ "timeWithdrawnMjdTai": None,
+ },
+ ]
+ job = client.load_table_from_file(
+ json_rows_to_buf(rows),
+ table_fqn,
+ job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON),
+ )
+ job.result()
+
+ # Create test DiaForcedSource table
+ schema = [
+ bigquery.SchemaField("diaObjectId", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("visit", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("detector", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("timeWithdrawnMjdTai", "FLOAT", mode="NULLABLE"),
+ ]
+ table_fqn = f"{self.target_dataset_fqn}.DiaForcedSource"
+ table = bigquery.Table(table_fqn, schema=schema)
+ self.bq_client.create_table(table)
+ rows = [
+ {"diaObjectId": 200001, "visit": 12345, "detector": 42, "timeWithdrawnMjdTai": None},
+ {"diaObjectId": 200001, "visit": 12346, "detector": 42, "timeWithdrawnMjdTai": None},
+ ]
+ job = self.bq_client.load_table_from_file(
+ json_rows_to_buf(rows),
+ table_fqn,
+ job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON),
+ )
+ job.result()
+
+ def test_apply_updates(self):
+ """Test that the update records are correctly uploaded to Google Cloud
+ Storage after replication.
+ """
+
+ class DummyApdbTableData(ApdbTableData):
+ def column_names(self) -> Sequence[str]:
+ return []
+
+ def column_defs(self) -> Sequence[tuple[str, felis.datamodel.DataType]]:
+ return []
+
+ def rows(self) -> Collection[tuple]:
+ return []
+
+ # Create and store the test update records
+ update_records = _create_test_update_records()
+ self.ppdb.store(
+ ReplicaChunk(
+ id=update_records.replica_chunk_id,
+ last_update_time=astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"),
+ unique_id=uuid.uuid4(),
+ ),
+ objects=DummyApdbTableData(),
+ sources=DummyApdbTableData(),
+ forced_sources=DummyApdbTableData(),
+ update_records=update_records.records,
+ update=True,
+ )
+
+ # Configure and run the uploader
+ uploader = ChunkUploaderWithoutPubSub(
+ self.ppdb_config,
+ wait_interval=0,
+ exit_on_empty=True,
+ exit_on_error=True,
+ )
+ print(f"Uploader will copy files to {uploader.bucket_name}/{uploader.prefix}")
+ uploader.run()
+
+ # Apply the updates to the target tables
+ updates_manager = UpdatesManager(self.ppdb)
+ updates_manager.apply_updates([update_records.replica_chunk_id])
diff --git a/tests/test_updates_merger.py b/tests/test_updates_merger.py
new file mode 100644
index 00000000..df8c144d
--- /dev/null
+++ b/tests/test_updates_merger.py
@@ -0,0 +1,235 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+import uuid
+
+try:
+ from google.cloud import bigquery
+except ImportError:
+ bigquery = None
+
+from lsst.dax.ppdb.bigquery.updates import (
+ DiaForcedSourceUpdatesMerger,
+ DiaObjectUpdatesMerger,
+ DiaSourceUpdatesMerger,
+ UpdateRecordExpander,
+ UpdatesTable,
+)
+from lsst.dax.ppdb.tests._bigquery import have_valid_google_credentials, json_rows_to_buf
+from lsst.dax.ppdb.tests._updates import _create_test_update_records
+
+
+@unittest.skipIf(not have_valid_google_credentials(), "Missing valid Google credentials")
+class TestUpdatesMerger(unittest.TestCase):
+ """Test UpdatesMerger functionality."""
+
+ def setUp(self):
+ self.client = bigquery.Client()
+ self.dataset_id = f"test_merger_{uuid.uuid4().hex[:8]}"
+ self.project_id = self.client.project
+ self.updates_table_fqn = f"{self.project_id}.{self.dataset_id}.updates"
+ self.target_dataset_fqn = f"{self.project_id}.{self.dataset_id}"
+ dataset = bigquery.Dataset(f"{self.project_id}.{self.dataset_id}")
+ dataset.default_table_expiration_ms = 3600000
+ self.client.create_dataset(dataset)
+
+ def tearDown(self):
+ try:
+ self.client.delete_dataset(self.dataset_id, delete_contents=True, not_found_ok=True)
+ except Exception:
+ pass
+
+ def _create_target_table(self):
+ schema = [
+ bigquery.SchemaField("diaObjectId", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("validityEndMjdTai", "FLOAT", mode="NULLABLE"),
+ bigquery.SchemaField("nDiaSources", "INTEGER", mode="NULLABLE"),
+ ]
+ table_fqn = f"{self.target_dataset_fqn}.DiaObject"
+ table = bigquery.Table(table_fqn, schema=schema)
+ self.client.create_table(table)
+ rows = [
+ {"diaObjectId": 200001, "validityEndMjdTai": None, "nDiaSources": 3},
+ {"diaObjectId": 200002, "validityEndMjdTai": None, "nDiaSources": 7},
+ {"diaObjectId": 200003, "validityEndMjdTai": 59000.0, "nDiaSources": 2},
+ ]
+ buf = json_rows_to_buf(rows)
+ job = self.client.load_table_from_file(
+ buf,
+ table_fqn,
+ job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON),
+ )
+ job.result()
+
+ def test_merge_diaobject(self):
+ self._create_target_table()
+ updates_table = UpdatesTable(self.client, self.updates_table_fqn)
+ updates_table.create()
+ update_records = _create_test_update_records()
+ expanded = UpdateRecordExpander.expand_updates(update_records)
+ updates_table.insert(expanded)
+ dedup_fqn = f"{self.updates_table_fqn}_dedup"
+ updates_table.deduplicate_to(dedup_fqn)
+ table_fqn = f"{self.target_dataset_fqn}.DiaObject"
+ query = f"SELECT * FROM `{table_fqn}` ORDER BY diaObjectId"
+ before = {r.diaObjectId: r for r in self.client.query(query).result()}
+ print("Before merge:", before)
+ merger = DiaObjectUpdatesMerger(self.client)
+ merger.merge(updates_table_fqn=dedup_fqn, target_dataset_fqn=self.target_dataset_fqn)
+ after = {r.diaObjectId: r for r in self.client.query(query).result()}
+ print("After merge:", after)
+ self.assertEqual(after[200001].validityEndMjdTai, 59580.0)
+ self.assertEqual(after[200001].nDiaSources, 5)
+ self.assertIsNone(after[200002].validityEndMjdTai)
+ self.assertEqual(after[200002].nDiaSources, 10)
+ self.assertEqual(after[200003].validityEndMjdTai, before[200003].validityEndMjdTai)
+ self.assertEqual(after[200003].nDiaSources, before[200003].nDiaSources)
+
+ def test_merge_diasource(self):
+ schema = [
+ bigquery.SchemaField("diaSourceId", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("diaObjectId", "INTEGER", mode="NULLABLE"),
+ bigquery.SchemaField("ssObjectId", "INTEGER", mode="NULLABLE"),
+ bigquery.SchemaField("ssObjectReassocTimeMjdTai", "FLOAT", mode="NULLABLE"),
+ bigquery.SchemaField("timeWithdrawnMjdTai", "FLOAT", mode="NULLABLE"),
+ ]
+ table_fqn = f"{self.target_dataset_fqn}.DiaSource"
+ table = bigquery.Table(table_fqn, schema=schema)
+ self.client.create_table(table)
+ rows = [
+ {
+ "diaSourceId": 100001,
+ "diaObjectId": 200001,
+ "ssObjectId": None,
+ "ssObjectReassocTimeMjdTai": None,
+ "timeWithdrawnMjdTai": None,
+ },
+ {
+ "diaSourceId": 100002,
+ "diaObjectId": 200002,
+ "ssObjectId": None,
+ "ssObjectReassocTimeMjdTai": None,
+ "timeWithdrawnMjdTai": None,
+ },
+ {
+ "diaSourceId": 100003,
+ "diaObjectId": 200003,
+ "ssObjectId": None,
+ "ssObjectReassocTimeMjdTai": None,
+ "timeWithdrawnMjdTai": None,
+ },
+ {
+ "diaSourceId": 100004,
+ "diaObjectId": 200004,
+ "ssObjectId": None,
+ "ssObjectReassocTimeMjdTai": None,
+ "timeWithdrawnMjdTai": None,
+ },
+ ]
+ job = self.client.load_table_from_file(
+ json_rows_to_buf(rows),
+ table_fqn,
+ job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON),
+ )
+ job.result()
+
+ updates_table = UpdatesTable(self.client, self.updates_table_fqn)
+ updates_table.create()
+ update_records = _create_test_update_records()
+ expanded = UpdateRecordExpander.expand_updates(update_records)
+ updates_table.insert(expanded)
+ dedup_fqn = f"{self.updates_table_fqn}_dedup"
+ updates_table.deduplicate_to(dedup_fqn)
+
+ query = f"SELECT * FROM `{table_fqn}` ORDER BY diaSourceId"
+ before = {r.diaSourceId: r for r in self.client.query(query).result()}
+ merger = DiaSourceUpdatesMerger(self.client)
+ merger.merge(updates_table_fqn=dedup_fqn, target_dataset_fqn=self.target_dataset_fqn)
+ after = {r.diaSourceId: r for r in self.client.query(query).result()}
+
+ self.assertEqual(after[100001].diaObjectId, 400001)
+ self.assertEqual(after[100002].ssObjectId, 2001)
+ self.assertEqual(after[100002].ssObjectReassocTimeMjdTai, 59580.0)
+ self.assertEqual(after[100003].timeWithdrawnMjdTai, 59580.0)
+ self.assertEqual(after[100004].diaObjectId, before[100004].diaObjectId)
+ self.assertEqual(after[100004].ssObjectId, before[100004].ssObjectId)
+ self.assertEqual(after[100004].timeWithdrawnMjdTai, before[100004].timeWithdrawnMjdTai)
+
+ def test_merge_diaforcedsource(self):
+ schema = [
+ bigquery.SchemaField("diaObjectId", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("visit", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("detector", "INTEGER", mode="REQUIRED"),
+ bigquery.SchemaField("timeWithdrawnMjdTai", "FLOAT", mode="NULLABLE"),
+ ]
+ table_fqn = f"{self.target_dataset_fqn}.DiaForcedSource"
+ table = bigquery.Table(table_fqn, schema=schema)
+ self.client.create_table(table)
+ rows = [
+ {"diaObjectId": 200001, "visit": 12345, "detector": 42, "timeWithdrawnMjdTai": None},
+ {"diaObjectId": 200001, "visit": 12346, "detector": 42, "timeWithdrawnMjdTai": None},
+ ]
+ job = self.client.load_table_from_file(
+ json_rows_to_buf(rows),
+ table_fqn,
+ job_config=bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON),
+ )
+ job.result()
+
+ updates_table = UpdatesTable(self.client, self.updates_table_fqn)
+ updates_table.create()
+ update_records = _create_test_update_records()
+ expanded = UpdateRecordExpander.expand_updates(update_records)
+ updates_table.insert(expanded)
+ dedup_fqn = f"{self.updates_table_fqn}_dedup"
+ updates_table.deduplicate_to(dedup_fqn)
+
+ query = f"SELECT * FROM `{table_fqn}` ORDER BY diaObjectId, visit, detector"
+ before = {(r.diaObjectId, r.visit, r.detector): r for r in self.client.query(query).result()}
+ merger = DiaForcedSourceUpdatesMerger(self.client)
+ merger.merge(updates_table_fqn=dedup_fqn, target_dataset_fqn=self.target_dataset_fqn)
+ after = {(r.diaObjectId, r.visit, r.detector): r for r in self.client.query(query).result()}
+
+ self.assertEqual(after[(200001, 12345, 42)].timeWithdrawnMjdTai, 59580.0)
+ self.assertEqual(
+ after[(200001, 12346, 42)].timeWithdrawnMjdTai,
+ before[(200001, 12346, 42)].timeWithdrawnMjdTai,
+ )
+
+ def test_merge_no_updates(self):
+ self._create_target_table()
+ updates_table = UpdatesTable(self.client, self.updates_table_fqn)
+ updates_table.create()
+ dedup_fqn = f"{self.updates_table_fqn}_dedup"
+ updates_table.deduplicate_to(dedup_fqn)
+ table_fqn = f"{self.target_dataset_fqn}.DiaObject"
+ before = {r.diaObjectId: r for r in self.client.query(f"SELECT * FROM `{table_fqn}`").result()}
+ merger = DiaObjectUpdatesMerger(self.client)
+ merger.merge(updates_table_fqn=dedup_fqn, target_dataset_fqn=self.target_dataset_fqn)
+ after = {r.diaObjectId: r for r in self.client.query(f"SELECT * FROM `{table_fqn}`").result()}
+ for obj_id in before:
+ self.assertEqual(before[obj_id].validityEndMjdTai, after[obj_id].validityEndMjdTai)
+ self.assertEqual(before[obj_id].nDiaSources, after[obj_id].nDiaSources)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_updates_table.py b/tests/test_updates_table.py
new file mode 100644
index 00000000..07fc1202
--- /dev/null
+++ b/tests/test_updates_table.py
@@ -0,0 +1,202 @@
+# This file is part of dax_ppdb.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+import uuid
+
+from google.cloud import bigquery
+
+from lsst.dax.ppdb.bigquery.updates import UpdateRecordExpander, UpdatesTable
+from lsst.dax.ppdb.tests._bigquery import have_valid_google_credentials
+from lsst.dax.ppdb.tests._updates import _create_test_update_records
+
+
+@unittest.skipIf(not have_valid_google_credentials(), "Missing valid Google credentials")
+class TestUpdatesTable(unittest.TestCase):
+ """Test UpdatesTable functionality."""
+
+ def setUp(self) -> None:
+ """Set up test fixtures."""
+ # Create BigQuery client
+ self.client = bigquery.Client()
+
+ # Create unique dataset name for this test run
+ self.dataset_id = f"test_updates_{uuid.uuid4().hex[:8]}"
+ self.project_id = self.client.project
+ self.table_name = "updates"
+ self.table_fqn = f"{self.project_id}.{self.dataset_id}.{self.table_name}"
+
+ # Create the test dataset
+ dataset = bigquery.Dataset(f"{self.project_id}.{self.dataset_id}")
+ # Set a short expiration for cleanup safety (1 hour)
+ dataset.default_table_expiration_ms = 3600000 # 1 hour
+ self.dataset = self.client.create_dataset(dataset)
+
+ # Create UpdatesTable instance
+ self.updates_table = UpdatesTable(self.client, self.table_fqn)
+
+ def tearDown(self) -> None:
+ """Clean up test fixtures."""
+ # Always clean up the test dataset, whether test passed or failed
+ try:
+ self.client.delete_dataset(self.dataset_id, delete_contents=True, not_found_ok=True)
+ except Exception:
+ # If deletion fails, at least the expiration will clean it up
+ pass
+
+ def test_table_fqn_property(self) -> None:
+ """Test the table_fqn property."""
+ self.assertEqual(self.updates_table.table_fqn, self.table_fqn)
+
+ def test_create_table(self) -> None:
+ """Test creating the updates table."""
+ table = self.updates_table.create()
+
+ # Verify table was created successfully
+ self.assertEqual(table.table_id, self.table_name)
+ self.assertEqual(table.dataset_id, self.dataset_id)
+
+ # Verify schema is correct
+ expected_fields = {
+ "table_name": ("STRING", "REQUIRED"),
+ "record_id": ("INTEGER", "REPEATED"),
+ "record_id_hash": ("STRING", "REQUIRED"),
+ "field_name": ("STRING", "REQUIRED"),
+ "value_json": ("JSON", "REQUIRED"),
+ "replica_chunk_id": ("INTEGER", "REQUIRED"),
+ "update_order": ("INTEGER", "NULLABLE"),
+ "update_time_ns": ("INTEGER", "NULLABLE"),
+ }
+
+ actual_fields = {field.name: (field.field_type, field.mode) for field in table.schema}
+ self.assertEqual(actual_fields, expected_fields)
+
+ def test_create_table_already_exists(self) -> None:
+ """Test creating a table that already exists raises an error."""
+ # Create table first time - should succeed
+ self.updates_table.create()
+
+ # Try to create again - should raise Conflict
+ with self.assertRaises(Exception) as cm:
+ self.updates_table.create()
+
+ # Check that it's a conflict-type error
+ self.assertIn("already exists", str(cm.exception).lower())
+
+ def test_insert_records(self) -> None:
+ """Test insertion of expanded records into the table."""
+ # Create the table first
+ self.updates_table.create()
+
+ # Get test update records and expand them
+ update_records = _create_test_update_records()
+ expanded_records = UpdateRecordExpander.expand_updates(update_records)
+
+ # Insert the records
+ job = self.updates_table.insert(expanded_records)
+
+ # Verify the job completed successfully
+ self.assertIsNone(job.errors)
+
+ # Verify records were inserted by querying the table
+ query = f"SELECT COUNT(*) as count FROM `{self.table_fqn}`"
+ result = list(self.client.query(query).result())
+ record_count = result[0].count
+
+ # Should have 10 total expanded records based on the test data
+ # (1 + 2 + 1 + 1 + 2 + 1 from original records + 2 duplicates)
+ self.assertEqual(record_count, 10)
+
+ # Verify some specific data was inserted correctly
+ query = f"""
+ SELECT table_name, record_id, field_name, replica_chunk_id
+ FROM `{self.table_fqn}`
+ """
+ # WHERE table_name = 'DiaForcedSource'
+ results = list(self.client.query(query).result())
+
+ print(results) # Debug print to see what was inserted
+ # Should have one DiaForcedSource record
+ # self.assertEqual(len(results), 1)
+ # row = results[0]
+ # self.assertEqual(row.table_name, "DiaForcedSource")
+ # self.assertEqual(row.record_id, [200001, 12345, 42])
+ # self.assertEqual(row.field_name, "timeWithdrawnMjdTai")
+ # self.assertEqual(row.replica_chunk_id, self.replica_chunk_id)
+
+ def test_insert_empty_records(self) -> None:
+ """Test insertion of empty record list."""
+ # Create the table first
+ self.updates_table.create()
+
+ # Insert empty list
+ job = self.updates_table.insert([])
+
+ # Verify the job completed successfully
+ self.assertIsNone(job.errors)
+
+ # Verify no records were inserted
+ query = f"SELECT COUNT(*) as count FROM `{self.table_fqn}`"
+ result = list(self.client.query(query).result())
+ record_count = result[0].count
+ self.assertEqual(record_count, 0)
+
+ def test_deduplicate_records(self) -> None:
+ """Test deduplication functionality."""
+ # Create the source table
+ self.updates_table.create()
+
+ # Get test records (which now include duplicates) and expand them
+ update_records = _create_test_update_records()
+ expanded_records = UpdateRecordExpander.expand_updates(update_records)
+
+ # Insert all records (including duplicates)
+ self.updates_table.insert(expanded_records)
+
+ # Count original records
+ query = f"SELECT COUNT(*) as count FROM `{self.table_fqn}`"
+ original_count = list(self.client.query(query).result())[0].count
+
+ # Create deduplicated table
+ dedup_table_fqn = f"{self.table_fqn}_dedup"
+ self.updates_table.deduplicate_to(dedup_table_fqn)
+
+ # Count deduplicated records
+ query = f"SELECT COUNT(*) as count FROM `{dedup_table_fqn}`"
+ dedup_count = list(self.client.query(query).result())[0].count
+
+ # Should have fewer records after deduplication
+ self.assertLess(dedup_count, original_count)
+
+ # Verify specific deduplication behavior
+ record_id_hash = UpdatesTable._compute_record_id_hash([100001])
+ query = f"""
+ SELECT value_json
+ FROM `{dedup_table_fqn}`
+ WHERE record_id_hash = '{record_id_hash}' AND field_name = 'diaObjectId'
+ """
+ result = list(self.client.query(query).result())
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0].value_json, 400001) # Should be the later update
+
+
+if __name__ == "__main__":
+ unittest.main()