diff --git a/.changes/unreleased/Under the Hood-20231108-163613.yaml b/.changes/unreleased/Under the Hood-20231108-163613.yaml new file mode 100644 index 00000000000..091c09bfe32 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231108-163613.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Consolidate deferral methods & flags +time: 2023-11-08T16:36:13.234324-05:00 +custom: + Author: jtcohen6 + Issue: 7965 8715 diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index ffc73323df8..706370609ca 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -113,6 +113,7 @@ def _get_params_by_source(ctx: Context, source_type: ParameterSource): def _assign_params( ctx: Context, params_assigned_from_default: set, + params_assigned_from_user: set, deprecated_env_vars: Dict[str, Callable], ): """Recursively adds all click params to flag object""" @@ -178,15 +179,30 @@ def _assign_params( object.__setattr__(self, flag_name, param_value) # Track default assigned params. - if is_default: + # For flags that are accepted at both 'parent' and 'child' levels, + # we need to track user-provided and default values across both, + # to support detection of mutually exclusive flags later on. + if not is_default: + params_assigned_from_user.add(param_name) + if param_name in params_assigned_from_default: + params_assigned_from_default.remove(param_name) + if is_default and param_name not in params_assigned_from_user: params_assigned_from_default.add(param_name) if ctx.parent: - _assign_params(ctx.parent, params_assigned_from_default, deprecated_env_vars) + _assign_params( + ctx.parent, + params_assigned_from_default, + params_assigned_from_user, + deprecated_env_vars, + ) + params_assigned_from_user = set() # type: Set[str] params_assigned_from_default = set() # type: Set[str] deprecated_env_vars: Dict[str, Callable] = {} - _assign_params(ctx, params_assigned_from_default, deprecated_env_vars) + _assign_params( + ctx, params_assigned_from_default, params_assigned_from_user, deprecated_env_vars + ) # Set deprecated_env_var_warnings to be fired later after events have been init. object.__setattr__( @@ -203,7 +219,10 @@ def _assign_params( invoked_subcommand.ignore_unknown_options = True invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv) _assign_params( - invoked_subcommand_ctx, params_assigned_from_default, deprecated_env_vars + invoked_subcommand_ctx, + params_assigned_from_default, + params_assigned_from_user, + deprecated_env_vars, ) if not project_flags: @@ -378,7 +397,11 @@ def add_fn(x): if k == "macro" and command == CliCommand.RUN_OPERATION: add_fn(v) # None is a Singleton, False is a Flyweight, only one instance of each. - elif v is None or v is False: + elif (v is None or v is False) and k not in ( + # These are None by default but they do not support --no-{flag} + "defer_state", + "log_format", + ): add_fn(f"--no-{spinal_cased}") elif v is True: add_fn(f"--{spinal_cased}") diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 3525de95f96..51bb3d6d9bd 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -121,11 +121,19 @@ def invoke(self, args: List[str], **kwargs) -> dbtRunnerResult: def global_flags(func): @p.cache_selected_only @p.debug + @p.defer + @p.deprecated_defer + @p.defer_state + @p.deprecated_favor_state @p.deprecated_print + @p.deprecated_state @p.enable_legacy_logger @p.fail_fast + @p.favor_state + @p.indirect_selection @p.log_cache_events @p.log_file_max_bytes + @p.log_format @p.log_format_file @p.log_level @p.log_level_file @@ -141,12 +149,15 @@ def global_flags(func): @p.record_timing_info @p.send_anonymous_usage_stats @p.single_threaded + @p.state @p.static_parser @p.use_colors @p.use_colors_file @p.use_experimental_parser @p.version @p.version_check + @p.warn_error + @p.warn_error_options @p.write_json @functools.wraps(func) def wrapper(*args, **kwargs): @@ -164,9 +175,6 @@ def wrapper(*args, **kwargs): ) @click.pass_context @global_flags -@p.warn_error -@p.warn_error_options -@p.log_format @p.show_resource_report def cli(ctx, **kwargs): """An ELT tool for managing your SQL transformations and data models. @@ -178,14 +186,9 @@ def cli(ctx, **kwargs): @cli.command("build") @click.pass_context @global_flags -@p.defer -@p.deprecated_defer @p.exclude -@p.favor_state -@p.deprecated_favor_state @p.full_refresh @p.include_saved_query -@p.indirect_selection @p.profile @p.profiles_dir @p.project_dir @@ -193,9 +196,6 @@ def cli(ctx, **kwargs): @p.select @p.selector @p.show -@p.state -@p.defer_state -@p.deprecated_state @p.store_failures @p.target @p.target_path @@ -257,11 +257,7 @@ def docs(ctx, **kwargs): @click.pass_context @global_flags @p.compile_docs -@p.defer -@p.deprecated_defer @p.exclude -@p.favor_state -@p.deprecated_favor_state @p.profile @p.profiles_dir @p.project_dir @@ -269,9 +265,6 @@ def docs(ctx, **kwargs): @p.selector @p.empty_catalog @p.static -@p.state -@p.defer_state -@p.deprecated_state @p.target @p.target_path @p.threads @@ -328,14 +321,9 @@ def docs_serve(ctx, **kwargs): @cli.command("compile") @click.pass_context @global_flags -@p.defer -@p.deprecated_defer @p.exclude -@p.favor_state -@p.deprecated_favor_state @p.full_refresh @p.show_output_format -@p.indirect_selection @p.introspect @p.profile @p.profiles_dir @@ -344,9 +332,6 @@ def docs_serve(ctx, **kwargs): @p.select @p.selector @p.inline -@p.state -@p.defer_state -@p.deprecated_state @p.compile_inject_ephemeral_ctes @p.target @p.target_path @@ -376,15 +361,10 @@ def compile(ctx, **kwargs): @cli.command("show") @click.pass_context @global_flags -@p.defer -@p.deprecated_defer @p.exclude -@p.favor_state -@p.deprecated_favor_state @p.full_refresh @p.show_output_format @p.show_limit -@p.indirect_selection @p.introspect @p.profile @p.profiles_dir @@ -392,9 +372,6 @@ def compile(ctx, **kwargs): @p.select @p.selector @p.inline -@p.state -@p.defer_state -@p.deprecated_state @p.target @p.target_path @p.threads @@ -514,7 +491,6 @@ def init(ctx, **kwargs): @click.pass_context @global_flags @p.exclude -@p.indirect_selection @p.models @p.output @p.output_keys @@ -524,9 +500,6 @@ def init(ctx, **kwargs): @p.resource_type @p.raw_select @p.selector -@p.state -@p.defer_state -@p.deprecated_state @p.target @p.target_path @p.vars @@ -582,10 +555,6 @@ def parse(ctx, **kwargs): @cli.command("run") @click.pass_context @global_flags -@p.defer -@p.deprecated_defer -@p.favor_state -@p.deprecated_favor_state @p.exclude @p.full_refresh @p.profile @@ -594,9 +563,6 @@ def parse(ctx, **kwargs): @p.empty @p.select @p.selector -@p.state -@p.defer_state -@p.deprecated_state @p.target @p.target_path @p.threads @@ -629,7 +595,6 @@ def run(ctx, **kwargs): @p.vars @p.profile @p.target -@p.state @p.threads @p.full_refresh @requires.postflight @@ -654,7 +619,6 @@ def retry(ctx, **kwargs): @cli.command("clone") @click.pass_context @global_flags -@p.defer_state @p.exclude @p.full_refresh @p.profile @@ -663,7 +627,6 @@ def retry(ctx, **kwargs): @p.resource_type @p.select @p.selector -@p.state # required @p.target @p.target_path @p.threads @@ -731,9 +694,6 @@ def run_operation(ctx, **kwargs): @p.select @p.selector @p.show -@p.state -@p.defer_state -@p.deprecated_state @p.target @p.target_path @p.threads @@ -760,19 +720,12 @@ def seed(ctx, **kwargs): @cli.command("snapshot") @click.pass_context @global_flags -@p.defer -@p.deprecated_defer @p.exclude -@p.favor_state -@p.deprecated_favor_state @p.profile @p.profiles_dir @p.project_dir @p.select @p.selector -@p.state -@p.defer_state -@p.deprecated_state @p.target @p.target_path @p.threads @@ -815,9 +768,6 @@ def source(ctx, **kwargs): @p.project_dir @p.select @p.selector -@p.state -@p.defer_state -@p.deprecated_state @p.target @p.target_path @p.threads @@ -851,20 +801,12 @@ def freshness(ctx, **kwargs): @cli.command("test") @click.pass_context @global_flags -@p.defer -@p.deprecated_defer @p.exclude -@p.favor_state -@p.deprecated_favor_state -@p.indirect_selection @p.profile @p.profiles_dir @p.project_dir @p.select @p.selector -@p.state -@p.defer_state -@p.deprecated_state @p.store_failures @p.target @p.target_path diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 1f99665af7e..3064a1baa81 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1406,7 +1406,8 @@ def sql(self) -> Optional[str]: if self.model.language == ModelLanguage.sql: # type: ignore[union-attr] # If the model is deferred and the adapter doesn't support zero-copy cloning, then select * from the prod # relation - if getattr(self.model, "defer_relation", None): + # TODO: avoid routing on args.which if possible + if getattr(self.model, "defer_relation", None) and self.config.args.which == "clone": # TODO https://github.com/dbt-labs/dbt-core/issues/7976 return f"select * from {self.model.defer_relation.relation_name or str(self.defer_relation)}" # type: ignore[union-attr] elif getattr(self.model, "extra_ctes_injected", None): diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 426b99d39b6..f0e090c95d9 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -1341,7 +1341,7 @@ def is_invalid_protected_ref( node.package_name != target_model.package_name and restrict_package_access ) - # Called by RunTask.defer_to_manifest + # Called by GraphRunnableTask.defer_to_manifest def merge_from_artifact( self, adapter, @@ -1370,6 +1370,13 @@ def merge_from_artifact( merged.add(unique_id) self.nodes[unique_id] = node.replace(deferred=True) + # for all other nodes, add 'defer_relation' + elif current and node.resource_type in refables and not node.is_ephemeral: + defer_relation = DeferRelation( + node.database, node.schema, node.alias, node.relation_name + ) + self.nodes[unique_id] = current.replace(defer_relation=defer_relation) + # Rebuild the flat_graph, which powers the 'graph' context variable, # now that we've deferred some nodes self.build_flat_graph() @@ -1378,25 +1385,6 @@ def merge_from_artifact( sample = list(islice(merged, 5)) fire_event(MergedFromState(num_merged=len(merged), sample=sample)) - # Called by CloneTask.defer_to_manifest - def add_from_artifact( - self, - other: "WritableManifest", - ) -> None: - """Update this manifest by *adding* information about each node's location - in the other manifest. - - Only non-ephemeral refable nodes are examined. - """ - refables = set(NodeType.refable()) - for unique_id, node in other.nodes.items(): - current = self.nodes.get(unique_id) - if current and (node.resource_type in refables and not node.is_ephemeral): - defer_relation = DeferRelation( - node.database, node.schema, node.alias, node.relation_name - ) - self.nodes[unique_id] = current.replace(defer_relation=defer_relation) - # Methods that were formerly in ParseResult def add_macro(self, source_file: SourceFile, macro: Macro): diff --git a/core/dbt/task/clone.py b/core/dbt/task/clone.py index 849eec1c8b1..ef595afef7f 100644 --- a/core/dbt/task/clone.py +++ b/core/dbt/task/clone.py @@ -1,15 +1,15 @@ import threading -from typing import AbstractSet, Any, List, Iterable, Set +from typing import AbstractSet, Any, List, Iterable, Set, Optional from dbt.adapters.base import BaseRelation from dbt.clients.jinja import MacroGenerator from dbt.context.providers import generate_runtime_model_context +from dbt.contracts.graph.manifest import WritableManifest from dbt.artifacts.run import RunStatus, RunResult from dbt.common.dataclass_schema import dbtClassMixin from dbt.common.exceptions import DbtInternalError, CompilationError from dbt.graph import ResourceTypeSelector from dbt.node_types import NodeType -from dbt.parser.manifest import write_manifest from dbt.task.base import BaseRunner from dbt.task.run import _validate_materialization_relations_dict from dbt.task.runnable import GraphRunnableTask @@ -94,6 +94,11 @@ class CloneTask(GraphRunnableTask): def raise_on_first_error(self): return False + def _get_deferred_manifest(self) -> Optional[WritableManifest]: + # Unlike other commands, 'clone' always requires a state manifest + # Load previous state, regardless of whether --defer flag has been set + return self._get_previous_state() + def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRelation]: if self.manifest is None: raise DbtInternalError("manifest was None in get_model_schemas") @@ -154,17 +159,3 @@ def get_node_selector(self) -> ResourceTypeSelector: def get_runner_type(self, _): return CloneRunner - - # Note that this is different behavior from --defer with other commands, which *merge* - # selected nodes from this manifest + unselected nodes from the other manifest - def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): - deferred_manifest = self._get_deferred_manifest() - if deferred_manifest is None: - return - if self.manifest is None: - raise DbtInternalError( - "Expected to defer to manifest, but there is no runtime manifest to defer from!" - ) - self.manifest.add_from_artifact(other=deferred_manifest) - # TODO: is it wrong to write the manifest here? I think it's right... - write_manifest(self.manifest, self.config.target_path) diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 4d6e049b359..1c20a41e07e 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -1,7 +1,5 @@ import threading -from typing import AbstractSet, Optional -from dbt.contracts.graph.manifest import WritableManifest from dbt.artifacts.run import RunStatus, RunResult from dbt.common.events.base_types import EventLevel from dbt.common.events.functions import fire_event @@ -15,7 +13,7 @@ from dbt.graph import ResourceTypeSelector from dbt.node_types import NodeType -from dbt.parser.manifest import write_manifest, process_node +from dbt.parser.manifest import process_node from dbt.parser.sql import SqlBlockParser from dbt.task.base import BaseRunner from dbt.task.runnable import GraphRunnableTask @@ -101,26 +99,6 @@ def task_end_messages(self, results): ) ) - def _get_deferred_manifest(self) -> Optional[WritableManifest]: - return super()._get_deferred_manifest() if self.args.defer else None - - def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): - deferred_manifest = self._get_deferred_manifest() - if deferred_manifest is None: - return - if self.manifest is None: - raise DbtInternalError( - "Expected to defer to manifest, but there is no runtime manifest to defer from!" - ) - self.manifest.merge_from_artifact( - adapter=adapter, - other=deferred_manifest, - selected=selected_uids, - favor_state=bool(self.args.favor_state), - ) - # TODO: is it wrong to write the manifest here? I think it's right... - write_manifest(self.manifest, self.config.project_target_path) - def _runtime_initialize(self): if getattr(self.args, "inline", None): try: diff --git a/core/dbt/task/freshness.py b/core/dbt/task/freshness.py index 8301893add9..db4d9ca1075 100644 --- a/core/dbt/task/freshness.py +++ b/core/dbt/task/freshness.py @@ -171,10 +171,6 @@ def node_is_match(self, node): class FreshnessTask(GraphRunnableTask): - def defer_to_manifest(self, adapter, selected_uids): - # freshness don't defer - return - def result_path(self): if self.args.output: return os.path.realpath(self.args.output) diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index 5c5edf421a5..5b2270d141c 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -183,10 +183,6 @@ def selection_arg(self): else: return self.args.select - def defer_to_manifest(self, adapter, selected_uids): - # list don't defer - return - def get_node_selector(self): if self.manifest is None or self.graph is None: raise DbtInternalError("manifest and graph must be set to get perform node selection") diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 81f1b4922bd..ddeaf116cde 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -125,9 +125,24 @@ def get_selection_spec(self) -> SelectionSpec: def get_node_selector(self) -> NodeSelector: raise NotImplementedError(f"get_node_selector not implemented for task {type(self)}") - @abstractmethod def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): - raise NotImplementedError(f"defer_to_manifest not implemented for task {type(self)}") + deferred_manifest = self._get_deferred_manifest() + if deferred_manifest is None: + return + if self.manifest is None: + raise DbtInternalError( + "Expected to defer to manifest, but there is no runtime manifest to defer from!" + ) + self.manifest.merge_from_artifact( + adapter=adapter, + other=deferred_manifest, + selected=selected_uids, + favor_state=bool(self.args.favor_state), + ) + # We're rewriting the manifest because it's been mutated during merge_from_artifact. + # This is to reflect which nodes had been deferred to (= replaced with) their counterparts. + if self.args.write_json: + write_manifest(self.manifest, self.config.project_target_path) def get_graph_queue(self) -> GraphQueue: selector = self.get_node_selector() @@ -609,7 +624,7 @@ def get_result(self, results, elapsed_time, generated_at): def task_end_messages(self, results): print_run_end_messages(results) - def _get_deferred_manifest(self) -> Optional[WritableManifest]: + def _get_previous_state(self) -> Optional[WritableManifest]: state = self.previous_defer_state or self.previous_state if not state: raise DbtRuntimeError( @@ -619,3 +634,6 @@ def _get_deferred_manifest(self) -> Optional[WritableManifest]: if not state.manifest: raise DbtRuntimeError(f'Could not find manifest in --state path: "{state}"') return state.manifest + + def _get_deferred_manifest(self) -> Optional[WritableManifest]: + return self._get_previous_state() if self.args.defer else None diff --git a/core/dbt/task/seed.py b/core/dbt/task/seed.py index 31dbd241e56..8c9d5279385 100644 --- a/core/dbt/task/seed.py +++ b/core/dbt/task/seed.py @@ -62,10 +62,6 @@ def print_result_line(self, result): class SeedTask(RunTask): - def defer_to_manifest(self, adapter, selected_uids): - # seeds don't defer - return - def raise_on_first_error(self): return False diff --git a/tests/functional/dbt_runner/test_dbt_runner.py b/tests/functional/dbt_runner/test_dbt_runner.py index 20041f05952..40edcccae8d 100644 --- a/tests/functional/dbt_runner/test_dbt_runner.py +++ b/tests/functional/dbt_runner/test_dbt_runner.py @@ -23,6 +23,8 @@ def test_command_invalid_option(self, dbt: dbtRunner) -> None: def test_command_mutually_exclusive_option(self, dbt: dbtRunner) -> None: res = dbt.invoke(["--warn-error", "--warn-error-options", '{"include": "all"}', "deps"]) assert type(res.exception) == DbtUsageException + res = dbt.invoke(["deps", "--warn-error", "--warn-error-options", '{"include": "all"}']) + assert type(res.exception) == DbtUsageException def test_invalid_command(self, dbt: dbtRunner) -> None: res = dbt.invoke(["invalid-command"]) diff --git a/tests/functional/defer_state/test_defer_state.py b/tests/functional/defer_state/test_defer_state.py index 3a139f8aa47..102345fdf6e 100644 --- a/tests/functional/defer_state/test_defer_state.py +++ b/tests/functional/defer_state/test_defer_state.py @@ -5,7 +5,6 @@ import pytest -from dbt.cli.exceptions import DbtUsageException from dbt.contracts.results import RunStatus from dbt.exceptions import DbtRuntimeError from dbt.tests.util import run_dbt, write_file, rm_file @@ -105,11 +104,6 @@ def run_and_save_state(self, project_root, with_snapshot=False): class TestDeferStateUnsupportedCommands(BaseDeferState): - def test_unsupported_commands(self, project): - # make sure these commands don"t work with --defer - with pytest.raises(DbtUsageException): - run_dbt(["seed", "--defer"]) - def test_no_state(self, project): # no "state" files present, snapshot fails with pytest.raises(DbtRuntimeError): diff --git a/tests/unit/test_manifest.py b/tests/unit/test_manifest.py index 8268e988361..5a2053e6f86 100644 --- a/tests/unit/test_manifest.py +++ b/tests/unit/test_manifest.py @@ -1019,7 +1019,7 @@ def test_build_flat_graph(self): self.assertEqual(frozenset(node), REQUIRED_PARSED_NODE_KEYS) self.assertEqual(compiled_count, 2) - def test_add_from_artifact(self): + def test_merge_from_artifact(self): original_nodes = deepcopy(self.nested_nodes) other_nodes = deepcopy(self.nested_nodes) @@ -1041,7 +1041,8 @@ def test_add_from_artifact(self): original_manifest = Manifest(nodes=original_nodes) other_manifest = Manifest(nodes=other_nodes) - original_manifest.add_from_artifact(other_manifest.writable_manifest()) + adapter = mock.MagicMock() + original_manifest.merge_from_artifact(adapter, other_manifest.writable_manifest(), {}) # new node added should not be in original manifest assert "model.root.nested2" not in original_manifest.nodes