|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -import math |
17 | | -from typing import Union |
| 16 | +from enum import IntEnum |
| 17 | +from typing import Union, Type |
18 | 18 |
|
19 | 19 | from pyignite.api.tx_api import tx_end, tx_start, tx_end_async, tx_start_async |
20 | 20 | from pyignite.datatypes import TransactionIsolation, TransactionConcurrency |
21 | 21 | from pyignite.exceptions import CacheError |
22 | 22 | from pyignite.utils import status_to_exception |
23 | 23 |
|
24 | 24 |
|
25 | | -def _convert_to_millis(timeout: Union[int, float]) -> int: |
26 | | - if isinstance(timeout, float): |
27 | | - return math.floor(timeout * 1000) |
28 | | - return timeout |
| 25 | +def _validate_int_enum_param(value: Union[int, IntEnum], cls: Type[IntEnum]): |
| 26 | + if value not in cls: |
| 27 | + raise ValueError(f'{value} not in {cls}') |
| 28 | + return value |
29 | 29 |
|
30 | 30 |
|
31 | | -class Transaction: |
| 31 | +def _validate_timeout(value): |
| 32 | + if not isinstance(value, int) or value < 0: |
| 33 | + raise ValueError(f'Timeout value should be a positive integer, {value} passed instead') |
| 34 | + return value |
| 35 | + |
| 36 | + |
| 37 | +def _validate_label(value): |
| 38 | + if value and not isinstance(value, str): |
| 39 | + raise ValueError(f'Label should be str, {type(value)} passed instead') |
| 40 | + return value |
| 41 | + |
| 42 | + |
| 43 | +class _BaseTransaction: |
| 44 | + def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC, |
| 45 | + isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None): |
| 46 | + self.client = client |
| 47 | + self.concurrency = _validate_int_enum_param(concurrency, TransactionConcurrency) |
| 48 | + self.isolation = _validate_int_enum_param(isolation, TransactionIsolation) |
| 49 | + self.timeout = _validate_timeout(timeout) |
| 50 | + self.label, self.closed = _validate_label(label), False |
| 51 | + |
| 52 | + |
| 53 | +class Transaction(_BaseTransaction): |
32 | 54 | """ |
33 | 55 | Thin client transaction. |
34 | 56 | """ |
35 | 57 | def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC, |
36 | 58 | isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None): |
37 | | - self.client, self.concurrency = client, concurrency |
38 | | - self.isolation, self.timeout = isolation, _convert_to_millis(timeout) |
39 | | - self.label, self.closed = label, False |
| 59 | + super().__init__(client, concurrency, isolation, timeout, label) |
40 | 60 | self.tx_id = self.__start_tx() |
41 | 61 |
|
42 | 62 | def commit(self) -> None: |
@@ -77,15 +97,13 @@ def __end_tx(self, committed): |
77 | 97 | return tx_end(self.tx_id, committed) |
78 | 98 |
|
79 | 99 |
|
80 | | -class AioTransaction: |
| 100 | +class AioTransaction(_BaseTransaction): |
81 | 101 | """ |
82 | 102 | Async thin client transaction. |
83 | 103 | """ |
84 | 104 | def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC, |
85 | 105 | isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None): |
86 | | - self.client, self.concurrency = client, concurrency |
87 | | - self.isolation, self.timeout = isolation, _convert_to_millis(timeout) |
88 | | - self.label, self.closed = label, False |
| 106 | + super().__init__(client, concurrency, isolation, timeout, label) |
89 | 107 |
|
90 | 108 | def __await__(self): |
91 | 109 | return (yield from self.__aenter__().__await__()) |
|
0 commit comments