Skip to content

Commit 86fb4d6

Browse files
committed
basetypes: better DictRow
1 parent 468bcb4 commit 86fb4d6

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

skytools/basetypes.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,43 @@
22
"""
33

44
import io
5-
from typing import Sequence, Mapping, List, Any, Optional, Union, IO
6-
5+
from typing import IO, Any, Iterable, Mapping, Optional, Sequence, Tuple, Union
76

87
try:
98
from typing import Protocol
109
except ImportError:
1110
Protocol = object # type: ignore
1211

13-
RowType = Sequence[Any]
12+
13+
ExecuteParams = Union[Sequence[Any], Mapping[str, Any]]
14+
15+
16+
class DictRow(Protocol):
17+
"""Allow both key and index-based access.
18+
19+
Both Psycopg2 DictRow and PL/Python rows support this.
20+
"""
21+
def keys(self) -> Iterable[str]: ...
22+
def values(self) -> Iterable[Any]: ...
23+
def items(self) -> Iterable[Tuple[str, Any]]: ...
24+
def __getitem__(self, key: Union[str, int]) -> Any: ...
25+
def __iter__(self) -> Iterable[str]: ...
26+
def __len__(self) -> int: ...
27+
def __contains__(self, key: str) -> bool: ...
1428

1529

1630
class Cursor(Protocol):
17-
def execute(self, sql: str, params: Optional[Union[Sequence[Any], Mapping[str, Any]]] = None) -> None: ...
18-
def fetchall(self) -> List[RowType]: ...
19-
def fetchone(self) -> RowType: ...
31+
def execute(self, sql: str, params: Optional[ExecuteParams] = None) -> None: ...
32+
def fetchall(self) -> Sequence[DictRow]: ...
33+
def fetchone(self) -> DictRow: ...
2034
def copy_from(self, buf: IO[str], hdr: str) -> None: ...
2135
def copy_expert(self, sql: str, f: Union[IO[str], io.TextIOBase]) -> None: ...
2236

2337

2438
class Connection(Protocol):
2539
def cursor(self) -> Cursor: ...
40+
def rollback(self) -> None: ...
41+
def commit(self) -> None: ...
2642

2743

2844
class Runnable(Protocol):

skytools/querybuilder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import skytools
1818

19-
from .sqltools import Cursor
19+
from .basetypes import Cursor, DictRow
2020

2121
try:
2222
import plpy
@@ -278,7 +278,7 @@ def __init__(self, sql: str):
278278
self.arg_map = qb._arg_value_list
279279
self.sql = sql
280280

281-
def execute(self, arg_dict: Mapping[str, Any], all_keys_required=True) -> List[Mapping[str, Any]]:
281+
def execute(self, arg_dict: Mapping[str, Any], all_keys_required=True) -> Sequence[DictRow]:
282282
try:
283283
if all_keys_required:
284284
arg_list = [arg_dict[k] for k in self.arg_map]
@@ -300,7 +300,7 @@ def __repr__(self) -> str:
300300
def plpy_exec(gd: Optional[Dict[str, Any]],
301301
sql: str,
302302
args: Optional[Mapping[str, Any]],
303-
all_keys_required=True) -> Sequence[Mapping[str, Any]]:
303+
all_keys_required=True) -> Sequence[DictRow]:
304304
"""Cached plan execution for PL/Python.
305305
306306
@param gd: dict to store cached plans under. If None, caching is disabled.

0 commit comments

Comments
 (0)