diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index d7eac2a4d..85682c66f 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -58,6 +58,8 @@ from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable from google.api_core.exceptions import Aborted +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper import google.auth.credentials import google.auth._default @@ -657,6 +659,7 @@ async def execute_query( DeadlineExceeded, ServiceUnavailable, ), + column_info: dict[str, Message | EnumTypeWrapper] | None = None, ) -> "ExecuteQueryIteratorAsync": """ Executes an SQL query on an instance. @@ -705,6 +708,62 @@ async def execute_query( If None, defaults to prepare_operation_timeout. prepare_retryable_errors: a list of errors that will be retried if encountered during prepareQuery. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + column_info: (Optional) A dictionary mapping column names to Protobuf message classes or EnumTypeWrapper objects. + This dictionary provides the necessary type information for deserializing PROTO and + ENUM column values from the query results. When an entry is provided + for a PROTO or ENUM column, the client library will attempt to deserialize the raw data. + + - For PROTO columns: The value in the dictionary should be the + Protobuf Message class (e.g., ``my_pb2.MyMessage``). + - For ENUM columns: The value should be the Protobuf EnumTypeWrapper + object (e.g., ``my_pb2.MyEnum``). + + Example:: + + import my_pb2 + + column_info = { + "my_proto_column": my_pb2.MyMessage, + "my_enum_column": my_pb2.MyEnum + } + + If ``column_info`` is not provided, or if a specific column name is not found + in the dictionary: + + - PROTO columns will be returned as raw bytes. + - ENUM columns will be returned as integers. + + Note for Nested PROTO or ENUM Fields: + + To specify types for PROTO or ENUM fields within STRUCTs or MAPs, use a dot-separated + path from the top-level column name. + + - For STRUCTs: ``struct_column_name.field_name`` + - For MAPs: ``map_column_name.key`` or ``map_column_name.value`` to specify types + for the map keys or values, respectively. + + Example:: + + import my_pb2 + + column_info = { + # Top-level column + "my_proto_column": my_pb2.MyMessage, + "my_enum_column": my_pb2.MyEnum, + + # Nested field in a STRUCT column named 'my_struct' + "my_struct.nested_proto_field": my_pb2.OtherMessage, + "my_struct.nested_enum_field": my_pb2.AnotherEnum, + + # Nested field in a MAP column named 'my_map' + "my_map.key": my_pb2.MapKeyEnum, # If map keys were enums + "my_map.value": my_pb2.MapValueMessage, + + # PROTO field inside a STRUCT, where the STRUCT is the value in a MAP column + "struct_map.value.nested_proto_field": my_pb2.DeeplyNestedProto, + "struct_map.value.nested_enum_field": my_pb2.DeeplyNestedEnum + } + Returns: ExecuteQueryIteratorAsync: an asynchronous iterator that yields rows returned by the query Raises: @@ -714,6 +773,7 @@ async def execute_query( google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error google.cloud.bigtable.data.exceptions.ParameterTypeInferenceFailed: Raised if a parameter is passed without an explicit type, and the type cannot be infered + google.protobuf.message.DecodeError: raised if the deserialization of a PROTO/ENUM value fails. """ instance_name = self._gapic_client.instance_path(self.project, instance_id) converted_param_types = _to_param_types(parameters, parameter_types) @@ -771,6 +831,7 @@ async def execute_query( attempt_timeout, operation_timeout, retryable_excs=retryable_excs, + column_info=column_info, ) @CrossSync.convert(sync_name="__enter__") diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index a7e07e20d..37c647028 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -49,6 +49,8 @@ from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable from google.api_core.exceptions import Aborted +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper import google.auth.credentials import google.auth._default from google.api_core import client_options as client_options_lib @@ -485,6 +487,7 @@ def execute_query( DeadlineExceeded, ServiceUnavailable, ), + column_info: dict[str, Message | EnumTypeWrapper] | None = None, ) -> "ExecuteQueryIterator": """Executes an SQL query on an instance. Returns an iterator to asynchronously stream back columns from selected rows. @@ -532,6 +535,62 @@ def execute_query( If None, defaults to prepare_operation_timeout. prepare_retryable_errors: a list of errors that will be retried if encountered during prepareQuery. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + column_info: (Optional) A dictionary mapping column names to Protobuf message classes or EnumTypeWrapper objects. + This dictionary provides the necessary type information for deserializing PROTO and + ENUM column values from the query results. When an entry is provided + for a PROTO or ENUM column, the client library will attempt to deserialize the raw data. + + - For PROTO columns: The value in the dictionary should be the + Protobuf Message class (e.g., ``my_pb2.MyMessage``). + - For ENUM columns: The value should be the Protobuf EnumTypeWrapper + object (e.g., ``my_pb2.MyEnum``). + + Example:: + + import my_pb2 + + column_info = { + "my_proto_column": my_pb2.MyMessage, + "my_enum_column": my_pb2.MyEnum + } + + If ``column_info`` is not provided, or if a specific column name is not found + in the dictionary: + + - PROTO columns will be returned as raw bytes. + - ENUM columns will be returned as integers. + + Note for Nested PROTO or ENUM Fields: + + To specify types for PROTO or ENUM fields within STRUCTs or MAPs, use a dot-separated + path from the top-level column name. + + - For STRUCTs: ``struct_column_name.field_name`` + - For MAPs: ``map_column_name.key`` or ``map_column_name.value`` to specify types + for the map keys or values, respectively. + + Example:: + + import my_pb2 + + column_info = { + # Top-level column + "my_proto_column": my_pb2.MyMessage, + "my_enum_column": my_pb2.MyEnum, + + # Nested field in a STRUCT column named 'my_struct' + "my_struct.nested_proto_field": my_pb2.OtherMessage, + "my_struct.nested_enum_field": my_pb2.AnotherEnum, + + # Nested field in a MAP column named 'my_map' + "my_map.key": my_pb2.MapKeyEnum, # If map keys were enums + "my_map.value": my_pb2.MapValueMessage, + + # PROTO field inside a STRUCT, where the STRUCT is the value in a MAP column + "struct_map.value.nested_proto_field": my_pb2.DeeplyNestedProto, + "struct_map.value.nested_enum_field": my_pb2.DeeplyNestedEnum + } + Returns: ExecuteQueryIterator: an asynchronous iterator that yields rows returned by the query Raises: @@ -541,6 +600,7 @@ def execute_query( google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error google.cloud.bigtable.data.exceptions.ParameterTypeInferenceFailed: Raised if a parameter is passed without an explicit type, and the type cannot be infered + google.protobuf.message.DecodeError: raised if the deserialization of a PROTO/ENUM value fails. """ instance_name = self._gapic_client.instance_path(self.project, instance_id) converted_param_types = _to_param_types(parameters, parameter_types) @@ -592,6 +652,7 @@ def execute_query( attempt_timeout, operation_timeout, retryable_excs=retryable_excs, + column_info=column_info, ) def __enter__(self): diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index d3ca890b4..bb105179e 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -23,6 +23,8 @@ TYPE_CHECKING, ) from google.api_core import retry as retries +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor from google.cloud.bigtable.data._helpers import ( @@ -87,6 +89,7 @@ def __init__( operation_timeout: float, req_metadata: Sequence[Tuple[str, str]] = (), retryable_excs: Sequence[type[Exception]] = (), + column_info: dict[str, Message | EnumTypeWrapper] | None = None, ) -> None: """ Collects responses from ExecuteQuery requests and parses them into QueryResultRows. @@ -107,6 +110,8 @@ def __init__( Failed requests will be retried within the budget req_metadata: metadata used while sending the gRPC request retryable_excs: a list of errors that will be retried if encountered. + column_info: dict with mappings between column names and additional column information + for protobuf deserialization. Raises: {NO_LOOP} :class:`ValueError ` as a safeguard if data is processed in an unexpected state @@ -135,6 +140,7 @@ def __init__( exception_factory=_retry_exception_factory, ) self._req_metadata = req_metadata + self._column_info = column_info try: self._register_instance_task = CrossSync.create_task( self._client._register_instance, @@ -202,7 +208,9 @@ async def _next_impl(self) -> CrossSync.Iterator[QueryResultRow]: raise ValueError( "Error parsing response before finalizing metadata" ) - results = self._reader.consume(batches_to_parse, self.metadata) + results = self._reader.consume( + batches_to_parse, self.metadata, self._column_info + ) if results is None: continue diff --git a/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py b/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py index 4cb5db291..a43539e55 100644 --- a/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py +++ b/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations -from typing import Any, Callable, Dict, Type +from typing import Any, Callable, Dict, Type, Optional, Union + +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper from google.cloud.bigtable.data.execute_query.values import Struct from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable_v2 import Value as PBValue @@ -30,24 +34,36 @@ SqlType.Struct: "array_value", SqlType.Array: "array_value", SqlType.Map: "array_value", + SqlType.Proto: "bytes_value", + SqlType.Enum: "int_value", } -def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> Any: +def _parse_array_type( + value: PBValue, + metadata_type: SqlType.Array, + column_name: str | None, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, +) -> list[Any]: """ used for parsing an array represented as a protobuf to a python list. """ return list( map( lambda val: _parse_pb_value_to_python_value( - val, metadata_type.element_type + val, metadata_type.element_type, column_name, column_info ), value.array_value.values, ) ) -def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any: +def _parse_map_type( + value: PBValue, + metadata_type: SqlType.Map, + column_name: str | None, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, +) -> dict[Any, Any]: """ used for parsing a map represented as a protobuf to a python dict. @@ -64,10 +80,16 @@ def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any: map( lambda map_entry: ( _parse_pb_value_to_python_value( - map_entry.array_value.values[0], metadata_type.key_type + map_entry.array_value.values[0], + metadata_type.key_type, + f"{column_name}.key" if column_name is not None else None, + column_info, ), _parse_pb_value_to_python_value( - map_entry.array_value.values[1], metadata_type.value_type + map_entry.array_value.values[1], + metadata_type.value_type, + f"{column_name}.value" if column_name is not None else None, + column_info, ), ), value.array_value.values, @@ -77,7 +99,12 @@ def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any: raise ValueError("Invalid map entry - less or more than two values.") -def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct: +def _parse_struct_type( + value: PBValue, + metadata_type: SqlType.Struct, + column_name: str | None, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, +) -> Struct: """ used for parsing a struct represented as a protobuf to a google.cloud.bigtable.data.execute_query.Struct @@ -88,13 +115,27 @@ def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct: struct = Struct() for value, field in zip(value.array_value.values, metadata_type.fields): field_name, field_type = field - struct.add_field(field_name, _parse_pb_value_to_python_value(value, field_type)) + nested_column_name: str | None + if column_name and field_name: + # qualify the column name for nested lookups + nested_column_name = f"{column_name}.{field_name}" + else: + nested_column_name = None + struct.add_field( + field_name, + _parse_pb_value_to_python_value( + value, field_type, nested_column_name, column_info + ), + ) return struct def _parse_timestamp_type( - value: PBValue, metadata_type: SqlType.Timestamp + value: PBValue, + metadata_type: SqlType.Timestamp, + column_name: str | None, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, ) -> DatetimeWithNanoseconds: """ used for parsing a timestamp represented as a protobuf to DatetimeWithNanoseconds @@ -102,15 +143,105 @@ def _parse_timestamp_type( return DatetimeWithNanoseconds.from_timestamp_pb(value.timestamp_value) -_TYPE_PARSERS: Dict[Type[SqlType.Type], Callable[[PBValue, Any], Any]] = { +def _parse_proto_type( + value: PBValue, + metadata_type: SqlType.Proto, + column_name: str | None, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, +) -> Message | bytes: + """ + Parses a serialized protobuf message into a Message object using type information + provided in column_info. + + Args: + value: The value to parse, expected to have a bytes_value attribute. + metadata_type: The expected SQL type (Proto). + column_name: The name of the column. + column_info: (Optional) A dictionary mapping column names to their + corresponding Protobuf Message classes. This information is used + to deserialize the raw bytes. + + Returns: + A deserialized Protobuf Message object if parsing is successful. + If the required type information is not found in column_info, the function + returns the original serialized data as bytes (value.bytes_value). + This fallback ensures that the raw data is still accessible. + + Raises: + google.protobuf.message.DecodeError: If `value.bytes_value` cannot be + parsed as the Message type specified in `column_info`. + """ + if ( + column_name is not None + and column_info is not None + and column_info.get(column_name) is not None + ): + default_proto_message = column_info.get(column_name) + if isinstance(default_proto_message, Message): + proto_message = type(default_proto_message)() + proto_message.ParseFromString(value.bytes_value) + return proto_message + return value.bytes_value + + +def _parse_enum_type( + value: PBValue, + metadata_type: SqlType.Enum, + column_name: str | None, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, +) -> int | str: + """ + Parses an integer value into a Protobuf enum name string using type information + provided in column_info. + + Args: + value: The value to parse, expected to have an int_value attribute. + metadata_type: The expected SQL type (Enum). + column_name: The name of the column. + column_info: (Optional) A dictionary mapping column names to their + corresponding Protobuf EnumTypeWrapper objects. This information + is used to convert the integer to an enum name. + + Returns: + A string representing the name of the enum value if conversion is successful. + If conversion fails for any reason, such as the required EnumTypeWrapper + not being found in column_info, or if an error occurs during the name lookup + (e.g., the integer is not a valid enum value), the function returns the + original integer value (value.int_value). This fallback ensures the + raw integer representation is still accessible. + """ + if ( + column_name is not None + and column_info is not None + and column_info.get(column_name) is not None + ): + proto_enum = column_info.get(column_name) + if isinstance(proto_enum, EnumTypeWrapper): + return proto_enum.Name(value.int_value) + return value.int_value + + +ParserCallable = Callable[ + [PBValue, Any, Optional[str], Optional[Dict[str, Union[Message, EnumTypeWrapper]]]], + Any, +] + +_TYPE_PARSERS: Dict[Type[SqlType.Type], ParserCallable] = { SqlType.Timestamp: _parse_timestamp_type, SqlType.Struct: _parse_struct_type, SqlType.Array: _parse_array_type, SqlType.Map: _parse_map_type, + SqlType.Proto: _parse_proto_type, + SqlType.Enum: _parse_enum_type, } -def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type) -> Any: +def _parse_pb_value_to_python_value( + value: PBValue, + metadata_type: SqlType.Type, + column_name: str | None, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, +) -> Any: """ used for converting the value represented as a protobufs to a python object. """ @@ -126,7 +257,7 @@ def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type) if kind in _TYPE_PARSERS: parser = _TYPE_PARSERS[kind] - return parser(value, metadata_type) + return parser(value, metadata_type, column_name, column_info) elif kind in _REQUIRED_PROTO_FIELDS: field_name = _REQUIRED_PROTO_FIELDS[kind] return getattr(value, field_name) diff --git a/google/cloud/bigtable/data/execute_query/_reader.py b/google/cloud/bigtable/data/execute_query/_reader.py index d9507fe35..467c2030f 100644 --- a/google/cloud/bigtable/data/execute_query/_reader.py +++ b/google/cloud/bigtable/data/execute_query/_reader.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import ( List, @@ -21,6 +22,8 @@ Sequence, ) from abc import ABC, abstractmethod +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper from google.cloud.bigtable_v2 import ProtoRows, Value as PBValue @@ -54,7 +57,10 @@ class _Reader(ABC, Generic[T]): @abstractmethod def consume( - self, batches_to_consume: List[bytes], metadata: Metadata + self, + batches_to_consume: List[bytes], + metadata: Metadata, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, ) -> Optional[Iterable[T]]: """This method receives a list of batches of bytes to be parsed as ProtoRows messages. It then uses the metadata to group the values in the parsed messages into rows. Returns @@ -64,6 +70,8 @@ def consume( :meth:`google.cloud.bigtable.byte_cursor._ByteCursor.consume` method. metadata: metadata used to transform values to rows + column_info: (Optional) dict with mappings between column names and additional column information + for protobuf deserialization. Returns: Iterable[T] or None: Iterable if gathered values can form one or more instances of T, @@ -89,7 +97,10 @@ def _parse_proto_rows(self, bytes_to_parse: bytes) -> Iterable[PBValue]: return proto_rows.values def _construct_query_result_row( - self, values: Sequence[PBValue], metadata: Metadata + self, + values: Sequence[PBValue], + metadata: Metadata, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, ) -> QueryResultRow: result = QueryResultRow() columns = metadata.columns @@ -99,12 +110,17 @@ def _construct_query_result_row( ), "This function should be called only when count of values matches count of columns." for column, value in zip(columns, values): - parsed_value = _parse_pb_value_to_python_value(value, column.column_type) + parsed_value = _parse_pb_value_to_python_value( + value, column.column_type, column.column_name, column_info + ) result.add_field(column.column_name, parsed_value) return result def consume( - self, batches_to_consume: List[bytes], metadata: Metadata + self, + batches_to_consume: List[bytes], + metadata: Metadata, + column_info: dict[str, Message | EnumTypeWrapper] | None = None, ) -> Optional[Iterable[QueryResultRow]]: num_columns = len(metadata.columns) rows = [] @@ -112,7 +128,11 @@ def consume( values = self._parse_proto_rows(batch_bytes) for row_data in batched(values, n=num_columns): if len(row_data) == num_columns: - rows.append(self._construct_query_result_row(row_data, metadata)) + rows.append( + self._construct_query_result_row( + row_data, metadata, column_info + ) + ) else: raise ValueError( "Unexpected error, recieved bad number of values. " diff --git a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py index 9c2d1c6d8..4eaeef1fa 100644 --- a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py @@ -18,6 +18,8 @@ from __future__ import annotations from typing import Any, Dict, Optional, Sequence, Tuple, TYPE_CHECKING from google.api_core import retry as retries +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor from google.cloud.bigtable.data._helpers import ( _attempt_timeout_generator, @@ -63,6 +65,7 @@ def __init__( operation_timeout: float, req_metadata: Sequence[Tuple[str, str]] = (), retryable_excs: Sequence[type[Exception]] = (), + column_info: dict[str, Message | EnumTypeWrapper] | None = None, ) -> None: """Collects responses from ExecuteQuery requests and parses them into QueryResultRows. @@ -82,6 +85,8 @@ def __init__( Failed requests will be retried within the budget req_metadata: metadata used while sending the gRPC request retryable_excs: a list of errors that will be retried if encountered. + column_info: dict with mappings between column names and additional column information + for protobuf deserialization. Raises: None :class:`ValueError ` as a safeguard if data is processed in an unexpected state @@ -110,6 +115,7 @@ def __init__( exception_factory=_retry_exception_factory, ) self._req_metadata = req_metadata + self._column_info = column_info try: self._register_instance_task = CrossSync._Sync_Impl.create_task( self._client._register_instance, @@ -164,7 +170,9 @@ def _next_impl(self) -> CrossSync._Sync_Impl.Iterator[QueryResultRow]: raise ValueError( "Error parsing response before finalizing metadata" ) - results = self._reader.consume(batches_to_parse, self.metadata) + results = self._reader.consume( + batches_to_parse, self.metadata, self._column_info + ) if results is None: continue except ValueError as e: diff --git a/google/cloud/bigtable/data/execute_query/metadata.py b/google/cloud/bigtable/data/execute_query/metadata.py index 2fd66947d..74b6cb836 100644 --- a/google/cloud/bigtable/data/execute_query/metadata.py +++ b/google/cloud/bigtable/data/execute_query/metadata.py @@ -296,6 +296,28 @@ def _to_value_pb_dict(self, value: Any) -> Dict[str, Any]: ) } + class Proto(Type): + """Proto SQL type.""" + + type_field_name = "proto_type" + + def _to_value_pb_dict(self, value: Any): + raise NotImplementedError("Proto is not supported as a query parameter") + + def _to_type_pb_dict(self) -> Dict[str, Any]: + raise NotImplementedError("Proto is not supported as a query parameter") + + class Enum(Type): + """Enum SQL type.""" + + type_field_name = "enum_type" + + def _to_value_pb_dict(self, value: Any): + raise NotImplementedError("Enum is not supported as a query parameter") + + def _to_type_pb_dict(self) -> Dict[str, Any]: + raise NotImplementedError("Enum is not supported as a query parameter") + class Metadata: """ @@ -388,6 +410,8 @@ def _pb_metadata_to_metadata_types( "bool_type": SqlType.Bool, "timestamp_type": SqlType.Timestamp, "date_type": SqlType.Date, + "proto_type": SqlType.Proto, + "enum_type": SqlType.Enum, "struct_type": SqlType.Struct, "array_type": SqlType.Array, "map_type": SqlType.Map, diff --git a/samples/testdata/README.md b/samples/testdata/README.md new file mode 100644 index 000000000..57520179f --- /dev/null +++ b/samples/testdata/README.md @@ -0,0 +1,5 @@ +#### To generate singer_pb2.py and descriptors.pb file from singer.proto using `protoc` +```shell +cd samples +protoc --proto_path=testdata/ --include_imports --descriptor_set_out=testdata/descriptors.pb --python_out=testdata/ testdata/singer.proto +``` \ No newline at end of file diff --git a/samples/testdata/descriptors.pb b/samples/testdata/descriptors.pb new file mode 100644 index 000000000..bddf04de3 Binary files /dev/null and b/samples/testdata/descriptors.pb differ diff --git a/samples/testdata/singer.proto b/samples/testdata/singer.proto new file mode 100644 index 000000000..d60e0dfb3 --- /dev/null +++ b/samples/testdata/singer.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package examples.bigtable.music; + +enum Genre { + POP = 0; + JAZZ = 1; + FOLK = 2; + ROCK = 3; +} + +message Singer { + string name = 1; + Genre genre = 2; +} diff --git a/samples/testdata/singer_pb2.py b/samples/testdata/singer_pb2.py new file mode 100644 index 000000000..d2a328df0 --- /dev/null +++ b/samples/testdata/singer_pb2.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: singer.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0csinger.proto\x12\x17\x65xamples.bigtable.music\"E\n\x06Singer\x12\x0c\n\x04name\x18\x01 \x01(\t\x12-\n\x05genre\x18\x02 \x01(\x0e\x32\x1e.examples.bigtable.music.Genre*.\n\x05Genre\x12\x07\n\x03POP\x10\x00\x12\x08\n\x04JAZZ\x10\x01\x12\x08\n\x04\x46OLK\x10\x02\x12\x08\n\x04ROCK\x10\x03\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'singer_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _GENRE._serialized_start=112 + _GENRE._serialized_end=158 + _SINGER._serialized_start=41 + _SINGER._serialized_end=110 +# @@protoc_insertion_point(module_scope) diff --git a/tests/system/admin_overlay/test_system_autogen.py b/tests/system/admin_overlay/test_system_autogen.py index 21e4aff3c..0b21b1a24 100644 --- a/tests/system/admin_overlay/test_system_autogen.py +++ b/tests/system/admin_overlay/test_system_autogen.py @@ -215,7 +215,7 @@ def test_optimize_restored_table( second_instance_storage_type, expect_optimize_operation, ): - instance_with_backup, table_to_backup = create_instance( + (instance_with_backup, table_to_backup) = create_instance( instance_admin_client, table_admin_client, data_client, @@ -223,7 +223,7 @@ def test_optimize_restored_table( instances_to_delete, admin_v2.StorageType.HDD, ) - instance_to_restore, _ = create_instance( + (instance_to_restore, _) = create_instance( instance_admin_client, table_admin_client, data_client, @@ -273,7 +273,7 @@ def test_wait_for_consistency( instances_to_delete, admin_overlay_project_id, ): - instance, table = create_instance( + (instance, table) = create_instance( instance_admin_client, table_admin_client, data_client, diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 693b8d966..46e9c2215 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -249,7 +249,7 @@ def test_mutation_set_cell(self, target, temp_rows): """Ensure cells can be set properly""" row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) target.mutate_row(row_key, mutation) @@ -303,7 +303,7 @@ def test_bulk_mutations_set_cell(self, client, target, temp_rows): from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -338,11 +338,11 @@ def test_mutations_batcher_context_manager(self, client, target, temp_rows): """test batcher with context manager. Should flush on exit""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( target, temp_rows, new_value=new_value2 ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -363,7 +363,7 @@ def test_mutations_batcher_timer_flush(self, client, target, temp_rows): from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -385,12 +385,12 @@ def test_mutations_batcher_count_flush(self, client, target, temp_rows): """batch should flush after flush_limit_mutation_count mutations""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( target, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) @@ -417,12 +417,12 @@ def test_mutations_batcher_bytes_flush(self, client, target, temp_rows): """batch should flush after flush_limit_bytes bytes""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( target, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) @@ -448,11 +448,11 @@ def test_mutations_batcher_no_flush(self, client, target, temp_rows): new_value = uuid.uuid4().hex.encode() start_value = b"unchanged" - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( target, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) diff --git a/tests/unit/data/execute_query/sql_helpers.py b/tests/unit/data/execute_query/sql_helpers.py index 5d5569dba..119bb2d50 100644 --- a/tests/unit/data/execute_query/sql_helpers.py +++ b/tests/unit/data/execute_query/sql_helpers.py @@ -204,6 +204,18 @@ def date_type() -> Type: return t +def proto_type() -> Type: + t = Type() + t.proto_type = {} + return t + + +def enum_type() -> Type: + t = Type() + t.enum_type = {} + return t + + def array_type(elem_type: Type) -> Type: t = Type() arr_type = Type.Array() diff --git a/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py b/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py index ee0322272..0a1be1423 100644 --- a/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py +++ b/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py @@ -25,6 +25,7 @@ from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable.data.execute_query.values import Struct from google.protobuf import timestamp_pb2 +from samples.testdata import singer_pb2 timestamp = int( datetime.datetime(2024, 5, 12, 17, 44, 12, tzinfo=datetime.timezone.utc).timestamp() @@ -267,6 +268,18 @@ def test_execute_query_parameters_not_supported_types(): {"test1": SqlType.Struct([("field1", SqlType.Int64())])}, ) + with pytest.raises(NotImplementedError, match="not supported"): + _format_execute_query_params( + {"test1": singer_pb2.Singer()}, + {"test1": SqlType.Proto()}, + ) + + with pytest.raises(NotImplementedError, match="not supported"): + _format_execute_query_params( + {"test1": singer_pb2.Genre.ROCK}, + {"test1": SqlType.Enum()}, + ) + def test_instance_execute_query_parameters_not_match(): with pytest.raises(ValueError, match="test2"): diff --git a/tests/unit/data/execute_query/test_query_result_parsing_utils.py b/tests/unit/data/execute_query/test_query_result_parsing_utils.py index 627570c37..ea03dfe9a 100644 --- a/tests/unit/data/execute_query/test_query_result_parsing_utils.py +++ b/tests/unit/data/execute_query/test_query_result_parsing_utils.py @@ -28,7 +28,8 @@ import datetime -from tests.unit.data.execute_query.sql_helpers import int64_type +from tests.unit.data.execute_query.sql_helpers import int64_type, proto_type, enum_type +from samples.testdata import singer_pb2 TYPE_BYTES = {"bytes_type": {}} TYPE_TIMESTAMP = {"timestamp_type": {}} @@ -82,9 +83,61 @@ def test_basic_types( assert type(metadata_type) is expected_metadata_type value = PBValue(value_dict) assert ( - _parse_pb_value_to_python_value(value._pb, metadata_type) == expected_value + _parse_pb_value_to_python_value(value._pb, metadata_type, "my_field") + == expected_value ) + def test__proto(self): + _type = PBType({"proto_type": {}}) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Proto + + singer = singer_pb2.Singer(name="John") + value = PBValue({"bytes_value": singer.SerializeToString()}) + + # without proto definition + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, "proto_field" + ) + assert result == singer.SerializeToString() + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + None, + {"proto_field": singer_pb2.Singer()}, + ) + assert result == singer.SerializeToString() + + # with proto definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "proto_field", + {"proto_field": singer_pb2.Singer()}, + ) + assert result == singer + + def test__enum(self): + _type = PBType({"enum_type": {}}) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Enum + + value = PBValue({"int_value": 1}) + + # without enum definition + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "enum_field") + assert result == 1 + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, None, {"enum_field": singer_pb2.Genre} + ) + assert result == 1 + + # with enum definition + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, "enum_field", {"enum_field": singer_pb2.Genre} + ) + assert result == "JAZZ" + # Larger test cases were extracted for readability def test__array(self): _type = PBType({"array_type": {"element_type": int64_type()}}) @@ -103,7 +156,79 @@ def test__array(self): } } ) - assert _parse_pb_value_to_python_value(value._pb, metadata_type) == [1, 2, 3, 4] + assert _parse_pb_value_to_python_value( + value._pb, metadata_type, "array_field" + ) == [1, 2, 3, 4] + + def test__array_of_protos(self): + _type = PBType({"array_type": {"element_type": proto_type()}}) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Array + assert type(metadata_type.element_type) is SqlType.Proto + + singer1 = singer_pb2.Singer(name="John") + singer2 = singer_pb2.Singer(name="Taylor") + value = PBValue( + { + "array_value": { + "values": [ + {"bytes_value": singer1.SerializeToString()}, + {"bytes_value": singer2.SerializeToString()}, + ] + } + } + ) + + # without proto definition + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, "array_field" + ) + assert result == [singer1.SerializeToString(), singer2.SerializeToString()] + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, None, {"array_field": singer_pb2.Singer()} + ) + assert result == [singer1.SerializeToString(), singer2.SerializeToString()] + + # with proto definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "array_field", + {"array_field": singer_pb2.Singer()}, + ) + assert result == [singer1, singer2] + + def test__array_of_enums(self): + _type = PBType({"array_type": {"element_type": enum_type()}}) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Array + assert type(metadata_type.element_type) is SqlType.Enum + + value = PBValue( + { + "array_value": { + "values": [ + {"int_value": 0}, # POP + {"int_value": 1}, # JAZZ + ] + } + } + ) + + # without enum definition + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, "array_field" + ) + assert result == [0, 1] + + # with enum definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "array_field", + {"array_field": singer_pb2.Genre}, + ) + assert result == ["POP", "JAZZ"] def test__struct(self): _type = PBType( @@ -164,7 +289,9 @@ def test__struct(self): with pytest.raises(KeyError, match="Ambigious field name"): metadata_type["field3"] - result = _parse_pb_value_to_python_value(value._pb, metadata_type) + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, "struct_field" + ) assert isinstance(result, Struct) assert result["field1"] == result[0] == 1 assert result[1] == "test2" @@ -177,6 +304,87 @@ def test__struct(self): assert result[2] == [2, 3, 4, 5] assert result[3] == "test4" + def test__struct_with_proto_and_enum(self): + singer1 = singer_pb2.Singer(name="John") + singer2 = singer_pb2.Singer(name="Taylor") + _type = PBType( + { + "struct_type": { + "fields": [ + { + "field_name": "field1", + "type_": proto_type(), + }, + { + "field_name": None, + "type_": proto_type(), + }, + { + "field_name": "field2", + "type_": enum_type(), + }, + { + "field_name": None, + "type_": enum_type(), + }, + ] + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ + {"bytes_value": singer1.SerializeToString()}, + {"bytes_value": singer2.SerializeToString()}, + {"int_value": 0}, + {"int_value": 1}, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Struct + assert type(metadata_type["field1"]) is SqlType.Proto + assert type(metadata_type["field2"]) is SqlType.Enum + assert type(metadata_type[0]) is SqlType.Proto + assert type(metadata_type[1]) is SqlType.Proto + assert type(metadata_type[2]) is SqlType.Enum + assert type(metadata_type[3]) is SqlType.Enum + + # without proto definition + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, "struct_field" + ) + assert isinstance(result, Struct) + assert result["field1"] == singer1.SerializeToString() + assert result["field2"] == 0 + assert result[0] == singer1.SerializeToString() + assert result[1] == singer2.SerializeToString() + assert result[2] == 0 + assert result[3] == 1 + + # with proto definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "struct_field", + { + "struct_field.field1": singer_pb2.Singer(), + "struct_field.field2": singer_pb2.Genre, + }, + ) + assert isinstance(result, Struct) + assert result["field1"] == singer1 + assert result["field2"] == "POP" + assert result[0] == singer1 + # unnamed proto fields won't get parsed + assert result[1] == singer2.SerializeToString() + assert result[2] == "POP" + # unnamed enum fields won't get parsed + assert result[3] == 1 + def test__array_of_structs(self): _type = PBType( { @@ -254,7 +462,9 @@ def test__array_of_structs(self): assert type(metadata_type.element_type[1]) is SqlType.String assert type(metadata_type.element_type["field3"]) is SqlType.Bool - result = _parse_pb_value_to_python_value(value._pb, metadata_type) + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, "array_field" + ) assert isinstance(result, list) assert len(result) == 4 @@ -278,6 +488,106 @@ def test__array_of_structs(self): assert result[3][1] == "test4" assert not result[3]["field3"] + def test__array_of_structs_with_proto_and_enum(self): + singer1 = singer_pb2.Singer(name="John") + singer2 = singer_pb2.Singer(name="Taylor") + _type = PBType( + { + "array_type": { + "element_type": { + "struct_type": { + "fields": [ + { + "field_name": "proto_field", + "type_": proto_type(), + }, + { + "field_name": "enum_field", + "type_": enum_type(), + }, + { + "field_name": None, + "type_": proto_type(), + }, + ] + } + } + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ + { + "array_value": { + "values": [ + {"bytes_value": singer1.SerializeToString()}, + {"int_value": 0}, # POP + {"bytes_value": singer1.SerializeToString()}, + ] + } + }, + { + "array_value": { + "values": [ + {"bytes_value": singer2.SerializeToString()}, + {"int_value": 1}, # JAZZ + {"bytes_value": singer2.SerializeToString()}, + ] + } + }, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Array + assert type(metadata_type.element_type) is SqlType.Struct + assert type(metadata_type.element_type["proto_field"]) is SqlType.Proto + assert type(metadata_type.element_type["enum_field"]) is SqlType.Enum + assert type(metadata_type.element_type[2]) is SqlType.Proto + + # without proto definition + result = _parse_pb_value_to_python_value( + value._pb, metadata_type, "array_field" + ) + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], Struct) + assert result[0]["proto_field"] == singer1.SerializeToString() + assert result[0]["enum_field"] == 0 + assert result[0][2] == singer1.SerializeToString() + assert isinstance(result[1], Struct) + assert result[1]["proto_field"] == singer2.SerializeToString() + assert result[1]["enum_field"] == 1 + assert result[1][2] == singer2.SerializeToString() + + # with proto definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "array_field", + { + "array_field.proto_field": singer_pb2.Singer(), + "array_field.enum_field": singer_pb2.Genre, + "array_field": singer_pb2.Singer(), # unused + }, + ) + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], Struct) + assert result[0]["proto_field"] == singer1 + assert result[0]["enum_field"] == "POP" + # unnamed proto fields won't get parsed + assert result[0][2] == singer1.SerializeToString() + assert isinstance(result[1], Struct) + assert result[1]["proto_field"] == singer2 + assert result[1]["enum_field"] == "JAZZ" + # unnamed proto fields won't get parsed + assert result[1][2] == singer2.SerializeToString() + def test__map(self): _type = PBType( { @@ -333,7 +643,7 @@ def test__map(self): assert type(metadata_type.key_type) is SqlType.Int64 assert type(metadata_type.value_type) is SqlType.String - result = _parse_pb_value_to_python_value(value._pb, metadata_type) + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") assert isinstance(result, dict) assert len(result) == 4 @@ -387,13 +697,135 @@ def test__map_repeated_values(self): ) metadata_type = _pb_type_to_metadata_type(_type) - result = _parse_pb_value_to_python_value(value._pb, metadata_type) + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") assert len(result) == 1 assert result == { 1: "test3", } + def test__map_with_protos(self): + singer1 = singer_pb2.Singer(name="John") + singer2 = singer_pb2.Singer(name="Taylor") + _type = PBType( + { + "map_type": { + "key_type": int64_type(), + "value_type": proto_type(), + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ + { + "array_value": { + "values": [ + {"int_value": 1}, + {"bytes_value": singer1.SerializeToString()}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 2}, + {"bytes_value": singer2.SerializeToString()}, + ] + } + }, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Map + assert type(metadata_type.key_type) is SqlType.Int64 + assert type(metadata_type.value_type) is SqlType.Proto + + # without proto definition + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") + assert isinstance(result, dict) + assert len(result) == 2 + assert result[1] == singer1.SerializeToString() + assert result[2] == singer2.SerializeToString() + + # with proto definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "map_field", + { + "map_field.value": singer_pb2.Singer(), + }, + ) + assert isinstance(result, dict) + assert len(result) == 2 + assert result[1] == singer1 + assert result[2] == singer2 + + def test__map_with_enums(self): + _type = PBType( + { + "map_type": { + "key_type": int64_type(), + "value_type": enum_type(), + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ + { + "array_value": { + "values": [ + {"int_value": 1}, + {"int_value": 0}, # POP + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 2}, + {"int_value": 1}, # JAZZ + ] + } + }, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Map + assert type(metadata_type.key_type) is SqlType.Int64 + assert type(metadata_type.value_type) is SqlType.Enum + + # without enum definition + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") + assert isinstance(result, dict) + assert len(result) == 2 + assert result[1] == 0 + assert result[2] == 1 + + # with enum definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "map_field", + { + "map_field.value": singer_pb2.Genre, + }, + ) + assert isinstance(result, dict) + assert len(result) == 2 + assert result[1] == "POP" + assert result[2] == "JAZZ" + def test__map_of_maps_of_structs(self): _type = PBType( { @@ -539,7 +971,7 @@ def test__map_of_maps_of_structs(self): assert type(metadata_type.value_type.value_type) is SqlType.Struct assert type(metadata_type.value_type.value_type["field1"]) is SqlType.Int64 assert type(metadata_type.value_type.value_type["field2"]) is SqlType.String - result = _parse_pb_value_to_python_value(value._pb, metadata_type) + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") assert result[1]["1_1"]["field1"] == 1 assert result[1]["1_1"]["field2"] == "test1" @@ -553,23 +985,31 @@ def test__map_of_maps_of_structs(self): assert result[2]["2_2"]["field1"] == 4 assert result[2]["2_2"]["field2"] == "test4" - def test__map_of_lists_of_structs(self): + def test__map_of_maps_of_structs_with_proto_and_enum(self): + singer1 = singer_pb2.Singer(name="John") + singer2 = singer_pb2.Singer(name="Taylor") + _type = PBType( { "map_type": { - "key_type": TYPE_BYTES, + "key_type": int64_type(), "value_type": { - "array_type": { - "element_type": { + "map_type": { + "key_type": {"string_type": {}}, + "value_type": { "struct_type": { "fields": [ { - "field_name": "timestamp", - "type_": TYPE_TIMESTAMP, + "field_name": "int_field", + "type_": int64_type(), }, { - "field_name": "value", - "type_": TYPE_BYTES, + "field_name": "singer", + "type_": proto_type(), + }, + { + "field_name": "genre", + "type_": enum_type(), }, ] } @@ -582,20 +1022,225 @@ def test__map_of_lists_of_structs(self): value = PBValue( { "array_value": { - "values": [ # list of (byte, list) tuples + "values": [ # list of (int, map) tuples { "array_value": { - "values": [ # (byte, list) tuple - {"bytes_value": b"key1"}, + "values": [ # (int, map) tuples + {"int_value": 1}, { "array_value": { - "values": [ # list of structs + "values": [ # list of (str, struct) tuples { "array_value": { - "values": [ # (timestamp, bytes) tuple + "values": [ # (str, struct) tuples + {"string_value": "1_1"}, { - "timestamp_value": { - "seconds": 1111111111 + "array_value": { + "values": [ + { + "int_value": 12 + }, + { + "bytes_value": singer1.SerializeToString() + }, + { + "int_value": 0 + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (str, struct) tuples + {"string_value": "1_2"}, + { + "array_value": { + "values": [ + { + "int_value": 34 + }, + { + "bytes_value": singer2.SerializeToString() + }, + { + "int_value": 1 + }, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (int, map) tuples + {"int_value": 2}, + { + "array_value": { + "values": [ # list of (str, struct) tuples + { + "array_value": { + "values": [ # (str, struct) tuples + {"string_value": "2_1"}, + { + "array_value": { + "values": [ + { + "int_value": 56 + }, + { + "bytes_value": singer1.SerializeToString() + }, + { + "int_value": 2 + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (str, struct) tuples + {"string_value": "2_2"}, + { + "array_value": { + "values": [ + { + "int_value": 78 + }, + { + "bytes_value": singer2.SerializeToString() + }, + { + "int_value": 3 + }, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Map + assert type(metadata_type.key_type) is SqlType.Int64 + assert type(metadata_type.value_type) is SqlType.Map + assert type(metadata_type.value_type.key_type) is SqlType.String + assert type(metadata_type.value_type.value_type) is SqlType.Struct + assert type(metadata_type.value_type.value_type["int_field"]) is SqlType.Int64 + assert type(metadata_type.value_type.value_type["singer"]) is SqlType.Proto + assert type(metadata_type.value_type.value_type["genre"]) is SqlType.Enum + + # without proto definition + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") + + assert result[1]["1_1"]["int_field"] == 12 + assert result[1]["1_1"]["singer"] == singer1.SerializeToString() + assert result[1]["1_1"]["genre"] == 0 + + assert result[1]["1_2"]["int_field"] == 34 + assert result[1]["1_2"]["singer"] == singer2.SerializeToString() + assert result[1]["1_2"]["genre"] == 1 + + assert result[2]["2_1"]["int_field"] == 56 + assert result[2]["2_1"]["singer"] == singer1.SerializeToString() + assert result[2]["2_1"]["genre"] == 2 + + assert result[2]["2_2"]["int_field"] == 78 + assert result[2]["2_2"]["singer"] == singer2.SerializeToString() + assert result[2]["2_2"]["genre"] == 3 + + # with proto definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "map_field", + { + "map_field.value.value.singer": singer_pb2.Singer(), + "map_field.value.value.genre": singer_pb2.Genre, + }, + ) + + assert result[1]["1_1"]["int_field"] == 12 + assert result[1]["1_1"]["singer"] == singer1 + assert result[1]["1_1"]["genre"] == "POP" + + assert result[1]["1_2"]["int_field"] == 34 + assert result[1]["1_2"]["singer"] == singer2 + assert result[1]["1_2"]["genre"] == "JAZZ" + + assert result[2]["2_1"]["int_field"] == 56 + assert result[2]["2_1"]["singer"] == singer1 + assert result[2]["2_1"]["genre"] == "FOLK" + + assert result[2]["2_2"]["int_field"] == 78 + assert result[2]["2_2"]["singer"] == singer2 + assert result[2]["2_2"]["genre"] == "ROCK" + + def test__map_of_lists_of_structs(self): + _type = PBType( + { + "map_type": { + "key_type": TYPE_BYTES, + "value_type": { + "array_type": { + "element_type": { + "struct_type": { + "fields": [ + { + "field_name": "timestamp", + "type_": TYPE_TIMESTAMP, + }, + { + "field_name": "value", + "type_": TYPE_BYTES, + }, + ] + } + }, + } + }, + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ # list of (byte, list) tuples + { + "array_value": { + "values": [ # (byte, list) tuple + {"bytes_value": b"key1"}, + { + "array_value": { + "values": [ # list of structs + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 1111111111 } }, { @@ -679,7 +1324,7 @@ def test__map_of_lists_of_structs(self): is SqlType.Timestamp ) assert type(metadata_type.value_type.element_type["value"]) is SqlType.Bytes - result = _parse_pb_value_to_python_value(value._pb, metadata_type) + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") timestamp1 = DatetimeWithNanoseconds( 2005, 3, 18, 1, 58, 31, tzinfo=datetime.timezone.utc @@ -703,6 +1348,341 @@ def test__map_of_lists_of_structs(self): assert result[b"key2"][1]["timestamp"] == timestamp4 assert result[b"key2"][1]["value"] == b"key2-value2" + def test__map_of_lists_of_structs_with_protos(self): + singer1 = singer_pb2.Singer(name="John") + singer2 = singer_pb2.Singer(name="Taylor") + singer3 = singer_pb2.Singer(name="Jay") + singer4 = singer_pb2.Singer(name="Eric") + + _type = PBType( + { + "map_type": { + "key_type": TYPE_BYTES, + "value_type": { + "array_type": { + "element_type": { + "struct_type": { + "fields": [ + { + "field_name": "timestamp", + "type_": TYPE_TIMESTAMP, + }, + { + "field_name": "value", + "type_": proto_type(), + }, + ] + } + }, + } + }, + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ # list of (byte, list) tuples + { + "array_value": { + "values": [ # (byte, list) tuple + {"bytes_value": b"key1"}, + { + "array_value": { + "values": [ # list of structs + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 1111111111 + } + }, + { + "bytes_value": singer1.SerializeToString() + }, + ] + } + }, + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 2222222222 + } + }, + { + "bytes_value": singer2.SerializeToString() + }, + ] + } + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (byte, list) tuple + {"bytes_value": b"key2"}, + { + "array_value": { + "values": [ # list of structs + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 3333333333 + } + }, + { + "bytes_value": singer3.SerializeToString() + }, + ] + } + }, + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 4444444444 + } + }, + { + "bytes_value": singer4.SerializeToString() + }, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + } + ) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Map + assert type(metadata_type.key_type) is SqlType.Bytes + assert type(metadata_type.value_type) is SqlType.Array + assert type(metadata_type.value_type.element_type) is SqlType.Struct + assert ( + type(metadata_type.value_type.element_type["timestamp"]) + is SqlType.Timestamp + ) + assert type(metadata_type.value_type.element_type["value"]) is SqlType.Proto + + timestamp1 = DatetimeWithNanoseconds( + 2005, 3, 18, 1, 58, 31, tzinfo=datetime.timezone.utc + ) + timestamp2 = DatetimeWithNanoseconds( + 2040, 6, 2, 3, 57, 2, tzinfo=datetime.timezone.utc + ) + timestamp3 = DatetimeWithNanoseconds( + 2075, 8, 18, 5, 55, 33, tzinfo=datetime.timezone.utc + ) + timestamp4 = DatetimeWithNanoseconds( + 2110, 11, 3, 7, 54, 4, tzinfo=datetime.timezone.utc + ) + + # without proto definition + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") + assert result[b"key1"][0]["timestamp"] == timestamp1 + assert result[b"key1"][0]["value"] == singer1.SerializeToString() + assert result[b"key1"][1]["timestamp"] == timestamp2 + assert result[b"key1"][1]["value"] == singer2.SerializeToString() + assert result[b"key2"][0]["timestamp"] == timestamp3 + assert result[b"key2"][0]["value"] == singer3.SerializeToString() + assert result[b"key2"][1]["timestamp"] == timestamp4 + assert result[b"key2"][1]["value"] == singer4.SerializeToString() + + # with proto definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "map_field", + { + "map_field.value.value": singer_pb2.Singer(), + }, + ) + assert result[b"key1"][0]["timestamp"] == timestamp1 + assert result[b"key1"][0]["value"] == singer1 + assert result[b"key1"][1]["timestamp"] == timestamp2 + assert result[b"key1"][1]["value"] == singer2 + assert result[b"key2"][0]["timestamp"] == timestamp3 + assert result[b"key2"][0]["value"] == singer3 + assert result[b"key2"][1]["timestamp"] == timestamp4 + assert result[b"key2"][1]["value"] == singer4 + + def test__map_of_lists_of_structs_with_enums(self): + _type = PBType( + { + "map_type": { + "key_type": TYPE_BYTES, + "value_type": { + "array_type": { + "element_type": { + "struct_type": { + "fields": [ + { + "field_name": "timestamp", + "type_": TYPE_TIMESTAMP, + }, + { + "field_name": "value", + "type_": enum_type(), + }, + ] + } + }, + } + }, + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ # list of (byte, list) tuples + { + "array_value": { + "values": [ # (byte, list) tuple + {"bytes_value": b"key1"}, + { + "array_value": { + "values": [ # list of structs + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 1111111111 + } + }, + {"int_value": 0}, + ] + } + }, + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 2222222222 + } + }, + {"int_value": 1}, + ] + } + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (byte, list) tuple + {"bytes_value": b"key2"}, + { + "array_value": { + "values": [ # list of structs + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 3333333333 + } + }, + {"int_value": 2}, + ] + } + }, + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 4444444444 + } + }, + {"int_value": 3}, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + } + ) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Map + assert type(metadata_type.key_type) is SqlType.Bytes + assert type(metadata_type.value_type) is SqlType.Array + assert type(metadata_type.value_type.element_type) is SqlType.Struct + assert ( + type(metadata_type.value_type.element_type["timestamp"]) + is SqlType.Timestamp + ) + assert type(metadata_type.value_type.element_type["value"]) is SqlType.Enum + + timestamp1 = DatetimeWithNanoseconds( + 2005, 3, 18, 1, 58, 31, tzinfo=datetime.timezone.utc + ) + timestamp2 = DatetimeWithNanoseconds( + 2040, 6, 2, 3, 57, 2, tzinfo=datetime.timezone.utc + ) + timestamp3 = DatetimeWithNanoseconds( + 2075, 8, 18, 5, 55, 33, tzinfo=datetime.timezone.utc + ) + timestamp4 = DatetimeWithNanoseconds( + 2110, 11, 3, 7, 54, 4, tzinfo=datetime.timezone.utc + ) + + # without enum definition + result = _parse_pb_value_to_python_value(value._pb, metadata_type, "map_field") + assert result[b"key1"][0]["timestamp"] == timestamp1 + assert result[b"key1"][0]["value"] == 0 + assert result[b"key1"][1]["timestamp"] == timestamp2 + assert result[b"key1"][1]["value"] == 1 + assert result[b"key2"][0]["timestamp"] == timestamp3 + assert result[b"key2"][0]["value"] == 2 + assert result[b"key2"][1]["timestamp"] == timestamp4 + assert result[b"key2"][1]["value"] == 3 + + # with enum definition + result = _parse_pb_value_to_python_value( + value._pb, + metadata_type, + "map_field", + { + "map_field.value.value": singer_pb2.Genre, + }, + ) + assert result[b"key1"][0]["timestamp"] == timestamp1 + assert result[b"key1"][0]["value"] == "POP" + assert result[b"key1"][1]["timestamp"] == timestamp2 + assert result[b"key1"][1]["value"] == "JAZZ" + assert result[b"key2"][0]["timestamp"] == timestamp3 + assert result[b"key2"][0]["value"] == "FOLK" + assert result[b"key2"][1]["timestamp"] == timestamp4 + assert result[b"key2"][1]["value"] == "ROCK" + def test__invalid_type_throws_exception(self): _type = PBType({"string_type": {}}) value = PBValue({"int_value": 1}) @@ -712,4 +1692,4 @@ def test__invalid_type_throws_exception(self): ValueError, match="string_value field for String type not found in a Value.", ): - _parse_pb_value_to_python_value(value._pb, metadata_type) + _parse_pb_value_to_python_value(value._pb, metadata_type, "string_field") diff --git a/tests/unit/data/execute_query/test_query_result_row_reader.py b/tests/unit/data/execute_query/test_query_result_row_reader.py index 6adb1e3c7..ab98b54bd 100644 --- a/tests/unit/data/execute_query/test_query_result_row_reader.py +++ b/tests/unit/data/execute_query/test_query_result_row_reader.py @@ -32,7 +32,9 @@ metadata, proto_rows_bytes, str_val, + bytes_val, ) +from samples.testdata import singer_pb2 class TestQueryResultRowReader: @@ -116,8 +118,8 @@ def test__received_values_are_passed_to_parser_in_batches(self): reader.consume([proto_rows_bytes(int_val(1), int_val(2))], metadata) parse_mock.assert_has_calls( [ - mock.call(PBValue(int_val(1)), SqlType.Int64()), - mock.call(PBValue(int_val(2)), SqlType.Int64()), + mock.call(PBValue(int_val(1)), SqlType.Int64(), "test1", None), + mock.call(PBValue(int_val(2)), SqlType.Int64(), "test2", None), ] ) @@ -137,7 +139,7 @@ def test__parser_errors_are_forwarded(self): parse_mock.assert_has_calls( [ - mock.call(PBValue(values[0]), SqlType.Int64()), + mock.call(PBValue(values[0]), SqlType.Int64(), "test1", None), ] ) @@ -243,6 +245,40 @@ def test_multiple_batches(self): assert row4["test1"] == 7 assert row4["test2"] == 8 + def test_multiple_batches_with_proto_and_enum_types(self): + singer1 = singer_pb2.Singer(name="John") + singer2 = singer_pb2.Singer(name="Taylor") + singer3 = singer_pb2.Singer(name="Jay") + singer4 = singer_pb2.Singer(name="Eric") + + reader = _QueryResultRowReader() + batches = [ + proto_rows_bytes( + bytes_val(singer1.SerializeToString()), + int_val(0), + bytes_val(singer2.SerializeToString()), + int_val(1), + ), + proto_rows_bytes(bytes_val(singer3.SerializeToString()), int_val(2)), + proto_rows_bytes(bytes_val(singer4.SerializeToString()), int_val(3)), + ] + + results = reader.consume( + batches, + Metadata([("singer", SqlType.Proto()), ("genre", SqlType.Enum())]), + {"singer": singer_pb2.Singer(), "genre": singer_pb2.Genre}, + ) + assert len(results) == 4 + [row1, row2, row3, row4] = results + assert row1["singer"] == singer1 + assert row1["genre"] == "POP" + assert row2["singer"] == singer2 + assert row2["genre"] == "JAZZ" + assert row3["singer"] == singer3 + assert row3["genre"] == "FOLK" + assert row4["singer"] == singer4 + assert row4["genre"] == "ROCK" + class TestMetadata: def test__duplicate_column_names(self):