From 99b7843e92191e6e8d824a527dcf5c8d5bff1761 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 5 Jan 2024 15:10:49 -0500 Subject: [PATCH 1/2] add unit tests --- tests/unit/__init__.py | 0 tests/unit/test_agate_helper.py | 227 ++++++++++++++ tests/unit/test_connection_retries.py | 59 ++++ tests/unit/test_core_dbt_utils.py | 73 +++++ tests/unit/test_event_handler.py | 40 +++ tests/unit/test_helper_types.py | 59 ++++ tests/unit/test_jinja.py | 425 ++++++++++++++++++++++++++ tests/unit/test_model_config.py | 92 ++++++ tests/unit/test_system_client.py | 271 ++++++++++++++++ tests/unit/test_utils.py | 143 +++++++++ 10 files changed, 1389 insertions(+) create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_agate_helper.py create mode 100644 tests/unit/test_connection_retries.py create mode 100644 tests/unit/test_core_dbt_utils.py create mode 100644 tests/unit/test_event_handler.py create mode 100644 tests/unit/test_helper_types.py create mode 100644 tests/unit/test_jinja.py create mode 100644 tests/unit/test_model_config.py create mode 100644 tests/unit/test_system_client.py create mode 100644 tests/unit/test_utils.py diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_agate_helper.py b/tests/unit/test_agate_helper.py new file mode 100644 index 00000000..9e63aa47 --- /dev/null +++ b/tests/unit/test_agate_helper.py @@ -0,0 +1,227 @@ +import unittest + +import agate + +from datetime import datetime +from decimal import Decimal +from isodate import tzinfo +import os +from shutil import rmtree +from tempfile import mkdtemp +from dbt.common.clients import agate_helper + +SAMPLE_CSV_DATA = """a,b,c,d,e,f,g +1,n,test,3.2,20180806T11:33:29.320Z,True,NULL +2,y,asdf,900,20180806T11:35:29.320Z,False,a string""" + +SAMPLE_CSV_BOM_DATA = "\ufeff" + SAMPLE_CSV_DATA + + +EXPECTED = [ + [ + 1, + "n", + "test", + Decimal("3.2"), + datetime(2018, 8, 6, 11, 33, 29, 320000, tzinfo=tzinfo.Utc()), + True, + None, + ], + [ + 2, + "y", + "asdf", + 900, + datetime(2018, 8, 6, 11, 35, 29, 320000, tzinfo=tzinfo.Utc()), + False, + "a string", + ], +] + + +EXPECTED_STRINGS = [ + ["1", "n", "test", "3.2", "20180806T11:33:29.320Z", "True", None], + ["2", "y", "asdf", "900", "20180806T11:35:29.320Z", "False", "a string"], +] + + +class TestAgateHelper(unittest.TestCase): + def setUp(self): + self.tempdir = mkdtemp() + + def tearDown(self): + rmtree(self.tempdir) + + def test_from_csv(self): + path = os.path.join(self.tempdir, "input.csv") + with open(path, "wb") as fp: + fp.write(SAMPLE_CSV_DATA.encode("utf-8")) + tbl = agate_helper.from_csv(path, ()) + self.assertEqual(len(tbl), len(EXPECTED)) + for idx, row in enumerate(tbl): + self.assertEqual(list(row), EXPECTED[idx]) + + def test_bom_from_csv(self): + path = os.path.join(self.tempdir, "input.csv") + with open(path, "wb") as fp: + fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8")) + tbl = agate_helper.from_csv(path, ()) + self.assertEqual(len(tbl), len(EXPECTED)) + for idx, row in enumerate(tbl): + self.assertEqual(list(row), EXPECTED[idx]) + + def test_from_csv_all_reserved(self): + path = os.path.join(self.tempdir, "input.csv") + with open(path, "wb") as fp: + fp.write(SAMPLE_CSV_DATA.encode("utf-8")) + tbl = agate_helper.from_csv(path, tuple("abcdefg")) + self.assertEqual(len(tbl), len(EXPECTED_STRINGS)) + for expected, row in zip(EXPECTED_STRINGS, tbl): + self.assertEqual(list(row), expected) + + def test_from_data(self): + column_names = ["a", "b", "c", "d", "e", "f", "g"] + data = [ + { + "a": "1", + "b": "n", + "c": "test", + "d": "3.2", + "e": "20180806T11:33:29.320Z", + "f": "True", + "g": "NULL", + }, + { + "a": "2", + "b": "y", + "c": "asdf", + "d": "900", + "e": "20180806T11:35:29.320Z", + "f": "False", + "g": "a string", + }, + ] + tbl = agate_helper.table_from_data(data, column_names) + self.assertEqual(len(tbl), len(EXPECTED)) + for idx, row in enumerate(tbl): + self.assertEqual(list(row), EXPECTED[idx]) + + def test_datetime_formats(self): + path = os.path.join(self.tempdir, "input.csv") + datetimes = [ + "20180806T11:33:29.000Z", + "20180806T11:33:29Z", + "20180806T113329Z", + ] + expected = datetime(2018, 8, 6, 11, 33, 29, 0, tzinfo=tzinfo.Utc()) + for dt in datetimes: + with open(path, "wb") as fp: + fp.write("a\n{}".format(dt).encode("utf-8")) + tbl = agate_helper.from_csv(path, ()) + self.assertEqual(tbl[0][0], expected) + + def test_merge_allnull(self): + t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c")) + t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c")) + result = agate_helper.merge_tables([t1, t2]) + self.assertEqual(result.column_names, ("a", "b", "c")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate_helper.Integer) + self.assertEqual(len(result), 4) + + def test_merge_mixed(self): + t1 = agate_helper.table_from_rows( + [(1, "a", None, None), (2, "b", None, None)], ("a", "b", "c", "d") + ) + t2 = agate_helper.table_from_rows( + [(3, "c", "dog", 1), (4, "d", "cat", 5)], ("a", "b", "c", "d") + ) + t3 = agate_helper.table_from_rows( + [(3, "c", None, 1.5), (4, "d", None, 3.5)], ("a", "b", "c", "d") + ) + + result = agate_helper.merge_tables([t1, t2]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + assert isinstance(result.column_types[3], agate_helper.Integer) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t1, t3]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate_helper.Integer) + assert isinstance(result.column_types[3], agate.data_types.Number) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t2, t3]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + assert isinstance(result.column_types[3], agate.data_types.Number) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t3, t2]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + assert isinstance(result.column_types[3], agate.data_types.Number) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t1, t2, t3]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + assert isinstance(result.column_types[3], agate.data_types.Number) + self.assertEqual(len(result), 6) + + def test_nocast_string_types(self): + # String fields should not be coerced into a representative type + # See: https://github.com/dbt-labs/dbt-core/issues/2984 + + column_names = ["a", "b", "c", "d", "e"] + result_set = [ + {"a": "0005", "b": "01T00000aabbccdd", "c": "true", "d": 10, "e": False}, + {"a": "0006", "b": "01T00000aabbccde", "c": "false", "d": 11, "e": True}, + ] + + tbl = agate_helper.table_from_data_flat(data=result_set, column_names=column_names) + self.assertEqual(len(tbl), len(result_set)) + + expected = [ + ["0005", "01T00000aabbccdd", "true", Decimal(10), False], + ["0006", "01T00000aabbccde", "false", Decimal(11), True], + ] + + for i, row in enumerate(tbl): + self.assertEqual(list(row), expected[i]) + + def test_nocast_bool_01(self): + # True and False values should not be cast to 1 and 0, and vice versa + # See: https://github.com/dbt-labs/dbt-core/issues/4511 + + column_names = ["a", "b"] + result_set = [ + {"a": True, "b": 1}, + {"a": False, "b": 0}, + ] + + tbl = agate_helper.table_from_data_flat(data=result_set, column_names=column_names) + self.assertEqual(len(tbl), len(result_set)) + + assert isinstance(tbl.column_types[0], agate.data_types.Boolean) + assert isinstance(tbl.column_types[1], agate_helper.Integer) + + expected = [ + [True, Decimal(1)], + [False, Decimal(0)], + ] + + for i, row in enumerate(tbl): + self.assertEqual(list(row), expected[i]) diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py new file mode 100644 index 00000000..c135696b --- /dev/null +++ b/tests/unit/test_connection_retries.py @@ -0,0 +1,59 @@ +import functools +import pytest +from requests.exceptions import RequestException +from dbt.common.exceptions import ConnectionError +from dbt.common.utils.connection import connection_exception_retry + + +def no_retry_fn(): + return "success" + + +class TestNoRetries: + def test_no_retry(self): + fn_to_retry = functools.partial(no_retry_fn) + result = connection_exception_retry(fn_to_retry, 3) + + expected = "success" + + assert result == expected + + +def no_success_fn(): + raise RequestException("You'll never pass") + return "failure" + + +class TestMaxRetries: + def test_no_retry(self): + fn_to_retry = functools.partial(no_success_fn) + + with pytest.raises(ConnectionError): + connection_exception_retry(fn_to_retry, 3) + + +def single_retry_fn(): + global counter + if counter == 0: + counter += 1 + raise RequestException("You won't pass this one time") + elif counter == 1: + counter += 1 + return "success on 2" + + return "How did we get here?" + + +class TestSingleRetry: + def test_no_retry(self): + global counter + counter = 0 + + fn_to_retry = functools.partial(single_retry_fn) + result = connection_exception_retry(fn_to_retry, 3) + expected = "success on 2" + + # We need to test the return value here, not just that it did not throw an error. + # If the value is not being passed it causes cryptic errors + assert result == expected + assert counter == 2 diff --git a/tests/unit/test_core_dbt_utils.py b/tests/unit/test_core_dbt_utils.py new file mode 100644 index 00000000..b3957f90 --- /dev/null +++ b/tests/unit/test_core_dbt_utils.py @@ -0,0 +1,73 @@ +import requests +import tarfile +import unittest + +from dbt.common.exceptions import ConnectionError +from dbt.common.utils.connection import connection_exception_retry + + +class TestCoreDbtUtils(unittest.TestCase): + def test_connection_exception_retry_none(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add(), 5) + self.assertEqual(1, counter) + + def test_connection_exception_retry_success_requests_exception(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_requests_exception(), 5) + self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry + + def test_connection_exception_retry_max(self): + Counter._reset() + with self.assertRaises(ConnectionError): + connection_exception_retry(lambda: Counter._add_with_exception(), 5) + self.assertEqual(6, counter) # 6 = original attempt plus 5 retries + + def test_connection_exception_retry_success_failed_untar(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_untar_exception(), 5) + self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry + + def test_connection_exception_retry_success_failed_eofexception(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_eof_exception(), 5) + self.assertEqual(2, counter) # 2 = original attempt returned EOFError, plus 1 retry + + +counter: int = 0 + + +class Counter: + def _add(): + global counter + counter += 1 + + # All exceptions that Requests explicitly raises inherit from + # requests.exceptions.RequestException so we want to make sure that raises plus one exception + # that inherit from it for sanity + def _add_with_requests_exception(): + global counter + counter += 1 + if counter < 2: + raise requests.exceptions.RequestException + + def _add_with_exception(): + global counter + counter += 1 + raise requests.exceptions.ConnectionError + + def _add_with_untar_exception(): + global counter + counter += 1 + if counter < 2: + raise tarfile.ReadError + + def _add_with_eof_exception(): + global counter + counter += 1 + if counter < 2: + raise EOFError + + def _reset(): + global counter + counter = 0 diff --git a/tests/unit/test_event_handler.py b/tests/unit/test_event_handler.py new file mode 100644 index 00000000..7067e3bf --- /dev/null +++ b/tests/unit/test_event_handler.py @@ -0,0 +1,40 @@ +import logging + +from dbt.common.events.base_types import EventLevel +from dbt.common.events.event_handler import DbtEventLoggingHandler, set_package_logging +from dbt.common.events.event_manager import TestEventManager + + +def test_event_logging_handler_emits_records_correctly(): + event_manager = TestEventManager() + handler = DbtEventLoggingHandler(event_manager=event_manager, level=logging.DEBUG) + log = logging.getLogger("test") + log.setLevel(logging.DEBUG) + log.addHandler(handler) + + log.debug("test") + log.info("test") + log.warning("test") + log.error("test") + log.exception("test") + log.critical("test") + assert len(event_manager.event_history) == 6 + assert event_manager.event_history[0][1] == EventLevel.DEBUG + assert event_manager.event_history[1][1] == EventLevel.INFO + assert event_manager.event_history[2][1] == EventLevel.WARN + assert event_manager.event_history[3][1] == EventLevel.ERROR + assert event_manager.event_history[4][1] == EventLevel.ERROR + assert event_manager.event_history[5][1] == EventLevel.ERROR + + +def test_set_package_logging_sets_level_correctly(): + event_manager = TestEventManager() + log = logging.getLogger("test") + set_package_logging("test", logging.DEBUG, event_manager) + log.debug("debug") + assert len(event_manager.event_history) == 1 + set_package_logging("test", logging.WARN, event_manager) + log.debug("debug 2") + assert len(event_manager.event_history) == 1 + log.warning("warning") + assert len(event_manager.event_history) == 3 # warning logs create two events diff --git a/tests/unit/test_helper_types.py b/tests/unit/test_helper_types.py new file mode 100644 index 00000000..59b05d90 --- /dev/null +++ b/tests/unit/test_helper_types.py @@ -0,0 +1,59 @@ +import pytest + +from dbt.common.helper_types import IncludeExclude, WarnErrorOptions +from dbt.common.dataclass_schema import ValidationError + + +class TestIncludeExclude: + def test_init_invalid(self): + with pytest.raises(ValidationError): + IncludeExclude(include="invalid") + + with pytest.raises(ValidationError): + IncludeExclude(include=["ItemA"], exclude=["ItemB"]) + + @pytest.mark.parametrize( + "include,exclude,expected_includes", + [ + ("all", [], True), + ("*", [], True), + ("*", ["ItemA"], False), + (["ItemA"], [], True), + (["ItemA", "ItemB"], [], True), + ], + ) + def test_includes(self, include, exclude, expected_includes): + include_exclude = IncludeExclude(include=include, exclude=exclude) + + assert include_exclude.includes("ItemA") == expected_includes + + +class TestWarnErrorOptions: + def test_init_invalid_error(self): + with pytest.raises(ValidationError): + WarnErrorOptions(include=["InvalidError"], valid_error_names=set(["ValidError"])) + + with pytest.raises(ValidationError): + WarnErrorOptions( + include="*", exclude=["InvalidError"], valid_error_names=set(["ValidError"]) + ) + + def test_init_invalid_error_default_valid_error_names(self): + with pytest.raises(ValidationError): + WarnErrorOptions(include=["InvalidError"]) + + with pytest.raises(ValidationError): + WarnErrorOptions(include="*", exclude=["InvalidError"]) + + def test_init_valid_error(self): + warn_error_options = WarnErrorOptions( + include=["ValidError"], valid_error_names=set(["ValidError"]) + ) + assert warn_error_options.include == ["ValidError"] + assert warn_error_options.exclude == [] + + warn_error_options = WarnErrorOptions( + include="*", exclude=["ValidError"], valid_error_names=set(["ValidError"]) + ) + assert warn_error_options.include == "*" + assert warn_error_options.exclude == ["ValidError"] diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py new file mode 100644 index 00000000..65520743 --- /dev/null +++ b/tests/unit/test_jinja.py @@ -0,0 +1,425 @@ +import unittest + +from dbt.common.clients.jinja import extract_toplevel_blocks +from dbt.common.exceptions import CompilationError + + +class TestBlockLexer(unittest.TestCase): + def test_basic(self): + body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' + block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" + blocks = extract_toplevel_blocks( + block_data, allowed_blocks={"mytype"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, body) + self.assertEqual(blocks[0].full_block, block_data) + + def test_multiple(self): + body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' + body_two = ( + "{{ config(bar=1)}}\r\nselect * from {% if foo %} thing " + "{% else %} other_thing {% endif %}" + ) + + block_data = ( + " {% mytype foo %}" + + body_one + + "{% endmytype %}" + + "\r\n{% othertype bar %}" + + body_two + + "{% endothertype %}" + ) + blocks = extract_toplevel_blocks( + block_data, allowed_blocks={"mytype", "othertype"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 2) + + def test_comments(self): + body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' + comment = "{# my comment #}" + block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" + blocks = extract_toplevel_blocks( + comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, body) + self.assertEqual(blocks[0].full_block, block_data) + + def test_evil_comments(self): + body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' + comment = "{# external comment {% othertype bar %} select * from thing.other_thing{% endothertype %} #}" + block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" + blocks = extract_toplevel_blocks( + comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, body) + self.assertEqual(blocks[0].full_block, block_data) + + def test_nested_comments(self): + body = '{# my comment #} {{ config(foo="bar") }}\r\nselect * from {# my other comment embedding {% endmytype %} #} this.that\r\n' + block_data = " \n\r\t{%- mytype foo %}" + body + "{% endmytype -%}" + comment = "{# external comment {% othertype bar %} select * from thing.other_thing{% endothertype %} #}" + blocks = extract_toplevel_blocks( + comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, body) + self.assertEqual(blocks[0].full_block, block_data) + + def test_complex_file(self): + blocks = extract_toplevel_blocks( + complex_snapshot_file, allowed_blocks={"mytype", "myothertype"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 3) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].full_block, "{% mytype foo %} some stuff {% endmytype %}") + self.assertEqual(blocks[0].contents, " some stuff ") + self.assertEqual(blocks[1].block_type_name, "mytype") + self.assertEqual(blocks[1].block_name, "bar") + self.assertEqual(blocks[1].full_block, bar_block) + self.assertEqual(blocks[1].contents, bar_block[16:-15].rstrip()) + self.assertEqual(blocks[2].block_type_name, "myothertype") + self.assertEqual(blocks[2].block_name, "x") + self.assertEqual(blocks[2].full_block, x_block.strip()) + self.assertEqual( + blocks[2].contents, + x_block[len("\n{% myothertype x %}") : -len("{% endmyothertype %}\n")], + ) + + def test_peaceful_macro_coexistence(self): + body = "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} {# my model #} {% a b %} test {% enda %}" + blocks = extract_toplevel_blocks( + body, allowed_blocks={"macro", "a"}, collect_raw_data=True + ) + self.assertEqual(len(blocks), 4) + self.assertEqual(blocks[0].full_block, "{# my macro #} ") + self.assertEqual(blocks[1].block_type_name, "macro") + self.assertEqual(blocks[1].block_name, "foo") + self.assertEqual(blocks[1].contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") + self.assertEqual(blocks[3].block_type_name, "a") + self.assertEqual(blocks[3].block_name, "b") + self.assertEqual(blocks[3].contents, " test ") + + def test_macro_with_trailing_data(self): + body = "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} {# my model #} {% a b %} test {% enda %} raw data so cool" + blocks = extract_toplevel_blocks( + body, allowed_blocks={"macro", "a"}, collect_raw_data=True + ) + self.assertEqual(len(blocks), 5) + self.assertEqual(blocks[0].full_block, "{# my macro #} ") + self.assertEqual(blocks[1].block_type_name, "macro") + self.assertEqual(blocks[1].block_name, "foo") + self.assertEqual(blocks[1].contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") + self.assertEqual(blocks[3].block_type_name, "a") + self.assertEqual(blocks[3].block_name, "b") + self.assertEqual(blocks[3].contents, " test ") + self.assertEqual(blocks[4].full_block, " raw data so cool") + + def test_macro_with_crazy_args(self): + body = """{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}cool{# block comment with {% endmacro %} in it #} stuff here {% endmacro %}""" + blocks = extract_toplevel_blocks(body, allowed_blocks={"macro"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "macro") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual( + blocks[0].contents, "cool{# block comment with {% endmacro %} in it #} stuff here " + ) + + def test_materialization_parse(self): + body = "{% materialization xxx, default %} ... {% endmaterialization %}" + blocks = extract_toplevel_blocks( + body, allowed_blocks={"materialization"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "materialization") + self.assertEqual(blocks[0].block_name, "xxx") + self.assertEqual(blocks[0].full_block, body) + + body = '{% materialization xxx, adapter="other" %} ... {% endmaterialization %}' + blocks = extract_toplevel_blocks( + body, allowed_blocks={"materialization"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "materialization") + self.assertEqual(blocks[0].block_name, "xxx") + self.assertEqual(blocks[0].full_block, body) + + def test_nested_not_ok(self): + # we don't allow nesting same blocks + body = "{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}" + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body, allowed_blocks={"myblock"}) + + def test_incomplete_block_failure(self): + fullbody = "{% myblock foo %} {% endmyblock %}" + for length in range(len("{% myblock foo %}"), len(fullbody) - 1): + body = fullbody[:length] + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body, allowed_blocks={"myblock"}) + + def test_wrong_end_failure(self): + body = "{% myblock foo %} {% endotherblock %}" + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) + + def test_comment_no_end_failure(self): + body = "{# " + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body) + + def test_comment_only(self): + body = "{# myblock #}" + blocks = extract_toplevel_blocks(body) + self.assertEqual(len(blocks), 1) + blocks = extract_toplevel_blocks(body, collect_raw_data=False) + self.assertEqual(len(blocks), 0) + + def test_comment_block_self_closing(self): + # test the case where a comment start looks a lot like it closes itself + # (but it doesn't in jinja!) + body = "{#} {% myblock foo %} {#}" + blocks = extract_toplevel_blocks(body, collect_raw_data=False) + self.assertEqual(len(blocks), 0) + + def test_embedded_self_closing_comment_block(self): + body = "{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, body) + self.assertEqual(blocks[0].contents, " {#}{% endmyblock %} {#}") + + def test_set_statement(self): + body = "{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_set_block(self): + body = "{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_crazy_set_statement(self): + body = '{% set x = (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}{% set y = otherthing("{% myblock foo %}") %}' + blocks = extract_toplevel_blocks( + body, allowed_blocks={"otherblock"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") + self.assertEqual(blocks[0].block_type_name, "otherblock") + + def test_do_statement(self): + body = "{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_deceptive_do_statement(self): + body = "{% do thing %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_do_block(self): + body = "{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks( + body, allowed_blocks={"do", "myblock"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 2) + self.assertEqual(blocks[0].contents, "thing.update()") + self.assertEqual(blocks[0].block_type_name, "do") + self.assertEqual(blocks[1].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_crazy_do_statement(self): + body = '{% do (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}{% do otherthing("{% myblock foo %}") %}{% myblock x %}hi{% endmyblock %}' + blocks = extract_toplevel_blocks( + body, allowed_blocks={"myblock", "otherblock"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 2) + self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") + self.assertEqual(blocks[0].block_type_name, "otherblock") + self.assertEqual(blocks[1].full_block, "{% myblock x %}hi{% endmyblock %}") + self.assertEqual(blocks[1].block_type_name, "myblock") + + def test_awful_jinja(self): + blocks = extract_toplevel_blocks( + if_you_do_this_you_are_awful, + allowed_blocks={"snapshot", "materialization"}, + collect_raw_data=False, + ) + self.assertEqual(len(blocks), 2) + self.assertEqual(len([b for b in blocks if b.block_type_name == "__dbt__data"]), 0) + self.assertEqual(blocks[0].block_type_name, "snapshot") + self.assertEqual( + blocks[0].contents, + "\n ".join( + [ + """{% set x = ("{% endsnapshot %}" + (40 * '%})')) %}""", + "{# {% endsnapshot %} #}", + "{% embedded %}", + " some block data right here", + "{% endembedded %}", + ] + ), + ) + self.assertEqual(blocks[1].block_type_name, "materialization") + self.assertEqual(blocks[1].contents, "\nhi\n") + + def test_quoted_endblock_within_block(self): + body = '{% myblock something -%} {% set x = ("{% endmyblock %}") %} {% endmyblock %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "myblock") + self.assertEqual(blocks[0].contents, '{% set x = ("{% endmyblock %}") %} ') + + def test_docs_block(self): + body = '{% docs __my_doc__ %} asdf {# nope {% enddocs %}} #} {% enddocs %} {% docs __my_other_doc__ %} asdf "{% enddocs %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) + self.assertEqual(len(blocks), 2) + self.assertEqual(blocks[0].block_type_name, "docs") + self.assertEqual(blocks[0].contents, " asdf {# nope {% enddocs %}} #} ") + self.assertEqual(blocks[0].block_name, "__my_doc__") + self.assertEqual(blocks[1].block_type_name, "docs") + self.assertEqual(blocks[1].contents, ' asdf "') + self.assertEqual(blocks[1].block_name, "__my_other_doc__") + + def test_docs_block_expr(self): + body = '{% docs more_doc %} asdf {{ "{% enddocs %}" ~ "}}" }}{% enddocs %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "docs") + self.assertEqual(blocks[0].contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') + self.assertEqual(blocks[0].block_name, "more_doc") + + def test_unclosed_model_quotes(self): + # test case for https://github.com/dbt-labs/dbt-core/issues/1533 + body = '{% model my_model -%} select * from "something"."something_else{% endmodel %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"model"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "model") + self.assertEqual(blocks[0].contents, 'select * from "something"."something_else') + self.assertEqual(blocks[0].block_name, "my_model") + + def test_if(self): + # if you conditionally define your macros/models, don't + body = "{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}" + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body) + + def test_if_innocuous(self): + body = "{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}" + blocks = extract_toplevel_blocks(body) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, body) + + def test_for(self): + # no for-loops over macros. + body = "{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}" + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body) + + def test_for_innocuous(self): + # no for-loops over macros. + body = "{% for x in range(10) %}{% something my_something %} adsf {% endsomething %}{% endfor %}" + blocks = extract_toplevel_blocks(body) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, body) + + def test_endif(self): + body = "{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}" + with self.assertRaises(CompilationError) as err: + extract_toplevel_blocks(body) + self.assertIn( + "Got an unexpected control flow end tag, got endif but never saw a preceeding if (@ 1:53)", + str(err.exception), + ) + + def test_if_endfor(self): + body = "{% if x %}...{% endfor %}{% endif %}" + with self.assertRaises(CompilationError) as err: + extract_toplevel_blocks(body) + self.assertIn( + "Got an unexpected control flow end tag, got endfor but expected endif next (@ 1:13)", + str(err.exception), + ) + + def test_if_endfor_newlines(self): + body = "{% if x %}\n ...\n {% endfor %}\n{% endif %}" + with self.assertRaises(CompilationError) as err: + extract_toplevel_blocks(body) + self.assertIn( + "Got an unexpected control flow end tag, got endfor but expected endif next (@ 3:4)", + str(err.exception), + ) + + +bar_block = """{% mytype bar %} +{# a comment + that inside it has + {% mytype baz %} +{% endmyothertype %} +{% endmytype %} +{% endmytype %} + {# +{% endmytype %}#} + +some other stuff + +{%- endmytype%}""" + +x_block = """ +{% myothertype x %} +before +{##} +and after +{% endmyothertype %} +""" + +complex_snapshot_file = ( + """ +{#some stuff {% mytype foo %} #} +{% mytype foo %} some stuff {% endmytype %} + +""" + + bar_block + + x_block +) + + +if_you_do_this_you_are_awful = """ +{#} here is a comment with a block inside {% block x %} asdf {% endblock %} {#} +{% do + set('foo="bar"') +%} +{% set x = ("100" + "hello'" + '%}') %} +{% snapshot something -%} + {% set x = ("{% endsnapshot %}" + (40 * '%})')) %} + {# {% endsnapshot %} #} + {% embedded %} + some block data right here + {% endembedded %} +{%- endsnapshot %} + +{% raw %} + {% set x = SYNTAX ERROR} +{% endraw %} + + +{% materialization whatever, adapter='thing' %} +hi +{% endmaterialization %} +""" diff --git a/tests/unit/test_model_config.py b/tests/unit/test_model_config.py new file mode 100644 index 00000000..2c746408 --- /dev/null +++ b/tests/unit/test_model_config.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass, field +from dbt.common.dataclass_schema import dbtClassMixin +from typing import List, Dict +from dbt.common.contracts.config.metadata import ShowBehavior +from dbt.common.contracts.config.base import MergeBehavior, CompareBehavior + + +@dataclass +class ThingWithMergeBehavior(dbtClassMixin): + default_behavior: int + appended: List[str] = field(metadata={"merge": MergeBehavior.Append}) + updated: Dict[str, int] = field(metadata={"merge": MergeBehavior.Update}) + clobbered: str = field(metadata={"merge": MergeBehavior.Clobber}) + keysappended: Dict[str, int] = field(metadata={"merge": MergeBehavior.DictKeyAppend}) + + +def test_merge_behavior_meta(): + existing = {"foo": "bar"} + initial_existing = existing.copy() + assert set(MergeBehavior) == { + MergeBehavior.Append, + MergeBehavior.Update, + MergeBehavior.Clobber, + MergeBehavior.DictKeyAppend, + } + for behavior in MergeBehavior: + assert behavior.meta() == {"merge": behavior} + assert behavior.meta(existing) == {"merge": behavior, "foo": "bar"} + assert existing == initial_existing + + +def test_merge_behavior_from_field(): + fields = [f[0] for f in ThingWithMergeBehavior._get_fields()] + fields = {name: f for f, name in ThingWithMergeBehavior._get_fields()} + assert set(fields) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} + assert MergeBehavior.from_field(fields["default_behavior"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields["appended"]) == MergeBehavior.Append + assert MergeBehavior.from_field(fields["updated"]) == MergeBehavior.Update + assert MergeBehavior.from_field(fields["clobbered"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields["keysappended"]) == MergeBehavior.DictKeyAppend + + +@dataclass +class ThingWithShowBehavior(dbtClassMixin): + default_behavior: int + hidden: str = field(metadata={"show_hide": ShowBehavior.Hide}) + shown: float = field(metadata={"show_hide": ShowBehavior.Show}) + + +def test_show_behavior_meta(): + existing = {"foo": "bar"} + initial_existing = existing.copy() + assert set(ShowBehavior) == {ShowBehavior.Hide, ShowBehavior.Show} + for behavior in ShowBehavior: + assert behavior.meta() == {"show_hide": behavior} + assert behavior.meta(existing) == {"show_hide": behavior, "foo": "bar"} + assert existing == initial_existing + + +def test_show_behavior_from_field(): + fields = [f[0] for f in ThingWithShowBehavior._get_fields()] + fields = {name: f for f, name in ThingWithShowBehavior._get_fields()} + assert set(fields) == {"default_behavior", "hidden", "shown"} + assert ShowBehavior.from_field(fields["default_behavior"]) == ShowBehavior.Show + assert ShowBehavior.from_field(fields["hidden"]) == ShowBehavior.Hide + assert ShowBehavior.from_field(fields["shown"]) == ShowBehavior.Show + + +@dataclass +class ThingWithCompareBehavior(dbtClassMixin): + default_behavior: int + included: float = field(metadata={"compare": CompareBehavior.Include}) + excluded: str = field(metadata={"compare": CompareBehavior.Exclude}) + + +def test_compare_behavior_meta(): + existing = {"foo": "bar"} + initial_existing = existing.copy() + assert set(CompareBehavior) == {CompareBehavior.Include, CompareBehavior.Exclude} + for behavior in CompareBehavior: + assert behavior.meta() == {"compare": behavior} + assert behavior.meta(existing) == {"compare": behavior, "foo": "bar"} + assert existing == initial_existing + + +def test_compare_behavior_from_field(): + fields = [f[0] for f in ThingWithCompareBehavior._get_fields()] + fields = {name: f for f, name in ThingWithCompareBehavior._get_fields()} + assert set(fields) == {"default_behavior", "included", "excluded"} + assert CompareBehavior.from_field(fields["default_behavior"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields["included"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields["excluded"]) == CompareBehavior.Exclude diff --git a/tests/unit/test_system_client.py b/tests/unit/test_system_client.py new file mode 100644 index 00000000..c54b98d4 --- /dev/null +++ b/tests/unit/test_system_client.py @@ -0,0 +1,271 @@ +import os +import shutil +import stat +import unittest +import tarfile +import pathspec +from pathlib import Path +from tempfile import mkdtemp, NamedTemporaryFile + +from dbt.common.exceptions import ExecutableError, WorkingDirectoryError +import dbt.common.clients.system + + +class SystemClient(unittest.TestCase): + def setUp(self): + super().setUp() + self.tmp_dir = mkdtemp() + self.profiles_path = "{}/profiles.yml".format(self.tmp_dir) + + def set_up_profile(self): + with open(self.profiles_path, "w") as f: + f.write("ORIGINAL_TEXT") + + def get_profile_text(self): + with open(self.profiles_path, "r") as f: + return f.read() + + def tearDown(self): + try: + shutil.rmtree(self.tmp_dir) + except Exception as e: # noqa: [F841] + pass + + def test__make_file_when_exists(self): + self.set_up_profile() + written = dbt.common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") + + self.assertFalse(written) + self.assertEqual(self.get_profile_text(), "ORIGINAL_TEXT") + + def test__make_file_when_not_exists(self): + written = dbt.common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") + + self.assertTrue(written) + self.assertEqual(self.get_profile_text(), "NEW_TEXT") + + def test__make_file_with_overwrite(self): + self.set_up_profile() + written = dbt.common.clients.system.make_file( + self.profiles_path, contents="NEW_TEXT", overwrite=True + ) + + self.assertTrue(written) + self.assertEqual(self.get_profile_text(), "NEW_TEXT") + + def test__make_dir_from_str(self): + test_dir_str = self.tmp_dir + "/test_make_from_str/sub_dir" + dbt.common.clients.system.make_directory(test_dir_str) + self.assertTrue(Path(test_dir_str).is_dir()) + + def test__make_dir_from_pathobj(self): + test_dir_pathobj = Path(self.tmp_dir + "/test_make_from_pathobj/sub_dir") + dbt.common.clients.system.make_directory(test_dir_pathobj) + self.assertTrue(test_dir_pathobj.is_dir()) + + +class TestRunCmd(unittest.TestCase): + """Test `run_cmd`. + + Don't mock out subprocess, in order to expose any OS-level differences. + """ + + not_a_file = "zzzbbfasdfasdfsdaq" + + def setUp(self): + self.tempdir = mkdtemp() + self.run_dir = os.path.join(self.tempdir, "run_dir") + self.does_not_exist = os.path.join(self.tempdir, "does_not_exist") + self.empty_file = os.path.join(self.tempdir, "empty_file") + if os.name == "nt": + self.exists_cmd = ["cmd", "/C", "echo", "hello"] + else: + self.exists_cmd = ["echo", "hello"] + + os.mkdir(self.run_dir) + with open(self.empty_file, "w") as fp: # noqa: [F841] + pass # "touch" + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test__executable_does_not_exist(self): + with self.assertRaises(ExecutableError) as exc: + dbt.common.clients.system.run_cmd(self.run_dir, [self.does_not_exist]) + + msg = str(exc.exception).lower() + + self.assertIn("path", msg) + self.assertIn("could not find", msg) + self.assertIn(self.does_not_exist.lower(), msg) + + def test__not_exe(self): + with self.assertRaises(ExecutableError) as exc: + dbt.common.clients.system.run_cmd(self.run_dir, [self.empty_file]) + + msg = str(exc.exception).lower() + if os.name == "nt": + # on windows, this means it's not an executable at all! + self.assertIn("not executable", msg) + else: + # on linux, this means you don't have executable permissions on it + self.assertIn("permissions", msg) + self.assertIn(self.empty_file.lower(), msg) + + def test__cwd_does_not_exist(self): + with self.assertRaises(WorkingDirectoryError) as exc: + dbt.common.clients.system.run_cmd(self.does_not_exist, self.exists_cmd) + msg = str(exc.exception).lower() + self.assertIn("does not exist", msg) + self.assertIn(self.does_not_exist.lower(), msg) + + def test__cwd_not_directory(self): + with self.assertRaises(WorkingDirectoryError) as exc: + dbt.common.clients.system.run_cmd(self.empty_file, self.exists_cmd) + + msg = str(exc.exception).lower() + self.assertIn("not a directory", msg) + self.assertIn(self.empty_file.lower(), msg) + + def test__cwd_no_permissions(self): + # it would be nice to add a windows test. Possible path to that is via + # `psexec` (to get SYSTEM privs), use `icacls` to set permissions on + # the directory for the test user. I'm pretty sure windows users can't + # create files that they themselves cannot access. + if os.name == "nt": + return + + # read-only -> cannot cd to it + os.chmod(self.run_dir, stat.S_IRUSR) + + with self.assertRaises(WorkingDirectoryError) as exc: + dbt.common.clients.system.run_cmd(self.run_dir, self.exists_cmd) + + msg = str(exc.exception).lower() + self.assertIn("permissions", msg) + self.assertIn(self.run_dir.lower(), msg) + + def test__ok(self): + out, err = dbt.common.clients.system.run_cmd(self.run_dir, self.exists_cmd) + self.assertEqual(out.strip(), b"hello") + self.assertEqual(err.strip(), b"") + + +class TestFindMatching(unittest.TestCase): + def setUp(self): + self.base_dir = mkdtemp() + self.tempdir = mkdtemp(dir=self.base_dir) + + def test_find_matching_lowercase_file_pattern(self): + with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir) as named_file: + file_path = os.path.dirname(named_file.name) + relative_path = os.path.basename(file_path) + out = dbt.common.clients.system.find_matching( + self.base_dir, + [relative_path], + "*.sql", + ) + expected_output = [ + { + "searched_path": relative_path, + "absolute_path": named_file.name, + "relative_path": os.path.basename(named_file.name), + "modification_time": out[0]["modification_time"], + } + ] + self.assertEqual(out, expected_output) + + def test_find_matching_uppercase_file_pattern(self): + with NamedTemporaryFile(prefix="sql-files", suffix=".SQL", dir=self.tempdir) as named_file: + file_path = os.path.dirname(named_file.name) + relative_path = os.path.basename(file_path) + out = dbt.common.clients.system.find_matching(self.base_dir, [relative_path], "*.sql") + expected_output = [ + { + "searched_path": relative_path, + "absolute_path": named_file.name, + "relative_path": os.path.basename(named_file.name), + "modification_time": out[0]["modification_time"], + } + ] + self.assertEqual(out, expected_output) + + def test_find_matching_file_pattern_not_found(self): + with NamedTemporaryFile(prefix="sql-files", suffix=".SQLT", dir=self.tempdir): + out = dbt.common.clients.system.find_matching(self.tempdir, [""], "*.sql") + self.assertEqual(out, []) + + def test_ignore_spec(self): + with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir): + out = dbt.common.clients.system.find_matching( + self.tempdir, + [""], + "*.sql", + pathspec.PathSpec.from_lines( + pathspec.patterns.GitWildMatchPattern, "sql-files*".splitlines() + ), + ) + self.assertEqual(out, []) + + def tearDown(self): + try: + shutil.rmtree(self.base_dir) + except Exception as e: # noqa: [F841] + pass + + +class TestUntarPackage(unittest.TestCase): + def setUp(self): + self.base_dir = mkdtemp() + self.tempdir = mkdtemp(dir=self.base_dir) + self.tempdest = mkdtemp(dir=self.base_dir) + + def tearDown(self): + try: + shutil.rmtree(self.base_dir) + except Exception as e: # noqa: [F841] + pass + + def test_untar_package_success(self): + # set up a valid tarball to test against + with NamedTemporaryFile( + prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False + ) as named_tar_file: + tar_file_full_path = named_tar_file.name + with NamedTemporaryFile(prefix="a", suffix=".txt", dir=self.tempdir) as file_a: + file_a.write(b"some text in the text file") + relative_file_a = os.path.basename(file_a.name) + with tarfile.open(fileobj=named_tar_file, mode="w:gz") as tar: + tar.addfile(tarfile.TarInfo(relative_file_a), open(file_a.name)) + + # now we test can test that we can untar the file successfully + assert tarfile.is_tarfile(tar.name) + dbt.common.clients.system.untar_package(tar_file_full_path, self.tempdest) + path = Path(os.path.join(self.tempdest, relative_file_a)) + assert path.is_file() + + def test_untar_package_failure(self): + # create a text file then rename it as a tar (so it's invalid) + with NamedTemporaryFile( + prefix="a", suffix=".txt", dir=self.tempdir, delete=False + ) as file_a: + file_a.write(b"some text in the text file") + txt_file_name = file_a.name + file_path = os.path.dirname(txt_file_name) + tar_file_path = os.path.join(file_path, "mypackage.2.tar.gz") + os.rename(txt_file_name, tar_file_path) + + # now that we're set up, test that untarring the file fails + with self.assertRaises(tarfile.ReadError) as exc: # noqa: [F841] + dbt.common.clients.system.untar_package(tar_file_path, self.tempdest) + + def test_untar_package_empty(self): + # create a tarball with nothing in it + with NamedTemporaryFile( + prefix="my-empty-package.2", suffix=".tar.gz", dir=self.tempdir + ) as named_file: + + # make sure we throw an error for the empty file + with self.assertRaises(tarfile.ReadError) as exc: + dbt.common.clients.system.untar_package(named_file.name, self.tempdest) + self.assertEqual("empty file", str(exc.exception)) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 00000000..a2108843 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,143 @@ +import unittest + +import dbt.exceptions +import dbt.common.utils + + +class TestDeepMerge(unittest.TestCase): + def test__simple_cases(self): + cases = [ + {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, + { + "args": [{}, {"b": 1}, {"a": 1}], + "expected": {"a": 1, "b": 1}, + "description": "three merges", + }, + ] + + for case in cases: + actual = dbt.common.utils.deep_merge(*case["args"]) + self.assertEqual( + case["expected"], + actual, + "failed on {} (actual {}, expected {})".format( + case["description"], actual, case["expected"] + ), + ) + + +class TestMerge(unittest.TestCase): + def test__simple_cases(self): + cases = [ + {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, + { + "args": [{}, {"b": 1}, {"a": 1}], + "expected": {"a": 1, "b": 1}, + "description": "three merges", + }, + ] + + for case in cases: + actual = dbt.common.utils.deep_merge(*case["args"]) + self.assertEqual( + case["expected"], + actual, + "failed on {} (actual {}, expected {})".format( + case["description"], actual, case["expected"] + ), + ) + + +class TestDeepMap(unittest.TestCase): + def setUp(self): + self.input_value = { + "foo": { + "bar": "hello", + "baz": [1, 90.5, "990", "89.9"], + }, + "nested": [ + { + "test": "90", + "other_test": None, + }, + { + "test": 400, + "other_test": 4.7e9, + }, + ], + } + + @staticmethod + def intify_all(value, _): + try: + return int(value) + except (TypeError, ValueError): + return -1 + + def test__simple_cases(self): + expected = { + "foo": { + "bar": -1, + "baz": [1, 90, 990, -1], + }, + "nested": [ + { + "test": 90, + "other_test": -1, + }, + { + "test": 400, + "other_test": 4700000000, + }, + ], + } + actual = dbt.common.utils.deep_map_render(self.intify_all, self.input_value) + self.assertEqual(actual, expected) + + actual = dbt.common.utils.deep_map_render(self.intify_all, expected) + self.assertEqual(actual, expected) + + @staticmethod + def special_keypath(value, keypath): + + if tuple(keypath) == ("foo", "baz", 1): + return "hello" + else: + return value + + def test__keypath(self): + expected = { + "foo": { + "bar": "hello", + # the only change from input is the second entry here + "baz": [1, "hello", "990", "89.9"], + }, + "nested": [ + { + "test": "90", + "other_test": None, + }, + { + "test": 400, + "other_test": 4.7e9, + }, + ], + } + actual = dbt.common.utils.deep_map_render(self.special_keypath, self.input_value) + self.assertEqual(actual, expected) + + actual = dbt.common.utils.deep_map_render(self.special_keypath, expected) + self.assertEqual(actual, expected) + + def test__noop(self): + actual = dbt.common.utils.deep_map_render(lambda x, _: x, self.input_value) + self.assertEqual(actual, self.input_value) + + def test_trivial(self): + cases = [[], {}, 1, "abc", None, True] + for case in cases: + result = dbt.common.utils.deep_map_render(lambda x, _: x, case) + self.assertEqual(result, case) + + with self.assertRaises(dbt.exceptions.DbtConfigError): + dbt.common.utils.deep_map_render(lambda x, _: x, {"foo": object()}) From 0443b52e27024e2cb8de9b39509d95a4a28d567e Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Fri, 5 Jan 2024 15:01:23 -0600 Subject: [PATCH 2/2] fix requirements, update MAKEFILE to work with hatch --- Makefile | 27 ++++++++++++++++++--------- dbt/common/ui.py | 12 ++++++++++++ pyproject.toml | 5 ++++- tests/unit/test_utils.py | 4 ++-- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 6ca5597e..298b696d 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,24 @@ .DEFAULT_GOAL:=help -.PHONY: dev_req -dev_req: ## Installs dbt-* packages in develop mode along with only development dependencies. - @\ - pip install -r dev-requirements.txt -.PHONY: dev -dev: dev_req ## Installs dbt-* packages in develop mode along with development dependencies and pre-commit. - @\ - pre-commit install +.PHONY: run install-hatch overwrite-pre-commit install test lint json_schema + +run: + export FORMAT_JSON_LOGS="1" + +install-hatch: + pip3 install hatch + +# This edits your local pre-commit hook file to use Hatch when executing. +overwrite-pre-commit: + hatch run dev-env:pre-commit install + hatch run dev-env:sed -i -e "s/exec /exec hatch run dev-env:/g" .git/hooks/pre-commit + +test: + export FORMAT_JSON_LOGS="1" && hatch -v run dev-env:pytest -n auto tests + +lint: + hatch run dev-env:pre-commit run --show-diff-on-failure --color=always --all-files .PHONY: proto_types proto_types: ## generates google protobuf python file from types.proto @@ -20,4 +30,3 @@ help: ## Show this help message. @echo @echo 'targets:' @grep -E '^[8+a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' - diff --git a/dbt/common/ui.py b/dbt/common/ui.py index d7665797..80eeb6c6 100644 --- a/dbt/common/ui.py +++ b/dbt/common/ui.py @@ -1,8 +1,20 @@ +import sys import textwrap from typing import Dict import colorama +# Colorama is needed for colored logs on Windows because we're using logger.info +# intead of print(). If the Windows env doesn't have a TERM var set or it is set to None +# (i.e. in the case of Git Bash on Windows- this emulates Unix), then it's safe to initialize +# Colorama with wrapping turned on which allows us to strip ANSI sequences from stdout. +# You can safely initialize Colorama for any OS and the coloring stays the same except +# when piped to another process for Linux and MacOS, then it loses the coloring. To combat +# that, we will just initialize Colorama when needed on Windows using a non-Unix terminal. + +if sys.platform == "win32" and (not os.getenv("TERM") or os.getenv("TERM") == "None"): + colorama.init(wrap=True) + COLORS: Dict[str, str] = { "red": colorama.Fore.RED, "green": colorama.Fore.GREEN, diff --git a/pyproject.toml b/pyproject.toml index 818bc1e9..f641b275 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,13 @@ classifiers = [ ] dependencies = [ "agate~=1.7.0", + "colorama>=0.3.9,<0.5", # TODO: major version 0 - should we use it? "jsonschema~=4.0", "Jinja2~=3.0", "mashumaro[msgpack]~=3.9", + "protobuf>=4.0.0", "python-dateutil~=2.0", + "requests<3.0.0", "typing-extensions~=4.4", ] @@ -106,4 +109,4 @@ disallow_untyped_defs = false profile = "black" [tool.black] -line-length = 120 \ No newline at end of file +line-length = 120 diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index a2108843..7b153b33 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,6 +1,6 @@ import unittest -import dbt.exceptions +import dbt.common.exceptions import dbt.common.utils @@ -139,5 +139,5 @@ def test_trivial(self): result = dbt.common.utils.deep_map_render(lambda x, _: x, case) self.assertEqual(result, case) - with self.assertRaises(dbt.exceptions.DbtConfigError): + with self.assertRaises(dbt.common.exceptions.DbtConfigError): dbt.common.utils.deep_map_render(lambda x, _: x, {"foo": object()})