diff --git a/classes/__init__.py b/classes/__init__.py index 82e4a7f..f1a8a7c 100644 --- a/classes/__init__.py +++ b/classes/__init__.py @@ -3,7 +3,21 @@ from shared import data class Response: - '''A response object''' + """HTTP response object for STAT Function modules. + + This class encapsulates the response data, status code, and content type + for HTTP responses returned by STAT Function modules. All modules should return this class. + + Args: + body: The response body content, typically a dictionary or object. + statuscode (int, optional): HTTP status code. Defaults to 200. + contenttype (str, optional): Content type header. Defaults to 'application/json'. + + Attributes: + body: The response body content. + statuscode (int): HTTP status code. + contenttype (str): Content type header. + """ def __init__(self, body, statuscode=200, contenttype='application/json'): self.body = body @@ -11,7 +25,21 @@ def __init__(self, body, statuscode=200, contenttype='application/json'): self.contenttype = contenttype class STATError(Exception): - '''A handled STAT exception''' + """Exception class for STAT Function errors. + + This exception is raised when a handled error occurs in STAT Function processing. + It includes additional context such as source error details and HTTP status codes. + + Args: + error (str): The error message describing what went wrong. + source_error (dict, optional): Additional error details from the source. Defaults to {}. + status_code (int, optional): HTTP status code associated with the error. Defaults to 400. + + Attributes: + error (str): The error message. + source_error (dict): Additional error details from the source. + status_code (int): HTTP status code associated with the error. + """ def __init__(self, error:str, source_error:dict={}, status_code:int=400): self.error = error @@ -19,11 +47,30 @@ def __init__(self, error:str, source_error:dict={}, status_code:int=400): self.status_code = status_code class STATNotFound(STATError): - '''A handled STAT exception where the API call returned a 404 error''' + """STAT exception raised when an API call returns a 404 Not Found error. + + This exception is a specialized version of STATError for cases where + a resource or endpoint could not be found. + """ pass class STATTooManyRequests(STATError): - '''A handled STAT exception where the API call returned a 429 error''' + """STAT exception raised when an API call returns a 429 Too Many Requests error. + + This exception includes retry timing information to help with rate limiting handling. + + Args: + error (str): The error message describing what went wrong. + source_error (dict, optional): Additional error details from the source. Defaults to {}. + status_code (int, optional): HTTP status code associated with the error. Defaults to 400. + retry_after (int, optional): Number of seconds to wait before retrying. Defaults to 10. + + Attributes: + error (str): The error message. + source_error (dict): Additional error details from the source. + status_code (int): HTTP status code associated with the error. + retry_after (str): String representation of seconds to wait before retrying. + """ def __init__(self, error:str, source_error:dict={}, status_code:int=400, retry_after:int=10): self.error = error self.source_error = source_error diff --git a/shared/data.py b/shared/data.py index 7b0785a..27b9bfe 100644 --- a/shared/data.py +++ b/shared/data.py @@ -8,7 +8,26 @@ tz_info = os.getenv('TIME_ZONE') def list_to_html_table(input_list:list, max_rows:int=20, max_cols:int=10, nan_str:str='N/A', escape_html:bool=True, columns:list=None, index:bool=False, justify:str='left', drop_empty_cols:bool=False): - '''Convert a list of dictionaries into an HTML table''' + """Convert a list of dictionaries into an HTML table. + + This function takes a list of dictionaries and converts it to an HTML table format, + with optional time zone and date formatting based on environment variables. + + Args: + input_list (list): List of dictionaries to convert to HTML table. + max_rows (int, optional): Maximum number of rows to display. Defaults to 20. + max_cols (int, optional): Maximum number of columns to display. Defaults to 10. + nan_str (str, optional): String to replace NaN values. Defaults to 'N/A'. + escape_html (bool, optional): Whether to escape HTML characters. Defaults to True. + columns (list, optional): Specific columns to include. Defaults to None. + index (bool, optional): Whether to include row index. Defaults to False. + justify (str, optional): Text alignment for table. Defaults to 'left'. + drop_empty_cols (bool, optional): Whether to drop empty columns. Defaults to False. + + Returns: + str: HTML table representation of the input data. + """ + df = pd.DataFrame(input_list) df.index = df.index + 1 @@ -49,7 +68,22 @@ def list_to_html_table(input_list:list, max_rows:int=20, max_cols:int=10, nan_st return html_table def update_column_value_in_list(input_list:list, col_name:str, update_str:str): - '''Updates the value of a column in each dict in the list with a value from another column, to include the column value in your replacement use [col_value]''' + """Update column values in a list of dictionaries with templated strings. + + Updates the value of a specified column in each dictionary within the list. + The update string can include [col_value] as a placeholder which will be + replaced with the current column value. + + Args: + input_list (list): List of dictionaries to update. + col_name (str): Name of the column to update. + update_str (str): Template string for the new value. Use [col_value] + to reference the current column value. + + Returns: + list: New list with updated dictionaries. + """ + updated_list = [] for row in input_list: current_row = copy.copy(row) @@ -59,7 +93,21 @@ def update_column_value_in_list(input_list:list, col_name:str, update_str:str): return updated_list def replace_column_value_in_list(input_list:list, col_name:str, original_value:str, replacement_value:str): - '''Updates the value of a column in each dict in the list, with a static value''' + """Replace specific values in a column across a list of dictionaries. + + Updates the value of a specified column in each dictionary by replacing + occurrences of the original value with the replacement value. + + Args: + input_list (list): List of dictionaries to update. + col_name (str): Name of the column to update. + original_value (str): Value to be replaced. + replacement_value (str): Value to replace with. + + Returns: + list: New list with updated dictionaries. + """ + updated_list = [] for row in input_list: current_row = copy.copy(row) @@ -69,7 +117,20 @@ def replace_column_value_in_list(input_list:list, col_name:str, original_value:s return updated_list def return_highest_value(input_list:list, key:str, order:list=['High','Medium','Low','Informational','None','Unknown']): - '''Locate the highest value in a list of dictionaries by key''' + """Find the highest value in a list of dictionaries. + + Searches through a list of dictionaries to find the highest value + for a specific key, based on a predefined priority order. + + Args: + input_list (list): List of dictionaries to search. + key (str): Key name to check in each dictionary. + order (list, optional): Priority order from highest to lowest. + Defaults to ['High','Medium','Low','Informational','None','Unknown']. + + Returns: + str: The highest priority value found, or 'Unknown' if none match. + """ unsorted_list = [] for item in input_list: @@ -82,7 +143,20 @@ def return_highest_value(input_list:list, key:str, order:list=['High','Medium',' return 'Unknown' def join_lists(left_list, right_list, kind, left_key, right_key, fill_nan=None): - '''Join 2 lists of objects using a key. Supported join kinds left, right, outer, inner, cross''' + """Join two lists of dictionaries using specified keys and join kind. + + Args: + left_list (list): List of dictionaries to join from the left. + right_list (list): List of dictionaries to join from the right. + kind (str): Type of join to perform ('left', 'right', 'outer', 'inner', 'cross'). + left_key (str): Key name in the left list to join on. + right_key (str): Key name in the right list to join on. + fill_nan (any, optional): Value to fill NaN entries in the result. + + Returns: + list: List of dictionaries resulting from the join operation. + """ + left_df = pd.DataFrame(left_list) right_df = pd.DataFrame(right_list) join_data = left_df.merge(right=right_df, how=kind, left_on=left_key, right_on=right_key) @@ -94,6 +168,16 @@ def join_lists(left_list, right_list, kind, left_key, right_key, fill_nan=None): return join_data.to_dict('records') def sum_column_by_key(input_list, key): + """Calculate the sum of a numeric column in a list of dictionaries. + + Args: + input_list (list): List of dictionaries containing the data. + key (str): Key name of the column to sum. + + Returns: + int: Sum of the column values, or 0 if key doesn't exist. + """ + df = pd.DataFrame(input_list) try: val = int(df[key].sum()) @@ -102,6 +186,16 @@ def sum_column_by_key(input_list, key): return val def max_column_by_key(input_list, key): + """Find the maximum value in a numeric column of a list of dictionaries. + + Args: + input_list (list): List of dictionaries containing the data. + key (str): Key name of the column to find maximum value. + + Returns: + int: Maximum value in the column, or 0 if key doesn't exist. + """ + df = pd.DataFrame(input_list) try: val = int(df[key].max()) @@ -110,6 +204,16 @@ def max_column_by_key(input_list, key): return val def min_column_by_key(input_list, key): + """Find the minimum value in a numeric column of a list of dictionaries. + + Args: + input_list (list): List of dictionaries containing the data. + key (str): Key name of the column to find minimum value. + + Returns: + int: Minimum value in the column, or 0 if key doesn't exist. + """ + df = pd.DataFrame(input_list) try: val = int(df[key].min()) @@ -118,6 +222,18 @@ def min_column_by_key(input_list, key): return val def sort_list_by_key(input_list, key, ascending=False, drop_columns:list=[]): + """Sort a list of dictionaries by a specified key. + + Args: + input_list (list): List of dictionaries to sort. + key (str): Key name to sort by. + ascending (bool, optional): Sort in ascending order. Defaults to False. + drop_columns (list, optional): List of column names to drop from result. Defaults to []. + + Returns: + list: Sorted list of dictionaries. + """ + df = pd.DataFrame(input_list) df = df.sort_values(by=[key], ascending=ascending) if drop_columns: @@ -125,12 +241,36 @@ def sort_list_by_key(input_list, key, ascending=False, drop_columns:list=[]): return df.to_dict('records') def coalesce(*args): + """Return the first non-None argument. + + Similar to SQL COALESCE function, returns the first argument that is not None. + + Args: + *args: Variable number of arguments to check. + + Returns: + Any: The first non-None argument, or None if all arguments are None. + """ + for arg in args: if arg is not None: return arg def version_check(current_version:str, avaialble_version:str, update_check_type:str): + """Check if an update is available based on version comparison. + + Compares current version with available version to determine if an update + is available based on the specified update check type. + + Args: + current_version (str): Current version string in format "major.minor.build". + avaialble_version (str): Available version string in format "major.minor.build". + update_check_type (str): Type of update to check for ('Major', 'Minor', or 'Build'). + Returns: + dict: Dictionary with 'UpdateAvailable' (bool) and 'UpdateType' (str) keys. + """ + if current_version == 'Unknown': return {'UpdateAvailable': False, 'UpdateType': 'None'} @@ -153,16 +293,46 @@ def version_check(current_version:str, avaialble_version:str, update_check_type: return {'UpdateAvailable': False, 'UpdateType': 'None'} def return_property_as_list(input_list:list, property_name:str): + """Returns a list of property data from a list of dictionaries. + + From a list of dicstionaries, extracts the values of a specified property + and returns them as a list. + + Args: + input_list (list): List of dictionaries to extract property from. + property_name (str): Name of the property to extract from each dictionary. + + Returns: + list: List of values for the specified property from each dictionary. + """ + return_list = [] for item in input_list: return_list.append(item[property_name]) return return_list def return_slope(input_list_x:list, input_list_y:list): + """Calculate the slope of a linear regression line for two lists of values. + + Args: + input_list_x (list): List of x-values. + input_list_y (list): List of y-values. + + Returns: + float: Slope of the linear regression line calculated from the two lists. + """ + slope = ( ( len(input_list_y) * sum([a * b for a, b in zip(input_list_x, input_list_y)]) ) - ( sum(input_list_x) * sum(input_list_y)) ) / ( ( len(input_list_y) * sum([sq ** 2 for sq in input_list_x])) - (sum(input_list_x) ** 2)) return slope def get_current_version(): + """Get the current STAT Function version from version.json file. + + Reads the version information from the modules/version.json file. + + Returns: + str: Current function version string, or 'Unknown' if unable to read. + """ try: with open(pathlib.Path(__file__).parent.parent / 'modules/version.json') as f: @@ -173,14 +343,44 @@ def get_current_version(): return stat_version def load_json_from_file(file_name:str): + """Load and parse JSON data from a file in the modules/files directory. + + Args: + file_name (str): Name of the JSON file to load. + + Returns: + dict: Parsed JSON data from the file. + """ + with open(pathlib.Path(__file__).parent.parent / f'modules/files/{file_name}') as f: return json.loads(f.read()) def load_text_from_file(file_name:str, **kwargs): + """Load text from a file and format it with provided keyword arguments. + + Args: + file_name (str): Name of the text file to load from modules/files directory. + **kwargs: Keyword arguments to use for string formatting. + + Returns: + str: Formatted text content from the file. + """ + with open(pathlib.Path(__file__).parent.parent / f'modules/files/{file_name}') as f: return f.read().format(**kwargs) def list_to_string(list_in:list, delimiter:str=', ', empty_str:str='N/A'): + """Convert a list to a delimited string. + + Args: + list_in (list): List to convert to string. + delimiter (str, optional): Delimiter to use between items. Defaults to ', '. + empty_str (str, optional): String to return if list is empty. Defaults to 'N/A'. + + Returns: + str: Delimited string representation of the list, or empty_str if list is empty. + """ + if not list_in: return empty_str @@ -193,12 +393,15 @@ def list_to_string(list_in:list, delimiter:str=', ', empty_str:str='N/A'): return list_in def parse_kv_string(kv:str, item_delimitter:str=';', value_delimitter:str='='): - """ - Parse a string of key-value pairs into a dictionary. - :param kv: The string to parse. - :param item_delimitter: The delimiter between items. Default is ';'. - :param value_delimitter: The delimiter between key and value. Default is '='. - :return: A dictionary of key-value pairs. + """Parse a string of key-value pairs into a dictionary. + + Args: + kv (str): The string containing key-value pairs to parse. + item_delimitter (str, optional): Delimiter between items. Defaults to ';'. + value_delimitter (str, optional): Delimiter between key and value. Defaults to '='. + + Returns: + dict: Dictionary of parsed key-value pairs. Values are None if no delimiter found. """ kv_pairs = kv.split(item_delimitter) diff --git a/shared/rest.py b/shared/rest.py index ec8818e..3539923 100644 --- a/shared/rest.py +++ b/shared/rest.py @@ -23,6 +23,19 @@ kv_secret = None def token_cache(base_module:BaseModule, api:str): + """Retrieve cached authentication token for the specified API. + + Gets the appropriate authentication token for the specified API service, + handling multi-tenant configurations and token expiration checking. + + Args: + base_module (BaseModule): Base module containing multi-tenant configuration. + api (str): API service name ('arm', 'msgraph', 'la', 'm365', 'mde'). + + Returns: + str: Authentication token for the specified API. + """ + global stat_token default_tenant = os.getenv('AZURE_TENANT_ID') @@ -93,27 +106,73 @@ def get_kv_secret(): return kv_return.value def rest_call_get(base_module:BaseModule, api:str, path:str, headers:dict={}): - '''Perform a GET HTTP call to a REST API. Accepted API values are arm, msgraph, la, m365 and mde''' + """Perform a GET HTTP call to a REST API. + + Args: + base_module (BaseModule): Base module containing incident information. + api (str): API service name ('arm', 'msgraph', 'la', 'm365', 'mde'). + path (str): API endpoint path to call. + headers (dict, optional): Additional headers to include in the request. Defaults to an empty dictionary. + + Returns: + Response: HTTP response object containing the result of the API call. + """ response = execute_rest_call(base_module, 'get', api, path, None, headers) return response def rest_call_post(base_module:BaseModule, api:str, path:str, body, headers:dict={}): - '''Perform a POST HTTP call to a REST API. Accepted API values are arm, msgraph, la, m365 and mde''' + """Perform a POST HTTP call to a REST API. + + Args: + base_module (BaseModule): Base module containing incident information. + api (str): API service name ('arm', 'msgraph', 'la', 'm365', 'mde'). + path (str): API endpoint path to call. + body: Request body to send with the POST request. + headers (dict, optional): Additional headers to include in the request. Defaults to an empty dictionary. + + Returns: + Response: HTTP response object containing the result of the API call. + """ response = execute_rest_call(base_module, 'post', api, path, body, headers) return response def rest_call_put(base_module:BaseModule, api:str, path:str, body, headers:dict={}): - '''Perform a PUT HTTP call to a REST API. Accepted API values are arm, msgraph, la, m365 and mde''' + """Perform a PUT HTTP call to a REST API. + + Args: + base_module (BaseModule): Base module containing incident information. + api (str): API service name ('arm', 'msgraph', 'la', 'm365', 'mde'). + path (str): API endpoint path to call. + body: Request body to send with the PUT request. + headers (dict, optional): Additional headers to include in the request. Defaults to an empty dictionary. + + Returns: + Response: HTTP response object containing the result of the API call. + """ response = execute_rest_call(base_module, 'put', api, path, body, headers) return response def execute_rest_call(base_module:BaseModule, method:str, api:str, path:str, body=None, headers:dict={}): + """Execute a REST API call with retry logic for handling rate limits and connection errors. + This should only be called from the rest_call_get, rest_call_post, or rest_call_put functions. + + Args: + base_module (BaseModule): Base module containing incident information. + method (str): HTTP method to use ('get', 'post', 'put'). + api (str): API service name ('arm', 'msgraph', 'la', 'm365', 'mde'). + path (str): API endpoint path to call. + body: Request body to send with the request, if applicable. + headers (dict, optional): Additional headers to include in the request. Defaults to an empty dictionary. + Returns: + Response: HTTP response object containing the result of the API call. + """ + token = token_cache(base_module, api) url = get_endpoint(api) + path headers['Authorization'] = 'Bearer ' + token.token @@ -163,6 +222,18 @@ def check_rest_response(response:Response, api, path): return def execute_la_query(base_module:BaseModule, query:str, lookbackindays:int, endpoint:str='query'): + """Execute a Log Analytics Query. + + Args: + base_module (BaseModule): Base module containing incident information. + query (str): Log Analytics query to execute. + lookbackindays (int): Number of days to look back in the query. + endpoint (str): Endpoint to use for the query ('query' or 'search'). Defaults to 'query'. + + Returns: + list: List of query results, where each result is a dictionary mapping column names to values. + """ + duration = 'P' + str(lookbackindays) + 'D' if endpoint == 'search': @@ -189,6 +260,16 @@ def execute_la_query(base_module:BaseModule, query:str, lookbackindays:int, endp return query_results def execute_m365d_query(base_module:BaseModule, query:str): + """Execute a M365 Advanced Hunting Query. + + Args: + base_module (BaseModule): Base module containing incident information. + query (str): Log Analytics query to execute. + + Returns: + list: List of query results, where each result is a dictionary mapping column names to values. + """ + path = '/api/advancedhunting/run' body = {'Query': query} response = rest_call_post(base_module, 'm365', path, body) @@ -197,6 +278,16 @@ def execute_m365d_query(base_module:BaseModule, query:str): return data['Results'] def execute_mde_query(base_module:BaseModule, query:str): + """Execute a MDE Advanced Hunting Query. + + Args: + base_module (BaseModule): Base module containing incident information. + query (str): Log Analytics query to execute. + + Returns: + list: List of query results, where each result is a dictionary mapping column names to values. + """ + path = '/api/advancedqueries/run' body = {'Query': query} response = rest_call_post(base_module, 'mde', path, body) @@ -223,6 +314,19 @@ def get_endpoint(api:str): '(ARM_ENDPOINT, GRAPH_ENDPOINT, LOGANALYTICS_ENDPOINT, M365_ENDPOINT, and MDE_ENDPOINT).') def add_incident_comment(base_module:BaseModule, comment:str): + """Add a comment to a Microsoft Sentinel incident. + + Creates a new comment on the specified incident using the Azure REST API. + The comment is truncated to 30,000 characters if longer. + + Args: + base_module (BaseModule): Base module containing incident information. + comment (str): Comment text to add to the incident. + + Returns: + Response or str: API response object on success, or 'Comment failed' on error. + """ + path = base_module.IncidentARMId + '/comments/' + str(uuid.uuid4()) + '?api-version=2023-02-01' try: response = rest_call_put(base_module, 'arm', path, {'properties': {'message': comment[:30000]}}) diff --git a/tests/test_data.py b/tests/test_data.py index b2e2c9d..439736d 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,4 +1,5 @@ from shared import data +import pytest def test_return_highest_value(): @@ -121,6 +122,86 @@ def test_parse_kv_string(): assert parsed_data['key1'] == 'value1' assert parsed_data['key2'] == 'value2' +def test_parse_kv_string_empty(): + """Test parse_kv_string with empty input""" + empty_data = data.parse_kv_string('') + # The function returns {'': None} for empty input + assert empty_data == {'': None} + +def test_parse_kv_string_malformed(): + """Test parse_kv_string with malformed input""" + malformed_data = data.parse_kv_string('key1=value1;invalidpair;key2=value2') + assert malformed_data['key1'] == 'value1' + assert malformed_data['key2'] == 'value2' + # Should handle malformed pairs gracefully + +def test_list_to_string_with_custom_separator(): + """Test list_to_string with custom separator""" + list_data = ['a', 'b', 'c'] + string_data = data.list_to_string(list_data, delimiter=' | ') + assert string_data == 'a | b | c' + +def test_return_highest_value_empty_list(): + """Test return_highest_value with empty list""" + result = data.return_highest_value([], 'Severity') + assert result == 'Unknown' + +def test_return_highest_value_missing_key(): + """Test return_highest_value with missing key in data raises KeyError""" + test_data = [{'Other': 'value1'}, {'Other': 'value2'}] + with pytest.raises(KeyError): + data.return_highest_value(test_data, 'MissingKey') + +def test_sort_list_by_key_missing_key(): + """Test sort_list_by_key with missing key""" + test_data = [{'Value': 5}, {'NoValue': 3}, {'Value': 1}] + sorted_data = data.sort_list_by_key(test_data, 'Value', False) + # Should handle missing keys gracefully + assert len(sorted_data) == 3 + assert sorted_data[0]['Value'] == 5 + assert sorted_data[2]['NoValue'] == 3 + +def test_max_column_by_key_empty_list(): + """Test max_column_by_key with empty list""" + result = data.max_column_by_key([], 'Value') + assert result == 0 + +def test_sum_column_by_key_empty_list(): + """Test sum_column_by_key with empty list""" + result = data.sum_column_by_key([], 'Value') + assert result == 0 + +def test_join_lists_empty_lists(): + """Test join_lists with empty lists raises KeyError""" + with pytest.raises(KeyError): + data.join_lists([], [], 'left', 'Key', 'Key') + +def test_return_slope_edge_cases(): + """Test return_slope with edge cases""" + # Same values should have 0 slope + dates = [1, 2, 3, 4] + values = [5, 5, 5, 5] + result = data.return_slope(dates, values) + assert result == 0.0 + + # Single point raises ZeroDivisionError + single_date = [1] + single_value = [5] + with pytest.raises(ZeroDivisionError): + data.return_slope(single_date, single_value) + +def test_load_text_from_file(): + """Test load_text_from_file with a valid file path""" + content = data.load_text_from_file('exposure-device.kql', mde_id_list=['12345']) + assert "let mde_ids = dynamic(['12345']);" in content + +def test_load_json_from_file(): + """Test load_json_from_file with a valid file path""" + content = data.load_json_from_file('privileged-roles.json') + assert "Exchange Administrator" in content + assert "Global Administrator" in content + assert "Security Administrator" in content + def list_data(): test_data = [ { diff --git a/tests/test_rest.py b/tests/test_rest.py index 7ffc684..fd9ce58 100644 --- a/tests/test_rest.py +++ b/tests/test_rest.py @@ -1,7 +1,8 @@ from shared import rest -from classes import BaseModule +from classes import BaseModule, STATError, STATNotFound, STATTooManyRequests import json, os import requests +import pytest def test_get_endpoint(): @@ -27,6 +28,28 @@ def test_execute_mde_query(): result = rest.execute_mde_query(get_base_module_object(), 'DeviceInfo | take 5') assert len(result) == 5 +def test_rest_response(): + response_429 = requests.Response() + response_429.status_code = 429 + response_404 = requests.Response() + response_404.status_code = 404 + response_401 = requests.Response() + response_401.status_code = 401 + response_200 = requests.Response() + response_200.status_code = 200 + + with pytest.raises(STATTooManyRequests): + rest.check_rest_response(response_429, 'test', 'test') + + with pytest.raises(STATNotFound): + rest.check_rest_response(response_404, 'test', 'test') + + with pytest.raises(STATError): + rest.check_rest_response(response_401, 'test', 'test') + + assert None is rest.check_rest_response(response_200, 'test', 'test') + + def get_base_module_object(): base_module_body = json.loads(requests.get(url=os.getenv('BASEDATA')).content) base_object = BaseModule() diff --git a/tests/test_stat_classes.py b/tests/test_stat_classes.py new file mode 100644 index 0000000..5cdaab3 --- /dev/null +++ b/tests/test_stat_classes.py @@ -0,0 +1,473 @@ +from classes import * +import pytest + +def test_base_module_initialization(): + """Test BaseModule class initialization""" + base = BaseModule() + + # Test that all required attributes are initialized + assert hasattr(base, 'Accounts') + assert hasattr(base, 'AccountsCount') + assert hasattr(base, 'IPs') + assert hasattr(base, 'IPsCount') + assert hasattr(base, 'Domains') + assert hasattr(base, 'DomainsCount') + assert hasattr(base, 'FileHashes') + assert hasattr(base, 'FileHashesCount') + assert hasattr(base, 'Files') + assert hasattr(base, 'FilesCount') + assert hasattr(base, 'Hosts') + assert hasattr(base, 'HostsCount') + assert hasattr(base, 'URLs') + assert hasattr(base, 'URLsCount') + assert hasattr(base, 'OtherEntities') + assert hasattr(base, 'OtherEntitiesCount') + assert hasattr(base, 'ModuleName') + + # Test default values + assert base.ModuleName == 'BaseModule' + +def test_base_module_add_ip_entity(): + """Test BaseModule add_ip_entity method""" + base = BaseModule() + + # Add a global IP + base.add_ip_entity('4.172.0.0', {'country': 'CA'}, {'raw': 'data'}, 1) + + assert len(base.IPs) == 1 + assert base.IPs[0]['Address'] == '4.172.0.0' + assert base.IPs[0]['IPType'] == 1 + assert base.IPs[0]['GeoData']['country'] == 'CA' + + # Add a private IP + base.add_ip_entity('192.168.1.1', {}, {'raw': 'data'}, 2) + + assert len(base.IPs) == 2 + assert base.IPs[1]['Address'] == '192.168.1.1' + assert base.IPs[1]['IPType'] == 2 + + # Add an IP with default type (unknown) + base.add_ip_entity('10.0.0.1', {}, {'raw': 'data'}) + + assert len(base.IPs) == 3 + assert base.IPs[2]['IPType'] == 9 # default unknown type + +def test_base_module_get_ip_list(): + """Test BaseModule get_ip_list method""" + base = BaseModule() + base.add_ip_entity('4.172.0.0', {'country': 'CA'}, {'raw': 'data'}, 1) + base.add_ip_entity('192.168.1.1', {}, {'raw': 'data'}, 2) + base.add_ip_entity('10.0.0.1', {}, {'raw': 'data'}) + + ip_list = base.get_ip_list() + assert len(ip_list) == 3 + assert ip_list[0] == '4.172.0.0' + +def test_base_module_get_ip_kql_table(): + """Test BaseModule get_ip_kql_table method""" + base = BaseModule() + base.add_ip_entity('4.172.0.0', {'country': 'CA'}, {'raw': 'data'}, 1) + base.add_ip_entity('192.168.1.1', {}, {'raw': 'data'}, 2) + base.add_ip_entity('10.0.0.1', {}, {'raw': 'data'}) + kql = base.get_ip_kql_table() + + assert '4.172.0.0' in kql + assert 'let ipEntities = print t = todynamic(url_decode(' in kql + assert 'IPAddress=tostring(t.Address)' in kql + +def test_base_module_get_host_mdeid_list(): + """Test BaseModule get_host_mdeid_list method""" + base = BaseModule() + base.add_host_entity('host1.contoso.com', 'host1', 'contoso.com', 'mdedeviceid1', {'raw': 'data'}) + base.add_host_entity('host2.contoso.com', 'host2', 'contoso.com', 'mdedeviceid2', {'raw': 'data'}) + + host_list = base.get_host_mdeid_list() + assert len(host_list) == 2 + assert host_list[0] == 'mdedeviceid1' + +def test_base_module_get_host_kql_table(): + """Test BaseModule get_host_kql_table method""" + base = BaseModule() + base.add_host_entity('host1.contoso.com', 'host1', 'contoso.com', 'mdedeviceid1', {'raw': 'data'}) + base.add_host_entity('host2.contoso.com', 'host2', 'contoso.com', 'mdedeviceid2', {'raw': 'data'}) + + kql = base.get_host_kql_table() + + assert 'host1.contoso.com' in kql + assert 'host2.contoso.com' in kql + assert 'let hostEntities = print t = todynamic(url_decode(' in kql + assert 'FQDN=tostring(t.FQDN), Hostname=tostring(t.Hostname);' in kql + +def test_base_module_get_filehash_list(): + """Test BaseModule get_filehash_list method""" + base = BaseModule() + base.FileHashes = [ + {'FileHash': 'abc123', 'Algorithm': 'SHA256'}, + {'FileHash': 'def456', 'Algorithm': 'MD5'} + ] + + hash_list = base.get_filehash_list() + assert len(hash_list) == 2 + assert hash_list[0] == 'abc123' + +def test_base_module_get_filehash_kql_table(): + """Test BaseModule get_filehash_kql_table method""" + base = BaseModule() + base.FileHashes = [ + {'FileHash': 'abc123', 'Algorithm': 'SHA256'}, + {'FileHash': 'def456', 'Algorithm': 'MD5'} + ] + + kql = base.get_filehash_kql_table() + + assert 'abc123' in kql + assert 'def456' in kql + assert 'let hashEntities = print t = todynamic(url_decode(' in kql + assert 'FileHash=tostring(t.FileHash)' in kql + assert 'Algorithm=tostring(t.Algorithm)' in kql + +def test_base_module_get_domain_list(): + """Test BaseModule get_domain_list method""" + base = BaseModule() + base.Domains = [ + {'Domain': 'contoso.com'}, + {'Domain': 'fabrikam.com'}, + ] + + domain_list = base.get_domain_list() + assert len(domain_list) == 2 + assert domain_list[0] == 'contoso.com' + +def test_base_module_get_domain_kql_table(): + """Test BaseModule get_domain_kql_table method""" + base = BaseModule() + base.Domains = [ + {'Domain': 'contoso.com'}, + {'Domain': 'fabrikam.com'}, + ] + + kql = base.get_domain_kql_table() + + assert 'contoso.com' in kql + assert 'fabrikam.com' in kql + assert 'let domainEntities = print t = todynamic(url_decode(' in kql + assert 'Domain=tostring(t.Domain)' in kql + +def test_base_module_get_account_kql_table(): + """Test BaseModule get_account_kql_table method""" + base = BaseModule() + base.Accounts = [ + { + 'userPrincipalName': 'user1@contoso.com', + 'onPremisesSamAccountName': 'user1', + 'onPremisesSecurityIdentifier': 'S-1-5-21-123', + 'id': 'abc-123', + 'manager': { + 'userPrincipalName': 'manager@contoso.com' + } + } + ] + + kql = base.get_account_kql_table() + print(kql) + + assert 'user1%40contoso.com' in kql + assert 'S-1-5-21-123' in kql + assert 'manager%40contoso.com' in kql + assert 'let accountEntities = print t = todynamic(url_decode(' in kql + assert 'UserPrincipalName=tostring(t.userPrincipalName)' in kql + assert 'SamAccountName=tostring(t.SamAccountName)' in kql + assert 'ObjectSID=tostring(t.SID)' in kql + +def test_base_module_get_url_list(): + """Test BaseModule get_url_list method""" + base = BaseModule() + base.URLs = [ + {'Url': 'https://contoso.com'}, + {'Url': 'https://fabrikam.com'}, + ] + + url_list = base.get_url_list() + assert len(url_list) == 2 + assert url_list[0] == 'https://contoso.com' + +def test_base_module_get_url_kql_table(): + """Test BaseModule get_url_kql_table method""" + base = BaseModule() + base.URLs = [ + {'Url': 'https://contoso.com'}, + {'Url': 'https://fabrikam.com'}, + ] + + kql = base.get_url_kql_table() + + assert 'https%3A//contoso.com' in kql + assert 'https%3A//fabrikam.com' in kql + assert 'let urlEntities = print t = todynamic(url_decode(' in kql + assert 'Url=tostring(t.Url)' in kql + +def test_base_module_check_global_and_local_ips(): + """Test BaseModule check_global_and_local_ips method""" + base = BaseModule() + base.IPs = [ + {'Address': '8.8.8.8', 'IPType': 1}, # global + {'Address': '192.168.1.1', 'IPType': 2}, # private + {'Address': '127.0.0.1', 'IPType': 9} # unknown + ] + + result = base.check_global_and_local_ips() + + # The method should return information about IP types present + assert result is True + +def test_user_exposure_module(): + """Test UserExposureModule load_from_input method""" + user_exp = UserExposureModule() + + test_data = { + 'AnalyzedEntities': 5, + 'Nodes': [{'UserNodeId': 'user1'}, {'UserNodeId': 'user2'}], + 'Paths': [{'UserNodeId': 'user1', 'path': 'test'}] + } + + user_exp.load_from_input(test_data) + + assert user_exp.AnalyzedEntities == 5 + assert len(user_exp.Nodes) == 2 + assert len(user_exp.Paths) == 1 + + nodes_without_paths = user_exp.nodes_without_paths() + + # Should return only user2 since user1 has a path + assert len(nodes_without_paths) == 1 + assert nodes_without_paths[0]['UserNodeId'] == 'user2' + +def test_device_exposure_module(): + """Test DeviceExposureModule class""" + device_exp = DeviceExposureModule() + + test_data = { + 'AnalyzedEntities': 3, + 'Nodes': [{'ComputerNodeId': 'comp1'}, {'ComputerNodeId': 'comp2'}], + 'Paths': [{'ComputerNodeId': 'comp1', 'path': 'test'}] + } + + device_exp.load_from_input(test_data) + + assert device_exp.AnalyzedEntities == 3 + assert len(device_exp.Nodes) == 2 + assert len(device_exp.Paths) == 1 + + nodes_without_paths = device_exp.nodes_without_paths() + + # Should return comp2 + assert len(nodes_without_paths) == 1 + assert nodes_without_paths[0]['ComputerNodeId'] == 'comp2' + +def test_response_initialization(): + """Test Response class initialization""" + body = {'key': 'value'} + response = Response(body, statuscode=200, contenttype='application/json') + + assert response.body == body + assert response.statuscode == 200 + assert response.contenttype == 'application/json' + +def test_kql_module_initialization(): + """Test KQLModule class initialization""" + kql_module = KQLModule() + + assert kql_module.ModuleName == 'KQLModule' + assert kql_module.DetailedResults == [] + assert kql_module.ResultsCount == 0 + assert kql_module.ResultsFound is False + +def test_watchlist_module_initialization(): + """Test WatchlistModule class initialization""" + watchlist_module = WatchlistModule() + + assert watchlist_module.ModuleName == 'WatchlistModule' + assert watchlist_module.DetailedResults == [] + assert watchlist_module.EntitiesAnalyzedCount == 0 + assert watchlist_module.EntitiesOnWatchlist is False + assert watchlist_module.EntitiesOnWatchlistCount == 0 + assert watchlist_module.WatchlistName == '' + +def test_ti_module_initialization(): + """Test TIModule class initialization""" + ti_module = TIModule() + + assert ti_module.ModuleName == 'TIModule' + assert ti_module.AnyTIFound is False + assert ti_module.DetailedResults == [] + assert ti_module.DomainEntitiesCount == 0 + assert ti_module.DomainEntitiesWithTI == 0 + assert ti_module.DomainTIFound is False + assert ti_module.FileHashEntitiesCount == 0 + assert ti_module.FileHashEntitiesWithTI == 0 + assert ti_module.FileHashTIFound is False + assert ti_module.IPEntitiesCount == 0 + assert ti_module.IPEntitiesWithTI == 0 + assert ti_module.IPTIFound is False + assert ti_module.TotalTIMatchCount == 0 + assert ti_module.URLEntitiesCount == 0 + assert ti_module.URLEntitiesWithTI == 0 + assert ti_module.URLTIFound is False + +def test_related_alerts_module_initialization(): + """Test RelatedAlertsModule class initialization""" + related_alerts_module = RelatedAlertsModule() + + assert related_alerts_module.ModuleName == 'RelatedAlerts' + assert related_alerts_module.AllTactics == [] + assert related_alerts_module.AllTacticsCount == 0 + assert related_alerts_module.DetailedResults == [] + assert related_alerts_module.FusionIncident is False + assert related_alerts_module.HighestSeverityAlert == '' + assert related_alerts_module.RelatedAccountAlertsCount == 0 + assert related_alerts_module.RelatedAccountAlertsFound is False + assert related_alerts_module.RelatedAlertsCount == 0 + assert related_alerts_module.RelatedAlertsFound is False + assert related_alerts_module.RelatedHostAlertsCount == 0 + assert related_alerts_module.RelatedHostAlertsFound is False + assert related_alerts_module.RelatedIPAlertsCount == 0 + assert related_alerts_module.RelatedIPAlertsFound is False + +def test_ueba_module_initialization(): + """Test UEBAModule class initialization""" + ueba_module = UEBAModule() + + assert ueba_module.ModuleName == 'UEBAModule' + assert ueba_module.AllEntityEventCount == 0 + assert ueba_module.AllEntityInvestigationPriorityAverage == 0.0 + assert ueba_module.AllEntityInvestigationPriorityMax == 0 + assert ueba_module.AllEntityInvestigationPrioritySum == 0 + assert ueba_module.AnomaliesFound is False + assert ueba_module.AnomalyCount == 0 + assert ueba_module.AnomalyTactics == [] + assert ueba_module.AnomalyTacticsCount == 0 + assert ueba_module.DetailedResults == [] + assert ueba_module.InvestigationPrioritiesFound is False + assert ueba_module.ThreatIntelFound is False + assert ueba_module.ThreatIntelMatchCount == 0 + +def test_scoring_module_initialization(): + """Test ScoringModule class initialization""" + scoring_module = ScoringModule() + + assert scoring_module.ModuleName == 'ScoringModule' + assert scoring_module.DetailedResults == [] + assert scoring_module.TotalScore == 0 + +def test_aad_module_initialization(): + """Test AADModule class initialization""" + aad_module = AADModule() + + assert aad_module.ModuleName == 'AADRisksModule' + assert aad_module.AnalyzedEntities == 0 + assert aad_module.FailedMFATotalCount == 0 + assert aad_module.HighestRiskLevel == '' + assert aad_module.MFAFraudTotalCount == 0 + assert aad_module.SuspiciousActivityReportTotalCount == 0 + assert aad_module.DetailedResults == [] + assert aad_module.RiskDetectionTotalCount == 0 + +def test_file_module_initialization(): + """Test FileModule class initialization""" + file_module = FileModule() + + assert file_module.ModuleName == 'FileModule' + assert file_module.AnalyzedEntities == 0 + assert file_module.DeviceUniqueDeviceTotalCount == 0 + assert file_module.DeviceUniqueFileNameTotalCount == 0 + assert file_module.DeviceFileActionTotalCount == 0 + assert file_module.EntitiesAttachmentCount == 0 + assert file_module.HashesLinkedToThreatCount == 0 + assert file_module.HashesNotMicrosoftSignedCount == 0 + assert file_module.HashesThreatList == [] + assert file_module.MaximumGlobalPrevalence == 0 + assert file_module.MinimumGlobalPrevalence == 0 + assert file_module.DetailedResults == [] + +def test_run_playbook_initialization(): + """Test RunPlaybook class initialization""" + run_playbook = RunPlaybook() + + assert run_playbook.LogicAppArmId == '' + assert run_playbook.TenantId == '' + assert run_playbook.PlaybookName == '' + assert run_playbook.IncidentArmId == '' + assert run_playbook.ModuleName == 'RunPlaybook' + +def test_exchange_module_initialization(): + """Test ExchangeModule class initialization""" + exchange_module = ExchangeModule() + + assert exchange_module.AllUsersInOffice is True + assert exchange_module.AllUsersOutOfOffice is False + assert exchange_module.Rules == [] + assert exchange_module.AuditEvents == [] + assert exchange_module.OOF == [] + assert exchange_module.UsersInOffice == 0 + assert exchange_module.UsersOutOfOffice == 0 + assert exchange_module.PrivilegedUsersWithMailbox == 0 + assert exchange_module.UsersUnknown == 0 + assert exchange_module.RulesDelete == 0 + assert exchange_module.RulesMove == 0 + assert exchange_module.RulesForward == 0 + assert exchange_module.DelegationsFound == 0 + assert exchange_module.ModuleName == 'ExchangeModule' + +def test_mde_module_initialization(): + """Test MDEModule class initialization""" + mde_module = MDEModule() + + assert mde_module.AnalyzedEntities == 0 + assert mde_module.IPsHighestExposureLevel == '' + assert mde_module.IPsHighestRiskScore == '' + assert mde_module.UsersHighestExposureLevel == '' + assert mde_module.UsersHighestRiskScore == '' + assert mde_module.HostsHighestExposureLevel == '' + assert mde_module.HostsHighestRiskScore == '' + assert mde_module.ModuleName == 'MDEModule' + assert mde_module.DetailedResults == {} + +def test_device_exposure_module_initialization(): + """Test DeviceExposureModule class initialization""" + device_exp_module = DeviceExposureModule() + + assert device_exp_module.AnalyzedEntities == 0 + assert device_exp_module.ModuleName == 'DeviceExposureModule' + assert device_exp_module.Nodes == [] + assert device_exp_module.Paths == [] + +def test_user_exposure_module_initialization(): + """Test UserExposureModule class initialization""" + user_exp_module = UserExposureModule() + + assert user_exp_module.AnalyzedEntities == 0 + assert user_exp_module.ModuleName == 'UserExposureModule' + assert user_exp_module.Nodes == [] + assert user_exp_module.Paths == [] + +def test_create_incident_initialization(): + """Test CreateIncident class initialization""" + create_incident = CreateIncident() + + assert create_incident.IncidentARMId == '' + assert create_incident.AlertARMId == '' + assert create_incident.Title == '' + assert create_incident.Description == '' + assert create_incident.Severity == '' + assert create_incident.IncidentNumber == 0 + assert create_incident.IncidentUrl == '' + assert create_incident.ModuleName == 'CreateIncident' + +def test_debug_module_initialization(): + """Test DebugModule class initialization""" + debug_module = DebugModule({'Test': 'Debug', 'Params': {'param1': 'value1'}}) + + assert debug_module.ModuleName == 'DebugModule' + assert debug_module.STATVersion is not None # Assuming data.get_current_version() returns a version + assert debug_module.Test == 'Debug' + assert debug_module.Params == {'param1': 'value1'}