diff --git a/th_cli/api_lib_autogen/api/test_run_executions_api.py b/th_cli/api_lib_autogen/api/test_run_executions_api.py index b74b67d..e0d8681 100644 --- a/th_cli/api_lib_autogen/api/test_run_executions_api.py +++ b/th_cli/api_lib_autogen/api/test_run_executions_api.py @@ -15,7 +15,7 @@ # # flake8: noqa E501 from asyncio import get_event_loop -from typing import IO, TYPE_CHECKING, Any, Awaitable, Dict, List, Optional +from typing import IO, TYPE_CHECKING, Any, Awaitable from fastapi.encoders import jsonable_encoder @@ -31,12 +31,12 @@ def __init__(self, api_client: "ApiClient"): def _build_for_abort_testing_api_v1_test_run_executions_abort_testing_post( self, - ) -> Awaitable[Dict[str, str]]: + ) -> Awaitable[dict[str, str]]: """ Cancel the current testing """ return self.api_client.request( - type_=Dict[str, str], + type_=dict[str, str], method="POST", url="/api/v1/test_run_executions/abort-testing", ) @@ -81,7 +81,7 @@ def _build_for_create_test_run_execution_cli_api_v1_test_run_executions_cli_post ) def _build_for_download_log_api_v1_test_run_executions_id_log_get( - self, id: int, json_entries: Optional[bool] = None, download: Optional[bool] = None + self, id: int, json_entries: bool | None = None, download: bool | None = None ) -> Awaitable[None]: """ Download the logs from a test run. Args: id (int): Id of the TestRunExectution the log is requested for json_entries (bool, optional): When set, return each log line as a json object download (bool, optional): When set, return as attachment @@ -116,14 +116,29 @@ def _build_for_get_test_runner_status_api_v1_test_run_executions_status_get( def _build_for_get_chip_server_info_api_v1_test_run_executions_chip_server_info_get( self, + discriminator: str | None = None, + setup_pin_code: str | None = None, + version: int | None = None, + vendor_id: int | None = None, + product_id: int | None = None, ) -> Awaitable[m.ChipServerInfo]: """ - Retrieve ChipServer node information. + Retrieve ChipServer node ID information and optionally generate manual pairing code. """ + all_params = { + "discriminator": discriminator, + "setup_pin_code": setup_pin_code, + "version": version, + "vendor_id": vendor_id, + "product_id": product_id, + } + query_params = {k: str(v) for k, v in all_params.items() if v is not None} + return self.api_client.request( type_=m.ChipServerInfo, method="GET", url="/api/v1/test_run_executions/chip-server/info", + params=query_params, ) def _build_for_read_test_run_execution_api_v1_test_run_executions_id_get( @@ -143,12 +158,12 @@ def _build_for_read_test_run_execution_api_v1_test_run_executions_id_get( def _build_for_read_test_run_executions_api_v1_test_run_executions_get( self, - project_id: Optional[int] = None, - archived: Optional[bool] = None, - search_query: Optional[str] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - ) -> Awaitable[List[m.TestRunExecutionWithStats]]: + project_id: int | None = None, + archived: bool | None = None, + search_query: str | None = None, + skip: int | None = None, + limit: int | None = None, + ) -> Awaitable[list[m.TestRunExecutionWithStats]]: """ Retrieve test runs, including statistics. Args: project_id: Filter test runs by project. archived: Get archived test runs, when true will return archived test runs only, when false only non-archived test runs are returned. skip: Pagination offset. limit: Max number of records to return. Returns: List of test runs with execution statistics. """ @@ -165,7 +180,7 @@ def _build_for_read_test_run_executions_api_v1_test_run_executions_get( query_params["limit"] = str(limit) return self.api_client.request( - type_=List[m.TestRunExecutionWithStats], + type_=list[m.TestRunExecutionWithStats], method="GET", url="/api/v1/test_run_executions/", params=query_params, @@ -220,8 +235,8 @@ def _build_for_upload_file_api_v1_test_run_executions_file_upload_post(self, fil """ Upload a file to the specified path of the current test run. Args: file: The file to upload. """ - files: Dict[str, IO[Any]] = {} # noqa F841 - data: Dict[str, Any] = {} # noqa F841 + files: dict[str, IO[Any]] = {} # noqa F841 + data: dict[str, Any] = {} # noqa F841 files["file"] = file return self.api_client.request( @@ -232,7 +247,7 @@ def _build_for_upload_file_api_v1_test_run_executions_file_upload_post(self, fil class AsyncTestRunExecutionsApi(_TestRunExecutionsApi): async def abort_testing_api_v1_test_run_executions_abort_testing_post( self, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Cancel the current testing """ @@ -267,7 +282,7 @@ async def create_test_run_execution_cli_api_v1_test_run_executions_cli_post( ) async def download_log_api_v1_test_run_executions_id_log_get( - self, id: int, json_entries: Optional[bool] = None, download: Optional[bool] = None + self, id: int, json_entries: bool | None = None, download: bool | None = None ) -> None: """ Download the logs from a test run. Args: id (int): Id of the TestRunExectution the log is requested for json_entries (bool, optional): When set, return each log line as a json object download (bool, optional): When set, return as attachment @@ -286,11 +301,22 @@ async def get_test_runner_status_api_v1_test_run_executions_status_get( async def get_chip_server_info_api_v1_test_run_executions_chip_server_info_get( self, + discriminator: str | None = None, + setup_pin_code: str | None = None, + version: int | None = None, + vendor_id: int | None = None, + product_id: int | None = None, ) -> m.ChipServerInfo: """ - Retrieve ChipServer node information. + Retrieve ChipServer node ID information and optionally generate manual pairing code. """ - return await self._build_for_get_chip_server_info_api_v1_test_run_executions_chip_server_info_get() + return await self._build_for_get_chip_server_info_api_v1_test_run_executions_chip_server_info_get( + discriminator=discriminator, + setup_pin_code=setup_pin_code, + version=version, + vendor_id=vendor_id, + product_id=product_id, + ) async def read_test_run_execution_api_v1_test_run_executions_id_get( self, id: int @@ -302,12 +328,12 @@ async def read_test_run_execution_api_v1_test_run_executions_id_get( async def read_test_run_executions_api_v1_test_run_executions_get( self, - project_id: Optional[int] = None, - archived: Optional[bool] = None, - search_query: Optional[str] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - ) -> List[m.TestRunExecutionWithStats]: + project_id: int | None = None, + archived: bool | None = None, + search_query: str | None = None, + skip: int | None = None, + limit: int | None = None, + ) -> list[m.TestRunExecutionWithStats]: """ Retrieve test runs, including statistics. Args: project_id: Filter test runs by project. archived: Get archived test runs, when true will return archived test runs only, when false only non-archived test runs are returned. skip: Pagination offset. limit: Max number of records to return. Returns: List of test runs with execution statistics. """ @@ -347,7 +373,7 @@ async def upload_file_api_v1_test_run_executions_file_upload_post(self, file: IO class SyncTestRunExecutionsApi(_TestRunExecutionsApi): def abort_testing_api_v1_test_run_executions_abort_testing_post( self, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Cancel the current testing """ @@ -386,7 +412,7 @@ def create_test_run_execution_cli_api_v1_test_run_executions_cli_post( return get_event_loop().run_until_complete(coroutine) def download_log_api_v1_test_run_executions_id_log_get( - self, id: int, json_entries: Optional[bool] = None, download: Optional[bool] = None + self, id: int, json_entries: bool | None = None, download: bool | None = None ) -> None: """ Download the logs from a test run. Args: id (int): Id of the TestRunExectution the log is requested for json_entries (bool, optional): When set, return each log line as a json object download (bool, optional): When set, return as attachment @@ -414,12 +440,12 @@ def read_test_run_execution_api_v1_test_run_executions_id_get(self, id: int) -> def read_test_run_executions_api_v1_test_run_executions_get( self, - project_id: Optional[int] = None, - archived: Optional[bool] = None, - search_query: Optional[str] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - ) -> List[m.TestRunExecutionWithStats]: + project_id: int | None = None, + archived: bool | None = None, + search_query: str | None = None, + skip: int | None = None, + limit: int | None = None, + ) -> list[m.TestRunExecutionWithStats]: """ Retrieve test runs, including statistics. Args: project_id: Filter test runs by project. archived: Get archived test runs, when true will return archived test runs only, when false only non-archived test runs are returned. skip: Pagination offset. limit: Max number of records to return. Returns: List of test runs with execution statistics. """ diff --git a/th_cli/api_lib_autogen/models.py b/th_cli/api_lib_autogen/models.py index 916c332..fb7958b 100644 --- a/th_cli/api_lib_autogen/models.py +++ b/th_cli/api_lib_autogen/models.py @@ -234,6 +234,7 @@ class TestRunnerStatus(BaseModel): class ChipServerInfo(BaseModel): node_id: "int" = Field(..., alias="node_id") node_id_hex: "str" = Field(..., alias="node_id_hex") + manual_pairing_code: "str | None" = Field(None, alias="manual_pairing_code") class TestStepExecution(BaseModel): diff --git a/th_cli/commands/run_tests.py b/th_cli/commands/run_tests.py index d540398..263497c 100644 --- a/th_cli/commands/run_tests.py +++ b/th_cli/commands/run_tests.py @@ -155,7 +155,7 @@ async def run_tests( pics=pics, project_id=project_id, ) - socket = TestRunSocket(new_test_run) + socket = TestRunSocket(new_test_run, project_config_dict) socket_task = asyncio.create_task(socket.connect_websocket()) new_test_run = await __start_test_run(async_apis, new_test_run) socket.run = new_test_run @@ -170,9 +170,7 @@ async def run_tests( await client.aclose() -async def __project_config( - async_apis: AsyncApis, project_id: int | None = None -) -> m.TestEnvironmentConfig: +async def __project_config(async_apis: AsyncApis, project_id: int | None = None) -> m.TestEnvironmentConfig: """Retrieve project configuration for given project ID or default configuration if none provided.""" projects_api = async_apis.projects_api @@ -223,16 +221,7 @@ async def __start_test_run( id = colorize_key_value("ID", str(test_run.id)) click.echo("") - click.echo(f"{header}:\n- {title}\n- {id}") - - # Fetch and display ChipServer node ID information - try: - chip_info = await test_run_executions_api.get_chip_server_info_api_v1_test_run_executions_chip_server_info_get() - node_id_hex = colorize_key_value("Node ID", chip_info.node_id_hex) - click.echo(f"- {node_id_hex}\n") - except UnexpectedResponse as e: - # If we can't fetch node_id, don't fail - just skip displaying it - click.echo(f"Could not fetch ChipServer node ID information: {e}\n") + click.echo(f"{header}:\n- {title}\n- {id}\n") try: return await test_run_executions_api.start_test_run_execution_api_v1_test_run_executions_id_start_post( diff --git a/th_cli/test_run/websocket.py b/th_cli/test_run/websocket.py index cd62ecc..7ed57d7 100644 --- a/th_cli/test_run/websocket.py +++ b/th_cli/test_run/websocket.py @@ -20,13 +20,22 @@ from websockets.client import WebSocketClientProtocol from websockets.client import connect as websocket_connect +from th_cli.api_lib_autogen.api_client import AsyncApis from th_cli.api_lib_autogen.models import ( TestCaseExecution, TestRunExecutionWithChildren, TestStepExecution, TestSuiteExecution, ) -from th_cli.colorize import HierarchyEnum, colorize_error, colorize_hierarchy_prefix, colorize_state +from th_cli.client import get_client +from th_cli.colorize import ( + HierarchyEnum, + colorize_error, + colorize_header, + colorize_hierarchy_prefix, + colorize_key_value, + colorize_state, +) from th_cli.config import config from th_cli.shared_constants import MessageTypeEnum @@ -55,8 +64,10 @@ class TestRunSocket: - def __init__(self, run: TestRunExecutionWithChildren): + def __init__(self, run: TestRunExecutionWithChildren, project_config_dict: dict | None = None): self.run = run + self.project_config_dict = project_config_dict or {} + self._chip_server_info_displayed = False # Track test step errors for WebRTC detection # Key: (suite_index, case_index), Value: list of error strings from all steps self.test_case_step_errors: dict[tuple[int, int], list[str]] = {} @@ -119,16 +130,66 @@ async def __handle_test_update(self, socket: WebSocketClientProtocol, update: Te elif isinstance(update.body, TestSuiteUpdate): self.__log_test_suite_update(update.body) elif isinstance(update.body, TestRunUpdate): - self.__log_test_run_update(update.body) + await self.__log_test_run_update(update.body) if update.body.state != "executing": # Test run ended disconnect. await socket.close() - def __log_test_run_update(self, update: TestRunUpdate) -> None: + async def __log_test_run_update(self, update: TestRunUpdate) -> None: + # Display CHIP server info when test run starts executing (SDK container already running) + if update.state.value == "executing" and not self._chip_server_info_displayed: + await self.__display_manual_pairing_code() + self._chip_server_info_displayed = True + test_run_text = colorize_hierarchy_prefix("Test Run", HierarchyEnum.TEST_RUN.value) colored_state = colorize_state(update.state.value) click.echo(f"{test_run_text} {colored_state}") + async def __display_manual_pairing_code(self) -> None: + """Fetch and display manual pairing code after SDK container has started.""" + try: + # Extract device configuration + dut_config = self.project_config_dict.get("dut_config", {}) + discriminator = dut_config.get("discriminator") + setup_pin_code = dut_config.get("setup_code") + + if not discriminator or not setup_pin_code: + return # No device config available + + # Extract vendor_id and product_id from test_parameters if available + test_parameters = self.project_config_dict.get("test_parameters", {}) + version = test_parameters.get("version") + vendor_id = test_parameters.get("vendor_id") + product_id = test_parameters.get("product_id") + + # Create API client and fetch chip server info + client = get_client() + try: + test_run_api = AsyncApis(client).test_run_executions_api + chip_info = await test_run_api.get_chip_server_info_api_v1_test_run_executions_chip_server_info_get( + discriminator=discriminator, + setup_pin_code=setup_pin_code, + version=version, + vendor_id=vendor_id, + product_id=product_id, + ) + + if chip_info.manual_pairing_code: + node_id = colorize_key_value("Node ID", chip_info.node_id_hex) + manual_code = colorize_key_value("Manual Pairing Code", chip_info.manual_pairing_code) + click.echo("═══════════════════════════════════════════════════════") + click.echo(colorize_header("CHIP Server Information:")) + click.echo(f"- {node_id}") + click.echo(f"- {manual_code}") + click.echo("═══════════════════════════════════════════════════════") + click.echo("") + finally: + await client.aclose() + + except Exception as e: + logger.debug(f"Could not fetch manual pairing code: {e}") + # Don't fail the test run if we can't get the pairing code + def __log_test_suite_update(self, update: TestSuiteUpdate) -> None: suite = self.__suite(update.test_suite_execution_index) title = suite.test_suite_metadata.title @@ -159,7 +220,8 @@ def __log_test_case_update(self, update: TestCaseUpdate) -> None: if case_key in self.test_case_step_errors: all_errors.extend(self.test_case_step_errors[case_key]) logger.debug( - f"Found {len(self.test_case_step_errors[case_key])} tracked step error(s): {self.test_case_step_errors[case_key]}" + f"Found {len(self.test_case_step_errors[case_key])} tracked step error(s): " + f"{self.test_case_step_errors[case_key]}" ) else: logger.debug(f"No tracked step errors found for test case {case_key}")