diff --git a/dbt_cloud_plugin/__init__.py b/dbt_cloud_plugin/__init__.py index 085d692..df320c5 100644 --- a/dbt_cloud_plugin/__init__.py +++ b/dbt_cloud_plugin/__init__.py @@ -1,10 +1 @@ -from airflow.plugins_manager import AirflowPlugin -from dbt_cloud_plugin.hooks.dbt_cloud_hook import DbtCloudHook -from dbt_cloud_plugin.operators.dbt_cloud_run_job_operator import DbtCloudRunJobOperator -from dbt_cloud_plugin.sensors.dbt_cloud_run_sensor import DbtCloudRunSensor - -class DbtCloudPlugin(AirflowPlugin): - name = "dbt_cloud_plugin" - operators = [DbtCloudRunJobOperator] - hooks = [DbtCloudHook] - sensors = [DbtCloudRunSensor] +from .helpers import generate_dbt_model_dependency diff --git a/dbt_cloud_plugin/dbt_cloud/dbt_cloud.py b/dbt_cloud_plugin/dbt_cloud/dbt_cloud.py index 7b0afdb..906ecc7 100644 --- a/dbt_cloud_plugin/dbt_cloud/dbt_cloud.py +++ b/dbt_cloud_plugin/dbt_cloud/dbt_cloud.py @@ -3,6 +3,9 @@ import requests import time +from airflow.exceptions import AirflowException + + class DbtCloud(object): """ Class for interacting with the dbt Cloud API @@ -25,19 +28,23 @@ def _get(self, url_suffix): if response.status_code == 200: return json.loads(response.content) else: - raise RuntimeError(response.content) + raise RuntimeError(f'Error getting URL {url}:\n{str(response.content)}') def _post(self, url_suffix, data=None): url = self.api_base + url_suffix headers = {'Authorization': 'token %s' % self.api_token} - response = requests.post(url, headers=headers, data=data) + response = requests.post(url, headers=headers, json=data) if response.status_code == 200: return json.loads(response.content) else: raise RuntimeError(response.content) - def list_jobs(self): - return self._get('/accounts/%s/jobs/' % self.account_id).get('data') + def list_jobs(self, environment_id=None): + jobs = self._get('/accounts/%s/jobs/' % self.account_id).get('data') + if environment_id is not None: + return [j for j in jobs if str(j['environment_id']) == str(environment_id)] + else: + return jobs def get_run(self, run_id): return self._get('/accounts/%s/runs/%s/' % (self.account_id, run_id)).get('data') @@ -56,8 +63,8 @@ def try_get_run(self, run_id, max_tries=3): raise RuntimeError("Too many failures ({}) while querying for run status".format(run_id)) - def run_job(self, job_name, data=None): - jobs = self.list_jobs() + def run_job(self, job_name, data=None, environment_id=None): + jobs = self.list_jobs(environment_id=environment_id) job_matches = [j for j in jobs if j['name'] == job_name] @@ -67,3 +74,17 @@ def run_job(self, job_name, data=None): job_def = job_matches[0] trigger_resp = self.trigger_job_run(job_id=job_def['id'], data=data) return trigger_resp + + def get_artifact(self, run_id, artifact_filename, step=None): + if step is not None: + query_string = f'?step={step}' + else: + query_string = '' + + return self._get( + f'/accounts/{self.account_id}/runs/{run_id}/artifacts/{artifact_filename}{query_string}' + ) + + def get_job(self, job_id): + return self._get(f'/accounts/{self.account_id}/jobs/{job_id}').get('data') + diff --git a/dbt_cloud_plugin/helpers.py b/dbt_cloud_plugin/helpers.py new file mode 100644 index 0000000..efc4fea --- /dev/null +++ b/dbt_cloud_plugin/helpers.py @@ -0,0 +1,91 @@ +from airflow.utils.task_group import TaskGroup +from airflow.models import BaseOperator +from airflow.operators.python_operator import ShortCircuitOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + +from .operators.dbt_cloud_check_model_result_operator import DbtCloudCheckModelResultOperator +from .operators.dbt_cloud_run_job_operator import DbtCloudRunJobOperator + + +class DbtCloudRunException(AirflowException): + @apply_defaults + def __init__(self, dbt_cloud_run_id: int, dbt_cloud_account_id: int, dbt_cloud_project_id: int, error_message: str, dbt_errors_dict: dict, *args, **kwargs): + if dbt_cloud_run_id is None: + raise ValueError('dbt_cloud_run_id cannot be None.') + if dbt_cloud_account_id is None: + raise ValueError('dbt_cloud_run_id cannot be None.') + if dbt_cloud_project_id is None: + raise ValueError('dbt_cloud_run_id cannot be None.') + if error_message is None: + raise ValueError('error_message cannot be None.') + if dbt_errors_dict is None: + raise ValueError('dbt_errors_dict cannot be None.') + + self.dbt_cloud_run_id = dbt_cloud_run_id + self.dbt_cloud_account_id = dbt_cloud_account_id + self.dbt_cloud_project_id = dbt_cloud_project_id + self.error_message = error_message + self.dbt_errors_dict = dbt_errors_dict + + super(AirflowException, self).__init__(error_message, *args, **kwargs) + +def generate_dbt_model_dependency(dbt_job_task, downstream_tasks, dependent_models, ensure_models_ran=True, retries=0, params=None): + """ + Create a dependency from one or more tasks on a set of models succeeding + in a dbt task. This function generates a new DbtCloudCheckModelResultOperator + task between dbt_job_task and downstream_tasks, checking that dependent_models + all ran successfully. + + :param dbt_job_task: The dbt Cloud operator which kicked off the run you want to check. + Both the credentials and the run_id will be pulled from this task. + :type dbt_job_task: DbtCloudRunJobOperator or DbtCloudRunAndWatchJobOperator + :param downstream_tasks: The downstream task(s) which depend on the model(s) succeeding. + Can be either a single task, a single TaskGroup, or a list of tasks. + :type downstream_tasks: BaseOperator or TaskGroup or list[BaseOperator] or list[TaskGroup] + :param dependent_models: The name(s) of the model(s) to check. See + DbtCloudCheckModelResultOperator for more details. + :type dependent_models: str or list[str] + :param ensure_models_ran: Whether to require that the dependent_models actually ran in + the run. If False, it will silently ignore models that didn't run. + :type ensure_models_ran: bool, default True + """ + + if not isinstance(dbt_job_task, DbtCloudRunJobOperator): + raise TypeError('dbt_job_task must be of type DbtCloudRunJobOperator or DbtCloudRunAndWatchOperator') + + if isinstance(downstream_tasks, list): + if len(downstream_tasks) == 0: + raise ValueError('You must pass at least one task in downstream_tasks') + if not (isinstance(downstream_tasks[0], BaseOperator) or isinstance(downstream_tasks[0], TaskGroup)): + raise TypeError('The elements of the downstream_tasks list must be of type BaseOperator or TaskGRoup') + elif not (isinstance(downstream_tasks, TaskGroup) or isinstance(downstream_tasks, BaseOperator)): + raise TypeError('downstream_tasks must be of one of the following types: BaseOperator, TaskGroup, or a list of one of those two') + + if isinstance(dependent_models, str): + dependent_models = [dependent_models] + model_ids = '__'.join(dependent_models) + task_id = f'check_dbt_model_results__{dbt_job_task.task_id}__{model_ids}' + task_id = task_id[:255].replace('.', '__') + + with TaskGroup(group_id=task_id) as check_dbt_model_results: + check_upstream_dbt_job_state = ShortCircuitOperator( + task_id='check_upstream_dbt_job_state', + python_callable=lambda **context: context['dag_run'].get_task_instance(dbt_job_task.task_id).state in ['success', 'failed'], + trigger_rule='all_done', + provide_context=True + ) + + check_dbt_model_successful = DbtCloudCheckModelResultOperator( + task_id='check_dbt_model_successful', + dbt_cloud_conn_id=dbt_job_task.dbt_cloud_conn_id, + dbt_cloud_run_id=f'{{{{ ti.xcom_pull(task_ids="{dbt_job_task.task_id}", key="dbt_cloud_run_id") }}}}', + model_names=dependent_models, + ensure_models_ran=ensure_models_ran, + retries=retries, + params=params + ) + + check_upstream_dbt_job_state >> check_dbt_model_successful + + return dbt_job_task >> check_dbt_model_results >> downstream_tasks diff --git a/dbt_cloud_plugin/hooks/dbt_cloud_hook.py b/dbt_cloud_plugin/hooks/dbt_cloud_hook.py index f2f4980..1076693 100644 --- a/dbt_cloud_plugin/hooks/dbt_cloud_hook.py +++ b/dbt_cloud_plugin/hooks/dbt_cloud_hook.py @@ -1,4 +1,7 @@ -from dbt_cloud_plugin.dbt_cloud.dbt_cloud import DbtCloud +import time +import warnings + +from ..dbt_cloud.dbt_cloud import DbtCloud from airflow.hooks.base_hook import BaseHook from airflow.exceptions import AirflowException @@ -44,6 +47,20 @@ def get_conn(self): return DbtCloud(dbt_cloud_account_id, dbt_cloud_api_token) + def _get_conn_extra(self): + conn = self.get_connection(self.dbt_cloud_conn_id).extra_dejson + config = {} + if 'git_branch' in conn: + config['git_branch'] = conn['git_branch'] + if 'schema_override' in conn: + config['schema_override'] = conn['schema_override'] + if 'target_name_override' in conn: + config['target_name_override'] = conn['target_name_override'] + if 'environment_id' in conn: + config['environment_id'] = conn['environment_id'] + + return config + def get_run_status(self, run_id): """ Return the status of an dbt cloud run. @@ -53,3 +70,77 @@ def get_run_status(self, run_id): run = dbt_cloud.try_get_run(run_id=run_id) status_name = RunStatus.lookup(run['status']) return status_name + + def get_run_manifest(self, run_id): + """ + Return the manifest.json from a dbt Cloud run. + """ + dbt_cloud = self.get_conn() + return dbt_cloud.get_artifact(run_id, 'manifest.json') + + def get_all_run_results(self, run_id): + """ + Return the results array from run_results.json from a dbt Cloud run, + concatenated across all (real) steps. + """ + dbt_cloud = self.get_conn() + + # first, determine the number of steps in this job + # it will either be defined in the run or in the job definition + run = dbt_cloud.get_run(run_id) + total_steps = len(run['run_steps']) + if total_steps == 0: # not defined on the run, check the job + job_id = run['job_id'] + job = dbt_cloud.get_job(job_id) + total_steps = len(job['execute_steps']) + + # the first 3 steps of a dbt Cloud job are always the same and never have any run results + # occasionally, the run_results.json file can take a few seconds to generate + current_index = 4 + final_index = current_index + total_steps + attempts = 0 + all_run_results = [] + + while attempts < 3: + try: + for step in range(current_index, final_index): + run_results = dbt_cloud.get_artifact(run_id, 'run_results.json', step=step) + all_run_results.extend(run_results['results']) + current_index += 1 + break + except RuntimeError as e: + attempts += 1 + + if attempts == 3 and len(all_run_results) == 0: + raise e + elif attempts == 3: + # sometimes the last step is not available, so we need to return what we have + warnings.warn(f'Only {len(all_run_results)} of {total_steps} steps were available in run_results.json') + return all_run_results + + time.sleep(15) + + return all_run_results + + def run_job(self, job_name, git_branch=None, schema_override=None, + target_name_override=None, steps_override=None, environment_id=None): + dbt_cloud = self.get_conn() + extra = self._get_conn_extra() + + data = {'cause': 'Kicked off via Airflow'} + # add optional settings + if git_branch or extra.get('git_branch', None): + data['git_branch'] = git_branch or extra.get('git_branch', None) + if schema_override or extra.get('schema_override', None): + data['schema_override'] = schema_override or extra.get('schema_override', None) + if target_name_override or extra.get('target_name_override', None): + data['target_name_override'] = target_name_override or extra.get('target_name_override', None) + if steps_override: + data['steps_override'] = steps_override + + # get environment + environment_id = environment_id or extra.get('environment_id', None) + + self.log.info(f'Triggering job {job_name} with data {data}') + + return dbt_cloud.run_job(job_name, data=data, environment_id=environment_id) diff --git a/dbt_cloud_plugin/operators/dbt_cloud_check_model_result_operator.py b/dbt_cloud_plugin/operators/dbt_cloud_check_model_result_operator.py new file mode 100644 index 0000000..972a9f7 --- /dev/null +++ b/dbt_cloud_plugin/operators/dbt_cloud_check_model_result_operator.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +import json +import requests +import time + +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException, AirflowSkipException +from ..hooks.dbt_cloud_hook import DbtCloudHook + + +SUCCESSFUL_STATUSES = ['success', 'pass'] + + +class DbtModelException(Exception): + pass + + +class DbtModelFailedException(DbtModelException): + pass + + +class DbtModelNotRunException(DbtModelException): + pass + + +class DbtCloudCheckModelResultOperator(BaseOperator): + """ + Check the results of a dbt Cloud job to see whether the model(s) you + care about ran successfully. Useful if you have a large dbt Cloud job, + but each of your downstream tasks only requires a small subset of models + to succeed. + + :param dbt_cloud_run_id: Run ID of a finished dbt Cloud job. Note that + this task must not start running until the dbt Cloud job has completed, + otherwise it will error out. See DbtCloudRunAndWatchJobOperator or + DbtCloudRunSensor. + :type dbt_cloud_run_id: str or int + :param model_names: A single model name or list of model names to check. + Tests and snapshot names also work. Note this must be the /name/, not the + node ID. In addition to the model name(s) supplied, all of the tests for + that model will also be checked (though unlike the model itself, they need + not have been run). + :type model_names: str or list[str] + :param dbt_cloud_conn_id: dbt Cloud connection ID + :type dbt_cloud_conn_id: str + :param ensure_models_ran: Whether to ensure all of the model_names were actually + executed in this run. Defaults to True to avoid accidentally mistyping the + model name, negating the value of this check. + :type ensure_models_ran: bool, default True + """ + + template_fields = ['dbt_cloud_run_id'] + + @apply_defaults + def __init__(self, dbt_cloud_run_id=None, model_names=None, ensure_models_ran=True, dbt_cloud_conn_id='dbt_default', *args, **kwargs): + super(DbtCloudCheckModelResultOperator, self).__init__(*args, **kwargs) + + if dbt_cloud_run_id is None: + raise AirflowException('No dbt Cloud run_id was supplied.') + if model_names is None: + raise AirflowException('No model names supplied.') + + self.dbt_cloud_conn_id = dbt_cloud_conn_id + self.dbt_cloud_run_id = dbt_cloud_run_id + if isinstance(model_names, str): + model_names = [model_names] + self.model_names = model_names + self.ensure_models_ran = ensure_models_ran + + def _find_test_dependencies_for_model_id(self, model_id, manifest): + tests = [] + for node, values in manifest['nodes'].items(): + if values['resource_type'] != 'test': + continue + for dependency in values['depends_on']['nodes']: + if dependency == model_id: + tests.append(node) + return tests + + def _find_model_id_from_name(self, model_name, manifest): + models = manifest['nodes'].values() + for model in models: + if model['name'] == model_name: + return model['unique_id'] + + def _check_that_model_passed(self, model_name, manifest, run_results): + model_id = self._find_model_id_from_name(model_name, manifest) + tests = self._find_test_dependencies_for_model_id(model_id, manifest) + all_dependencies = [model_id] + tests + self.log.info(f'Checking all dependencies for {model_name}: {all_dependencies}') + + ran_model = False + for result in run_results: + if result['unique_id'] == model_id: + ran_model = True + if result['unique_id'] in all_dependencies: + if result['status'] not in SUCCESSFUL_STATUSES: + raise DbtModelFailedException(f'Dependency {result["unique_id"]} did not pass, status: {result["status"]}!') + + if not ran_model and self.ensure_models_ran: + raise DbtModelNotRunException(f'Model {model_id} was not run!') + + def execute(self, **kwargs): + if self.dbt_cloud_run_id is None or self.dbt_cloud_run_id in ('', 'None'): + raise ValueError('dbt_cloud_run_id is empty!') + dbt_cloud_hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id) + manifest = dbt_cloud_hook.get_run_manifest(self.dbt_cloud_run_id) + run_results = dbt_cloud_hook.get_all_run_results(self.dbt_cloud_run_id) + + for model in self.model_names: + self._check_that_model_passed(model, manifest, run_results) + diff --git a/dbt_cloud_plugin/operators/dbt_cloud_run_and_watch_job_operator.py b/dbt_cloud_plugin/operators/dbt_cloud_run_and_watch_job_operator.py new file mode 100644 index 0000000..95d2b39 --- /dev/null +++ b/dbt_cloud_plugin/operators/dbt_cloud_run_and_watch_job_operator.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +import time +import requests + +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException, AirflowSkipException +from ..hooks.dbt_cloud_hook import DbtCloudHook +from ..operators.dbt_cloud_run_job_operator import DbtCloudRunJobOperator +from ..helpers import DbtCloudRunException + + +class DbtCloudRunAndWatchJobOperator(DbtCloudRunJobOperator): + """ + Operator to run a dbt cloud job. + :param dbt_cloud_conn_id: dbt Cloud connection ID. + :type dbt_cloud_conn_id: string + :param project_id: dbt Cloud project ID. + :type project_id: int + :param job_name: dbt Cloud job name. + :type job_name: string + """ + + @apply_defaults + def __init__(self, + poke_interval=60, + timeout=60 * 60 * 24, + soft_fail=False, + *args, **kwargs): + self.poke_interval = poke_interval + self.timeout = timeout + self.soft_fail = soft_fail + super(DbtCloudRunAndWatchJobOperator, self).__init__(*args, **kwargs) + + def execute(self, **kwargs): + response = super(DbtCloudRunAndWatchJobOperator, self).execute(**kwargs) + run_id = response['id'] + + self.account_id = response['job']['account_id'] + self.project_id = response['job']['project_id'] + self.environment_id = response['job']['environment_id'] + + # basically copy-pasting the Sensor code + self.log.info(f'Starting poke for job {run_id}') + try_number = 1 + started_at = time.monotonic() + + def run_duration(): + nonlocal started_at + return time.monotonic() - started_at + + while not self.poke(run_id): + if run_duration() > self.timeout: + if self.soft_fail: + raise AirflowSkipException(f'Time is out!') + else: + raise AirflowException(f'Time is out!') + else: + time.sleep(self.poke_interval) + try_number += 1 + self.log.info('Success criteria met. Exiting.') + + def poke(self, run_id): + self.log.info('Sensor checking state of dbt cloud run ID: %s', run_id) + dbt_cloud_hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id) + try: + run_status = dbt_cloud_hook.get_run_status(run_id=run_id) + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: + # Tolerate transient connectivity errors during polling; keep the job alive + self.log.warning('Transient network error while fetching run status for %s: %s', run_id, str(e)) + return False + self.log.info('State of Run ID {}: {}'.format(run_id, run_status)) + + if run_status.strip() == 'Cancelled': + raise AirflowException(f'dbt cloud Run ID {run_id} Cancelled.') + + elif run_status.strip() == 'Error': + run_results = dbt_cloud_hook.get_all_run_results(run_id=run_id) + manifest = dbt_cloud_hook.get_run_manifest(run_id=run_id) + + errors = {} + fail_states = {result['unique_id'] for result in run_results if result['status'] in ['error', 'failure', 'fail']} + + for unique_id in fail_states: + tags = manifest['nodes'][unique_id]['tags'] + resource_type = manifest['nodes'][unique_id]['resource_type'] + depends_on = manifest['nodes'][unique_id]['depends_on']['nodes'] + parent_models = {model: manifest['nodes'][model]['tags'] for model in manifest['nodes'][unique_id]['depends_on']['nodes'] if model not in manifest['sources']} + errors[unique_id] = { + 'tags': tags, + 'resource_type': resource_type, + 'depends_on': depends_on, + 'parent_models': parent_models, + } + + raise DbtCloudRunException( + dbt_cloud_run_id=run_id, + dbt_cloud_account_id=self.account_id, + dbt_cloud_project_id=self.project_id, + error_message=f'dbt cloud Run ID {run_id} Failed.', + dbt_errors_dict=errors + ) + + elif run_status.strip() == 'Success': + return True + + else: + return False diff --git a/dbt_cloud_plugin/operators/dbt_cloud_run_job_operator.py b/dbt_cloud_plugin/operators/dbt_cloud_run_job_operator.py index 17b8faa..f5f6236 100644 --- a/dbt_cloud_plugin/operators/dbt_cloud_run_job_operator.py +++ b/dbt_cloud_plugin/operators/dbt_cloud_run_job_operator.py @@ -4,9 +4,9 @@ import time from airflow.models import BaseOperator -from dbt_cloud_plugin.hooks.dbt_cloud_hook import DbtCloudHook from airflow.utils.decorators import apply_defaults from airflow.exceptions import AirflowException +from ..hooks.dbt_cloud_hook import DbtCloudHook class DbtCloudRunJobOperator(BaseOperator): """ @@ -23,6 +23,11 @@ class DbtCloudRunJobOperator(BaseOperator): def __init__(self, dbt_cloud_conn_id=None, job_name=None, + git_branch=None, + schema_override=None, + target_name_override=None, + steps_override=None, + environment_id=None, *args, **kwargs): super(DbtCloudRunJobOperator, self).__init__(*args, **kwargs) @@ -34,6 +39,11 @@ def __init__(self, self.dbt_cloud_conn_id = dbt_cloud_conn_id self.job_name = job_name + self.git_branch = git_branch + self.schema_override = schema_override + self.target_name_override = target_name_override + self.steps_override = steps_override + self.environment_id = environment_id def execute(self, **kwargs): @@ -41,11 +51,18 @@ def execute(self, **kwargs): try: dbt_cloud_hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id) - dbt_cloud = dbt_cloud_hook.get_conn() - data = {'cause':'Kicked off via Airflow'} - trigger_resp = dbt_cloud.run_job(self.job_name, data=data) + trigger_resp = dbt_cloud_hook.run_job( + self.job_name, + git_branch=self.git_branch, + schema_override=self.schema_override, + target_name_override=self.target_name_override, + steps_override=self.steps_override, + environment_id=self.environment_id + ) self.log.info('Triggered Run ID {}'.format(trigger_resp['id'])) except RuntimeError as e: raise AirflowException("Error while triggering job {}: {}".format(self.job_name, e)) - return trigger_resp['id'] + run_id = trigger_resp['id'] + self.xcom_push(kwargs['context'], 'dbt_cloud_run_id', run_id) + return trigger_resp diff --git a/dbt_cloud_plugin/sensors/dbt_cloud_job_sensor.py b/dbt_cloud_plugin/sensors/dbt_cloud_job_sensor.py index 4c5c376..6ab907d 100644 --- a/dbt_cloud_plugin/sensors/dbt_cloud_job_sensor.py +++ b/dbt_cloud_plugin/sensors/dbt_cloud_job_sensor.py @@ -1,4 +1,4 @@ -from dbt_cloud_plugin.hooks.dbt_cloud_hook import DbtCloudHook +from ..hooks.dbt_cloud_hook import DbtCloudHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults from airflow.exceptions import AirflowException @@ -38,11 +38,11 @@ def poke(self, context): self.log.info('State of Run ID {}: {}'.format(self.run_id, run_status)) TERMINAL_RUN_STATES = ['Success', 'Error', 'Cancelled'] - FAILED_RUN_STATES = ['Error'] + FAILED_RUN_STATES = ['Error', 'Cancelled'] - if run_status in FAILED_RUN_STATES: - return AirflowException('dbt cloud Run ID {} Failed.'.format(self.run_id)) - if run_status in TERMINAL_RUN_STATES: + if run_status.strip() in FAILED_RUN_STATES: + raise AirflowException('dbt cloud Run ID {} Failed.'.format(self.run_id)) + if run_status.strip() in TERMINAL_RUN_STATES: return True else: return False diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..c1901e2 --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ + +from distutils.core import setup + +setup( + name='dbt_cloud_plugin', + version='0.2', + packages=[ + 'dbt_cloud_plugin', + 'dbt_cloud_plugin.dbt_cloud', + 'dbt_cloud_plugin.hooks', + 'dbt_cloud_plugin.operators', + 'dbt_cloud_plugin.sensors', + ], + install_requires=[ + 'apache-airflow' + ] + +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_dbt_cloud_check_model_result_operator.py b/tests/test_dbt_cloud_check_model_result_operator.py new file mode 100644 index 0000000..63db4cd --- /dev/null +++ b/tests/test_dbt_cloud_check_model_result_operator.py @@ -0,0 +1,161 @@ +import unittest +from unittest.mock import MagicMock, patch +from airflow.exceptions import AirflowException +from dbt_cloud_plugin.operators.dbt_cloud_check_model_result_operator import DbtCloudCheckModelResultOperator, DbtModelFailedException, DbtModelNotRunException + +class TestDbtCloudCheckModelResultOperator(unittest.TestCase): + + def setUp(self): + self.operator = DbtCloudCheckModelResultOperator( + task_id='test_task', + dbt_cloud_conn_id='dbt_cloud_conn', + dbt_cloud_run_id='run_id', + model_names=['model1', 'model2'], + ) + + def test_init_with_invalid_run_id(self): + with self.assertRaises(AirflowException): + DbtCloudCheckModelResultOperator( + task_id='test_task', + dbt_cloud_conn_id='dbt_cloud_conn', + dbt_cloud_run_id=None, + model_names=['model1'], + ) + + def test_init_with_invalid_model_names(self): + with self.assertRaises(AirflowException): + DbtCloudCheckModelResultOperator( + task_id='test_task', + dbt_cloud_conn_id='dbt_cloud_conn', + dbt_cloud_run_id='run_id', + model_names=None, + ) + + def test_find_test_dependencies_for_model_id(self): + manifest = { + 'nodes': { + 'model_id1': {'resource_type': 'model', 'depends_on': {'nodes': []}}, + 'model_id2': {'resource_type': 'model', 'depends_on': {'nodes': []}}, + 'test_id1': {'resource_type': 'test', 'depends_on': {'nodes': ['model_id1']}}, + 'test_id2': {'resource_type': 'test', 'depends_on': {'nodes': ['model_id2']}}, + } + } + result = self.operator._find_test_dependencies_for_model_id('model_id1', manifest) + self.assertEqual(result, ['test_id1']) + + def test_find_model_id_from_name(self): + manifest = { + 'nodes': { + 'model_id1': {'name': 'model1', 'unique_id': 'model_id1'}, + 'model_id2': {'name': 'model2', 'unique_id': 'model_id2'}, + } + } + result = self.operator._find_model_id_from_name('model2', manifest) + self.assertEqual(result, 'model_id2') + + def test_check_that_model_passed(self): + manifest = { + 'nodes': { + 'model_id1': {'name': 'model1', 'unique_id': 'model_id1', 'resource_type': 'model'}, + } + } + run_results = [{'unique_id': 'model_id1', 'status': 'success'}] + self.operator._check_that_model_passed('model1', manifest, run_results) + + run_results = [{'unique_id': 'model_id1', 'status': 'pass'}] + self.operator._check_that_model_passed('model1', manifest, run_results) + + def test_check_that_model_passed_test_dependency(self): + manifest = { + 'nodes': { + 'model_id1': {'name': 'model1', 'unique_id': 'model_id1', 'resource_type': 'model'}, + 'test_id1': {'name': 'test1', 'unique_id': 'test_id1', 'resource_type': 'test', + 'depends_on': {'nodes': ['model_id1']}} + } + } + run_results = [{'unique_id': 'model_id1', 'status': 'success'}] + self.operator._check_that_model_passed('model1', manifest, run_results) + + run_results = [{'unique_id': 'model_id1', 'status': 'pass'}] + self.operator._check_that_model_passed('model1', manifest, run_results) + + run_results = [{'unique_id': 'model_id1', 'status': 'pass'}, + {'unique_id': 'test_id1', 'status': 'failed'}] + with self.assertRaises(DbtModelFailedException): + self.operator._check_that_model_passed('model1', manifest, run_results) + + def test_check_that_model_failed(self): + manifest = { + 'nodes': { + 'model_id1': {'name': 'model1', 'unique_id': 'model_id1', 'resource_type': 'model'}, + } + } + run_results = [{'unique_id': 'model_id1', 'status': 'failed'}] + with self.assertRaises(DbtModelFailedException): + self.operator._check_that_model_passed('model1', manifest, run_results) + + def test_check_that_model_did_not_run_when_ensuring_models_ran(self): + manifest = { + 'nodes': { + 'model_id1': {'name': 'model1', 'unique_id': 'model_id1', 'resource_type': 'model'}, + } + } + run_results = [] + with self.assertRaises(DbtModelNotRunException): + self.operator._check_that_model_passed('model1', manifest, run_results) + + def test_ignore_that_model_did_not_run_when_not_ensuring_models_ran(self): + manifest = { + 'nodes': { + 'model_id1': {'name': 'model1', 'unique_id': 'model_id1', 'resource_type': 'model'}, + } + } + run_results = [] + self.operator.ensure_models_ran = False + self.operator._check_that_model_passed('model1', manifest, run_results) + + @patch('dbt_cloud_plugin.operators.dbt_cloud_check_model_result_operator.DbtCloudHook') + def test_execute(self, mock_hook_class): + mock_hook = MagicMock() + mock_hook_class.return_value = mock_hook + mock_hook.get_run_manifest.return_value = {} + mock_hook.get_all_run_results.return_value = [] + self.operator._check_that_model_passed = MagicMock() + self.operator._check_that_model_passed.return_value = None + + self.operator.execute() + + mock_hook.get_run_manifest.assert_called_once_with('run_id') + mock_hook.get_all_run_results.assert_called_once_with('run_id') + self.operator._check_that_model_passed.assert_any_call('model1', {}, []) + self.operator._check_that_model_passed.assert_any_call('model2', {}, []) + + @patch('dbt_cloud_plugin.operators.dbt_cloud_check_model_result_operator.DbtCloudHook') + def test_execute_fails_with_missing_run_id(self, mock_hook_class): + mock_hook = MagicMock() + mock_hook_class.return_value = mock_hook + mock_hook.get_run_manifest.return_value = {} + mock_hook.get_all_run_results.return_value = [] + self.operator._check_that_model_passed = MagicMock() + self.operator._check_that_model_passed.return_value = None + + self.operator.dbt_cloud_run_id = '' + with self.assertRaises(ValueError): + self.operator.execute() + + @patch('dbt_cloud_plugin.operators.dbt_cloud_check_model_result_operator.DbtCloudHook') + def test_execute_fails_with_missing_run_id(self, mock_hook_class): + mock_hook = MagicMock() + mock_hook_class.return_value = mock_hook + mock_hook.get_run_manifest.return_value = {} + mock_hook.get_all_run_results.return_value = [] + self.operator._check_that_model_passed = MagicMock() + self.operator._check_that_model_passed.return_value = None + + self.operator.dbt_cloud_run_id = None + with self.assertRaises(ValueError): + self.operator.execute() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_generate_dbt_model_dependency.py b/tests/test_generate_dbt_model_dependency.py new file mode 100644 index 0000000..8f0d473 --- /dev/null +++ b/tests/test_generate_dbt_model_dependency.py @@ -0,0 +1,102 @@ +import unittest +from unittest.mock import Mock +from datetime import datetime + +from airflow import DAG +from airflow.utils.task_group import TaskGroup +from airflow.operators.dummy import DummyOperator + +from dbt_cloud_plugin.helpers import generate_dbt_model_dependency +from dbt_cloud_plugin.operators.dbt_cloud_run_and_watch_job_operator import DbtCloudRunAndWatchJobOperator +from dbt_cloud_plugin.operators.dbt_cloud_check_model_result_operator import DbtCloudCheckModelResultOperator + + +class TestGenerateDbtModelDependency(unittest.TestCase): + + def test_generate_dbt_model_dependency_with_single_task(self): + # Create a minimal DAG for testing + dag = DAG(dag_id='test_dag', start_date=datetime(2023, 1, 1), catchup=False) + with dag: + dbt_job_task = DbtCloudRunAndWatchJobOperator(task_id='dbt_job_task', dbt_cloud_conn_id='dbt_default', job_name='test job') + + task1 = DummyOperator(task_id='task1') + + generate_dbt_model_dependency(dbt_job_task, task1, ['model1', 'model2'], ensure_models_ran=True) + + self.assertEqual(len(dag.tasks), 3) + # find the generated task + for task in dag.tasks: + if isinstance(task, DbtCloudCheckModelResultOperator): + result = task + + # Verify the generated dependencies + self.assertEqual(result.task_id, 'check_dbt_model_results__dbt_job_task__model1__model2') + self.assertIsInstance(result, DbtCloudCheckModelResultOperator) + self.assertEqual(result.dbt_cloud_conn_id, dbt_job_task.dbt_cloud_conn_id) + self.assertEqual(result.ensure_models_ran, True) + self.assertEqual(result.trigger_rule, 'all_done') + self.assertEqual(result.retries, 0) + self.assertIn(dbt_job_task, result.get_flat_relatives(upstream=True)) + self.assertIn(task1, result.get_flat_relatives(upstream=False)) + + def test_generate_dbt_model_dependency_with_list(self): + # Create a minimal DAG for testing + dag = DAG(dag_id='test_dag', start_date=datetime(2023, 1, 1), catchup=False) + with dag: + dbt_job_task = DbtCloudRunAndWatchJobOperator(task_id='dbt_job_task', dbt_cloud_conn_id='dbt_default', job_name='test job') + + task1 = DummyOperator(task_id='task1') + task2 = DummyOperator(task_id='task2') + downstream_tasks = [task1, task2] + + generate_dbt_model_dependency(dbt_job_task, downstream_tasks, ['model1', 'model2'], ensure_models_ran=True) + + self.assertEqual(len(dag.tasks), 4) + # find the generated task + for task in dag.tasks: + if isinstance(task, DbtCloudCheckModelResultOperator): + result = task + + # Verify the generated dependencies + self.assertEqual(result.task_id, 'check_dbt_model_results__dbt_job_task__model1__model2') + self.assertIsInstance(result, DbtCloudCheckModelResultOperator) + self.assertEqual(result.dbt_cloud_conn_id, dbt_job_task.dbt_cloud_conn_id) + self.assertEqual(result.ensure_models_ran, True) + self.assertEqual(result.trigger_rule, 'all_done') + self.assertEqual(result.retries, 0) + self.assertIn(dbt_job_task, result.get_flat_relatives(upstream=True)) + self.assertIn(task1, result.get_flat_relatives(upstream=False)) + self.assertIn(task2, result.get_flat_relatives(upstream=False)) + + + def test_generate_dbt_model_dependency_with_task_group(self): + # Create a minimal DAG for testing + dag = DAG(dag_id='test_dag', start_date=datetime(2023, 1, 1), catchup=False) + with dag: + dbt_job_task = DbtCloudRunAndWatchJobOperator(task_id='dbt_job_task', dbt_cloud_conn_id='dbt_default', job_name='test job') + + with TaskGroup(group_id='task_group_name') as downstream_tasks: + task1 = DummyOperator(task_id='task1') + task2 = DummyOperator(task_id='task2') + + generate_dbt_model_dependency(dbt_job_task, downstream_tasks, ['model1', 'model2'], ensure_models_ran=False) + + self.assertEqual(len(dag.tasks), 4) + # find the generated task + for task in dag.tasks: + if isinstance(task, DbtCloudCheckModelResultOperator): + result = task + + # Verify the generated dependencies + self.assertEqual(result.task_id, 'check_dbt_model_results__dbt_job_task__model1__model2') + self.assertEqual(result.dbt_cloud_conn_id, dbt_job_task.dbt_cloud_conn_id) + self.assertEqual(result.ensure_models_ran, False) + self.assertEqual(result.trigger_rule, 'all_done') + self.assertEqual(result.retries, 0) + self.assertIn(dbt_job_task, result.get_flat_relatives(upstream=True)) + self.assertIn(task1, result.get_flat_relatives(upstream=False)) + self.assertIn(task2, result.get_flat_relatives(upstream=False)) + + +if __name__ == '__main__': + unittest.main()