Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions scripts/gen_bridge_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def generate_python_services(
from __future__ import annotations

from datetime import timedelta
from typing import Mapping, Optional, Union, TYPE_CHECKING
from typing import TYPE_CHECKING
from collections.abc import Mapping

import google.protobuf.empty_pb2

$service_imports
Expand Down Expand Up @@ -110,8 +112,8 @@ async def $method_name(
self,
req: $request_type,
retry: bool = False,
metadata: Mapping[str, Union[str, bytes]] = {},
timeout: Optional[timedelta] = None,
metadata: Mapping[str, str | bytes] = {},
timeout: timedelta | None = None,
) -> $response_type:
"""Invokes the $service_name.$method_name rpc method."""
return await self._client._rpc_call(
Expand Down
4 changes: 2 additions & 2 deletions scripts/gen_payload_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def emit_loop(


def emit_singular(
field_name: str, access_expr: str, child_method: str, presence_word: Optional[str]
field_name: str, access_expr: str, child_method: str, presence_word: str | None
) -> str:
# Helper to emit a singular field visit with presence check and optional headers guard
if presence_word:
Expand Down Expand Up @@ -152,7 +152,7 @@ async def _visit_payload_container(self, fs, o):
""",
]

def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]:
def check_repeated(self, child_desc, field, iter_expr) -> str | None:
# Special case for repeated payloads, handle them directly
if child_desc.full_name == Payload.DESCRIPTOR.full_name:
return emit_singular(field.name, iter_expr, "payload_container", None)
Expand Down
5 changes: 3 additions & 2 deletions scripts/gen_protos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import subprocess
import sys
import tempfile
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import List, Mapping
from typing import List

base_dir = Path(__file__).parent.parent
proto_dir = (
Expand Down Expand Up @@ -64,7 +65,7 @@ def fix_generated_output(base_path: Path):
- protoc doesn't generate the correct import paths
(https://github.com/protocolbuffers/protobuf/issues/1491)
"""
imports: Mapping[str, List[str]] = collections.defaultdict(list)
imports: Mapping[str, list[str]] = collections.defaultdict(list)
for p in base_path.iterdir():
if p.is_dir():
fix_generated_output(p)
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import sys
import time
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncIterator

from temporalio import activity, workflow
from temporalio.testing import WorkflowEnvironment
Expand Down
72 changes: 33 additions & 39 deletions temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,16 @@
import inspect
import logging
import threading
from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
Union,
Expand All @@ -53,7 +49,7 @@ def defn(fn: CallableType) -> CallableType: ...

@overload
def defn(
*, name: Optional[str] = None, no_thread_cancel_exception: bool = False
*, name: str | None = None, no_thread_cancel_exception: bool = False
) -> Callable[[CallableType], CallableType]: ...


Expand All @@ -64,9 +60,9 @@ def defn(


def defn(
fn: Optional[CallableType] = None,
fn: CallableType | None = None, # type: ignore[reportInvalidTypeVarUse]
*,
name: Optional[str] = None,
name: str | None = None,
no_thread_cancel_exception: bool = False,
dynamic: bool = False,
):
Expand Down Expand Up @@ -111,11 +107,11 @@ class Info:
attempt: int
current_attempt_scheduled_time: datetime
heartbeat_details: Sequence[Any]
heartbeat_timeout: Optional[timedelta]
heartbeat_timeout: timedelta | None
is_local: bool
schedule_to_close_timeout: Optional[timedelta]
schedule_to_close_timeout: timedelta | None
scheduled_time: datetime
start_to_close_timeout: Optional[timedelta]
start_to_close_timeout: timedelta | None
started_time: datetime
task_queue: str
task_token: bytes
Expand All @@ -124,7 +120,7 @@ class Info:
workflow_run_id: str
workflow_type: str
priority: temporalio.common.Priority
retry_policy: Optional[temporalio.common.RetryPolicy]
retry_policy: temporalio.common.RetryPolicy | None
"""The retry policy of this activity.

Note that the server may have set a different policy than the one provided when scheduling the activity.
Expand All @@ -151,7 +147,7 @@ def _logger_details(self) -> Mapping[str, Any]:

@dataclass
class _ActivityCancellationDetailsHolder:
details: Optional[ActivityCancellationDetails] = None
details: ActivityCancellationDetails | None = None


@dataclass(frozen=True)
Expand Down Expand Up @@ -183,20 +179,20 @@ def _from_proto(
class _Context:
info: Callable[[], Info]
# This is optional because during interceptor init it is not present
heartbeat: Optional[Callable[..., None]]
heartbeat: Callable[..., None] | None
cancelled_event: _CompositeEvent
worker_shutdown_event: _CompositeEvent
shield_thread_cancel_exception: Optional[Callable[[], AbstractContextManager]]
payload_converter_class_or_instance: Union[
Type[temporalio.converter.PayloadConverter],
temporalio.converter.PayloadConverter,
]
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
client: Optional[Client]
shield_thread_cancel_exception: Callable[[], AbstractContextManager] | None
payload_converter_class_or_instance: (
type[temporalio.converter.PayloadConverter]
| temporalio.converter.PayloadConverter
)
runtime_metric_meter: temporalio.common.MetricMeter | None
client: Client | None
cancellation_details: _ActivityCancellationDetailsHolder
_logger_details: Optional[Mapping[str, Any]] = None
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
_metric_meter: Optional[temporalio.common.MetricMeter] = None
_logger_details: Mapping[str, Any] | None = None
_payload_converter: temporalio.converter.PayloadConverter | None = None
_metric_meter: temporalio.common.MetricMeter | None = None

@staticmethod
def current() -> _Context:
Expand Down Expand Up @@ -258,9 +254,9 @@ def metric_meter(self) -> temporalio.common.MetricMeter:
@dataclass
class _CompositeEvent:
# This should always be present, but is sometimes lazily set internally
thread_event: Optional[threading.Event]
thread_event: threading.Event | None
# Async event only for async activities
async_event: Optional[asyncio.Event]
async_event: asyncio.Event | None

def set(self) -> None:
if not self.thread_event:
Expand All @@ -279,7 +275,7 @@ async def wait(self) -> None:
raise RuntimeError("not in async activity")
await self.async_event.wait()

def wait_sync(self, timeout: Optional[float] = None) -> None:
def wait_sync(self, timeout: float | None = None) -> None:
if not self.thread_event:
raise RuntimeError("Missing event")
self.thread_event.wait(timeout)
Expand Down Expand Up @@ -330,7 +326,7 @@ def info() -> Info:
return _Context.current().info()


def cancellation_details() -> Optional[ActivityCancellationDetails]:
def cancellation_details() -> ActivityCancellationDetails | None:
"""Cancellation details of the current activity, if any. Once set, cancellation details do not change."""
return _Context.current().cancellation_details.details

Expand Down Expand Up @@ -398,7 +394,7 @@ async def wait_for_cancelled() -> None:
await _Context.current().cancelled_event.wait()


def wait_for_cancelled_sync(timeout: Optional[Union[timedelta, float]] = None) -> None:
def wait_for_cancelled_sync(timeout: timedelta | float | None = None) -> None:
"""Synchronously block while waiting for a cancellation request on this
activity.

Expand Down Expand Up @@ -437,7 +433,7 @@ async def wait_for_worker_shutdown() -> None:


def wait_for_worker_shutdown_sync(
timeout: Optional[Union[timedelta, float]] = None,
timeout: timedelta | float | None = None,
) -> None:
"""Synchronously block while waiting for shutdown to be called on the
worker.
Expand Down Expand Up @@ -511,9 +507,7 @@ class LoggerAdapter(logging.LoggerAdapter):
use by others. Default is False.
"""

def __init__(
self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]
) -> None:
def __init__(self, logger: logging.Logger, extra: Mapping[str, Any] | None) -> None:
"""Create the logger adapter."""
super().__init__(logger, extra or {})
self.activity_info_on_message = True
Expand All @@ -522,7 +516,7 @@ def __init__(

def process(
self, msg: Any, kwargs: MutableMapping[str, Any]
) -> Tuple[Any, MutableMapping[str, Any]]:
) -> tuple[Any, MutableMapping[str, Any]]:
"""Override to add activity details."""
if (
self.activity_info_on_message
Expand Down Expand Up @@ -559,16 +553,16 @@ def base_logger(self) -> logging.Logger:

@dataclass(frozen=True)
class _Definition:
name: Optional[str]
name: str | None
fn: Callable
is_async: bool
no_thread_cancel_exception: bool
# Types loaded on post init if both are None
arg_types: Optional[List[Type]] = None
ret_type: Optional[Type] = None
arg_types: list[type] | None = None
ret_type: type | None = None

@staticmethod
def from_callable(fn: Callable) -> Optional[_Definition]:
def from_callable(fn: Callable) -> _Definition | None:
defn = getattr(fn, "__temporal_activity_definition", None)
if isinstance(defn, _Definition):
# We have to replace the function with the given callable here
Expand All @@ -592,7 +586,7 @@ def must_from_callable(fn: Callable) -> _Definition:
def _apply_to_callable(
fn: Callable,
*,
activity_name: Optional[str],
activity_name: str | None,
no_thread_cancel_exception: bool = False,
) -> None:
# Validate the activity
Expand Down
41 changes: 21 additions & 20 deletions temporalio/bridge/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass
from datetime import timedelta
from typing import Mapping, Optional, Tuple, Type, TypeVar, Union
from typing import Optional, Tuple, Type, TypeVar, Union

import google.protobuf.message

Expand All @@ -20,10 +21,10 @@
class ClientTlsConfig:
"""Python representation of the Rust struct for configuring TLS."""

server_root_ca_cert: Optional[bytes]
domain: Optional[str]
client_cert: Optional[bytes]
client_private_key: Optional[bytes]
server_root_ca_cert: bytes | None
domain: str | None
client_cert: bytes | None
client_private_key: bytes | None


@dataclass
Expand All @@ -34,7 +35,7 @@ class ClientRetryConfig:
randomization_factor: float
multiplier: float
max_interval_millis: int
max_elapsed_time_millis: Optional[int]
max_elapsed_time_millis: int | None
max_retries: int


Expand All @@ -51,23 +52,23 @@ class ClientHttpConnectProxyConfig:
"""Python representation of the Rust struct for configuring HTTP proxy."""

target_host: str
basic_auth: Optional[Tuple[str, str]]
basic_auth: tuple[str, str] | None


@dataclass
class ClientConfig:
"""Python representation of the Rust struct for configuring the client."""

target_url: str
metadata: Mapping[str, Union[str, bytes]]
api_key: Optional[str]
metadata: Mapping[str, str | bytes]
api_key: str | None
identity: str
tls_config: Optional[ClientTlsConfig]
retry_config: Optional[ClientRetryConfig]
keep_alive_config: Optional[ClientKeepAliveConfig]
tls_config: ClientTlsConfig | None
retry_config: ClientRetryConfig | None
keep_alive_config: ClientKeepAliveConfig | None
client_name: str
client_version: str
http_connect_proxy_config: Optional[ClientHttpConnectProxyConfig]
http_connect_proxy_config: ClientHttpConnectProxyConfig | None


@dataclass
Expand All @@ -77,8 +78,8 @@ class RpcCall:
rpc: str
req: bytes
retry: bool
metadata: Mapping[str, Union[str, bytes]]
timeout_millis: Optional[int]
metadata: Mapping[str, str | bytes]
timeout_millis: int | None


ProtoMessage = TypeVar("ProtoMessage", bound=google.protobuf.message.Message)
Expand Down Expand Up @@ -108,11 +109,11 @@ def __init__(
self._runtime = runtime
self._ref = ref

def update_metadata(self, metadata: Mapping[str, Union[str, bytes]]) -> None:
def update_metadata(self, metadata: Mapping[str, str | bytes]) -> None:
"""Update underlying metadata on Core client."""
self._ref.update_metadata(metadata)

def update_api_key(self, api_key: Optional[str]) -> None:
def update_api_key(self, api_key: str | None) -> None:
"""Update underlying API key on Core client."""
self._ref.update_api_key(api_key)

Expand All @@ -122,10 +123,10 @@ async def call(
service: str,
rpc: str,
req: google.protobuf.message.Message,
resp_type: Type[ProtoMessage],
resp_type: type[ProtoMessage],
retry: bool,
metadata: Mapping[str, Union[str, bytes]],
timeout: Optional[timedelta],
metadata: Mapping[str, str | bytes],
timeout: timedelta | None,
) -> ProtoMessage:
"""Make RPC call using SDK Core."""
# Prepare call
Expand Down
Loading
Loading