Skip to content

Commit 1b66735

Browse files
feature: add xsd validation for xml readers
* feat: add xsd validation for xml readers
1 parent 62b573e commit 1b66735

File tree

22 files changed

+379
-433
lines changed

22 files changed

+379
-433
lines changed

.github/workflows/ci_testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- name: Install extra dependencies for a python install
1818
run: |
1919
sudo apt-get update
20-
sudo apt -y install --no-install-recommends liblzma-dev libbz2-dev libreadline-dev
20+
sudo apt -y install --no-install-recommends liblzma-dev libbz2-dev libreadline-dev libxml2-utils
2121
2222
- name: Install asdf cli
2323
uses: asdf-vm/actions/setup@v4

src/dve/core_engine/backends/implementations/duckdb/readers/csv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ class DuckDBCSVReader(BaseFileReader):
3030
# TODO - stringify or not
3131
def __init__(
3232
self,
33+
*,
3334
header: bool = True,
3435
delim: str = ",",
3536
quotechar: str = '"',
3637
connection: Optional[DuckDBPyConnection] = None,
38+
**_,
3739
):
3840
self.header = header
3941
self.delim = delim

src/dve/core_engine/backends/implementations/duckdb/readers/json.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
class DuckDBJSONReader(BaseFileReader):
2121
"""A reader for JSON files"""
2222

23-
def __init__(self, json_format: Optional[str] = "array"):
23+
def __init__(
24+
self,
25+
*,
26+
json_format: Optional[str] = "array",
27+
**_,
28+
):
2429
self._json_format = json_format
2530

2631
super().__init__()

src/dve/core_engine/backends/implementations/duckdb/readers/xml.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import BaseModel
99

1010
from dve.core_engine.backends.base.reader import read_function
11+
from dve.core_engine.backends.exceptions import MessageBearingError
1112
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import duckdb_write_parquet
1213
from dve.core_engine.backends.readers.xml import XMLStreamReader
1314
from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model
@@ -18,13 +19,21 @@
1819
class DuckDBXMLStreamReader(XMLStreamReader):
1920
"""A reader for XML files"""
2021

21-
def __init__(self, ddb_connection: Optional[DuckDBPyConnection] = None, **kwargs):
22+
def __init__(self, *, ddb_connection: Optional[DuckDBPyConnection] = None, **kwargs):
2223
self.ddb_connection = ddb_connection if ddb_connection else default_connection
2324
super().__init__(**kwargs)
2425

2526
@read_function(DuckDBPyRelation)
2627
def read_to_relation(self, resource: URI, entity_name: str, schema: type[BaseModel]):
2728
"""Returns a relation object from the source xml"""
29+
if self.xsd_location:
30+
msg = self._run_xmllint(file_uri=resource)
31+
if msg:
32+
raise MessageBearingError(
33+
"Submitted file failed XSD validation.",
34+
messages=[msg],
35+
)
36+
2837
polars_schema: dict[str, pl.DataType] = { # type: ignore
2938
fld.name: get_polars_type_from_annotation(fld.annotation)
3039
for fld in stringify_model(schema).__fields__.values()

src/dve/core_engine/backends/implementations/spark/readers/csv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
multi_line: bool = False,
3232
encoding: str = "utf-8-sig",
3333
spark_session: Optional[SparkSession] = None,
34+
**_,
3435
) -> None:
3536

3637
self.delimiter = delimiter

src/dve/core_engine/backends/implementations/spark/readers/json.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
encoding: Optional[str] = "utf-8",
2828
multi_line: Optional[bool] = True,
2929
spark_session: Optional[SparkSession] = None,
30+
**_,
3031
) -> None:
3132

3233
self.encoding = encoding

src/dve/core_engine/backends/implementations/spark/readers/xml.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55
from typing import Any, Optional
66

77
from pydantic import BaseModel
8+
from pyspark.errors.exceptions.base import AnalysisException
89
from pyspark.sql import DataFrame, SparkSession
910
from pyspark.sql import functions as sf
1011
from pyspark.sql.column import Column
1112
from pyspark.sql.types import StringType, StructField, StructType
12-
from pyspark.sql.utils import AnalysisException
1313
from typing_extensions import Literal
1414

15-
from dve.core_engine.backends.base.reader import BaseFileReader, read_function
16-
from dve.core_engine.backends.exceptions import EmptyFileError
15+
from dve.core_engine.backends.base.reader import read_function
16+
from dve.core_engine.backends.exceptions import EmptyFileError, MessageBearingError
1717
from dve.core_engine.backends.implementations.spark.spark_helpers import (
1818
df_is_empty,
1919
get_type_from_annotation,
2020
spark_write_parquet,
2121
)
22-
from dve.core_engine.backends.readers.xml import XMLStreamReader
22+
from dve.core_engine.backends.readers.xml import BasicXMLFileReader, XMLStreamReader
2323
from dve.core_engine.type_hints import URI, EntityName
2424
from dve.parser.file_handling import get_content_length
2525
from dve.parser.file_handling.service import open_stream
@@ -43,7 +43,7 @@ def read_to_dataframe(
4343
) -> DataFrame:
4444
"""Stream an XML file into a Spark data frame"""
4545
if not self.spark:
46-
self.spark = SparkSession.builder.getOrCreate()
46+
self.spark = SparkSession.builder.getOrCreate() # type: ignore
4747
spark_schema = get_type_from_annotation(schema)
4848
return self.spark.createDataFrame( # type: ignore
4949
list(self.read_to_py_iterator(resource, entity_name, schema)),
@@ -52,7 +52,7 @@ def read_to_dataframe(
5252

5353

5454
@spark_write_parquet
55-
class SparkXMLReader(BaseFileReader): # pylint: disable=too-many-instance-attributes
55+
class SparkXMLReader(BasicXMLFileReader): # pylint: disable=too-many-instance-attributes
5656
"""A reader for XML files built atop Spark-XML."""
5757

5858
def __init__(
@@ -70,21 +70,33 @@ def __init__(
7070
sanitise_multiline: bool = True,
7171
namespace=None,
7272
trim_cells=True,
73+
xsd_location: Optional[URI] = None,
74+
xsd_error_code: Optional[str] = None,
75+
xsd_error_message: Optional[str] = None,
76+
rules_location: Optional[URI] = None,
7377
**_,
7478
) -> None:
75-
self.record_tag = record_tag
76-
self.spark_session = spark_session or SparkSession.builder.getOrCreate()
79+
80+
super().__init__(
81+
record_tag=record_tag,
82+
root_tag=root_tag,
83+
trim_cells=trim_cells,
84+
null_values=null_values,
85+
sanitise_multiline=sanitise_multiline,
86+
xsd_location=xsd_location,
87+
xsd_error_code=xsd_error_code,
88+
xsd_error_message=xsd_error_message,
89+
rules_location=rules_location,
90+
)
91+
92+
self.spark_session = spark_session or SparkSession.builder.getOrCreate() # type: ignore
7793
self.sampling_ratio = sampling_ratio
7894
self.exclude_attribute = exclude_attribute
7995
self.mode = mode
8096
self.infer_schema = infer_schema
8197
self.ignore_namespace = ignore_namespace
82-
self.root_tag = root_tag
8398
self.sanitise_multiline = sanitise_multiline
84-
self.null_values = null_values
8599
self.namespace = namespace
86-
self.trim_cells = trim_cells
87-
super().__init__()
88100

89101
def read_to_py_iterator(
90102
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
@@ -106,6 +118,14 @@ def read_to_dataframe(
106118
if get_content_length(resource) == 0:
107119
raise EmptyFileError(f"File at {resource} is empty.")
108120

121+
if self.xsd_location:
122+
msg = self._run_xmllint(file_uri=resource)
123+
if msg:
124+
raise MessageBearingError(
125+
"Submitted file failed XSD validation.",
126+
messages=[msg],
127+
)
128+
109129
spark_schema: StructType = get_type_from_annotation(schema)
110130
kwargs = {
111131
"rowTag": self.record_tag,
@@ -143,7 +163,7 @@ def read_to_dataframe(
143163
kwargs["rowTag"] = f"{namespace}:{self.record_tag}"
144164
df = (
145165
self.spark_session.read.format("xml")
146-
.options(**kwargs)
166+
.options(**kwargs) # type: ignore
147167
.load(resource, schema=read_schema)
148168
)
149169
if self.root_tag and df.columns:

src/dve/core_engine/backends/readers/csv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
trim_cells: bool = True,
3737
null_values: Collection[str] = frozenset({"NULL", "null", ""}),
3838
encoding: str = "utf-8-sig",
39+
**_,
3940
):
4041
"""Init function for the base CSV reader.
4142

src/dve/core_engine/backends/readers/xml.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
from dve.core_engine.backends.base.reader import BaseFileReader
1414
from dve.core_engine.backends.exceptions import EmptyFileError
15+
from dve.core_engine.backends.readers.xml_linting import run_xmllint
1516
from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model
1617
from dve.core_engine.loggers import get_logger
18+
from dve.core_engine.message import FeedbackMessage
1719
from dve.core_engine.type_hints import URI, EntityName
1820
from dve.parser.file_handling import NonClosingTextIOWrapper, get_content_length, open_stream
1921
from dve.parser.file_handling.implementations.file import (
@@ -101,7 +103,7 @@ def clear(self) -> None:
101103
def __iter__(self) -> Iterator["XMLElement"]: ...
102104

103105

104-
class BasicXMLFileReader(BaseFileReader):
106+
class BasicXMLFileReader(BaseFileReader): # pylint: disable=R0902
105107
"""A reader for XML files built atop LXML."""
106108

107109
def __init__(
@@ -114,6 +116,10 @@ def __init__(
114116
sanitise_multiline: bool = True,
115117
encoding: str = "utf-8-sig",
116118
n_records_to_read: Optional[int] = None,
119+
xsd_location: Optional[URI] = None,
120+
xsd_error_code: Optional[str] = None,
121+
xsd_error_message: Optional[str] = None,
122+
rules_location: Optional[URI] = None,
117123
**_,
118124
):
119125
"""Init function for the base XML reader.
@@ -148,6 +154,15 @@ def __init__(
148154
"""Encoding of the XML file."""
149155
self.n_records_to_read = n_records_to_read
150156
"""The maximum number of records to read from a document."""
157+
if rules_location is not None and xsd_location is not None:
158+
self.xsd_location = rules_location + xsd_location
159+
else:
160+
self.xsd_location = xsd_location # type: ignore
161+
"""The URI of the xsd file if wishing to perform xsd validation."""
162+
self.xsd_error_code = xsd_error_code
163+
"""The error code to be reported if xsd validation fails (if xsd)"""
164+
self.xsd_error_message = xsd_error_message
165+
"""The error message to be reported if xsd validation fails"""
151166
super().__init__()
152167
self._logger = get_logger(__name__)
153168

@@ -259,6 +274,22 @@ def _parse_xml(
259274
for element in elements:
260275
yield self._parse_element(element, template_row)
261276

277+
def _run_xmllint(self, file_uri: URI) -> FeedbackMessage | None:
278+
"""Run xmllint package to validate against a given xsd. Requires xmlint to be installed
279+
onto the system to run succesfully."""
280+
if self.xsd_location is None:
281+
raise AttributeError("Trying to run XML lint with no `xsd_location` provided.")
282+
if self.xsd_error_code is None:
283+
raise AttributeError("Trying to run XML with no `xsd_error_code` provided.")
284+
if self.xsd_error_message is None:
285+
raise AttributeError("Trying to run XML with no `xsd_error_message` provided.")
286+
return run_xmllint(
287+
file_uri=file_uri,
288+
schema_uri=self.xsd_location,
289+
error_code=self.xsd_error_code,
290+
error_message=self.xsd_error_message,
291+
)
292+
262293
def read_to_py_iterator(
263294
self,
264295
resource: URI,

0 commit comments

Comments
 (0)