Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lisa/platform_.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def _get_environment_information(self, environment: Environment) -> Dict[str, st

def _get_node_information(self, node: Node) -> Dict[str, str]:
return {}

def _get_runbook_information(self, runbook: schema.TypedSchema) -> Dict[str, str]:
return {}

def _cleanup(self) -> None:
"""
Expand Down Expand Up @@ -140,6 +143,18 @@ def get_node_information(self, node: Node) -> Dict[str, str]:

return information

@hookimpl
def get_runbook_information(self, runbook: Any) -> Dict[str, str]:
information: Dict[str, str] = {}
try:
information.update(self._get_runbook_information(runbook=self.runbook))
except Exception as identifier:
self._log.exception(
"failed to get runbook information on platform", exc_info=identifier
)

return information

def prepare_environment(self, environment: Environment) -> Environment:
"""
return prioritized environments.
Expand Down
13 changes: 13 additions & 0 deletions lisa/sut_orchestrator/azure/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ def azure_update_arm_template(
"""
...

@hookspec
def azure_update_sku_capability(
self, vmsize: str, environment: Environment
) -> None:
"""
Implement it to update SKU capability.

Args:
vmsize: the vm sku whose capability is to be updated.
environment: the deploying environment.
"""
...


class AzureHookSpecDefaultImpl:
__error_maps: List[Tuple[str, Pattern[str], Any]] = [
Expand Down
70 changes: 68 additions & 2 deletions lisa/sut_orchestrator/azure/platform_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,47 @@ def _process_marketplace_image_plan(

return plan

def _generate_sku_capability(
self,
vm_size: str,
location: str,
cap_file: str,
) -> AzureCapability:
# some vm size cannot be queried from API, and the capability
# may be queried from capability files through hooks.
capability_dict = plugin_manager.hook.azure_add_sku_capability(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to match the hookspec azure_update_sku_capability? Or is this a different hook?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess they should be the same.

vm_size,
cap_file,
)

node_space = schema.NodeSpace(
node_count=1,
core_count=search_space.IntRange(min=1),
memory_mb=search_space.IntRange(min=0),
gpu_count=search_space.IntRange(min=0),
)

azure_capability = AzureCapability(
location=location,
vm_size=vm_size,
capability=node_space,
resource_sku={},
)

node_space.name = f"{location}_{vm_size}"
node_space.features = search_space.SetSpace[schema.FeatureSettings](
is_allow_set=True
)

# all nodes support following features
all_features = self.supported_features()
node_space.features.update(
[schema.FeatureSettings.create(x.name()) for x in all_features]
)
convert_to_azure_node_space(node_space)

return azure_capability

def _generate_max_capability(self, vm_size: str, location: str) -> AzureCapability:
# some vm size cannot be queried from API, so use default capability to
# run with best guess on capability.
Expand Down Expand Up @@ -2462,12 +2503,24 @@ def _get_normalized_vm_size(self, name: str, location: str, log: Logger) -> str:
return matched_name

def _get_capabilities(
self, vm_sizes: List[str], location: str, use_max_capability: bool, log: Logger
self,
vm_sizes: List[str],
location: str,
use_max_capability: bool,
cap_file: str,
log: Logger,
) -> List[AzureCapability]:
candidate_caps: List[AzureCapability] = []
caps = self.get_location_info(location, log).capabilities

for vm_size in vm_sizes:
# force to read SKU capability from capability file if it is provided.
if cap_file:
candidate_caps.append(self._generate_sku_capability(vm_size,
location,
cap_file))
continue

# force to use max capability to run test cases as much as possible,
# or force to support non-exists vm size.
if use_max_capability:
Expand Down Expand Up @@ -2599,6 +2652,7 @@ def _get_allowed_capabilities(
self, req: schema.NodeSpace, location: str, log: Logger
) -> Tuple[List[AzureCapability], str]:
node_runbook = req.get_extended_runbook(AzureNodeSchema, AZURE)
cap_file: str = ""
error: str = ""
if node_runbook.vm_size:
# find the vm_size
Expand All @@ -2613,7 +2667,18 @@ def _get_allowed_capabilities(
f"no vm size matched '{node_runbook.vm_size}' on location "
f"'{location}', using the raw string as vm size name."
)
allowed_vm_sizes = [node_runbook.vm_size]
# First check if capability file is appended to the vm_sizes
vm_sizes = node_runbook.vm_size
split_vm_sizes = [x.strip() for x in node_runbook.vm_size.split("|")]
if len(split_vm_sizes) > 1:
if split_vm_sizes[0].startswith("Compute"):
cap_file = split_vm_sizes[0]
vm_sizes = split_vm_sizes[1]
else:
cap_file = split_vm_sizes[1]
vm_sizes = split_vm_sizes[0]

allowed_vm_sizes = [x.strip() for x in vm_sizes.split(",")]
else:
location_info = self.get_location_info(location, log)
allowed_vm_sizes = [key for key, _ in location_info.capabilities.items()]
Expand All @@ -2624,6 +2689,7 @@ def _get_allowed_capabilities(
vm_sizes=allowed_vm_sizes,
location=location,
use_max_capability=node_runbook.maximize_capability,
cap_file=cap_file,
log=log,
)

Expand Down
Loading