diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1f0ffb7..eb9d949 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -19,7 +19,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10","3.12"] + python-version: ["3.12"] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/stat_function_build.yml b/.github/workflows/stat_function_build.yml index fe83c03..1759fdf 100644 --- a/.github/workflows/stat_function_build.yml +++ b/.github/workflows/stat_function_build.yml @@ -37,7 +37,7 @@ jobs: run: | pushd './${{ env.AZURE_FUNCTIONAPP_PACKAGE_PATH }}' python -m pip install --upgrade pip - pip install -r requirements.txt --target=".python_packages/lib/site-packages" + pip install -r requirements.txt --platform manylinux_2_17_x86_64 --only-binary=:all: --target=".python_packages/lib/site-packages" popd - name: 'ZIP Function App' diff --git a/.vscode/launch.json b/.vscode/launch.json index 4508b45..7bb67f5 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -3,10 +3,14 @@ "configurations": [ { "name": "Attach to Python Functions", - "type": "python", + "type": "debugpy", "request": "attach", - "port": 9091, - "preLaunchTask": "func: host start" + "connect": { + "host": "localhost", + "port": 9091, + }, + "preLaunchTask": "func: host start", + "postDebugTask": "Release Blocked Port" } ] } \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json index ba75962..37ad114 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -1,27 +1,33 @@ { - "version": "2.0.0", - "tasks": [ - { - "type": "func", - "label": "func: host start", - "command": "host start", - "problemMatcher": "$func-python-watch", - "isBackground": true, - "dependsOn": "pip install (functions)" - }, - { - "label": "pip install (functions)", - "type": "shell", - "osx": { - "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" - }, - "windows": { - "command": "${config:azureFunctions.pythonVenv}\\Scripts\\python -m pip install -r requirements.txt" - }, - "linux": { - "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" - }, - "problemMatcher": [] - } - ] + "version": "2.0.0", + "tasks": [ + { + "type": "func", + "label": "func: host start", + "command": "host start", + "problemMatcher": "$func-python-watch", + "isBackground": true, + "dependsOn": "pip install (functions)" + }, + { + "label": "pip install (functions)", + "type": "shell", + "osx": { + "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" + }, + "windows": { + "command": "${config:azureFunctions.pythonVenv}\\Scripts\\python -m pip install -r requirements.txt" + }, + "linux": { + "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" + }, + "problemMatcher": [] + }, + { + "label": "Release Blocked Port", + "type": "shell", + "command": "powershell.exe -ExecutionPolicy Bypass -Command \"Stop-Process -Id (Get-NetTCPConnection -LocalPort 7071).OwningProcess -Force\"", + "problemMatcher": [] + } + ] } \ No newline at end of file diff --git a/README.md b/README.md index 0d7d36c..1ff4057 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ To debug in VS Code create a local.settings.json file in the root of the project "AZURE_TENANT_ID": "", "AZURE_CLIENT_ID": "", "AZURE_CLIENT_SECRET": "", - "AZURE_AUTHORITY_HOST": "login.microsoftonline.com", + "AZURE_AUTHORITY_HOST": "https://login.microsoftonline.com", "ARM_ENDPOINT": "management.azure.com", "GRAPH_ENDPOINT": "graph.microsoft.com", "LOGANALYTICS_ENDPOINT": "api.loganalytics.io", diff --git a/classes/__init__.py b/classes/__init__.py index f1a8a7c..2108d0d 100644 --- a/classes/__init__.py +++ b/classes/__init__.py @@ -46,6 +46,14 @@ def __init__(self, error:str, source_error:dict={}, status_code:int=400): self.source_error = source_error self.status_code = status_code +class STATServerError(STATError): + """STAT exception raised when an API call returns a 5xx series error. + + This exception is a specialized version of STATError for cases where + a server side error was encountered and a retry may be needed. + """ + pass + class STATNotFound(STATError): """STAT exception raised when an API call returns a 404 Not Found error. @@ -54,6 +62,14 @@ class STATNotFound(STATError): """ pass +class STATFailedToDecodeToken(STATError): + """STAT exception raised when the JWT can't be decoded to check for app roles.""" + pass + +class STATInsufficientPermissions(STATError): + """STAT exception raised when the STAT Function identity does not have sufficient permissions.""" + pass + class STATTooManyRequests(STATError): """STAT exception raised when an API call returns a 429 Too Many Requests error. @@ -85,6 +101,7 @@ def __init__(self): self.AccountsCount = 0 self.AccountsOnPrem = [] self.Alerts = [] + self.CreatedTime = '' self.Domains = [] self.DomainsCount = 0 self.EntitiesCount = 0 @@ -99,6 +116,8 @@ def __init__(self): self.IncidentARMId = "" self.IncidentTriggered = False self.IncidentAvailable = False + self.MailMessages = [] + self.MailMessagesCount = 0 self.ModuleVersions = {} self.MultiTenantConfig = {} self.OtherEntities = [] @@ -117,6 +136,7 @@ def __init__(self): def load_incident_trigger(self, req_body): self.IncidentARMId = req_body['object']['id'] + self.CreatedTime = req_body['object']['properties']['createdTimeUtc'] self.IncidentTriggered = True self.IncidentAvailable = True self.SentinelRGARMId = "/subscriptions/" + req_body['workspaceInfo']['SubscriptionId'] + "/resourceGroups/" + req_body['workspaceInfo']['ResourceGroupName'] @@ -127,6 +147,7 @@ def load_incident_trigger(self, req_body): def load_alert_trigger(self, req_body): self.IncidentTriggered = False + self.CreatedTime = req_body['EndTimeUtc'] self.SentinelRGARMId = "/subscriptions/" + req_body['WorkspaceSubscriptionId'] + "/resourceGroups/" + req_body['WorkspaceResourceGroup'] self.WorkspaceId = req_body['WorkspaceId'] @@ -135,6 +156,7 @@ def load_from_input(self, basebody): self.AccountsCount = basebody['AccountsCount'] self.AccountsOnPrem = basebody.get('AccountsOnPrem', []) self.Alerts = basebody.get('Alerts', []) + self.CreatedTime = basebody.get('CreatedTime', '') self.Domains = basebody['Domains'] self.DomainsCount = basebody['DomainsCount'] self.EntitiesCount = basebody['EntitiesCount'] @@ -149,6 +171,8 @@ def load_from_input(self, basebody): self.IncidentTriggered = basebody['IncidentTriggered'] self.IncidentAvailable = basebody['IncidentAvailable'] self.IncidentARMId = basebody['IncidentARMId'] + self.MailMessages = basebody.get('MailMessages', []) + self.MailMessagesCount = basebody.get('MailMessagesCount', 0) self.ModuleVersions = basebody['ModuleVersions'] self.MultiTenantConfig = basebody.get('MultiTenantConfig', {}) self.OtherEntities = basebody['OtherEntities'] @@ -189,11 +213,16 @@ def add_account_entity(self, data): def add_onprem_account_entity(self, data): self.AccountsOnPrem.append(data) - def get_ip_list(self): + def get_ip_list(self, include_mail_ips:bool=True): ip_list = [] for ip in self.IPs: ip_list.append(ip['Address']) + if include_mail_ips: + for message in self.MailMessages: + if message.get('senderDetail', {}).get('ipv4'): + ip_list.append(message.get('senderDetail', {}).get('ipv4')) + return ip_list def get_domain_list(self): @@ -203,27 +232,43 @@ def get_domain_list(self): return domain_list - def get_url_list(self): + def get_url_list(self, include_mail_urls:bool=True): url_list = [] for url in self.URLs: url_list.append(url['Url']) + if include_mail_urls: + for message in self.MailMessages: + for url in message.get('urls', []): + url_list.append(url.get('url')) + return url_list - def get_filehash_list(self): + def get_filehash_list(self, include_mail_hashes:bool=True): hash_list = [] for hash in self.FileHashes: hash_list.append(hash['FileHash']) + + if include_mail_hashes: + for message in self.MailMessages: + for attachment in message.get('attachments', []): + if attachment.get('sha256'): + hash_list.append(attachment.get('sha256')) return hash_list - def get_ip_kql_table(self): + def get_ip_kql_table(self, include_mail_ips:bool=True): ip_data = [] for ip in self.IPs: ip_data.append({'Address': ip.get('Address'), 'Latitude': ip.get('GeoData').get('latitude'), 'Longitude': ip.get('GeoData').get('longitude'), \ 'Country': ip.get('GeoData').get('country'), 'State': ip.get('GeoData').get('state')}) + + if include_mail_ips: + for message in self.MailMessages: + if message.get('senderDetail', {}).get('ipv4'): + ip_data.append({'Address': message.get('senderDetail', {}).get('ipv4')}) encoded = urllib.parse.quote(json.dumps(ip_data)) @@ -268,12 +313,17 @@ def get_host_kql_table(self): ''' return kql - def get_url_kql_table(self): + def get_url_kql_table(self, include_mail_urls:bool=True): url_data = [] for url in self.URLs: url_data.append({'Url': url.get('Url')}) + if include_mail_urls: + for message in self.MailMessages: + for url in message.get('urls', []): + url_data.append({'Url': url.get('url')}) + encoded = urllib.parse.quote(json.dumps(url_data)) kql = f'''let urlEntities = print t = todynamic(url_decode('{encoded}')) @@ -282,12 +332,18 @@ def get_url_kql_table(self): ''' return kql - def get_filehash_kql_table(self): + def get_filehash_kql_table(self, include_mail_hashes:bool=True): hash_data = [] for hash in self.FileHashes: hash_data.append({'FileHash': hash.get('FileHash'), 'Algorithm': hash.get('Algorithm')}) + if include_mail_hashes: + for message in self.MailMessages: + for attachment in message.get('attachments', []): + if attachment.get('sha256'): + hash_data.append({'FileHash': attachment.get('sha256'), 'Algorithm': 'SHA256'}) + encoded = urllib.parse.quote(json.dumps(hash_data)) kql = f'''let hashEntities = print t = todynamic(url_decode('{encoded}')) @@ -308,6 +364,21 @@ def get_domain_kql_table(self): kql = f'''let domainEntities = print t = todynamic(url_decode('{encoded}')) | mv-expand t | project Domain=tostring(t.Domain); +''' + return kql + + def get_mail_kql_table(self): + + mail_data = [] + + for mail in self.MailMessages: + mail_data.append({'rec': mail.get('recipientEmailAddress'), 'nid': mail.get('networkMessageId'), 'send': mail.get('senderDetail', {}).get('fromAddress'), 'sendfrom': mail.get('senderDetail', {}).get('mailFromAddress')}) + + encoded = urllib.parse.quote(json.dumps(mail_data)) + + kql = f'''let mailEntities = print t = todynamic(url_decode('{encoded}')) +| mv-expand t +| project RecipientEmailAddress=tostring(t.rec), NetworkMessageId=tostring(t.nid), SenderMailFromAddress=tostring(t.send), SenderFromAddress=tostring(t.sendfrom); ''' return kql diff --git a/debug/debug.py b/debug/debug.py index 64b96c8..1c9d43c 100644 --- a/debug/debug.py +++ b/debug/debug.py @@ -26,6 +26,12 @@ def debug_module (req_body): case 'comment': default_debug(debug_out) comment_debug(debug_out) + case 'tag': + default_debug(debug_out) + tag_debug(debug_out) + case 'task': + default_debug(debug_out) + task_debug(debug_out) case 'exception': exception_debug(debug_out) case _: @@ -145,6 +151,42 @@ def comment_debug(debug_out:DebugModule): base_module.MultiTenantConfig = debug_out.Params.get('MultiTenantConfig', {}) base_module.IncidentARMId = debug_out.Params['IncidentARMId'] base_module.IncidentAvailable = True - response = rest.add_incident_comment(base_module, comment) - debug_out.CommentStatus = response.status_code - debug_out.CommentResponse = response.json() \ No newline at end of file + try: + response = rest.add_incident_comment(base_module, comment, raise_on_error=True) + debug_out.CommentStatus = response.status_code + debug_out.CommentResponse = response.json() + except STATError as e: + debug_out.CommentStatus = 'Comment failed' + debug_out.CommentSourceError = e.source_error + debug_out.CommentStatusCode = e.status_code + +def task_debug(debug_out:DebugModule): + title = debug_out.Params.get('TaskTitle') + description = debug_out.Params.get('TaskDescription', '') + base_module = BaseModule() + base_module.MultiTenantConfig = debug_out.Params.get('MultiTenantConfig', {}) + base_module.IncidentARMId = debug_out.Params['IncidentARMId'] + base_module.IncidentAvailable = True + try: + response = rest.add_incident_task(base_module, title, description, raise_on_error=True) + debug_out.TaskStatus = response.status_code + debug_out.TaskResponse = response.json() + except STATError as e: + debug_out.TaskStatus = 'Task failed' + debug_out.TaskSourceError = e.source_error + debug_out.TaskStatusCode = e.status_code + +def tag_debug(debug_out:DebugModule): + tag = [debug_out.Params.get('Tag')] + base_module = BaseModule() + base_module.MultiTenantConfig = debug_out.Params.get('MultiTenantConfig', {}) + base_module.IncidentARMId = debug_out.Params['IncidentARMId'] + base_module.IncidentAvailable = True + try: + response = rest.add_incident_tags(base_module, tag, raise_on_error=True) + debug_out.TagStatus = response.status_code + debug_out.TagResponse = response.json() + except STATError as e: + debug_out.TagStatus = 'Tag failed' + debug_out.TagSourceError = e.source_error + debug_out.TagStatusCode = e.status_code \ No newline at end of file diff --git a/host.json b/host.json index fd4bee7..f2b7c0d 100644 --- a/host.json +++ b/host.json @@ -10,6 +10,6 @@ }, "extensionBundle": { "id": "Microsoft.Azure.Functions.ExtensionBundle", - "version": "[3.*, 4.0.0)" + "version": "[4.0.0, 5.0.0)" } } \ No newline at end of file diff --git a/modules/aadrisks.py b/modules/aadrisks.py index 55c0c80..06359b3 100644 --- a/modules/aadrisks.py +++ b/modules/aadrisks.py @@ -1,11 +1,15 @@ from classes import BaseModule, Response, AADModule, STATError, STATNotFound from shared import rest, data import json, datetime +import logging def execute_aadrisks_module (req_body): #Inputs AddIncidentComments, AddIncidentTask, Entities, IncidentTaskInstructions, LookbackInDays, MFAFailureLookup, MFAFraudLookup, SuspiciousActivityReportLookup - + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'AAD Risks Module invoked with parameters: {log_params}') + base_object = BaseModule() base_object.load_from_input(req_body['BaseModuleBody']) diff --git a/modules/base.py b/modules/base.py index c196c32..7380f77 100644 --- a/modules/base.py +++ b/modules/base.py @@ -5,6 +5,7 @@ import logging import requests import ipaddress +import datetime as dt stat_version = None @@ -14,6 +15,10 @@ def execute_base_module (req_body): global enrich_roles global enrich_mde_device + # Log module invocation with parameters (excluding incident/alert body data) + log_params = {k: v for k, v in req_body.items() if k != 'Body'} + logging.info(f'Base Module invoked with parameters: {log_params}') + base_object = BaseModule() try: @@ -26,6 +31,10 @@ def execute_base_module (req_body): enrich_roles = req_body.get('EnrichAccountsWithRoles', True) enrich_mde_device = req_body.get('EnrichHostsWithMDE', True) + #Check for Directory.Read.All or combination of other sufficient roles + rest.check_app_role2(base_object, 'msgraph', ['Organization.Read.All', 'Directory.Read.All', 'Organization.ReadWrite.All', 'Directory.ReadWrite.All'], raise_on_fail_to_decode=False) + rest.check_app_role2(base_object, 'msgraph', ['User.Read.All', 'User.ReadWrite.All', 'Directory.Read.All', 'Directory.ReadWrite.All'], raise_on_fail_to_decode=False) + if trigger_type.lower() == 'incident': entities = process_incident_trigger(req_body) else: @@ -43,10 +52,11 @@ def execute_base_module (req_body): enrich_files(entities) enrich_filehashes(entities) enrich_urls(entities) + enrich_mail_message(entities) append_other_entities(entities) base_object.CurrentVersion = data.get_current_version() - base_object.EntitiesCount = base_object.AccountsCount + base_object.IPsCount + base_object.DomainsCount + base_object.FileHashesCount + base_object.FilesCount + base_object.HostsCount + base_object.OtherEntitiesCount + base_object.URLsCount + base_object.EntitiesCount = base_object.AccountsCount + base_object.IPsCount + base_object.DomainsCount + base_object.FileHashesCount + base_object.FilesCount + base_object.HostsCount + base_object.OtherEntitiesCount + base_object.URLsCount + base_object.MailMessagesCount org_info = json.loads(rest.rest_call_get(base_object, api='msgraph', path='/v1.0/organization').content) base_object.TenantDisplayName = org_info['value'][0]['displayName'] @@ -67,15 +77,26 @@ def execute_base_module (req_body): account_comment = '' ip_comment = '' + mail_comment = '' if req_body.get('AddAccountComments', True) and base_object.AccountsCount > 0: - account_comment = 'Account Info:
' + get_account_comment() + account_comment = '

Account Info:

' + get_account_comment() if req_body.get('AddIPComments', True) and base_object.check_global_and_local_ips(): - ip_comment = 'IP Info:
' + get_ip_comment() - - if (req_body.get('AddAccountComments', True) and base_object.AccountsCount > 0) or (req_body.get('AddIPComments', True) and base_object.check_global_and_local_ips()): - comment = account_comment + '

' + ip_comment + ip_comment = '

IP Info:

' + get_ip_comment() + + if req_body.get('AddMailComments', True) and base_object.MailMessages: + mail_comment = '

Mail Message Info:

' + get_mail_comment() + + if (req_body.get('AddAccountComments', True) and base_object.AccountsCount > 0) or (req_body.get('AddIPComments', True) and base_object.check_global_and_local_ips()) or (req_body.get('AddMailComments', True) and base_object.MailMessages): + comment = '' + if account_comment: + comment += account_comment + '

' + if ip_comment: + comment += ip_comment + '

' + if mail_comment: + comment += mail_comment + rest.add_incident_comment(base_object, comment) return Response(base_object) @@ -212,13 +233,81 @@ def enrich_domains(entities): raw_entity = data.coalesce(domain.get('properties'), domain) base_object.Domains.append({'Domain': domain_name, 'RawEntity': raw_entity}) +def enrich_mail_message(entities): + mail_entities = list(filter(lambda x: x['kind'].lower() == 'mailmessage', entities)) + base_object.MailMessagesCount = len(mail_entities) + message_role = rest.check_app_role(base_object, 'msgraph', ['SecurityAnalyzedMessage.Read.All','SecurityAnalyzedMessage.ReadWrite.All']) + + for mail in mail_entities: + recipient = data.coalesce(mail.get('properties',{}).get('recipient'), mail.get('Recipient')) + network_message_id = data.coalesce(mail.get('properties',{}).get('networkMessageId'), mail.get('NetworkMessageId')) + receive_date = data.coalesce(mail.get('properties',{}).get('receiveDate'), mail.get('ReceivedDate')) + + if receive_date: + try: + start_time = (data.convert_from_iso_format(receive_date) + dt.timedelta(days=-14)).strftime("%Y-%m-%dT%H:%M:%SZ") + end_time = (data.convert_from_iso_format(receive_date) + dt.timedelta(days=14)).strftime("%Y-%m-%dT%H:%M:%SZ") + except ValueError as e: + start_time = (data.convert_from_iso_format(base_object.CreatedTime) + dt.timedelta(days=-14)).strftime("%Y-%m-%dT%H:%M:%SZ") + end_time = (data.convert_from_iso_format(base_object.CreatedTime) + dt.timedelta(days=14)).strftime("%Y-%m-%dT%H:%M:%SZ") + else: + start_time = (data.convert_from_iso_format(base_object.CreatedTime) + dt.timedelta(days=-14)).strftime("%Y-%m-%dT%H:%M:%SZ") + end_time = (data.convert_from_iso_format(base_object.CreatedTime) + dt.timedelta(days=14)).strftime("%Y-%m-%dT%H:%M:%SZ") + + raw_entity = data.coalesce(mail.get('properties'), mail) + + if not message_role: + base_object.MailMessages.append({'networkMessageId': network_message_id, 'recipientEmailAddress': recipient, 'EnrichmentMethod': 'MailMessage - No App Role', 'RawEntity': raw_entity}) + logging.warning(f"mailMessage appended without enrichment - Missing app role SecurityAnalyzedMessage.Read.All") + continue + + if recipient and network_message_id: + try: + get_message = json.loads(rest.rest_call_get(base_object, api='msgraph', path=f"/beta/security/collaboration/analyzedemails?startTime={start_time}&endTime={end_time}&filter=networkMessageId eq '{network_message_id}' and recipientEmailAddress eq '{recipient}'").content) + if get_message['value']: + message_details = json.loads(rest.rest_call_get(base_object, api='msgraph', path=f"/beta/security/collaboration/analyzedemails/{get_message['value'][0]['id']}").content) + message_details['RawEntity'] = raw_entity + else: + message_details = { + 'networkMessageId': network_message_id, + 'recipientEmailAddress': recipient, + 'EnrichmentMethod': 'MailMessage - analyzedMessage could not be found', + 'RawEntity': raw_entity + } + except: + message_details = { + 'networkMessageId': network_message_id, + 'recipientEmailAddress': recipient, + 'EnrichmentMethod': 'MailMessage - Failed to get analyzedMessage', + 'RawEntity': raw_entity + } + + else: + message_details = {'EnrichmentMethod': 'MailMessage - No Recipient or NetworkMessageId', 'RawEntity': raw_entity} + + base_object.MailMessages.append(message_details) + + def enrich_files(entities): file_entities = list(filter(lambda x: x['kind'].lower() == 'file', entities)) base_object.FilesCount = len(file_entities) for file in file_entities: raw_entity = data.coalesce(file.get('properties'), file) - base_object.Files.append({'FileName': data.coalesce(file.get('properties',{}).get('friendlyName'), file.get('Name')),'RawEntity': raw_entity}) + file_field = data.coalesce(file.get('properties',{}).get('friendlyName'), file.get('Name')) + file_name = file_field.split('/')[-1].split('\\')[-1] + file_name_path_f = file_field.rsplit('/', 1) + file_name_path_b = file_field.rsplit('\\', 1) + + if len(file_name_path_f) > 1: + file_name_path = file_name_path_f[0] + '/' + elif len(file_name_path_b) > 1: + file_name_path = file_name_path_b[0] + '\\' + else: + file_name_path = None + + file_directory = data.coalesce(file.get('properties',{}).get('directory'), file.get('Directory'), file_name_path, '') + base_object.Files.append({'FileName': file_name, 'FilePath': f'{file_directory}{file_name}', 'RawEntity': raw_entity}) def enrich_filehashes(entities): filehash_entities = list(filter(lambda x: x['kind'].lower() == 'filehash', entities)) @@ -240,7 +329,7 @@ def enrich_urls(entities): base_object.URLs.append({'Url': url_data, 'RawEntity': raw_entity}) def append_other_entities(entities): - other_entities = list(filter(lambda x: x['kind'].lower() not in ('ip','account','dnsresolution','dns','file','filehash','host','url'), entities)) + other_entities = list(filter(lambda x: x['kind'].lower() not in ('ip','account','dnsresolution','dns','file','filehash','host','url','mailmessage'), entities)) base_object.OtherEntitiesCount = len(other_entities) for entity in other_entities: @@ -444,17 +533,25 @@ def get_account_comment(): upn_data = f'{account_upn}
(Contact User)' else: upn_data = account_upn - - account_list.append({'UserPrincipalName': upn_data, 'City': account.get('city'), 'Country': account.get('country'), \ - 'Department': account.get('department'), 'JobTitle': account.get('jobTitle'), 'Office': account.get('officeLocation'), \ - 'AADRoles': account.get('AssignedRoles'), 'ManagerUPN': account.get('manager', {}).get('userPrincipalName'), \ - 'MfaRegistered': account.get('isMfaRegistered'), 'SSPREnabled': account.get('isSSPREnabled'), \ - 'SSPRRegistered': account.get('isSSPRRegistered')}) - + + if upn_data: + account_list.append({ + 'User': f"{upn_data}
JobTitle: {account.get('jobTitle')}", + 'Location': f"Department: {account.get('department')}
Office: {account.get('officeLocation')}
City: {account.get('city')}
Country: {account.get('country')}", + 'OtherDetails': f"AADRoles: {', '.join(account.get('AssignedRoles', []))}
Manager: {account.get('manager', {}).get('userPrincipalName')}
MFA Registered: {account.get('isMfaRegistered')}
SSPR Enabled: {account.get('isSSPREnabled')}
SSPR Registered: {account.get('isSSPRRegistered')}
OnPremSynced: {account.get('onPremisesSyncEnabled')}" + }) + else: + account_list.append({ + 'User': "Unknown User", + 'OtherDetails': f"Failed to lookup account details for 1 account entity
Enrichment Method: {account.get('EnrichmentMethod')}", + }) + for onprem_acct in base_object.AccountsOnPrem: - account_list.append( - {'UserPrincipalName': data.coalesce(onprem_acct.get('userPrincipalName'),onprem_acct.get('onPremisesSamAccountName')), 'Department': onprem_acct.get('department'), 'JobTitle': onprem_acct.get('jobTitle'), 'ManagerUPN': onprem_acct.get('manager'), 'Notes': 'On-Prem - No Entra Sync'} - ) + account_list.append({ + 'User': f"{data.coalesce(onprem_acct.get('userPrincipalName'),onprem_acct.get('onPremisesSamAccountName'))}
JobTitle: {onprem_acct.get('jobTitle')}", + 'Location': f"Department: {onprem_acct.get('department')}", + 'OtherDetails': f"Manager: {onprem_acct.get('manager')}
OnPremSynced: On-Prem Only" + }) return data.list_to_html_table(account_list, 20, 20, escape_html=False) @@ -465,12 +562,38 @@ def get_ip_comment(): if ip.get('IPType') != 3: #Excludes link local addresses from the IP comment geo = ip.get('GeoData') - ip_list.append({'IP': ip.get('Address'), 'City': geo.get('city'), 'State': geo.get('state'), 'Country': geo.get('country'), \ - 'Organization': geo.get('organization'), 'OrganizationType': geo.get('organizationType'), 'ASN': geo.get('asn'), 'IPType': ip.get('IPType')}) + + ip_list.append({ + 'IP': ip.get('Address'), + 'Location': f"City: {geo.get('city', 'Unknown')}
State: {geo.get('state', 'Unknown')}
Country: {geo.get('country', 'Unknown')}", + 'OtherDetails': f"Organization: {geo.get('organization', 'Unknown')}
OrganizationType: {geo.get('organizationType', 'Unknown')}
ASN: {geo.get('asn', 'Unknown')}", + 'IPType': ip.get('IPType') + }) ip_list = data.sort_list_by_key(ip_list, 'IPType', ascending=True, drop_columns=['IPType']) - return data.list_to_html_table(ip_list) + return data.list_to_html_table(ip_list, escape_html=False) + +def get_mail_comment(): + + mail_list = [] + for msg in base_object.MailMessages: + if msg.get('EnrichmentMethod'): + mail_list.append({ + 'MessageDetails': f"NetworkMessageId: {msg.get('networkMessageId')}
Recipient: {msg.get('recipientEmailAddress', 'Unknown')}", + 'EnrichmentMethod': f"Enrichment Method: {msg.get('EnrichmentMethod')}", + }) + else: + msg_time = msg.get('loggedDateTime') + explorer_link = f"https://security.microsoft.com/emailentity?f=summary&startTime={msg_time}&endTime={msg_time}&id={msg.get('networkMessageId')}&recipient={msg.get('recipientEmailAddress')}&tid={base_object.TenantId}" + mail_list.append({ + 'MessageDetails': f"Recipient: {msg.get('recipientEmailAddress')}
Sender: {msg.get('senderDetail', {}).get('fromAddress')}
SenderFromAddress: {msg.get('senderDetail', {}).get('mailFromAddress')}
Subject: {msg.get('subject')}
AttachmentCount: {len(msg.get('attachments', []))}
URLCount: {len(msg.get('urls', []))}
(Open Entity Page)", + 'Delivery': f"Original Delivery: {msg.get('originalDelivery', {}).get('location')}
Latest Delivery: {msg.get('latestDelivery', {}).get('location')}", + 'Authentication': f"SPF: {msg.get('authenticationDetails', {}).get('senderPolicyFramework')}
DKIM: {msg.get('authenticationDetails', {}).get('dkim')}
DMARC: {msg.get('authenticationDetails', {}).get('dmarc')}", + 'ThreatInfo': f"ThreatTypes: {', '.join(msg.get('threatTypes', []))}
DetectionMethods: {', '.join(msg.get('detectionMethods', []))}" + }) + + return data.list_to_html_table(mail_list, escape_html=False) def get_stat_version(version_check_type): diff --git a/modules/createincident.py b/modules/createincident.py index 0229197..2864e8a 100644 --- a/modules/createincident.py +++ b/modules/createincident.py @@ -2,9 +2,14 @@ from shared import rest, data import json import uuid +import logging def execute_create_incident (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'Create Incident Module invoked with parameters: {log_params}') + #Inputs: Severity, Title, Description base_object = BaseModule() diff --git a/modules/exchange.py b/modules/exchange.py index 316f622..21ba481 100644 --- a/modules/exchange.py +++ b/modules/exchange.py @@ -3,9 +3,14 @@ import json import re from datetime import datetime, timezone +import logging def execute_exchange_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'Exchange Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, Entities, IncidentTaskInstructions, KQLQuery, LookbackInDays, QueryDescription, RunQueryAgainst base_object = BaseModule() @@ -90,7 +95,7 @@ def append_enabled(exch:ExchangeModule, upn, internal, external): replace_nbsp = re.compile(' |\\n') int_msg = re.sub(replace_nbsp, ' ', re.sub(clean_html, '', internal)) ext_msg = re.sub(replace_nbsp, ' ', re.sub(clean_html, '', external)) - exch.OOF.append({'ExternalMessage': ext_msg, 'InternalMessage': int_msg, 'OOFStatus': 'enabled', 'UPN': upn}) + exch.OOF.append({'ExternalMessage': ext_msg.strip(), 'InternalMessage': int_msg.strip(), 'OOFStatus': 'enabled', 'UPN': upn}) def audit_check(base_object:BaseModule, exch:ExchangeModule, module_lookback:int): #Retrieve OfficeActivity Audits @@ -113,14 +118,23 @@ def oof_check(base_object:BaseModule, exch:ExchangeModule, upn): raise STATError(e.error, e.source_error, e.status_code) else: current_time = datetime.now(timezone.utc) + try: + oof_start = datetime.strptime(results['scheduledStartDateTime']['dateTime'][:-3], "%Y-%m-%dT%H:%M:%S.%f").replace(tzinfo=timezone.utc) + oof_end = datetime.strptime(results['scheduledEndDateTime']['dateTime'][:-3], "%Y-%m-%dT%H:%M:%S.%f").replace(tzinfo=timezone.utc) + except: + oof_start = None + oof_end = None + if results['status'].lower() == 'disabled': exch.UsersInOffice += 1 append_disabled(exch, upn) elif results['status'].lower() == 'enabled' or results['status'].lower() == 'alwaysenabled': exch.UsersOutOfOffice += 1 append_enabled(exch, upn, results['internalReplyMessage'], results['externalReplyMessage']) - elif results['status'].lower() == 'scheduled' and current_time >= results['scheduledStartDateTime']['dateTime'] \ - and current_time <= results['scheduledEndDateTime']['dateTime']: + elif not(oof_start): + exch.UsersInOffice += 1 + append_disabled(exch, upn) + elif results['status'].lower() == 'scheduled' and current_time >= oof_start and current_time <= oof_end: exch.UsersOutOfOffice += 1 append_enabled(exch, upn, results['internalReplyMessage'], results['externalReplyMessage']) else: @@ -138,7 +152,7 @@ def rule_check(base_object:BaseModule, exch:ExchangeModule, account): except STATNotFound: exch.UsersUnknown += 1 return - + privileged_roles = data.load_json_from_file('privileged-roles.json') if set(account_roles).intersection(privileged_roles): exch.PrivilegedUsersWithMailbox += 1 diff --git a/modules/exposure_device.py b/modules/exposure_device.py index f221793..f7f0a42 100644 --- a/modules/exposure_device.py +++ b/modules/exposure_device.py @@ -1,9 +1,14 @@ from classes import BaseModule, Response, DeviceExposureModule from shared import rest, data import json +import logging def execute_device_exposure_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'Device Exposure Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, BaseModuleBody, IncidentTaskInstructions, AddIncidentTags base_object = BaseModule() diff --git a/modules/exposure_user.py b/modules/exposure_user.py index 91c82c1..84c15c0 100644 --- a/modules/exposure_user.py +++ b/modules/exposure_user.py @@ -1,9 +1,14 @@ from classes import BaseModule, Response, UserExposureModule from shared import rest, data import json +import logging def execute_user_exposure_module (req_body): - + + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'User Exposure Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, BaseModuleBody, IncidentTaskInstructions, AddIncidentTags base_object = BaseModule() diff --git a/modules/file.py b/modules/file.py index 18357a9..aa83398 100644 --- a/modules/file.py +++ b/modules/file.py @@ -1,9 +1,14 @@ from classes import BaseModule, Response, FileModule, STATError, STATNotFound from shared import rest, data import json +import logging def execute_file_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'File Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, BaseModuleBody, IncidentTaskInstructions base_object = BaseModule() diff --git a/modules/kql.py b/modules/kql.py index 2a67995..ea50c4f 100644 --- a/modules/kql.py +++ b/modules/kql.py @@ -1,10 +1,15 @@ from classes import BaseModule, Response, KQLModule from shared import rest, data +import logging def execute_kql_module (req_body): #Inputs AddIncidentComments, AddIncidentTask, Entities, IncidentTaskInstructions, KQLQuery, LookbackInDays, QueryDescription, RunQueryAgainst + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'KQL Module invoked with parameters: {log_params}') + base_object = BaseModule() base_object.load_from_input(req_body['BaseModuleBody']) endpoint = req_body.get('Endpoint', 'query').lower() @@ -15,8 +20,12 @@ def execute_kql_module (req_body): ip_entities = base_object.get_ip_kql_table() account_entities = base_object.get_account_kql_table(include_unsynced=True) host_entities = base_object.get_host_kql_table() + mail_entities = base_object.get_mail_kql_table() + + query = arm_id + ip_entities + account_entities + host_entities + mail_entities + req_body['KQLQuery'] - query = arm_id + ip_entities + account_entities + host_entities + req_body['KQLQuery'] + # Log the executed KQL query for troubleshooting + logging.info(f'Executing KQL query: {query}') if req_body.get('RunQueryAgainst') == 'M365': results = rest.execute_m365d_query(base_object, query) diff --git a/modules/mdca.py b/modules/mdca.py index 79f42a1..6113397 100644 --- a/modules/mdca.py +++ b/modules/mdca.py @@ -1,9 +1,14 @@ from classes import BaseModule, Response, MDCAModule, STATError, STATNotFound from shared import rest, data import json,os,base64 +import logging def execute_mdca_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'MDCA Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, ScoreThreshold, TopUserThreshold base_object = BaseModule() diff --git a/modules/mde.py b/modules/mde.py index ed76827..3608fa3 100644 --- a/modules/mde.py +++ b/modules/mde.py @@ -1,9 +1,14 @@ from classes import BaseModule, Response, MDEModule from shared import rest, data import json +import logging def execute_mde_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'MDE Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, Entities, IncidentTaskInstructions base_object = BaseModule() diff --git a/modules/playbook.py b/modules/playbook.py index ffa9562..cf30a82 100644 --- a/modules/playbook.py +++ b/modules/playbook.py @@ -1,8 +1,13 @@ from classes import BaseModule, Response, STATError, RunPlaybook from shared import rest +import logging def execute_playbook_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'Playbook Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, LogicAppResourceId, PlaybookName, TenantId base_object = BaseModule() diff --git a/modules/relatedalerts.py b/modules/relatedalerts.py index 9063cdf..da86970 100644 --- a/modules/relatedalerts.py +++ b/modules/relatedalerts.py @@ -2,9 +2,14 @@ from shared import rest, data import datetime as dt import json +import logging def execute_relatedalerts_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'Related Alerts Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, BaseModuleBody, IncidentTaskInstructions #LookbackInDays, CheckAccountEntityMatches, CheckHostEntityMatches, CheckIPEntityMatches, AlertKQLFilter diff --git a/modules/scoring.py b/modules/scoring.py index a23a038..d3031c7 100644 --- a/modules/scoring.py +++ b/modules/scoring.py @@ -1,8 +1,14 @@ from classes import * from shared import data, rest +import logging +import json def execute_scoring_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody, ScoringData) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody' and k != 'ScoringData'} + logging.info(f'Scoring Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, BaseModuleBody, IncidentTaskInstructions, ScoringData base_object = BaseModule() @@ -14,6 +20,14 @@ def execute_scoring_module (req_body): for input_module in req_body['ScoringData']: module_body = input_module['ModuleBody'] + + # Convert ModuleBody from string if submitted as a JSON string + if isinstance(module_body, str): + try: + module_body = json.loads(module_body) + except json.JSONDecodeError as e: + raise STATError(f'Failed to parse ModuleBody of scoring data for item with label {input_module.get("ScoreLabel","Unknown")}, invalid JSON string', {'Error': str(e)}) + module = module_body.get('ModuleName') label = input_module.get('ScoreLabel', module) multiplier = float(input_module.get('ScoreMultiplier', 1)) diff --git a/modules/ti.py b/modules/ti.py index 08d41f0..a49a29f 100644 --- a/modules/ti.py +++ b/modules/ti.py @@ -1,7 +1,12 @@ from classes import BaseModule, Response, TIModule, STATError from shared import rest, data +import logging def execute_ti_module (req_body): + + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'Threat Intelligence Module invoked with parameters: {log_params}') #Inputs AddIncidentComments, AddIncidentTask, BaseModuleBody, IncidentTaskInstructions, CheckDomains, CheckFileHashes, CheckIPs, CheckURLs diff --git a/modules/ueba.py b/modules/ueba.py index f338430..fc13f5d 100644 --- a/modules/ueba.py +++ b/modules/ueba.py @@ -1,9 +1,14 @@ from classes import BaseModule, Response, UEBAModule from shared import rest, data import ast +import logging def execute_ueba_module (req_body): + # Log module invocation with parameters (excluding BaseModuleBody) + log_params = {k: v for k, v in req_body.items() if k != 'BaseModuleBody'} + logging.info(f'UEBA Module invoked with parameters: {log_params}') + #Inputs AddIncidentComments, AddIncidentTask, BaseModuleBody, IncidentTaskInstructions #LookbackInDays, MinimumInvestigationPriority diff --git a/modules/version.json b/modules/version.json index d20a644..a2445d3 100644 --- a/modules/version.json +++ b/modules/version.json @@ -1,3 +1,3 @@ { - "FunctionVersion": "2.2.0" + "FunctionVersion": "2.3.0" } diff --git a/requirements.txt b/requirements.txt index 0c034fc..d0b7ec2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,3 @@ azure-identity azure-keyvault-secrets requests pandas -cryptography==43.0.3 - -#Limiting cryptograpy due to https://github.com/Azure/azure-functions-python-worker/issues/1651 \ No newline at end of file diff --git a/shared/data.py b/shared/data.py index 27b9bfe..bf7788f 100644 --- a/shared/data.py +++ b/shared/data.py @@ -3,6 +3,7 @@ import pathlib import json import os +from datetime import datetime, timezone date_format = os.getenv('DATE_FORMAT') tz_info = os.getenv('TIME_ZONE') @@ -412,4 +413,21 @@ def parse_kv_string(kv:str, item_delimitter:str=';', value_delimitter:str='='): result[key.strip()] = value.strip() else: result[pair.strip()] = None - return result \ No newline at end of file + return result + +def convert_from_iso_format(date_string): + """Convert an ISO 8601 date string to a datetime object and assume UTC if no time zone is provided. + + Args: + date_string (str): The ISO 8601 date string to convert. + + Returns: + datetime: A datetime object representing the date and time with time zone information. + """ + + ts = datetime.fromisoformat(date_string) + + if ts.tzinfo is not None and ts.tzinfo.utcoffset(ts) is not None: + return ts + else: + return ts.replace(tzinfo=timezone.utc) diff --git a/shared/rest.py b/shared/rest.py index 3539923..3d516b9 100644 --- a/shared/rest.py +++ b/shared/rest.py @@ -8,7 +8,8 @@ import uuid import time import base64 -from classes import STATError, STATNotFound, BaseModule, STATTooManyRequests +import logging +from classes import STATError, STATNotFound, BaseModule, STATTooManyRequests, STATServerError, STATFailedToDecodeToken, STATInsufficientPermissions stat_token = {} graph_endpoint = os.getenv('GRAPH_ENDPOINT') @@ -199,12 +200,23 @@ def execute_rest_call(base_module:BaseModule, method:str, api:str, path:str, bod retry_after = 10 wait_time += retry_after if wait_time > 60: + logging.warning(f'The API call to {api} with path {path} failed with {e.source_error}. Maximum retry time exceeded.') raise STATTooManyRequests(error=e.error, source_error=e.source_error, status_code=e.status_code, retry_after=e.retry_after) + logging.info(f'The API call to {api} with path {path} failed with status {e.source_error}. Retrying.') time.sleep(retry_after) + except STATServerError as e: + wait_time += 15 + if wait_time > 60: + logging.warning(f'The API call to {api} with path {path} failed with {e.source_error}. Maximum retry time exceeded.') + raise STATServerError(error=f'Server error returned by {url}', source_error=e.source_error, status_code=500) + logging.info(f'The API call to {api} with path {path} failed with status {e.source_error}. Retrying.') + time.sleep(15) except ConnectionError as e: wait_time += 20 - if wait_time >= 60: + if wait_time > 60: + logging.warning(f'The API call to {api} with path {path} failed with status {e}. Maximum retry time exceeded.') raise STATError(error=f'Failed to establish a new connection to {url}', source_error=e, status_code=500) + logging.info(f'The API call to {api} with path {path} failed with status Connection Error. Retrying.') time.sleep(20) else: retry_call = False @@ -214,10 +226,14 @@ def execute_rest_call(base_module:BaseModule, method:str, api:str, path:str, bod def check_rest_response(response:Response, api, path): if response.status_code == 404: + logging.info(f'The API call to {api} with path {path} returned a 404 Not Found error, in some cases this is expected.') raise STATNotFound(f'The API call to {api} with path {path} failed with status {response.status_code}', source_error={'status_code': int(response.status_code), 'reason': str(response.reason)}) - elif response.status_code == 429 or response.status_code == 408: + elif response.status_code in (429, 408): raise STATTooManyRequests(f'The API call to {api} with path {path} failed with status {response.status_code}', source_error={'status_code': int(response.status_code), 'reason': str(response.reason)}, retry_after=response.headers.get('Retry-After', 10), status_code=int(response.status_code)) + elif response.status_code >= 500: + raise STATServerError(f'The API call to {api} with path {path} failed with status {response.status_code}', source_error={'status_code': int(response.status_code), 'reason': str(response.reason)}, status_code=int(response.status_code)) elif response.status_code >= 300: + logging.warning(f'The API call to {api} with path {path} failed with status {response.status_code} and reason {str(response.reason)}. No Retries will be attempted.') raise STATError(f'The API call to {api} with path {path} failed with status {response.status_code}', source_error={'status_code': int(response.status_code), 'reason': str(response.reason)}) return @@ -313,7 +329,7 @@ def get_endpoint(api:str): 'Ensure that all API endpoint enrivonrment variables are correctly set in the STAT Function App ' '(ARM_ENDPOINT, GRAPH_ENDPOINT, LOGANALYTICS_ENDPOINT, M365_ENDPOINT, and MDE_ENDPOINT).') -def add_incident_comment(base_module:BaseModule, comment:str): +def add_incident_comment(base_module:BaseModule, comment:str, raise_on_error:bool=False): """Add a comment to a Microsoft Sentinel incident. Creates a new comment on the specified incident using the Azure REST API. @@ -322,6 +338,7 @@ def add_incident_comment(base_module:BaseModule, comment:str): Args: base_module (BaseModule): Base module containing incident information. comment (str): Comment text to add to the incident. + raise_on_error (bool): Whether to raise an exception on error. Defaults to False. Returns: Response or str: API response object on success, or 'Comment failed' on error. @@ -330,26 +347,44 @@ def add_incident_comment(base_module:BaseModule, comment:str): path = base_module.IncidentARMId + '/comments/' + str(uuid.uuid4()) + '?api-version=2023-02-01' try: response = rest_call_put(base_module, 'arm', path, {'properties': {'message': comment[:30000]}}) + except STATError as e: + logging.warning(f'Failed to add comment to incident {base_module.IncidentARMId} with error: {e.source_error}, status code: {e.status_code}') + if raise_on_error: + raise + else: + response = 'Comment failed' except: response = 'Comment failed' return response -def add_incident_task(base_module:BaseModule, title:str, description:str, status:str='New'): +def add_incident_task(base_module:BaseModule, title:str, description:str, status:str='New', raise_on_error:bool=False): path = base_module.IncidentARMId + '/tasks/' + str(uuid.uuid4()) + '?api-version=2024-03-01' if description is None or description == '': try: response = rest_call_put(base_module, 'arm', path, {'properties': {'title': title, 'status': status}}) + except STATError as e: + logging.warning(f'Failed to add task to incident {base_module.IncidentARMId} with error: {e.source_error}, status code: {e.status_code}') + if raise_on_error: + raise + else: + response = 'Task addition failed' except: response = 'Task addition failed' else: try: response = rest_call_put(base_module, 'arm', path, {'properties': {'title': title, 'description': description[:3000], 'status': status}}) + except STATError as e: + logging.warning(f'Failed to add task to incident {base_module.IncidentARMId} with error: {e.source_error}, status code: {e.status_code}') + if raise_on_error: + raise + else: + response = 'Task addition failed' except: response = 'Task addition failed' return response -def add_incident_tags(base_module:BaseModule, tags:list): +def add_incident_tags(base_module:BaseModule, tags:list, raise_on_error:bool=False): if tags: path = base_module.IncidentARMId + '?api-version=2025-03-01' tags_to_add = False @@ -378,20 +413,76 @@ def add_incident_tags(base_module:BaseModule, tags:list): response_put = rest_call_put(base_module, 'arm', path, body=body) else: response_put = 'No new tags to add' + except STATError as e: + logging.warning(f'Failed to add tag to incident {base_module.IncidentARMId} with error: {e.source_error}, status code: {e.status_code}') + if raise_on_error: + raise + else: + response_put = 'Tag addition failed' except: response_put = 'Tag addition failed' return response_put def check_app_role(base_module:BaseModule, token_type:str, app_roles:list): + """Check if the current application token has at least one of the required roles. + Args: + base_module (BaseModule): Base module containing incident information. + token_type (str): Type of token to check ('msgraph', 'la', 'm365', 'mde'). + app_roles (list): List of required application roles to check against the token. + Returns: + bool: True if the token has at least one of the required roles, False otherwise. + Raises: + STATFailedToDecodeToken: If the JWT token cannot be decoded to check for roles. + """ token = token_cache(base_module, token_type) - content = token.token.split('.')[1] + '==' - b64_decoded = base64.urlsafe_b64decode(content) - decoded_token = json.loads(b64_decoded) - token_roles = decoded_token.get('roles') + try: + content = token.token.split('.')[1] + '==' + b64_decoded = base64.urlsafe_b64decode(content) + decoded_token = json.loads(b64_decoded) + token_roles = decoded_token.get('roles', []) + except: + logging.warning(f'Failed to decode the JWT token for {token_type}.') + raise STATFailedToDecodeToken(f'Failed to decode the JWT token for {token_type}. Ensure the token is valid and properly formatted.') matched_roles = [item for item in app_roles if item in token_roles] if matched_roles: return True return False - \ No newline at end of file + +def check_app_role2(base_module:BaseModule, token_type:str, app_roles:list, raise_on_fail_to_decode:bool=False, token:str=None): + """Check if the current application token has at least one of the required roles. + Args: + base_module (BaseModule): Base module containing incident information. + token_type (str): Type of token to check ('msgraph', 'la', 'm365', 'mde'). + app_roles (list): List of required application roles to check against the token. + raise_on_fail_to_decode (bool): Whether to raise an exception if the token cannot be decoded. Defaults to False. + token (str, optional): JWT token to check. If not provided, it will be fetched from the cache. + Returns: + None: If the token has at least one of the required roles. + Raises: + STATFailedToDecodeToken: If the JWT token cannot be decoded to check for roles. + STATInsufficientPermissions: If the token does not have sufficient permissions. + """ + if not token: + token = token_cache(base_module, token_type) + + try: + content = token.token.split('.')[1] + '==' + b64_decoded = base64.urlsafe_b64decode(content) + decoded_token = json.loads(b64_decoded) + token_roles = decoded_token.get('roles', []) + except: + if raise_on_fail_to_decode: + logging.warning(f'Failed to decode the JWT token for {token_type}, raising exception.') + raise STATFailedToDecodeToken(f'Failed to decode the JWT token for {token_type}. Ensure the token is valid and properly formatted.') + else: + logging.warning(f'Failed to decode the JWT token for {token_type}, returning without exception.') + return True + + matched_roles = [item for item in app_roles if item in token_roles] + if matched_roles: + return True + else: + raise STATInsufficientPermissions(f'The Microsoft Sentinel Triage AssistanT identity does not have sufficient permissions to perform this operation. Please ensure to run the GrantPermissions.ps1 script against the identity used by the STAT function app.') + \ No newline at end of file diff --git a/tests/test_data.py b/tests/test_data.py index 439736d..635d827 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,4 +1,5 @@ from shared import data +from datetime import datetime import pytest def test_return_highest_value(): @@ -202,6 +203,23 @@ def test_load_json_from_file(): assert "Global Administrator" in content assert "Security Administrator" in content +def test_datetime_conversion(): + """Test convert_from_iso_format with various date strings""" + converted_date = data.convert_from_iso_format("2025-07-25T15:00:00.123456Z") + converted_date2 = data.convert_from_iso_format("2025-07-25T15:00:00.1234567Z") + converted_date3 = data.convert_from_iso_format("2025-07-25T15:00:00.123456789+00:00") + converted_date4 = data.convert_from_iso_format("2025-07-25T15:00:00.123456789") + converted_date5 = data.convert_from_iso_format("2025-07-25T15:00:00Z") + converted_date6 = data.convert_from_iso_format("2025-07-25T15:00:00.000000Z") + converted_date7 = data.convert_from_iso_format("2025-07-25T16:00:00.000000+01:00") + + assert isinstance(converted_date, datetime) + assert converted_date == converted_date2 + assert converted_date == converted_date3 + assert converted_date == converted_date4 + assert converted_date5 == converted_date6 + assert converted_date5 == converted_date7 + def list_data(): test_data = [ { diff --git a/tests/test_rest.py b/tests/test_rest.py index fd9ce58..bae709b 100644 --- a/tests/test_rest.py +++ b/tests/test_rest.py @@ -1,5 +1,5 @@ from shared import rest -from classes import BaseModule, STATError, STATNotFound, STATTooManyRequests +from classes import BaseModule, STATError, STATNotFound, STATTooManyRequests, STATInsufficientPermissions, STATFailedToDecodeToken import json, os import requests import pytest @@ -49,6 +49,24 @@ def test_rest_response(): assert None is rest.check_rest_response(response_200, 'test', 'test') +def test_check_app_role(): + base = BaseModule() + assert rest.check_app_role(base, 'msgraph', ['User.ReadWrite.All', 'Directory.Read.All']) == True + assert rest.check_app_role(base, 'msgraph', ['User.Read.All']) == False + + +def test_check_app_role2(): + base = BaseModule() + assert rest.check_app_role2(base, 'msgraph', ['User.ReadWrite.All', 'Directory.Read.All'], raise_on_fail_to_decode=True) == True + + with pytest.raises(STATInsufficientPermissions): + rest.check_app_role2(base, 'msgraph', ['Directory.ReadWrite.All', 'Mail.Send'], raise_on_fail_to_decode=True) + + assert rest.check_app_role2(base, 'msgraph', ['User.Read.All'], raise_on_fail_to_decode=False, token='abc') == True + + with pytest.raises(STATFailedToDecodeToken): + rest.check_app_role2(base, 'msgraph', ['User.Read.All'], raise_on_fail_to_decode=True, token='abc') + def get_base_module_object(): base_module_body = json.loads(requests.get(url=os.getenv('BASEDATA')).content) diff --git a/tests/test_stat_base.py b/tests/test_stat_base.py index 7aa6254..136a0e8 100644 --- a/tests/test_stat_base.py +++ b/tests/test_stat_base.py @@ -7,6 +7,7 @@ def test_base_module_incident(): assert base_response.statuscode == 200 assert base_response.body.AccountsCount == 2 + assert base_response.body.MailMessagesCount == 1 assert len(base_response.body.Accounts) == base_response.body.AccountsCount assert len(base_response.body.Domains) == base_response.body.DomainsCount assert len(base_response.body.FileHashes) == base_response.body.FileHashesCount @@ -15,6 +16,7 @@ def test_base_module_incident(): assert len(base_response.body.IPs) == base_response.body.IPsCount assert len(base_response.body.URLs) == base_response.body.URLsCount assert len(base_response.body.OtherEntities) == base_response.body.OtherEntitiesCount + assert len(base_response.body.MailMessages) == base_response.body.MailMessagesCount def test_base_module_alert(): @@ -29,6 +31,7 @@ def test_base_module_alert(): assert len(base_response.body.IPs) == base_response.body.IPsCount assert len(base_response.body.URLs) == base_response.body.URLsCount assert len(base_response.body.OtherEntities) == base_response.body.OtherEntitiesCount + assert len(base_response.body.MailMessages) == base_response.body.MailMessagesCount def get_incident_trigger_data(): trigger_data = json.loads(requests.get(url=os.getenv('INCIDENTDATA')).content)