Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/DM-53664.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for automatic assignment of ``tap:column_index`` to column refs
40 changes: 35 additions & 5 deletions python/felis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,46 @@
help="Generate IDs for all objects that do not have them",
default=True,
)
@click.option(
"--column-ref-index-increment",
type=int,
help="Automatically set 'tap:column_index' on column references, using the specified increment "
"(must be at least 1)",
default=None,
)
@click.pass_context
def cli(ctx: click.Context, log_level: str, log_file: str | None, id_generation: bool) -> None:
def cli(
ctx: click.Context,
log_level: str,
log_file: str | None,
id_generation: bool,
column_ref_index_increment: int | None,
) -> None:
"""Felis command line tools"""
ctx.ensure_object(dict)

# Configure logging (must come first)
if log_file:
logging.basicConfig(filename=log_file, level=log_level)
else:
logging.basicConfig(level=log_level)

# Configure ID generation (flag can only turn it off)
ctx.obj["id_generation"] = id_generation
if ctx.obj["id_generation"]:
logger.info("ID generation is enabled")
else:
logger.info("ID generation is disabled")
if log_file:
logging.basicConfig(filename=log_file, level=log_level)
else:
logging.basicConfig(level=log_level)

# Configure automatic indexing of column references (optional)
if column_ref_index_increment is not None and column_ref_index_increment < 1:
raise click.ClickException("column_ref_index_increment must be at least 1")
ctx.obj["column_ref_index_increment"] = column_ref_index_increment
if ctx.obj["column_ref_index_increment"] is not None:
logger.info(
f"Automatic indexing of column references is enabled with increment "
f"{ctx.obj['column_ref_index_increment']}"
)


@cli.command("create", help="Create database objects from the Felis file")
Expand Down Expand Up @@ -322,6 +349,7 @@ def load_tap_schema(
file,
context={
"id_generation": ctx.obj["id_generation"],
"column_ref_index_increment": ctx.obj["column_ref_index_increment"],
"force_unbounded_arraysize": force_unbounded_arraysize,
},
)
Expand Down Expand Up @@ -471,8 +499,10 @@ def validate(
"check_tap_table_indexes": check_tap_table_indexes,
"check_tap_principal": check_tap_principal,
"id_generation": ctx.obj["id_generation"],
"column_ref_index_increment": ctx.obj["column_ref_index_increment"],
},
)
logger.info(f"Successfully validated {file_name}")
except ValidationError as e:
logger.error(e)
rc = 1
Expand Down
45 changes: 38 additions & 7 deletions python/felis/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,10 +1304,12 @@ def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema:
tables in this schema.
"""
context = info.context
if context is not None and context.get("dereference_resources", False):
dereference_resources = True
else:
dereference_resources = False
column_ref_index_increment: int | None = None
dereference_resources = False
if context is not None:
dereference_resources = context.get("dereference_resources", False)
column_ref_index_increment = context.get("column_ref_index_increment", None)

for table in self.tables:
if column_refs := table.column_refs:
for resource_name, tables in column_refs.items():
Expand All @@ -1319,6 +1321,7 @@ def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema:
tables,
resource_schema,
dereference_resources,
column_ref_index_increment,
)
if dereference_resources and len(table.column_refs) > 0:
# Clear column refs in table if fully dereferencing
Expand All @@ -1334,10 +1337,13 @@ def _process_column_refs(
ref_tables: ResourceTableMap,
resource_schema: Schema,
dereference_resources: bool = False,
column_ref_index_increment: int | None = None,
) -> None:
"""Process column references from an external resource and add them
to the given table.
"""
current_column_index = column_ref_index_increment if column_ref_index_increment is not None else -1

for table_name, columns in ref_tables.items():
try:
resource_table = resource_schema._find_table_by_name(table_name)
Expand Down Expand Up @@ -1375,20 +1381,46 @@ def _process_column_refs(
# Create a copy of the base column and apply
# overrides
column_copy = base_column.model_copy()

# Set the local name (key from the mapping)
column_copy.name = local_column_name

if not dereference_resources:
# Flag the column as a resource reference so it will not be
# written out during serialization
column_copy._is_resource_ref = True

# Apply overrides to the original column definition
overrides: ColumnOverrides | None = None
if column_ref is not None and column_ref.overrides is not None:
overrides = column_ref.overrides
for field_name, override_value in overrides.model_dump().items():
if override_value is not None:
setattr(column_copy, field_name, override_value)

# Manually set the ID of the copied column as
# ID generation has already occurred
# Manually set the ID of the copied column as ID generation has
# already occurred by now
column_copy.id = f"{table.id}.{local_column_name}"

# Apply automatic assignment of 'tap:column_index', if enabled
if column_ref_index_increment is not None:
if (not overrides) or (overrides.tap_column_index is None):
column_copy.tap_column_index = current_column_index
current_column_index += column_ref_index_increment
logger.debug(
f"Automatically assigned 'tap:column_index' {column_copy.tap_column_index} to "
f"column '{local_column_name}' in table '{table_name}' from resource "
f"'{resource_schema.name}'"
)
else:
# Skip automatic assignment of 'tap:column_index' if it
# is already overridden
logger.debug(
f"Skipping automatic assignment of 'tap:column_index' for column "
f"'{local_column_name}' in table '{table_name}' from resource "
f"'{resource_schema.name}' as it is already overridden to "
f"{column_copy.tap_column_index}"
)
table.columns.append(column_copy)
logger.debug(
f"Dereferenced column '{local_column_name}' from table '{table_name}' "
Expand Down Expand Up @@ -1839,7 +1871,6 @@ def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any]
pydantic.ValidationError
Raised if the schema fails validation.
"""
logger.debug(f"Loading schema from: '{resource_path}'")
try:
rp_stream = ResourcePath(resource_path).read()
except Exception as e:
Expand Down
128 changes: 128 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import logging
import os
import shutil
import tempfile
Expand Down Expand Up @@ -47,6 +48,11 @@ def setUp(self) -> None:
self.sqlite_url = f"sqlite:///{self.tmpdir}/db.sqlite3"
print(f"Using temporary directory: {self.tmpdir}")

# Clear any existing logging handlers to ensure fresh configuration for
# each test
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)

def tearDown(self) -> None:
"""Clean up temporary directory."""
shutil.rmtree(self.tmpdir, ignore_errors=True)
Expand Down Expand Up @@ -88,6 +94,15 @@ def test_validate(self) -> None:
"""Test validate command."""
run_cli(["validate", TEST_YAML])

def test_validate_with_log_file(self) -> None:
"""Test validate command with log file."""
log_file = os.path.join(self.tmpdir, "validate.log")
run_cli([f"--log-file={log_file}", "validate", TEST_YAML], log_level=logging.DEBUG, print_cmd=True)
if not os.path.exists(log_file):
self.fail("Log file was not created")
if os.path.getsize(log_file) == 0:
self.fail("Log file is empty")

def test_validate_with_id_generation(self) -> None:
"""Test that loading a schema with IDs works if ID generation is
enabled. This is the default behavior.
Expand Down Expand Up @@ -358,5 +373,118 @@ def test_generate_and_load_sql(self) -> None:
self.fail(f"Test failed with exception: {e}")


class ColumnRefsTestCase(unittest.TestCase):
"""Test handling of column references in CLI."""

def setUp(self) -> None:
"""Set up a temporary directory for tests."""
self.temp_dir = tempfile.mkdtemp(dir=TEST_DIR)
self.sqlite_url = f"sqlite:///{self.temp_dir}/db.sqlite3"

# Write out source schema file
source_schema_content = """
name: source_schema
tables:
- name: source_table
columns:
- name: ref_col1
datatype: int
- name: ref_col2
datatype: string
length: 64
- name: ref_col3
datatype: float
"""
source_schema_path = os.path.join(self.temp_dir, "source_schema.yaml")
with open(source_schema_path, "w") as f:
f.write(source_schema_content.strip())

# Write out referencing schema file
ref_schema_content = """
name: ref_schema
resources:
source_schema:
uri: {resource_path}
tables:
- name: ref_table
columnRefs:
source_schema:
source_table:
ref_col1:
ref_col2:
overrides:
tap:column_index: 15
col3:
ref_name: ref_col3
"""
self.ref_schema_path = os.path.join(self.temp_dir, "ref_schema.yaml")
ref_content = ref_schema_content.format(resource_path=source_schema_path)
with open(self.ref_schema_path, "w") as f:
f.write(ref_content.strip())

def tearDown(self) -> None:
"""Clean up temporary directory."""
shutil.rmtree(self.temp_dir, ignore_errors=True)

def test_validate_with_column_ref_index_increment(self) -> None:
"""Test that passing a valid value for column reference index increment
works.
"""
run_cli(
[
"--column-ref-index-increment=1",
"validate",
self.ref_schema_path,
]
)

def test_validate_with_column_ref_index_increment_error(self) -> None:
"""Test that passing an invalid value for column reference index raises
an error.
"""
run_cli(
[
"--column-ref-index-increment=-1",
"validate",
self.ref_schema_path,
],
expect_error=True,
)

def test_load_tap_schema_with_column_refs(self) -> None:
"""Test load-tap-schema command with column reference index
increment.
"""
# Create the TAP_SCHEMA database
run_cli(["init-tap-schema", f"--engine-url={self.sqlite_url}"])

# Load the TAP_SCHEMA data that includes column references
run_cli(
[
"load-tap-schema",
f"--engine-url={self.sqlite_url}",
self.ref_schema_path,
]
)

def test_load_tap_schema_with_column_ref_index_increment(self) -> None:
"""Test load-tap-schema command with column reference index
increment.
"""
# Create the TAP_SCHEMA database
run_cli(["init-tap-schema", f"--engine-url={self.sqlite_url}"])

# Load the TAP_SCHEMA data that includes column reference index
# increment
run_cli(
[
"--column-ref-index-increment=1",
"load-tap-schema",
f"--engine-url={self.sqlite_url}",
self.ref_schema_path,
]
)


if __name__ == "__main__":
unittest.main()
58 changes: 58 additions & 0 deletions tests/test_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,64 @@ def test_ref_schema_with_dereference_columns(self) -> None:
# Check that column_refs is empty after dereferencing
self.assertEqual(len(ref_table.column_refs), 0)

def test_tap_column_index_with_overrides(self) -> None:
"""Test that TAP column index is correctly assigned when an override
of that field is present in the column ref.
"""
# Write out source schema file
source_schema_content = """
name: source_schema
tables:
- name: source_table
columns:
- name: col1
datatype: int
- name: col2
datatype: int
- name: col3
datatype: int
"""
source_schema_path = os.path.join(self.temp_dir, "source_schema.yaml")
with open(source_schema_path, "w") as f:
f.write(source_schema_content.strip())

# Write out referencing schema file
ref_schema_content = """
name: ref_schema
resources:
source_schema:
uri: {resource_path}
tables:
- name: ref_table
columnRefs:
source_schema:
source_table:
col1:
col2:
overrides:
tap:column_index: 15
col3:
"""
ref_schema_path = os.path.join(self.temp_dir, "ref_schema.yaml")
ref_content = ref_schema_content.format(resource_path=source_schema_path)
with open(ref_schema_path, "w") as f:
f.write(ref_content.strip())

ref_schema = Schema.from_uri(
ref_schema_path,
context={"id_generation": True, "column_ref_index_increment": 10},
)

for column in ref_schema.tables[0].columns:
if column.name == "col1":
self.assertEqual(column.tap_column_index, 10)
elif column.name == "col2":
self.assertEqual(column.tap_column_index, 15)
elif column.name == "col3":
self.assertEqual(column.tap_column_index, 20)
else:
self.fail(f"Unexpected column name: {column.name}")


if __name__ == "__main__":
unittest.main()
Loading
Loading