Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 58 additions & 32 deletions th_cli/api_lib_autogen/api/test_run_executions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
# flake8: noqa E501
from asyncio import get_event_loop
from typing import IO, TYPE_CHECKING, Any, Awaitable, Dict, List, Optional
from typing import IO, TYPE_CHECKING, Any, Awaitable

from fastapi.encoders import jsonable_encoder

Expand All @@ -31,12 +31,12 @@ def __init__(self, api_client: "ApiClient"):

def _build_for_abort_testing_api_v1_test_run_executions_abort_testing_post(
self,
) -> Awaitable[Dict[str, str]]:
) -> Awaitable[dict[str, str]]:
"""
Cancel the current testing
"""
return self.api_client.request(
type_=Dict[str, str],
type_=dict[str, str],
method="POST",
url="/api/v1/test_run_executions/abort-testing",
)
Expand Down Expand Up @@ -81,7 +81,7 @@ def _build_for_create_test_run_execution_cli_api_v1_test_run_executions_cli_post
)

def _build_for_download_log_api_v1_test_run_executions_id_log_get(
self, id: int, json_entries: Optional[bool] = None, download: Optional[bool] = None
self, id: int, json_entries: bool | None = None, download: bool | None = None
) -> Awaitable[None]:
"""
Download the logs from a test run. Args: id (int): Id of the TestRunExectution the log is requested for json_entries (bool, optional): When set, return each log line as a json object download (bool, optional): When set, return as attachment
Expand Down Expand Up @@ -116,14 +116,29 @@ def _build_for_get_test_runner_status_api_v1_test_run_executions_status_get(

def _build_for_get_chip_server_info_api_v1_test_run_executions_chip_server_info_get(
self,
discriminator: str | None = None,
setup_pin_code: str | None = None,
version: int | None = None,
vendor_id: int | None = None,
product_id: int | None = None,
) -> Awaitable[m.ChipServerInfo]:
"""
Retrieve ChipServer node information.
Retrieve ChipServer node ID information and optionally generate manual pairing code.
"""
all_params = {
"discriminator": discriminator,
"setup_pin_code": setup_pin_code,
"version": version,
"vendor_id": vendor_id,
"product_id": product_id,
}
query_params = {k: str(v) for k, v in all_params.items() if v is not None}

return self.api_client.request(
type_=m.ChipServerInfo,
method="GET",
url="/api/v1/test_run_executions/chip-server/info",
params=query_params,
)

def _build_for_read_test_run_execution_api_v1_test_run_executions_id_get(
Expand All @@ -143,12 +158,12 @@ def _build_for_read_test_run_execution_api_v1_test_run_executions_id_get(

def _build_for_read_test_run_executions_api_v1_test_run_executions_get(
self,
project_id: Optional[int] = None,
archived: Optional[bool] = None,
search_query: Optional[str] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> Awaitable[List[m.TestRunExecutionWithStats]]:
project_id: int | None = None,
archived: bool | None = None,
search_query: str | None = None,
skip: int | None = None,
limit: int | None = None,
) -> Awaitable[list[m.TestRunExecutionWithStats]]:
"""
Retrieve test runs, including statistics. Args: project_id: Filter test runs by project. archived: Get archived test runs, when true will return archived test runs only, when false only non-archived test runs are returned. skip: Pagination offset. limit: Max number of records to return. Returns: List of test runs with execution statistics.
"""
Expand All @@ -165,7 +180,7 @@ def _build_for_read_test_run_executions_api_v1_test_run_executions_get(
query_params["limit"] = str(limit)

return self.api_client.request(
type_=List[m.TestRunExecutionWithStats],
type_=list[m.TestRunExecutionWithStats],
method="GET",
url="/api/v1/test_run_executions/",
params=query_params,
Expand Down Expand Up @@ -220,8 +235,8 @@ def _build_for_upload_file_api_v1_test_run_executions_file_upload_post(self, fil
"""
Upload a file to the specified path of the current test run. Args: file: The file to upload.
"""
files: Dict[str, IO[Any]] = {} # noqa F841
data: Dict[str, Any] = {} # noqa F841
files: dict[str, IO[Any]] = {} # noqa F841
data: dict[str, Any] = {} # noqa F841
files["file"] = file

return self.api_client.request(
Expand All @@ -232,7 +247,7 @@ def _build_for_upload_file_api_v1_test_run_executions_file_upload_post(self, fil
class AsyncTestRunExecutionsApi(_TestRunExecutionsApi):
async def abort_testing_api_v1_test_run_executions_abort_testing_post(
self,
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Cancel the current testing
"""
Expand Down Expand Up @@ -267,7 +282,7 @@ async def create_test_run_execution_cli_api_v1_test_run_executions_cli_post(
)

async def download_log_api_v1_test_run_executions_id_log_get(
self, id: int, json_entries: Optional[bool] = None, download: Optional[bool] = None
self, id: int, json_entries: bool | None = None, download: bool | None = None
) -> None:
"""
Download the logs from a test run. Args: id (int): Id of the TestRunExectution the log is requested for json_entries (bool, optional): When set, return each log line as a json object download (bool, optional): When set, return as attachment
Expand All @@ -286,11 +301,22 @@ async def get_test_runner_status_api_v1_test_run_executions_status_get(

async def get_chip_server_info_api_v1_test_run_executions_chip_server_info_get(
self,
discriminator: str | None = None,
setup_pin_code: str | None = None,
version: int | None = None,
vendor_id: int | None = None,
product_id: int | None = None,
) -> m.ChipServerInfo:
"""
Retrieve ChipServer node information.
Retrieve ChipServer node ID information and optionally generate manual pairing code.
"""
return await self._build_for_get_chip_server_info_api_v1_test_run_executions_chip_server_info_get()
return await self._build_for_get_chip_server_info_api_v1_test_run_executions_chip_server_info_get(
discriminator=discriminator,
setup_pin_code=setup_pin_code,
version=version,
vendor_id=vendor_id,
product_id=product_id,
)

async def read_test_run_execution_api_v1_test_run_executions_id_get(
self, id: int
Expand All @@ -302,12 +328,12 @@ async def read_test_run_execution_api_v1_test_run_executions_id_get(

async def read_test_run_executions_api_v1_test_run_executions_get(
self,
project_id: Optional[int] = None,
archived: Optional[bool] = None,
search_query: Optional[str] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> List[m.TestRunExecutionWithStats]:
project_id: int | None = None,
archived: bool | None = None,
search_query: str | None = None,
skip: int | None = None,
limit: int | None = None,
) -> list[m.TestRunExecutionWithStats]:
"""
Retrieve test runs, including statistics. Args: project_id: Filter test runs by project. archived: Get archived test runs, when true will return archived test runs only, when false only non-archived test runs are returned. skip: Pagination offset. limit: Max number of records to return. Returns: List of test runs with execution statistics.
"""
Expand Down Expand Up @@ -347,7 +373,7 @@ async def upload_file_api_v1_test_run_executions_file_upload_post(self, file: IO
class SyncTestRunExecutionsApi(_TestRunExecutionsApi):
def abort_testing_api_v1_test_run_executions_abort_testing_post(
self,
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Cancel the current testing
"""
Expand Down Expand Up @@ -386,7 +412,7 @@ def create_test_run_execution_cli_api_v1_test_run_executions_cli_post(
return get_event_loop().run_until_complete(coroutine)

def download_log_api_v1_test_run_executions_id_log_get(
self, id: int, json_entries: Optional[bool] = None, download: Optional[bool] = None
self, id: int, json_entries: bool | None = None, download: bool | None = None
) -> None:
"""
Download the logs from a test run. Args: id (int): Id of the TestRunExectution the log is requested for json_entries (bool, optional): When set, return each log line as a json object download (bool, optional): When set, return as attachment
Expand Down Expand Up @@ -414,12 +440,12 @@ def read_test_run_execution_api_v1_test_run_executions_id_get(self, id: int) ->

def read_test_run_executions_api_v1_test_run_executions_get(
self,
project_id: Optional[int] = None,
archived: Optional[bool] = None,
search_query: Optional[str] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> List[m.TestRunExecutionWithStats]:
project_id: int | None = None,
archived: bool | None = None,
search_query: str | None = None,
skip: int | None = None,
limit: int | None = None,
) -> list[m.TestRunExecutionWithStats]:
"""
Retrieve test runs, including statistics. Args: project_id: Filter test runs by project. archived: Get archived test runs, when true will return archived test runs only, when false only non-archived test runs are returned. skip: Pagination offset. limit: Max number of records to return. Returns: List of test runs with execution statistics.
"""
Expand Down
1 change: 1 addition & 0 deletions th_cli/api_lib_autogen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class TestRunnerStatus(BaseModel):
class ChipServerInfo(BaseModel):
node_id: "int" = Field(..., alias="node_id")
node_id_hex: "str" = Field(..., alias="node_id_hex")
manual_pairing_code: "str | None" = Field(None, alias="manual_pairing_code")


class TestStepExecution(BaseModel):
Expand Down
17 changes: 3 additions & 14 deletions th_cli/commands/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def run_tests(
pics=pics,
project_id=project_id,
)
socket = TestRunSocket(new_test_run)
socket = TestRunSocket(new_test_run, project_config_dict)
socket_task = asyncio.create_task(socket.connect_websocket())
new_test_run = await __start_test_run(async_apis, new_test_run)
socket.run = new_test_run
Expand All @@ -170,9 +170,7 @@ async def run_tests(
await client.aclose()


async def __project_config(
async_apis: AsyncApis, project_id: int | None = None
) -> m.TestEnvironmentConfig:
async def __project_config(async_apis: AsyncApis, project_id: int | None = None) -> m.TestEnvironmentConfig:
"""Retrieve project configuration for given project ID or default configuration if none provided."""
projects_api = async_apis.projects_api

Expand Down Expand Up @@ -223,16 +221,7 @@ async def __start_test_run(
id = colorize_key_value("ID", str(test_run.id))

click.echo("")
click.echo(f"{header}:\n- {title}\n- {id}")

# Fetch and display ChipServer node ID information
try:
chip_info = await test_run_executions_api.get_chip_server_info_api_v1_test_run_executions_chip_server_info_get()
node_id_hex = colorize_key_value("Node ID", chip_info.node_id_hex)
click.echo(f"- {node_id_hex}\n")
except UnexpectedResponse as e:
# If we can't fetch node_id, don't fail - just skip displaying it
click.echo(f"Could not fetch ChipServer node ID information: {e}\n")
click.echo(f"{header}:\n- {title}\n- {id}\n")

try:
return await test_run_executions_api.start_test_run_execution_api_v1_test_run_executions_id_start_post(
Expand Down
72 changes: 67 additions & 5 deletions th_cli/test_run/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,22 @@
from websockets.client import WebSocketClientProtocol
from websockets.client import connect as websocket_connect

from th_cli.api_lib_autogen.api_client import AsyncApis
from th_cli.api_lib_autogen.models import (
TestCaseExecution,
TestRunExecutionWithChildren,
TestStepExecution,
TestSuiteExecution,
)
from th_cli.colorize import HierarchyEnum, colorize_error, colorize_hierarchy_prefix, colorize_state
from th_cli.client import get_client
from th_cli.colorize import (
HierarchyEnum,
colorize_error,
colorize_header,
colorize_hierarchy_prefix,
colorize_key_value,
colorize_state,
)
from th_cli.config import config
from th_cli.shared_constants import MessageTypeEnum

Expand Down Expand Up @@ -55,8 +64,10 @@


class TestRunSocket:
def __init__(self, run: TestRunExecutionWithChildren):
def __init__(self, run: TestRunExecutionWithChildren, project_config_dict: dict | None = None):
self.run = run
self.project_config_dict = project_config_dict or {}
self._chip_server_info_displayed = False
# Track test step errors for WebRTC detection
# Key: (suite_index, case_index), Value: list of error strings from all steps
self.test_case_step_errors: dict[tuple[int, int], list[str]] = {}
Expand Down Expand Up @@ -119,16 +130,66 @@ async def __handle_test_update(self, socket: WebSocketClientProtocol, update: Te
elif isinstance(update.body, TestSuiteUpdate):
self.__log_test_suite_update(update.body)
elif isinstance(update.body, TestRunUpdate):
self.__log_test_run_update(update.body)
await self.__log_test_run_update(update.body)
if update.body.state != "executing":
# Test run ended disconnect.
await socket.close()

def __log_test_run_update(self, update: TestRunUpdate) -> None:
async def __log_test_run_update(self, update: TestRunUpdate) -> None:
# Display CHIP server info when test run starts executing (SDK container already running)
if update.state.value == "executing" and not self._chip_server_info_displayed:
await self.__display_manual_pairing_code()
self._chip_server_info_displayed = True

test_run_text = colorize_hierarchy_prefix("Test Run", HierarchyEnum.TEST_RUN.value)
colored_state = colorize_state(update.state.value)
click.echo(f"{test_run_text} {colored_state}")

async def __display_manual_pairing_code(self) -> None:
"""Fetch and display manual pairing code after SDK container has started."""
try:
# Extract device configuration
dut_config = self.project_config_dict.get("dut_config", {})
discriminator = dut_config.get("discriminator")
setup_pin_code = dut_config.get("setup_code")

if not discriminator or not setup_pin_code:
return # No device config available

# Extract vendor_id and product_id from test_parameters if available
test_parameters = self.project_config_dict.get("test_parameters", {})
version = test_parameters.get("version")
vendor_id = test_parameters.get("vendor_id")
product_id = test_parameters.get("product_id")

# Create API client and fetch chip server info
client = get_client()
try:
test_run_api = AsyncApis(client).test_run_executions_api
chip_info = await test_run_api.get_chip_server_info_api_v1_test_run_executions_chip_server_info_get(
discriminator=discriminator,
setup_pin_code=setup_pin_code,
version=version,
vendor_id=vendor_id,
product_id=product_id,
)

if chip_info.manual_pairing_code:
node_id = colorize_key_value("Node ID", chip_info.node_id_hex)
manual_code = colorize_key_value("Manual Pairing Code", chip_info.manual_pairing_code)
click.echo("═══════════════════════════════════════════════════════")
click.echo(colorize_header("CHIP Server Information:"))
click.echo(f"- {node_id}")
click.echo(f"- {manual_code}")
click.echo("═══════════════════════════════════════════════════════")
click.echo("")
finally:
await client.aclose()

except Exception as e:
logger.debug(f"Could not fetch manual pairing code: {e}")
# Don't fail the test run if we can't get the pairing code

def __log_test_suite_update(self, update: TestSuiteUpdate) -> None:
suite = self.__suite(update.test_suite_execution_index)
title = suite.test_suite_metadata.title
Expand Down Expand Up @@ -159,7 +220,8 @@ def __log_test_case_update(self, update: TestCaseUpdate) -> None:
if case_key in self.test_case_step_errors:
all_errors.extend(self.test_case_step_errors[case_key])
logger.debug(
f"Found {len(self.test_case_step_errors[case_key])} tracked step error(s): {self.test_case_step_errors[case_key]}"
f"Found {len(self.test_case_step_errors[case_key])} tracked step error(s): "
f"{self.test_case_step_errors[case_key]}"
)
else:
logger.debug(f"No tracked step errors found for test case {case_key}")
Expand Down