diff --git a/airflow_dbt/hooks/__init__.py b/airflow_dbt/hooks/__init__.py index 8644e88..a5c997a 100644 --- a/airflow_dbt/hooks/__init__.py +++ b/airflow_dbt/hooks/__init__.py @@ -1 +1,2 @@ from .dbt_hook import DbtCliHook +from .dbt_google_hook import DbtCloudBuildHook diff --git a/airflow_dbt/hooks/dbt_google_hook.py b/airflow_dbt/hooks/dbt_google_hook.py new file mode 100644 index 0000000..36aaebd --- /dev/null +++ b/airflow_dbt/hooks/dbt_google_hook.py @@ -0,0 +1,165 @@ +import logging +import os +import pprint +import tarfile +from tempfile import NamedTemporaryFile +from typing import Any, Dict, List + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.cloud_build import CloudBuildHook +from airflow.providers.google.cloud.hooks.gcs import ( + GCSHook, _parse_gcs_url, +) +from airflow.providers.google.get_provider_info import get_provider_info +from packaging import version + +from .dbt_hook import DbtBaseHook + +# Check we're using the right google provider version. As composer is the +# most brad used Airflow installation we will default to the latest version +# composer is using +google_providers_version = get_provider_info().get('versions')[0] +v_min = version.parse('5.0.0') +v_max = version.parse('6.0.0') +v_provider = version.parse(google_providers_version) +if not v_min <= v_provider < v_max: + raise Exception( + f'The provider "apache-airflow-providers-google" version "' + f'{google_providers_version}" is not compatible with the current API. ' + f'Please install a compatible version in the range [{v_min}, {v_max})"' + ) + + +class DbtCloudBuildHook(DbtBaseHook): + """ + Runs the dbt command in a Cloud Build job in GCP + + :type dir: str + :param dir: Optional, if set the process considers that sources must be + uploaded prior to running the DBT job + :type env: dict + :param env: If set, passed to the dbt executor + :param dbt_bin: The `dbt` CLI. Defaults to `dbt`, so assumes it's on your + `PATH` + :type dbt_bin: str + + :param project_id: GCP Project ID as stated in the console + :type project_id: str + :param timeout: Default is set in Cloud Build itself as ten minutes. A + duration in seconds with up to nine fractional digits, terminated by + 's'. Example: "3.5s" + :type timeout: str + :param wait: Waits for the cloud build process to finish. That is waiting + for the DBT command to finish running or run asynchronously + :type wait: bool + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param gcs_staging_location: Where to store the sources to be fetch later + by the cloud build job. It should be the GCS url for a folder. For + example: `gs://my-bucket/stored. A sub-folder will be generated to + avoid collision between possible different concurrent runs. + :param gcs_staging_location: str + :param dbt_version: the DBT version to be fetched from dockerhub. Defaults + to '0.21.0' + :type dbt_version: str + """ + + def __init__( + self, + project_id: str, + dir: str = None, + gcs_staging_location: str = None, + gcp_conn_id: str = "google_cloud_default", + dbt_version: str = '0.21.0', + env: Dict = None, + dbt_bin='', + service_account=None, + ): + staging_bucket, staging_blob = _parse_gcs_url(gcs_staging_location) + # we have provided something similar to + # 'gs:///' + if not staging_blob.endswith('.tar.gz'): + raise AirflowException( + f'The provided blob "{staging_blob}" to a compressed file does not ' + + f'have the right extension ".tar.gz' + ) + self.gcs_staging_bucket = staging_bucket + self.gcs_staging_blob = staging_blob + + self.dbt_version = dbt_version + self.cloud_build_hook = CloudBuildHook(gcp_conn_id=gcp_conn_id) + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.service_account = service_account + + super().__init__(dir=dir, env=env, dbt_bin=dbt_bin) + + def get_conn(self) -> Any: + """Returns the cloud build connection, which is a gcp connection""" + return self.cloud_build_hook.get_conn() + + def upload_dbt_sources(self) -> None: + """Upload sources from local to a staging location""" + logging.info( + f'Files in "{dir}" will be uploaded to GCS with the ' + f'prefix "gs://{self.gcs_staging_bucket}/{self.gcs_staging_blob}"' + ) + gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) + with \ + NamedTemporaryFile() as compressed_file, \ + tarfile.open(compressed_file.name, "w:gz") as tar: + tar.add(self.dir, arcname=os.path.basename(self.dir)) + gcs_hook.upload( + bucket_name=self.gcs_staging_bucket, + object_name=self.gcs_staging_blob, + filename=compressed_file.name, + ) + + def run_dbt(self, dbt_cmd: List[str]): + """ + Run the dbt cli + + :param dbt_cmd: The dbt whole command to run + :type dbt_cmd: List[str] + """ + """See: https://cloud.google.com/cloud-build/docs/api/reference/rest + /v1/projects.builds""" + + # if we indicate that the sources are in a local directory by setting + # the "dir" pointing to a local path, then those sources will be + # uploaded to the expected blob + if self.dir is not None: + self.upload_dbt_sources() + + cloud_build_config = { + 'steps': [{ + 'name': f'fishtownanalytics/dbt:{self.dbt_version}', + 'args': dbt_cmd, + 'env': [f'{k}={v}' for k, v in self.env.items()] + }], + 'source': { + 'storageSource': { + "bucket": self.gcs_staging_bucket, + "object": self.gcs_staging_blob, + } + } + } + + if self.service_account is not None: + cloud_build_config['serviceAccount'] = self.service_account + + cloud_build_config_str = pprint.pformat(cloud_build_config) + logging.info(f'Running the following cloud build config:\n{cloud_build_config_str}') + + results = self.cloud_build_hook.create_build( + body=cloud_build_config, + project_id=self.project_id, + ) + logging.info( + f'Triggered build {results["id"]}. You can find the logs at ' + f'{results["logUrl"]}' + ) + + def on_kill(self): + """Stopping the build is not implemented until google providers v6""" + raise NotImplementedError diff --git a/airflow_dbt/hooks/dbt_hook.py b/airflow_dbt/hooks/dbt_hook.py index 0b4caf0..3bc1cdd 100644 --- a/airflow_dbt/hooks/dbt_hook.py +++ b/airflow_dbt/hooks/dbt_hook.py @@ -1,142 +1,172 @@ from __future__ import print_function -import os -import signal -import subprocess + import json -from airflow.exceptions import AirflowException +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Union + from airflow.hooks.base_hook import BaseHook +from airflow.hooks.subprocess import SubprocessHook -class DbtCliHook(BaseHook): +class DbtBaseHook(BaseHook, ABC): """ Simple wrapper around the dbt CLI. - :param profiles_dir: If set, passed as the `--profiles-dir` argument to the `dbt` command - :type profiles_dir: str - :param target: If set, passed as the `--target` argument to the `dbt` command :type dir: str :param dir: The directory to run the CLI in - :type vars: str - :param vars: If set, passed as the `--vars` argument to the `dbt` command - :type vars: dict - :param full_refresh: If `True`, will fully-refresh incremental models. - :type full_refresh: bool - :param models: If set, passed as the `--models` argument to the `dbt` command - :type models: str - :param warn_error: If `True`, treat warnings as errors. - :type warn_error: bool - :param exclude: If set, passed as the `--exclude` argument to the `dbt` command - :type exclude: str - :param select: If set, passed as the `--select` argument to the `dbt` command - :type select: str - :param dbt_bin: The `dbt` CLI. Defaults to `dbt`, so assumes it's on your `PATH` + :type env: dict + :param env: If set, passed to the dbt executor + :param dbt_bin: The `dbt` CLI. Defaults to `dbt`, so assumes it's on your + `PATH` :type dbt_bin: str - :param output_encoding: Output encoding of bash command. Defaults to utf-8 - :type output_encoding: str - :param verbose: The operator will log verbosely to the Airflow logs - :type verbose: bool """ - def __init__(self, - profiles_dir=None, - target=None, - dir='.', - vars=None, - full_refresh=False, - data=False, - schema=False, - models=None, - exclude=None, - select=None, - dbt_bin='dbt', - output_encoding='utf-8', - verbose=True, - warn_error=False): - self.profiles_dir = profiles_dir - self.dir = dir - self.target = target - self.vars = vars - self.full_refresh = full_refresh - self.data = data - self.schema = schema - self.models = models - self.exclude = exclude - self.select = select + def __init__(self, env: Dict = None, dbt_bin='dbt'): + super().__init__() + self.env = env if env is not None else {} self.dbt_bin = dbt_bin - self.verbose = verbose - self.warn_error = warn_error - self.output_encoding = output_encoding - - def _dump_vars(self): - # The dbt `vars` parameter is defined using YAML. Unfortunately the standard YAML library - # for Python isn't very good and I couldn't find an easy way to have it formatted - # correctly. However, as YAML is a super-set of JSON, this works just fine. - return json.dumps(self.vars) - def run_cli(self, *command): + def generate_dbt_cli_command( + self, + base_command: str, + profiles_dir: str = '.', + project_dir: str = '.', + target: str = None, + vars: Dict = None, + full_refresh: bool = False, + data: bool = False, + schema: bool = False, + models: str = None, + exclude: str = None, + select: str = None, + use_colors: bool = None, + warn_error: bool = False, + ) -> List[str]: """ - Run the dbt cli - - :param command: The dbt command to run - :type command: str + Generate the command that will be run based on class properties, + presets and dbt commands + + :param base_command: The dbt sub-command to run + :type base_command: str + :param profiles_dir: If set, passed as the `--profiles-dir` argument to + the `dbt` command + :type profiles_dir: str + :param target: If set, passed as the `--target` argument to the `dbt` + command + :type vars: Union[str, dict] + :param vars: If set, passed as the `--vars` argument to the `dbt` + command + :param full_refresh: If `True`, will fully-refresh incremental models. + :type full_refresh: bool + :param data: + :type data: bool + :param schema: + :type schema: bool + :param models: If set, passed as the `--models` argument to the `dbt` + command + :type models: str + :param warn_error: If `True`, treat warnings as errors. + :type warn_error: bool + :param exclude: If set, passed as the `--exclude` argument to the `dbt` + command + :type exclude: str + :param select: If set, passed as the `--select` argument to the `dbt` + command + :type select: str """ + # if there's no bin do not append it. Rather generate the command + # without the `/path/to/dbt` prefix. That is useful for running it + # inside containers + if self.dbt_bin == '' or self.dbt_bin is None: + dbt_cmd = [] + else: + dbt_cmd = [self.dbt_bin] + + dbt_cmd.append(base_command) - dbt_cmd = [self.dbt_bin, *command] + if profiles_dir is not None: + dbt_cmd.extend(['--profiles-dir', profiles_dir]) - if self.profiles_dir is not None: - dbt_cmd.extend(['--profiles-dir', self.profiles_dir]) + if project_dir is not None: + dbt_cmd.extend(['--project-dir', project_dir]) - if self.target is not None: - dbt_cmd.extend(['--target', self.target]) + if target is not None: + dbt_cmd.extend(['--target', target]) - if self.vars is not None: - dbt_cmd.extend(['--vars', self._dump_vars()]) + if vars is not None: + dbt_cmd.extend(['--vars', json.dumps(vars)]) - if self.data: - dbt_cmd.extend(['--data']) + if data: + dbt_cmd.append('--data') - if self.schema: - dbt_cmd.extend(['--schema']) + if schema: + dbt_cmd.append('--schema') - if self.models is not None: - dbt_cmd.extend(['--models', self.models]) + if models is not None: + dbt_cmd.extend(['--models', models]) - if self.exclude is not None: - dbt_cmd.extend(['--exclude', self.exclude]) + if exclude is not None: + dbt_cmd.extend(['--exclude', exclude]) - if self.select is not None: - dbt_cmd.extend(['--select', self.select]) + if select is not None: + dbt_cmd.extend(['--select', select]) - if self.full_refresh: - dbt_cmd.extend(['--full-refresh']) + if full_refresh: + dbt_cmd.append('--full-refresh') - if self.warn_error: + if warn_error: dbt_cmd.insert(1, '--warn-error') - if self.verbose: - self.log.info(" ".join(dbt_cmd)) - - sp = subprocess.Popen( - dbt_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - cwd=self.dir, - close_fds=True) - self.sp = sp - self.log.info("Output:") - line = '' - for line in iter(sp.stdout.readline, b''): - line = line.decode(self.output_encoding).rstrip() - self.log.info(line) - sp.wait() - self.log.info( - "Command exited with return code %s", - sp.returncode - ) + if use_colors is not None: + colors_flag = "--use-colors" if use_colors else "--no-use-colors" + dbt_cmd.append(colors_flag) + + return dbt_cmd + + @abstractmethod + def run_dbt(self, dbt_cmd: Union[str, List[str]]): + """Run the dbt command""" + + +class DbtCliHook(DbtBaseHook): + """ + Run the dbt command in the same airflow worker the task is being run. + This requires the `dbt` python package to be installed in it first. Also + the dbt_bin path might not be set in the `PATH` variable, so it could be + necessary to set it in the constructor. - if sp.returncode: - raise AirflowException("dbt command failed") + :type dir: str + :param dir: The directory to run the CLI in + :type env: dict + :param env: If set, passed to the dbt executor + :param dbt_bin: The `dbt` CLI. Defaults to `dbt`, so assumes it's on your + `PATH` + :type dbt_bin: str + """ + + def __init__(self, dir: str = '.', env: Dict = None, dbt_bin='dbt'): + self.sp = SubprocessHook() + super().__init__(dir=dir, env=env, dbt_bin=dbt_bin) + + def get_conn(self) -> Any: + """ + Return the subprocess connection, which isn't implemented, just for + conformity + """ + return self.sp.get_conn() + + def run_dbt(self, dbt_cmd: Union[str, List[str]]): + """ + Run the dbt cli + + :param dbt_cmd: The dbt whole command to run + :type dbt_cmd: List[str] + """ + self.sp.run_command( + command=dbt_cmd, + env=self.env, + ) def on_kill(self): - self.log.info('Sending SIGTERM signal to dbt command') - os.killpg(os.getpgid(self.sp.pid), signal.SIGTERM) + """Kill the open subprocess if the task gets killed by Airflow""" + self.sp.send_sigterm() diff --git a/airflow_dbt/operators/dbt_operator.py b/airflow_dbt/operators/dbt_operator.py index 6233d8d..65a5efb 100644 --- a/airflow_dbt/operators/dbt_operator.py +++ b/airflow_dbt/operators/dbt_operator.py @@ -1,7 +1,11 @@ -from airflow_dbt.hooks.dbt_hook import DbtCliHook +import logging +from typing import Any, Dict + from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults +from airflow_dbt.hooks.dbt_hook import DbtCliHook + class DbtBaseOperator(BaseOperator): """ @@ -13,6 +17,8 @@ class DbtBaseOperator(BaseOperator): :param target: If set, passed as the `--target` argument to the `dbt` command :type dir: str :param dir: The directory to run the CLI in + :param env: If set, passed to the dbt executor + :type env: dict :type vars: str :param vars: If set, passed as the `--vars` argument to the `dbt` command :type vars: dict @@ -30,17 +36,31 @@ class DbtBaseOperator(BaseOperator): :type dbt_bin: str :param verbose: The operator will log verbosely to the Airflow logs :type verbose: bool + :param dbt_hook: The dbt hook to use as executor. For now the + implemented ones are: DbtCliHook, DbtCloudBuildHook. It should be an + instance of one of those, or another that inherits from DbtBaseHook. If + not provided by default a DbtCliHook will be instantiated with the + provided params + :type dbt_hook: DbtBaseHook + :param base_command: The dbt sub command to run, for example for `dbt + run` the base_command will be `run`. If any other flag not + contemplated must be included it can also be added to this string + :type base_command: str """ ui_color = '#d6522a' - template_fields = ['vars'] + template_fields = ['profiles_dir', 'project_dir', 'target', 'env', + 'vars', 'models', 'exclude', 'select', 'dbt_bin', 'verbose', + 'warn_error', 'full_refresh', 'data', 'schema', 'base_command'] @apply_defaults def __init__(self, profiles_dir=None, + project_dir = None, + dir: str = '.', target=None, - dir='.', + env: Dict = None, vars=None, models=None, exclude=None, @@ -51,13 +71,16 @@ def __init__(self, full_refresh=False, data=False, schema=False, + dbt_hook=None, + base_command=None, *args, **kwargs): super(DbtBaseOperator, self).__init__(*args, **kwargs) self.profiles_dir = profiles_dir + self.project_dir = project_dir if project_dir is not None else dir self.target = target - self.dir = dir + self.env = {} if env is None else env self.vars = vars self.models = models self.full_refresh = full_refresh @@ -68,13 +91,20 @@ def __init__(self, self.dbt_bin = dbt_bin self.verbose = verbose self.warn_error = warn_error - self.create_hook() - - def create_hook(self): - self.hook = DbtCliHook( + self.base_command = base_command + self.hook = dbt_hook if dbt_hook is not None else DbtCliHook( + dir=dir, + env=self.env, + dbt_bin=dbt_bin + ) + + def execute(self, context: Any): + """Runs the provided command in the provided execution environment""" + dbt_cli_command = self.hook.generate_dbt_cli_command( + base_command=self.base_command, profiles_dir=self.profiles_dir, + project_dir=self.project_dir, target=self.target, - dir=self.dir, vars=self.vars, full_refresh=self.full_refresh, data=self.data, @@ -82,63 +112,85 @@ def create_hook(self): models=self.models, exclude=self.exclude, select=self.select, - dbt_bin=self.dbt_bin, - verbose=self.verbose, - warn_error=self.warn_error) - - return self.hook + warn_error=self.warn_error, + ) + logging.info(f'Running dbt command "{dbt_cli_command}"') + self.hook.run_dbt(dbt_cli_command) class DbtRunOperator(DbtBaseOperator): + """Runs a dbt run command""" @apply_defaults def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtRunOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('run') + super().__init__( + profiles_dir=profiles_dir, + target=target, + base_command='run', + *args, + **kwargs + ) class DbtTestOperator(DbtBaseOperator): + """Runs a dbt test command""" @apply_defaults def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtTestOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('test') + super().__init__( + profiles_dir=profiles_dir, + target=target, + base_command='test', + *args, + **kwargs + ) class DbtDocsGenerateOperator(DbtBaseOperator): + """Runs a dbt docs generate command""" @apply_defaults def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtDocsGenerateOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, - **kwargs) - - def execute(self, context): - self.create_hook().run_cli('docs', 'generate') + super().__init__( + profiles_dir=profiles_dir, + target=target, + base_command='docs generate', + *args, + **kwargs + ) class DbtSnapshotOperator(DbtBaseOperator): + """Runs a dbt snapshot command""" @apply_defaults def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtSnapshotOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('snapshot') + super().__init__( + profiles_dir=profiles_dir, + target=target, + base_command='snapshot', + *args, + **kwargs + ) class DbtSeedOperator(DbtBaseOperator): + """Runs a dbt seed command""" @apply_defaults def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtSeedOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('seed') + super().__init__( + profiles_dir=profiles_dir, + target=target, + base_command='seed', + *args, + **kwargs + ) class DbtDepsOperator(DbtBaseOperator): + """Runs a dbt deps command""" @apply_defaults def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtDepsOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('deps') + super().__init__( + profiles_dir=profiles_dir, + target=target, + base_command='deps', + *args, + **kwargs + ) diff --git a/setup.py b/setup.py index 7d3b2f8..325e51b 100644 --- a/setup.py +++ b/setup.py @@ -78,4 +78,7 @@ def run(self): cmdclass={ 'upload': UploadCommand, }, + extras_require={ + 'google': 'apache-airflow-providers-google==5.0.0' + }, ) diff --git a/tests/hooks/test_dbt_hook.py b/tests/hooks/test_dbt_hook.py index 4dd39ed..e1990b8 100644 --- a/tests/hooks/test_dbt_hook.py +++ b/tests/hooks/test_dbt_hook.py @@ -1,54 +1,27 @@ -from unittest import TestCase -from unittest import mock -import subprocess -from airflow_dbt.hooks.dbt_hook import DbtCliHook +from unittest import TestCase, mock +from airflow.hooks.subprocess import SubprocessHook -class TestDbtHook(TestCase): +from airflow_dbt.hooks.dbt_hook import DbtCliHook - @mock.patch('subprocess.Popen') - def test_sub_commands(self, mock_subproc_popen): - mock_subproc_popen.return_value \ - .communicate.return_value = ('output', 'error') - mock_subproc_popen.return_value.returncode = 0 - mock_subproc_popen.return_value \ - .stdout.readline.side_effect = [b"placeholder"] +class TestDbtHook(TestCase): + @mock.patch.object(SubprocessHook, 'run_command') + def test_sub_commands(self, mock_run_command): hook = DbtCliHook() - hook.run_cli('docs', 'generate') - - mock_subproc_popen.assert_called_once_with( - [ - 'dbt', - 'docs', - 'generate' - ], - close_fds=True, + hook.run_dbt(['dbt', 'docs', 'generate']) + mock_run_command.assert_called_once_with( + command=['dbt', 'docs', 'generate'], + env={}, cwd='.', - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT - ) - - @mock.patch('subprocess.Popen') - def test_vars(self, mock_subproc_popen): - mock_subproc_popen.return_value \ - .communicate.return_value = ('output', 'error') - mock_subproc_popen.return_value.returncode = 0 - mock_subproc_popen.return_value \ - .stdout.readline.side_effect = [b"placeholder"] + ) - hook = DbtCliHook(vars={"foo": "bar", "baz": "true"}) - hook.run_cli('run') + def test_vars(self): + hook = DbtCliHook() + generated_command = hook.generate_dbt_cli_command( + 'run', + vars={"foo": "bar", "baz": "true"} + ) - mock_subproc_popen.assert_called_once_with( - [ - 'dbt', - 'run', - '--vars', - '{"foo": "bar", "baz": "true"}' - ], - close_fds=True, - cwd='.', - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT - ) + assert generated_command == ['dbt', 'run', '--vars', + '{"foo": "bar", "baz": "true"}'] diff --git a/tests/operators/test_dbt_operator.py b/tests/operators/test_dbt_operator.py index 8ce2c5f..27544c8 100644 --- a/tests/operators/test_dbt_operator.py +++ b/tests/operators/test_dbt_operator.py @@ -1,17 +1,18 @@ import datetime from unittest import TestCase, mock +from unittest.mock import patch + from airflow import DAG, configuration + +from airflow_dbt.hooks.dbt_google_hook import DbtCloudBuildHook from airflow_dbt.hooks.dbt_hook import DbtCliHook from airflow_dbt.operators.dbt_operator import ( - DbtSeedOperator, - DbtSnapshotOperator, - DbtRunOperator, + DbtDepsOperator, DbtRunOperator, DbtSeedOperator, DbtSnapshotOperator, DbtTestOperator, - DbtDepsOperator ) -class TestDbtOperator(TestCase): +class TestDbtCliOperator(TestCase): def setUp(self): configuration.conf.load_test_config() args = { @@ -20,47 +21,83 @@ def setUp(self): } self.dag = DAG('test_dag_id', default_args=args) - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_run(self, mock_run_cli): + @mock.patch.object(DbtCliHook, 'run_dbt') + def test_dbt_run(self, mock_run_dbt): operator = DbtRunOperator( task_id='run', dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('run') + mock_run_dbt.assert_called_once_with(['dbt', 'run']) - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_test(self, mock_run_cli): + @mock.patch.object(DbtCliHook, 'run_dbt') + def test_dbt_test(self, mock_run_dbt): operator = DbtTestOperator( task_id='test', dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('test') + mock_run_dbt.assert_called_once_with(['dbt', 'test']) - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_snapshot(self, mock_run_cli): + @mock.patch.object(DbtCliHook, 'run_dbt') + def test_dbt_snapshot(self, mock_run_dbt): operator = DbtSnapshotOperator( task_id='snapshot', dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('snapshot') + mock_run_dbt.assert_called_once_with(['dbt', 'snapshot']) - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_seed(self, mock_run_cli): + @mock.patch.object(DbtCliHook, 'run_dbt') + def test_dbt_seed(self, mock_run_dbt): operator = DbtSeedOperator( task_id='seed', dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('seed') + mock_run_dbt.assert_called_once_with(['dbt', 'seed']) - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_deps(self, mock_run_cli): + @mock.patch.object(DbtCliHook, 'run_dbt') + def test_dbt_deps(self, mock_run_dbt): operator = DbtDepsOperator( task_id='deps', dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('deps') + mock_run_dbt.assert_called_once_with(['dbt', 'deps']) + + +class TestDbtRunWithCloudBuild(TestCase): + def setUp(self): + configuration.conf.load_test_config() + args = { + 'owner': 'airflow', + 'start_date': datetime.datetime(2020, 2, 27) + } + self.dag = DAG('test_dag_id', default_args=args) + + @patch('airflow_dbt.hooks.dbt_google_hook.NamedTemporaryFile') + @patch('airflow_dbt.hooks.dbt_google_hook.CloudBuildHook') + @patch('airflow_dbt.hooks.dbt_google_hook.GCSHook') + def test_upload_files(self, MockGCSHook, MockCBHook, MockTempFile): + # Change the context provider returned name for the file + MockTempFile.return_value.__enter__.return_value.name = 'tempfile' + operator = DbtRunOperator( + task_id='test_dbt_run_on_cloud_build', + dbt_hook=DbtCloudBuildHook( + project_id='my-project-id', + gcp_conn_id='my_conn_id', + dir='.', + gcs_staging_location='gs://my-bucket/certain-folder' + '/stored_dbt_files.tar.gz' + ), + dag=self.dag + ) + operator.execute(None) + MockCBHook.assert_called_once_with(gcp_conn_id='my_conn_id') + MockGCSHook.assert_called_once_with(gcp_conn_id='my_conn_id') + MockGCSHook().upload.assert_called_once_with( + bucket_name='my-bucket', + object_name='certain-folder/stored_dbt_files.tar.gz', + filename='tempfile' + )