diff --git a/.changes/unreleased/Under the Hood-20240208-100709.yaml b/.changes/unreleased/Under the Hood-20240208-100709.yaml new file mode 100644 index 00000000..2c824b39 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240208-100709.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add Invocation Context Support to MultiThreadedExecutor +time: 2024-02-08T10:07:09.584747-05:00 +custom: + Author: peterallenwebb + Issue: "75" diff --git a/dbt_common/context.py b/dbt_common/context.py index f7f4b7ec..d916e06d 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -26,7 +26,7 @@ def env_secrets(self) -> List[str]: _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR") -def _reliably_get_invocation_var() -> ContextVar: +def reliably_get_invocation_var() -> ContextVar: invocation_var: Optional[ContextVar] = next( (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None ) @@ -38,11 +38,11 @@ def _reliably_get_invocation_var() -> ContextVar: def set_invocation_context(env: Mapping[str, str]) -> None: - invocation_var = _reliably_get_invocation_var() + invocation_var = reliably_get_invocation_var() invocation_var.set(InvocationContext(env)) def get_invocation_context() -> InvocationContext: - invocation_var = _reliably_get_invocation_var() + invocation_var = reliably_get_invocation_var() ctx = invocation_var.get() return ctx diff --git a/dbt_common/utils/executor.py b/dbt_common/utils/executor.py index 819a0e3a..0be40fcd 100644 --- a/dbt_common/utils/executor.py +++ b/dbt_common/utils/executor.py @@ -1,7 +1,10 @@ import concurrent.futures from contextlib import contextmanager +from contextvars import ContextVar from typing import Protocol, Optional +from dbt_common.context import get_invocation_context, reliably_get_invocation_var + class ConnectingExecutor(concurrent.futures.Executor): def submit_connected(self, adapter, conn_name, func, *args, **kwargs): @@ -60,8 +63,17 @@ class HasThreadingConfig(Protocol): threads: Optional[int] +def _thread_initializer(invocation_context: ContextVar) -> None: + invocation_var = reliably_get_invocation_var() + invocation_var.set(invocation_context) + + def executor(config: HasThreadingConfig) -> ConnectingExecutor: if config.args.single_threaded: return SingleThreadedExecutor() else: - return MultiThreadedExecutor(max_workers=config.threads) + return MultiThreadedExecutor( + max_workers=config.threads, + initializer=_thread_initializer, + initargs=(get_invocation_context(),), + )